Skip to content

Commit

Permalink
Merge pull request #205 from astronomy-commons/issue/200/prune-empty-…
Browse files Browse the repository at this point in the history
…partitions

Prune empty partitions from catalog
  • Loading branch information
camposandro authored Mar 8, 2024
2 parents 2fbc79c + 17abe64 commit 1637c08
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
55 changes: 54 additions & 1 deletion src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any, Callable, Dict, List, cast
import warnings
from typing import Any, Callable, Dict, List, Tuple, cast

import dask
import dask.dataframe as dd
Expand Down Expand Up @@ -131,12 +132,64 @@ def _perform_search(self, filtered_pixels: List[HealpixPixel], search: AbstractS
partitions = self._ddf.to_delayed()
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = [search.search_points(partition) for partition in targeted_partitions]
return self._construct_search_ddf(filtered_pixels, filtered_partitions)

def _construct_search_ddf(
self, filtered_pixels: List[HealpixPixel], filtered_partitions: List[Delayed]
) -> Tuple[dict, dd.DataFrame]:
"""Constructs a search catalog pixel map and respective Dask Dataframe
Args:
filtered_pixels (List[HealpixPixel]): The list of pixels in the search
filtered_partitions (List[Delayed]): The list of delayed partitions
Returns:
The catalog pixel map and the respective Dask DataFrame
"""
divisions = get_pixels_divisions(filtered_pixels)
search_ddf = dd.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions)
search_ddf = cast(dd.DataFrame, search_ddf)
ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)}
return ddf_partition_map, search_ddf

def prune_empty_partitions(self, persist: bool = False) -> Self:
"""Prunes the catalog of its empty partitions
Args:
persist (bool): If True previous computations are saved. Defaults to False.
Returns:
A new catalog containing only its non-empty partitions
"""
warnings.warn("Pruning empty partitions is expensive. It may run slow!", RuntimeWarning)
if persist:
self._ddf.persist()
non_empty_pixels, non_empty_partitions = self._get_non_empty_partitions()
ddf_partition_map, search_ddf = self._construct_search_ddf(non_empty_pixels, non_empty_partitions)
filtered_hc_structure = self.hc_structure.filter_from_pixel_list(non_empty_pixels)
return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure)

def _get_non_empty_partitions(self) -> Tuple[List[HealpixPixel], List[Delayed]]:
"""Determines which pixels and partitions of a catalog are not empty
Returns:
A tuple with the non-empty pixels and respective partitions
"""
partitions = self._ddf.to_delayed()

# Compute partition lengths (expensive operation)
partition_sizes = self._ddf.map_partitions(len).compute()
empty_partition_indices = np.argwhere(partition_sizes == 0).flatten()

# Extract the non-empty pixels and respective partitions
non_empty_pixels, non_empty_partitions = [], []
for pixel, partition_index in self._ddf_pixel_map.items():
if partition_index not in empty_partition_indices:
non_empty_pixels.append(pixel)
non_empty_partitions.append(partitions[partition_index])

return non_empty_pixels, non_empty_partitions

def skymap_data(
self, func: Callable[[pd.DataFrame, HealpixPixel], Any], **kwargs
) -> Dict[HealpixPixel, Delayed]:
Expand Down
43 changes: 43 additions & 0 deletions tests/lsdb/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,49 @@ def test_save_catalog_with_some_empty_partitions(small_sky_order1_catalog, tmp_p
assert list(catalog._ddf_pixel_map.keys()) == non_empty_pixels


def test_prune_empty_partitions(small_sky_order1_catalog):
# Perform a query that forces the existence of some empty partitions
catalog = small_sky_order1_catalog.query("ra > 350 and dec < -50")
_, non_empty_partitions = catalog._get_non_empty_partitions()
assert catalog._ddf.npartitions - len(non_empty_partitions) > 0

with pytest.warns(RuntimeWarning, match="slow"):
pruned_catalog = catalog.prune_empty_partitions()

# The empty partitions were removed and the computed content is the same
_, non_empty_partitions = pruned_catalog._get_non_empty_partitions()
assert pruned_catalog._ddf.npartitions - len(non_empty_partitions) == 0
pd.testing.assert_frame_equal(catalog.compute(), pruned_catalog.compute())


def test_prune_empty_partitions_with_none_to_remove(small_sky_order1_catalog):
# The catalog has no empty partitions to be removed
_, non_empty_partitions = small_sky_order1_catalog._get_non_empty_partitions()
assert small_sky_order1_catalog._ddf.npartitions == len(non_empty_partitions)

with pytest.warns(RuntimeWarning, match="slow"):
pruned_catalog = small_sky_order1_catalog.prune_empty_partitions()

# The number of partitions and the computed content are the same
_, non_empty_partitions = pruned_catalog._get_non_empty_partitions()
assert small_sky_order1_catalog._ddf.npartitions == pruned_catalog._ddf.npartitions
pd.testing.assert_frame_equal(small_sky_order1_catalog.compute(), pruned_catalog.compute())


def test_prune_empty_partitions_all_are_removed(small_sky_order1_catalog):
# Perform a query that forces the existence of an empty catalog
catalog = small_sky_order1_catalog.query("ra > 350 and ra < 350")
_, non_empty_partitions = catalog._get_non_empty_partitions()
assert len(non_empty_partitions) == 0

with pytest.warns(RuntimeWarning, match="slow"):
pruned_catalog = catalog.prune_empty_partitions()

# The pruned catalog is also empty
_, non_empty_partitions = pruned_catalog._get_non_empty_partitions()
assert len(non_empty_partitions) == 0


def test_skymap_data(small_sky_order1_catalog):
def func(df, healpix):
return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True)
Expand Down

0 comments on commit 1637c08

Please sign in to comment.