Skip to content

Commit

Permalink
Merge pull request #2139 from AdeelH/backport-0.30.1
Browse files Browse the repository at this point in the history
[BACKPORT] Backport changes to 0.30 branch for v0.30.1 release
  • Loading branch information
AdeelH authored May 6, 2024
2 parents db46a7e + 33fd2fb commit 8f5de9a
Show file tree
Hide file tree
Showing 18 changed files with 209 additions and 133 deletions.
2 changes: 2 additions & 0 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ python:
path: rastervision_pytorch_learner/
- method: pip
path: rastervision_pytorch_backend/
- method: pip
path: rastervision_aws_sagemaker/

# https://docs.readthedocs.io/en/stable/config-file/v2.html#search
search:
Expand Down
4 changes: 2 additions & 2 deletions docs/framework/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The ``--tensorboard`` option should be used if running locally and you would lik
export PROCESSED_URI="/opt/data/examples/spacenet/rio/processed-data"
export ROOT_URI="/opt/data/examples/spacenet/rio/local-output"
rastervision run local rastervision.examples.chip_classification.spacenet_rio \
rastervision run local rastervision.pytorch_backend.examples.chip_classification.spacenet_rio \
-a raw_uri $RAW_URI -a processed_uri $PROCESSED_URI -a root_uri $ROOT_URI \
-a test True --splits 2
Expand All @@ -104,7 +104,7 @@ To run the full experiment on GPUs using AWS Batch, use something like the follo
export PROCESSED_URI="s3://mybucket/examples/spacenet/rio/processed-data"
export ROOT_URI="s3://mybucket/examples/spacenet/rio/remote-output"
rastervision run batch rastervision.examples.chip_classification.spacenet_rio \
rastervision run batch rastervision.pytorch_backend.examples.chip_classification.spacenet_rio \
-a raw_uri $RAW_URI -a processed_uri $PROCESSED_URI -a root_uri $ROOT_URI \
-a test False --splits 8
Expand Down
105 changes: 60 additions & 45 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Any, Iterator, Tuple
import io
import os
import subprocess
Expand All @@ -16,41 +16,38 @@


# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
def get_matching_s3_objects(bucket, prefix='', suffix='',
request_payer='None'):
"""
Generate objects in an S3 bucket.
:param bucket: Name of the S3 bucket.
:param prefix: Only fetch objects whose key starts with
this prefix (optional).
:param suffix: Only fetch objects whose keys end with
this suffix (optional).
def get_matching_s3_objects(
bucket: str,
prefix: str = '',
suffix: str = '',
delimiter: str = '/',
request_payer: str = 'None') -> Iterator[tuple[str, Any]]:
"""Generate objects in an S3 bucket.
Args:
bucket: Name of the S3 bucket.
prefix: Only fetch objects whose key starts with this prefix.
suffix: Only fetch objects whose keys end with this suffix.
"""
s3 = S3FileSystem.get_client()
kwargs = {'Bucket': bucket, 'RequestPayer': request_payer}

# If the prefix is a single string (not a tuple of strings), we can
# do the filtering directly in the S3 API.
if isinstance(prefix, str):
kwargs['Prefix'] = prefix

kwargs = dict(
Bucket=bucket,
RequestPayer=request_payer,
Delimiter=delimiter,
Prefix=prefix,
)
while True:

# The S3 API response is a large blob of metadata.
# 'Contents' contains information about the listed objects.
resp = s3.list_objects_v2(**kwargs)

try:
contents = resp['Contents']
except KeyError:
return

for obj in contents:
resp: dict = s3.list_objects_v2(**kwargs)
dirs: list[dict] = resp.get('CommonPrefixes', {})
files: list[dict] = resp.get('Contents', {})
for obj in dirs:
key = obj['Prefix']
if key.startswith(prefix) and key.endswith(suffix):
yield key, obj
for obj in files:
key = obj['Key']
if key.startswith(prefix) and key.endswith(suffix):
yield obj

yield key, obj
# The S3 API is paginated, returning up to 1000 keys at a time.
# Pass the continuation token into the next response, until we
# reach the final page (when this field is missing).
Expand All @@ -60,16 +57,26 @@ def get_matching_s3_objects(bucket, prefix='', suffix='',
break


def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'):
"""
Generate the keys in an S3 bucket.
def get_matching_s3_keys(bucket: str,
prefix: str = '',
suffix: str = '',
delimiter: str = '/',
request_payer: str = 'None') -> Iterator[str]:
"""Generate the keys in an S3 bucket.
:param bucket: Name of the S3 bucket.
:param prefix: Only fetch keys that start with this prefix (optional).
:param suffix: Only fetch keys that end with this suffix (optional).
Args:
bucket: Name of the S3 bucket.
prefix: Only fetch keys that start with this prefix.
suffix: Only fetch keys that end with this suffix.
"""
for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer):
yield obj['Key']
obj_iterator = get_matching_s3_objects(
bucket,
prefix=prefix,
suffix=suffix,
delimiter=delimiter,
request_payer=request_payer)
out = (key for key, _ in obj_iterator)
return out


