From 9f7237d83f35603dd3c4e82dafd69cd1a0cc208d Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 10 Jul 2024 12:16:48 -0400 Subject: [PATCH] Fix bug in ChipClassificationSource.__getitem__() when bbox is specified. (#2193) --- .../chip_classification_label_source.py | 58 ++++++++++++++----- .../test_chip_classification_label_source.py | 54 +++++++++++++---- 2 files changed, 84 insertions(+), 28 deletions(-) diff --git a/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py b/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py index 8037f5914..6c2245c8a 100644 --- a/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py +++ b/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py @@ -173,25 +173,45 @@ def infer_cells(self, cells: Iterable[Box] | None = None ) -> ChipClassificationLabels: """Infer labels for a list of cells. - Only cells whose labels are not already known are inferred. + Cells are assumed to be in ``bbox`` coords as opposed to global coords + and are converted to global coords before inference. The returned + labels are in global coords. Only cells whose + labels are not already known are inferred. Args: - cells: Cells whose labels are to be inferred. Defaults to ``None``. + cells: Cells (in ``bbox`` coords) whose labels are to be inferred. + If ``None``, cells are assumed to be sliding windows of size + and stride ``cell_sz`` (specified in + :class:`.ChipClassificationLabelSourceConfig`). + Defaults to ``None``. Returns: - ChipClassificationLabels: labels + Labels (in global coords). """ - cfg = self.cfg if cells is None: - if cfg.cell_sz is None: + cell_sz = self.cfg.cell_sz + if cell_sz is None: raise ValueError('cell_sz is not set.') - cells = self.extent.get_windows(cfg.cell_sz, cfg.cell_sz) - else: - cells = [cell.to_global_coords(self.bbox) for cell in cells] + cells = self.extent.get_windows(cell_sz, cell_sz) + cells = [cell.to_global_coords(self.bbox) for cell in cells] + labels = self._infer_cells(cells) + return labels + + def _infer_cells(self, cells: Iterable[Box]) -> ChipClassificationLabels: + """Infer labels for a list of cells. + + Cells are assumed to be in global coords as opposed to ``bbox`` coords. + Only cells whose labels are not already known are inferred. + + Args: + cells: Cells (in global coords) whose labels are to be inferred. + Returns: + Labels (in global coords). + """ + cfg = self.cfg known_cells = [c for c in cells if c in self.labels] unknown_cells = [c for c in cells if c not in self.labels] - labels = infer_cells( cells=unknown_cells, labels_df=self.labels_df, @@ -199,28 +219,34 @@ def infer_cells(self, cells: Iterable[Box] | None = None use_intersection_over_cell=cfg.use_intersection_over_cell, pick_min_class_id=cfg.pick_min_class_id, background_class_id=cfg.background_class_id) - for cell in known_cells: class_id = self.labels.get_cell_class_id(cell) labels.set_cell(cell, class_id) - return labels def get_labels(self, window: Box | None = None) -> ChipClassificationLabels: + """Return label for a window, inferring it if not already known. + + If window is ``None``, returns all labels. + """ if window is None: return self.labels window = window.to_global_coords(self.bbox) - return self.labels.get_singleton_labels(window) + if window not in self.labels: + self.labels += self._infer_cells(cells=[window]) + labels = self.labels.get_singleton_labels(window) + return labels def __getitem__(self, key: Any) -> int: - """Return label for a window, inferring it if it is not already known. - """ + """Return class ID for a window, inferring it if not already known.""" if isinstance(key, Box): window = key + window = window.to_global_coords(self.bbox) if window not in self.labels: - self.labels += self.infer_cells(cells=[window]) - return self.labels[window].class_id + self.labels += self._infer_cells(cells=[window]) + class_id = self.labels[window].class_id + return class_id else: return super().__getitem__(key) diff --git a/tests/core/data/label_source/test_chip_classification_label_source.py b/tests/core/data/label_source/test_chip_classification_label_source.py index 6c49ff7fc..c326bff4e 100644 --- a/tests/core/data/label_source/test_chip_classification_label_source.py +++ b/tests/core/data/label_source/test_chip_classification_label_source.py @@ -1,16 +1,20 @@ +from collections.abc import Callable import unittest -import os +from os.path import join import geopandas as gpd +import numpy as np from rastervision.pipeline.file_system import json_to_file, get_tmp_dir from rastervision.core.box import Box from rastervision.core.data import ( - ClassConfig, ChipClassificationLabelSourceConfig, - GeoJSONVectorSourceConfig, ClassInferenceTransformerConfig, - BufferTransformerConfig) + BufferTransformerConfig, ChipClassificationLabelSource, + ChipClassificationLabelSourceConfig, ClassConfig, + ClassInferenceTransformerConfig, GeoJSONVectorSource, + GeoJSONVectorSourceConfig, IdentityCRSTransformer) from rastervision.core.data.label_source.chip_classification_label_source \ import infer_cells +from rastervision.core.data.label_store.utils import boxes_to_geojson from tests import data_file_path from tests.core.data.mock_crs_transformer import DoubleCRSTransformer @@ -30,6 +34,12 @@ def test_ensure_required_transformers(self): class TestChipClassificationLabelSource(unittest.TestCase): + def assertNoError(self, fn: Callable, msg: str = ''): + try: + fn() + except Exception: + self.fail(msg) + def setUp(self): self.crs_transformer = DoubleCRSTransformer() self.geojson = { @@ -76,7 +86,7 @@ def setUp(self): self.file_name = 'labels.json' self.tmp_dir = get_tmp_dir() - self.uri = os.path.join(self.tmp_dir.name, self.file_name) + self.uri = join(self.tmp_dir.name, self.file_name) json_to_file(self.geojson, self.uri) def tearDown(self): @@ -292,17 +302,37 @@ def test_get_labels(self): def test_getitem(self): # Extent contains both boxes. extent = Box.make_square(0, 0, 8) - config = ChipClassificationLabelSourceConfig( vector_source=GeoJSONVectorSourceConfig(uris=self.uri)) - source = config.build(self.class_config, self.crs_transformer, extent, - self.tmp_dir.name) - labels = source.get_labels() - + label_source = config.build(self.class_config, self.crs_transformer, + extent, self.tmp_dir.name) + labels = label_source.get_labels() cells = labels.get_cells() self.assertEqual(len(cells), 2) - self.assertEqual(source[cells[0]], self.class_id1) - self.assertEqual(source[cells[1]], self.class_id2) + self.assertEqual(label_source[cells[0]], self.class_id1) + self.assertEqual(label_source[cells[1]], self.class_id2) + + def test_getitem_and_get_labels_with_bbox(self): + extent = Box(0, 0, 100, 100) + boxes = extent.get_windows(10, 10) + class_config = ClassConfig(names=['a', 'b', 'c'], null_class='c') + class_ids = np.random.randint( + 0, len(class_config), size=len(boxes)).tolist() + crs_tf = IdentityCRSTransformer() + geojson = boxes_to_geojson(boxes, class_ids, crs_tf, class_config) + + ls_cfg = ChipClassificationLabelSourceConfig( + background_class_id=class_config.null_class_id, infer_cells=True) + bbox = Box(25, 25, 50, 50) + with get_tmp_dir() as tmp_dir: + labels_uri = join(tmp_dir, 'labels.json') + json_to_file(geojson, labels_uri) + vs = GeoJSONVectorSource(labels_uri, crs_tf) + ls = ChipClassificationLabelSource( + ls_cfg, vs, bbox=bbox, lazy=True) + self.assertNoError(lambda: ls[:10, :10]) + labels = ls.get_labels(Box(0, 0, 11, 11)) + self.assertListEqual(labels.get_cells(), [Box(25, 25, 36, 36)]) if __name__ == '__main__':