From f45bc564ee4e3127b976339cba5a5c41e57a884b Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 5 Mar 2024 10:14:38 -0500 Subject: [PATCH 1/3] Prune empty partitions from catalog --- src/lsdb/catalog/dataset/healpix_dataset.py | 34 +++++++++++++++- tests/lsdb/catalog/test_catalog.py | 43 +++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index bc612dcf..7aa78205 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -1,7 +1,9 @@ -from typing import List, cast +import warnings +from typing import List, Tuple, cast import dask.dataframe as dd import numpy as np +from dask.delayed import Delayed from hipscat.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset from hipscat.pixel_math import HealpixPixel from hipscat.pixel_math.healpix_pixel_function import get_pixel_argsort @@ -124,8 +126,38 @@ def _perform_search(self, filtered_pixels: List[HealpixPixel], search: AbstractS partitions = self._ddf.to_delayed() targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels] filtered_partitions = [search.search_points(partition) for partition in targeted_partitions] + return self._construct_search_ddf(filtered_pixels, filtered_partitions) + + def _construct_search_ddf(self, filtered_pixels, filtered_partitions): + """Constructs the search Dask DataFrame and the respective pixel map""" divisions = get_pixels_divisions(filtered_pixels) search_ddf = dd.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions) search_ddf = cast(dd.DataFrame, search_ddf) ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)} return ddf_partition_map, search_ddf + + def prune_empty_partitions(self) -> Self: + """Removes empty partitions from a catalog + + Returns: + A new catalog containing only the non-empty partitions + """ + warnings.warn("Pruning empty partitions is expensive. It may run slow!", RuntimeWarning) + non_empty_pixels, non_empty_partitions = self._get_non_empty_partitions() + ddf_partition_map, search_ddf = self._construct_search_ddf(non_empty_pixels, non_empty_partitions) + filtered_hc_structure = self.hc_structure.filter_from_pixel_list(non_empty_pixels) + return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure) + + def _get_non_empty_partitions(self) -> Tuple[List[HealpixPixel], List[Delayed]]: + """Computes the partition lengths and returns the indices of those that are empty""" + partitions = self._ddf.to_delayed() + # Compute partition lengths (expensive operation) + partition_sizes = self._ddf.map_partitions(len).compute() + empty_partition_indices = np.argwhere(partition_sizes == 0).flatten() + # Extract the non-empty pixels and respective partitions + non_empty_pixels, non_empty_partitions = [], [] + for pixel, partition_index in self._ddf_pixel_map.items(): + if partition_index not in empty_partition_indices: + non_empty_pixels.append(pixel) + non_empty_partitions.append(partitions[partition_index]) + return non_empty_pixels, non_empty_partitions diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index 6ebcfb3c..8e3db4b4 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -231,3 +231,46 @@ def test_save_catalog_with_some_empty_partitions(small_sky_order1_catalog, tmp_p assert catalog._ddf.npartitions == 1 assert len(catalog._ddf.partitions[0]) > 0 assert list(catalog._ddf_pixel_map.keys()) == non_empty_pixels + + +def test_prune_empty_partitions(small_sky_order1_catalog): + # Perform a query that forces the existence of some empty partitions + catalog = small_sky_order1_catalog.query("ra > 350 and dec < -50") + _, non_empty_partitions = catalog._get_non_empty_partitions() + assert catalog._ddf.npartitions - len(non_empty_partitions) > 0 + + with pytest.warns(RuntimeWarning, match="slow"): + pruned_catalog = catalog.prune_empty_partitions() + + # The empty partitions were removed and the computed content is the same + _, non_empty_partitions = pruned_catalog._get_non_empty_partitions() + assert pruned_catalog._ddf.npartitions - len(non_empty_partitions) == 0 + pd.testing.assert_frame_equal(catalog.compute(), pruned_catalog.compute()) + + +def test_prune_empty_partitions_with_none_to_remove(small_sky_order1_catalog): + # The catalog has no empty partitions to be removed + _, non_empty_partitions = small_sky_order1_catalog._get_non_empty_partitions() + assert small_sky_order1_catalog._ddf.npartitions == len(non_empty_partitions) + + with pytest.warns(RuntimeWarning, match="slow"): + pruned_catalog = small_sky_order1_catalog.prune_empty_partitions() + + # The number of partitions and the computed content are the same + _, non_empty_partitions = pruned_catalog._get_non_empty_partitions() + assert small_sky_order1_catalog._ddf.npartitions == pruned_catalog._ddf.npartitions + pd.testing.assert_frame_equal(small_sky_order1_catalog.compute(), pruned_catalog.compute()) + + +def test_prune_empty_partitions_all_are_removed(small_sky_order1_catalog): + # Perform a query that forces the existence of an empty catalog + catalog = small_sky_order1_catalog.query("ra > 350 and ra < 350") + _, non_empty_partitions = catalog._get_non_empty_partitions() + assert len(non_empty_partitions) == 0 + + with pytest.warns(RuntimeWarning, match="slow"): + pruned_catalog = catalog.prune_empty_partitions() + + # The pruned catalog is also empty + _, non_empty_partitions = pruned_catalog._get_non_empty_partitions() + assert len(non_empty_partitions) == 0 From c4535fda517058a73f2838b6a2b083f593445e7b Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 5 Mar 2024 10:34:19 -0500 Subject: [PATCH 2/3] Improve docstrings --- src/lsdb/catalog/dataset/healpix_dataset.py | 27 +++++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index 7aa78205..a21acaad 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -128,8 +128,18 @@ def _perform_search(self, filtered_pixels: List[HealpixPixel], search: AbstractS filtered_partitions = [search.search_points(partition) for partition in targeted_partitions] return self._construct_search_ddf(filtered_pixels, filtered_partitions) - def _construct_search_ddf(self, filtered_pixels, filtered_partitions): - """Constructs the search Dask DataFrame and the respective pixel map""" + def _construct_search_ddf( + self, filtered_pixels: List[HealpixPixel], filtered_partitions: List[Delayed] + ) -> Tuple[dict, dd.DataFrame]: + """Constructs a search catalog pixel map and respective Dask Dataframe + + Args: + filtered_pixels (List[HealpixPixel]): The list of pixels in the search + filtered_partitions (List[Delayed]): The list of delayed partitions + + Returns: + The catalog pixel map and the respective Dask DataFrame + """ divisions = get_pixels_divisions(filtered_pixels) search_ddf = dd.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions) search_ddf = cast(dd.DataFrame, search_ddf) @@ -137,10 +147,10 @@ def _construct_search_ddf(self, filtered_pixels, filtered_partitions): return ddf_partition_map, search_ddf def prune_empty_partitions(self) -> Self: - """Removes empty partitions from a catalog + """Prunes the catalog of its empty partitions Returns: - A new catalog containing only the non-empty partitions + A new catalog containing only its non-empty partitions """ warnings.warn("Pruning empty partitions is expensive. It may run slow!", RuntimeWarning) non_empty_pixels, non_empty_partitions = self._get_non_empty_partitions() @@ -149,15 +159,22 @@ def prune_empty_partitions(self) -> Self: return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure) def _get_non_empty_partitions(self) -> Tuple[List[HealpixPixel], List[Delayed]]: - """Computes the partition lengths and returns the indices of those that are empty""" + """Determines which pixels and partitions of a catalog are not empty + + Returns: + A tuple with the non-empty pixels and respective partitions + """ partitions = self._ddf.to_delayed() + # Compute partition lengths (expensive operation) partition_sizes = self._ddf.map_partitions(len).compute() empty_partition_indices = np.argwhere(partition_sizes == 0).flatten() + # Extract the non-empty pixels and respective partitions non_empty_pixels, non_empty_partitions = [], [] for pixel, partition_index in self._ddf_pixel_map.items(): if partition_index not in empty_partition_indices: non_empty_pixels.append(pixel) non_empty_partitions.append(partitions[partition_index]) + return non_empty_pixels, non_empty_partitions From 9d4f9a67560d0793d699a9e9ade123a89fc471d8 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Wed, 6 Mar 2024 15:12:31 -0500 Subject: [PATCH 3/3] Add optional persist call --- src/lsdb/catalog/dataset/healpix_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index a21acaad..aeb4a3cd 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -146,13 +146,18 @@ def _construct_search_ddf( ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)} return ddf_partition_map, search_ddf - def prune_empty_partitions(self) -> Self: + def prune_empty_partitions(self, persist: bool = False) -> Self: """Prunes the catalog of its empty partitions + Args: + persist (bool): If True previous computations are saved. Defaults to False. + Returns: A new catalog containing only its non-empty partitions """ warnings.warn("Pruning empty partitions is expensive. It may run slow!", RuntimeWarning) + if persist: + self._ddf.persist() non_empty_pixels, non_empty_partitions = self._get_non_empty_partitions() ddf_partition_map, search_ddf = self._construct_search_ddf(non_empty_pixels, non_empty_partitions) filtered_hc_structure = self.hc_structure.filter_from_pixel_list(non_empty_pixels)