def progressbar(total_size: int, desc: str):
Expand Down Expand Up @@ -180,8 +187,9 @@ def read_bytes(uri: str) -> bytes:
bucket, key = S3FileSystem.parse_uri(uri)
with io.BytesIO() as file_buffer:
try:
file_size = s3.head_object(
Bucket=bucket, Key=key)['ContentLength']
obj = s3.head_object(
Bucket=bucket, Key=key, RequestPayer=request_payer)
file_size = obj['ContentLength']
with progressbar(file_size, desc='Downloading') as bar:
s3.download_fileobj(
Bucket=bucket,
Expand Down Expand Up @@ -256,7 +264,9 @@ def copy_from(src_uri: str, dst_path: str) -> None:
request_payer = S3FileSystem.get_request_payer()
bucket, key = S3FileSystem.parse_uri(src_uri)
try:
file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength']
obj = s3.head_object(
Bucket=bucket, Key=key, RequestPayer=request_payer)
file_size = obj['ContentLength']
with progressbar(file_size, desc=f'Downloading') as bar:
s3.download_file(
Bucket=bucket,
Expand Down Expand Up @@ -284,11 +294,16 @@ def last_modified(uri: str) -> datetime:
return head_data['LastModified']

@staticmethod
def list_paths(uri, ext=''):
def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]:
request_payer = S3FileSystem.get_request_payer()
parsed_uri = urlparse(uri)
bucket = parsed_uri.netloc
prefix = os.path.join(parsed_uri.path[1:])
keys = get_matching_s3_keys(
bucket, prefix, suffix=ext, request_payer=request_payer)
return [os.path.join('s3://', bucket, key) for key in keys]
bucket,
prefix,
suffix=ext,
delimiter=delimiter,
request_payer=request_payer)
paths = [os.path.join('s3://', bucket, key) for key in keys]
return paths
9 changes: 9 additions & 0 deletions rastervision_core/rastervision/core/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,12 @@ def get_split_config(self, split_ind, num_splits):
@property
def all_scenes(self) -> List[SceneConfig]:
return self.train_scenes + self.validation_scenes + self.test_scenes

def __repr__(self):
num_train = len(self.train_scenes)
num_val = len(self.validation_scenes)
num_test = len(self.test_scenes)
out = (f'DatasetConfig(train_scenes=<{num_train} scenes>, '
f'validation_scenes=<{num_val} scenes>, '
f'test_scenes=<{num_test} scenes>)')
return out
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,20 @@ def get_chip(self,

return chip

def get_chip_by_map_window(
self,
window_map_coords: 'Box',
out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray':
"""Same as get_chip(), but input is a window in map coords. """
def get_chip_by_map_window(self, window_map_coords: 'Box', *args,
**kwargs) -> 'np.ndarray':
"""Same as get_chip(), but input is a window in map coords."""
window_pixel_coords = self.crs_transformer.map_to_pixel(
window_map_coords, bbox=self.bbox).normalize()
chip = self.get_chip(window_pixel_coords, out_shape=out_shape)
chip = self.get_chip(window_pixel_coords, *args, **kwargs)
return chip

def _get_chip_by_map_window(
self,
window_map_coords: 'Box',
out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray':
"""Same as _get_chip(), but input is a window in map coords. """
def _get_chip_by_map_window(self, window_map_coords: 'Box', *args,
**kwargs) -> 'np.ndarray':
"""Same as _get_chip(), but input is a window in map coords."""
window_pixel_coords = self.crs_transformer.map_to_pixel(
window_map_coords, bbox=self.bbox)
chip = self._get_chip(window_pixel_coords, out_shape=out_shape)
chip = self._get_chip(window_pixel_coords, *args, **kwargs)
return chip

def get_raw_chip(self,
Expand Down
2 changes: 1 addition & 1 deletion rastervision_core/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ rastervision_pipeline==0.30.0
shapely==2.0.2
geopandas==0.14.3
numpy==1.26.3
pillow==10.2.0
pillow==10.3.0
pyproj==3.6.1
rasterio==1.3.9
pystac==1.9.0
Expand Down
82 changes: 50 additions & 32 deletions rastervision_pipeline/rastervision/pipeline/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import TYPE_CHECKING
import sys
import os
import logging
import importlib
import importlib.util
from typing import List, Dict, Optional, Tuple

import click

from rastervision.pipeline import (registry_ as registry, rv_config_ as
rv_config)
from rastervision.pipeline.file_system import (file_to_json, get_tmp_dir)
from rastervision.pipeline.config import build_config, save_pipeline_config
from rastervision.pipeline.config import (build_config, Config,
save_pipeline_config)
from rastervision.pipeline.pipeline_config import PipelineConfig

if TYPE_CHECKING:
from rastervision.pipeline.runner import Runner

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -40,8 +42,9 @@ def convert_bool_args(args: dict) -> dict:
return new_args


def get_configs(cfg_module_path: str, runner: str,
args: Dict[str, any]) -> List[PipelineConfig]:
def get_configs(cfg_module_path: str,
runner: str | None = None,
args: dict[str, any] | None = None) -> list[PipelineConfig]:
"""Get PipelineConfigs from a module.
Calls a get_config(s) function with some arguments from the CLI
Expand All @@ -55,6 +58,26 @@ def get_configs(cfg_module_path: str, runner: str,
args: CLI args to pass to the get_config(s) function that comes from
the --args option
"""
if cfg_module_path.endswith('.json'):
cfgs_json = file_to_json(cfg_module_path)
if not isinstance(cfgs_json, list):
cfgs_json = [cfgs_json]
cfgs = [Config.deserialize(json) for json in cfgs_json]
else:
cfgs = get_configs_from_module(cfg_module_path, runner, args)

for cfg in cfgs:
if not issubclass(type(cfg), PipelineConfig):
raise TypeError('All objects returned by get_configs in '
f'{cfg_module_path} must be PipelineConfigs.')
return cfgs


def get_configs_from_module(cfg_module_path: str, runner: str,
args: dict[str, any]) -> list[PipelineConfig]:
import importlib
import importlib.util

if cfg_module_path.endswith('.py'):
# From https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa
spec = importlib.util.spec_from_file_location('rastervision.pipeline',
Expand All @@ -65,20 +88,14 @@ def get_configs(cfg_module_path: str, runner: str,
cfg_module = importlib.import_module(cfg_module_path)

_get_config = getattr(cfg_module, 'get_config', None)
_get_configs = _get_config
if _get_config is None:
_get_configs = getattr(cfg_module, 'get_configs', None)
_get_configs = getattr(cfg_module, 'get_configs', _get_config)
if _get_configs is None:
raise Exception('There must be a get_config or get_configs function '
f'in {cfg_module_path}.')
raise ImportError('There must be a get_config() or get_configs() '
f'function in {cfg_module_path}.')

cfgs = _get_configs(runner, **args)
if not isinstance(cfgs, list):
cfgs = [cfgs]

for cfg in cfgs:
if not issubclass(type(cfg), PipelineConfig):
raise Exception('All objects returned by get_configs in '
f'{cfg_module_path} must be PipelineConfigs.')
return cfgs


Expand All @@ -89,8 +106,7 @@ def get_configs(cfg_module_path: str, runner: str,
@click.option(
'-v', '--verbose', help='Increment the verbosity level.', count=True)
@click.option('--tmpdir', help='Root of temporary directories to use.')
def main(ctx: click.Context, profile: Optional[str], verbose: int,
tmpdir: str):
def main(ctx: click.Context, profile: str | None, verbose: int, tmpdir: str):
"""The main click command.
Sets the profile, verbosity, and tmp_dir in RVConfig.
Expand All @@ -103,20 +119,22 @@ def main(ctx: click.Context, profile: Optional[str], verbose: int,
rv_config.set_everett_config(profile=profile)


def _run_pipeline(cfg,
runner,
tmp_dir,
splits=1,
commands=None,
def _run_pipeline(cfg: PipelineConfig,
runner: 'Runner',
tmp_dir: str,
splits: int = 1,
commands: list[str] | None = None,
pipeline_run_name: str = 'raster-vision'):
cfg.update()
cfg.recursive_validate_config()
# This is to run the validation again to check any fields that may have changed
# after the Config was constructed, possibly by the update method.

# This is to run the validation again to check any fields that may have
# changed after the Config was constructed, possibly by the update method.
build_config(cfg.dict())
cfg_json_uri = cfg.get_config_uri()
save_pipeline_config(cfg, cfg_json_uri)
pipeline = cfg.build(tmp_dir)

if not commands:
commands = pipeline.commands

Expand Down Expand Up @@ -150,8 +168,8 @@ def _run_pipeline(cfg,
'--pipeline-run-name',
default='raster-vision',
help='The name for this run of the pipeline.')
def run(runner: str, cfg_module: str, commands: List[str],
arg: List[Tuple[str, str]], splits: int, pipeline_run_name: str):
def run(runner: str, cfg_module: str, commands: list[str],
arg: list[tuple[str, str]], splits: int, pipeline_run_name: str):
"""Run COMMANDS within pipelines in CFG_MODULE using RUNNER.
RUNNER: name of the Runner to use
Expand All @@ -178,9 +196,9 @@ def run(runner: str, cfg_module: str, commands: List[str],

def _run_command(cfg_json_uri: str,
command: str,
split_ind: Optional[int] = None,
num_splits: Optional[int] = None,
runner: Optional[str] = None):
split_ind: int | None = None,
num_splits: int | None = None,
runner: str | None = None):
"""Run a single command using a serialized PipelineConfig.
Args:
Expand Down Expand Up @@ -229,8 +247,8 @@ def _run_command(cfg_json_uri: str,
help='The number of processes to use for running splittable commands')
@click.option(
'--runner', type=str, help='Name of runner to use', default='inprocess')
def run_command(cfg_json_uri: str, command: str, split_ind: Optional[int],
num_splits: Optional[int], runner: str):
def run_command(cfg_json_uri: str, command: str, split_ind: int | None,
num_splits: int | None, runner: str):
"""Run a single COMMAND using a serialized PipelineConfig in CFG_JSON_URI."""
_run_command(
cfg_json_uri,
Expand Down
Loading

0 comments on commit 8f5de9a

Please sign in to comment.