Skip to content

Commit

Permalink
Merge pull request #209 from astronomy-commons/issue/201/coarse-fine-…
Browse files Browse the repository at this point in the history
…filtering

Add option for coarse spatial filtering
  • Loading branch information
camposandro authored Mar 8, 2024
2 parents d7a3133 + d14f0d8 commit 49e53b5
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 38 deletions.
44 changes: 28 additions & 16 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dask.dataframe as dd
import hipscat as hc
import pandas as pd
from hipscat.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog
from hipscat.pixel_math.polygon_filter import SphericalCoordinates

from lsdb.catalog.association_catalog import AssociationCatalog
Expand All @@ -14,6 +15,7 @@
from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm
from lsdb.core.search import ConeSearch, IndexSearch, PolygonSearch
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.core.search.box_search import BoxSearch
from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data
from lsdb.dask.join_catalog_data import join_catalog_data_on, join_catalog_data_through
Expand Down Expand Up @@ -185,7 +187,7 @@ def crossmatch(
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
return Catalog(ddf, ddf_map, hc_catalog)

def cone_search(self, ra: float, dec: float, radius_arcsec: float):
def cone_search(self, ra: float, dec: float, radius_arcsec: float, fine: bool = True) -> Catalog:
"""Perform a cone search to filter the catalog
Filters to points within radius great circle distance to the point specified by ra and dec in degrees.
Expand All @@ -195,30 +197,37 @@ def cone_search(self, ra: float, dec: float, radius_arcsec: float):
ra (float): Right Ascension of the center of the cone in degrees
dec (float): Declination of the center of the cone in degrees
radius_arcsec (float): Radius of the cone in arcseconds
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new Catalog containing the points filtered to those within the cone, and the partitions that
overlap the cone.
"""
return self._search(ConeSearch(ra, dec, radius_arcsec, self.hc_structure))
return self._search(ConeSearch(ra, dec, radius_arcsec, self.hc_structure), fine)

def box(self, ra: Tuple[float, float] | None = None, dec: Tuple[float, float] | None = None) -> Catalog:
def box(
self,
ra: Tuple[float, float] | None = None,
dec: Tuple[float, float] | None = None,
fine: bool = True,
) -> Catalog:
"""Performs filtering according to right ascension and declination ranges.
Filters to points within the region specified in degrees.
Filters partitions in the catalog to those that have some overlap with the region.
Args:
ra (Tuple[float, float]): The right ascension minimum and maximum values
dec (Tuple[float, float]): The declination minimum and maximum values
ra (Tuple[float, float]): The right ascension minimum and maximum values.
dec (Tuple[float, float]): The declination minimum and maximum values.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new catalog containing the points filtered to those within the region, and the
partitions that have some overlap with it.
"""
return self._search(BoxSearch(self.hc_structure, ra=ra, dec=dec))
return self._search(BoxSearch(self.hc_structure, ra=ra, dec=dec), fine)

def polygon_search(self, vertices: List[SphericalCoordinates]) -> Catalog:
def polygon_search(self, vertices: List[SphericalCoordinates], fine: bool = True) -> Catalog:
"""Perform a polygonal search to filter the catalog.
Filters to points within the polygonal region specified in ra and dec, in degrees.
Expand All @@ -227,14 +236,15 @@ def polygon_search(self, vertices: List[SphericalCoordinates]) -> Catalog:
Args:
vertices (List[Tuple[float, float]): The list of vertices of the polygon to
filter pixels with, as a list of (ra,dec) coordinates, in degrees.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new catalog containing the points filtered to those within the
polygonal region, and the partitions that have some overlap with it.
"""
return self._search(PolygonSearch(vertices, self.hc_structure))
return self._search(PolygonSearch(vertices, self.hc_structure), fine)

def index_search(self, ids, catalog_index: hc.catalog.index.index_catalog.IndexCatalog):
def index_search(self, ids, catalog_index: HCIndexCatalog, fine: bool = True) -> Catalog:
"""Find rows by ids (or other value indexed by a catalog index).
Filters partitions in the catalog to those that could contain the ids requested.
Expand All @@ -243,30 +253,32 @@ def index_search(self, ids, catalog_index: hc.catalog.index.index_catalog.IndexC
NB: This requires a previously-computed catalog index table.
Args:
ids: values to search for
catalog_index: a pre-computed hipscat catalog index
ids: Values to search for.
catalog_index (HCIndexCatalog): A pre-computed hipscat index catalog.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new Catalog containing the points filtered to those matching the ids.
"""
return self._search(IndexSearch(ids, catalog_index))
return self._search(IndexSearch(ids, catalog_index), fine)

def _search(self, search):
def _search(self, search: AbstractSearch, fine: bool = True):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria.
Filters to points that match some finer criteria.
Args:
search: instance of AbstractSearch
search (AbstractSearch): Instance of AbstractSearch.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_pixels = search.search_partitions(self.hc_structure.get_healpix_pixels())
filtered_hc_structure = self.hc_structure.filter_from_pixel_list(filtered_pixels)
ddf_partition_map, search_ddf = self._perform_search(filtered_pixels, search)
margin = self.margin._search(search) if self.margin is not None else None
ddf_partition_map, search_ddf = self._perform_search(filtered_pixels, search, fine)
margin = self.margin._search(search, fine) if self.margin is not None else None
return Catalog(search_ddf, ddf_partition_map, filtered_hc_structure, margin=margin)

def merge(
Expand Down
13 changes: 9 additions & 4 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,25 @@ def query(self, expr: str) -> Self:
ddf = self._ddf.query(expr)
return self.__class__(ddf, self._ddf_pixel_map, self.hc_structure)

def _perform_search(self, filtered_pixels: List[HealpixPixel], search: AbstractSearch):
def _perform_search(self, filtered_pixels: List[HealpixPixel], search: AbstractSearch, fine: bool = True):
"""Performs a search on the catalog from a list of pixels to search in
Args:
filtered_pixels (List[HealpixPixel]): List of pixels in the catalog to be searched
search (AbstractSearch): The search object to perform the search with
filtered_pixels (List[HealpixPixel]): List of pixels in the catalog to be searched.
search (AbstractSearch): Instance of AbstractSearch.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A tuple containing a dictionary mapping pixel to partition index and a dask dataframe
containing the search results
"""
partitions = self._ddf.to_delayed()
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = [search.search_points(partition) for partition in targeted_partitions]
filtered_partitions = (
[search.search_points(partition) for partition in targeted_partitions]
if fine
else targeted_partitions
)
return self._construct_search_ddf(filtered_pixels, filtered_partitions)

def _construct_search_ddf(
Expand Down
8 changes: 5 additions & 3 deletions src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hipscat.pixel_tree.pixel_tree_builder import PixelTreeBuilder

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.types import DaskDFPixelMap


Expand All @@ -29,14 +30,15 @@ def __init__(
):
super().__init__(ddf, ddf_pixel_map, hc_structure)

def _search(self, search):
def _search(self, search: AbstractSearch, fine: bool = True):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria and their neighbors.
Filters to points that match some finer criteria.
Args:
search: instance of AbstractSearch
search (AbstractSearch): Instance of AbstractSearch.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
Expand Down Expand Up @@ -72,5 +74,5 @@ def _search(self, search):
filtered_pixels = list(set(filtered_search_pixels + filtered_margin_pixels))

filtered_hc_structure = self.hc_structure.filter_from_pixel_list(filtered_pixels)
ddf_partition_map, search_ddf = self._perform_search(filtered_pixels, search)
ddf_partition_map, search_ddf = self._perform_search(filtered_pixels, search, fine)
return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure)
41 changes: 28 additions & 13 deletions tests/lsdb/catalog/test_box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hipscat.pixel_math.validators import ValidatorsErrors


def test_box_search_ra(small_sky_order1_catalog, assert_divisions_are_correct):
def test_box_search_ra_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
ra_search_catalog = small_sky_order1_catalog.box(ra=(280, 300))
ra_search_df = ra_search_catalog.compute()
ra_values = ra_search_df[small_sky_order1_catalog.hc_structure.catalog_info.ra_column]
Expand All @@ -12,7 +12,9 @@ def test_box_search_ra(small_sky_order1_catalog, assert_divisions_are_correct):
assert_divisions_are_correct(ra_search_catalog)


def test_box_search_ra_margin(small_sky_order1_source_with_margin, assert_divisions_are_correct):
def test_box_search_ra_filters_correct_points_margin(
small_sky_order1_source_with_margin, assert_divisions_are_correct
):
ra_search_catalog = small_sky_order1_source_with_margin.box(ra=(280, 300))
ra_search_df = ra_search_catalog.compute()
ra_values = ra_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
Expand Down Expand Up @@ -45,12 +47,10 @@ def test_box_search_ra_complement(small_sky_order1_catalog):
assert np.array_equal(np.sort(joined_values), np.sort(all_catalog_values))


def test_box_search_ra_wrapped_values(small_sky_order1_catalog):
def test_box_search_ra_wrapped_filters_correct_points(small_sky_order1_catalog):
ra_column = small_sky_order1_catalog.hc_structure.catalog_info.ra_column

ra_search_catalog = small_sky_order1_catalog.box(ra=(330, 30))
filtered_ra_values = ra_search_catalog.compute()[ra_column]

# Some other options with values that need to be wrapped
for ra_range in [(-30, 30), (330, 390), (330, -330)]:
catalog = small_sky_order1_catalog.box(ra=ra_range)
Expand All @@ -59,7 +59,7 @@ def test_box_search_ra_wrapped_values(small_sky_order1_catalog):
assert np.array_equal(ra_values, filtered_ra_values)


def test_box_search_dec(small_sky_order1_catalog, assert_divisions_are_correct):
def test_box_search_dec_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
dec_search_catalog = small_sky_order1_catalog.box(dec=(0, 30))
dec_search_df = dec_search_catalog.compute()
dec_values = dec_search_df[small_sky_order1_catalog.hc_structure.catalog_info.dec_column]
Expand All @@ -68,42 +68,57 @@ def test_box_search_dec(small_sky_order1_catalog, assert_divisions_are_correct):
assert_divisions_are_correct(dec_search_catalog)


def test_box_search_ra_and_dec(small_sky_order1_catalog, assert_divisions_are_correct):
def test_box_search_ra_and_dec_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
search_catalog = small_sky_order1_catalog.box(ra=(280, 300), dec=(-40, -30))

search_df = search_catalog.compute()
ra_values = search_df[small_sky_order1_catalog.hc_structure.catalog_info.ra_column]
dec_values = search_df[small_sky_order1_catalog.hc_structure.catalog_info.dec_column]

assert len(search_df) < len(small_sky_order1_catalog.compute())
assert all(280 <= ra <= 300 for ra in ra_values)
assert all(-40 <= dec <= -30 for dec in dec_values)
assert_divisions_are_correct(search_catalog)


def test_box_search_filters_partitions(small_sky_order1_catalog):
ra = (280, 300)
dec = (-40, -30)
hc_box_search = small_sky_order1_catalog.hc_structure.filter_by_box(ra, dec)
box_search_catalog = small_sky_order1_catalog.box(ra, dec, fine=False)
assert len(hc_box_search.get_healpix_pixels()) == len(box_search_catalog.get_healpix_pixels())
assert len(hc_box_search.get_healpix_pixels()) == box_search_catalog._ddf.npartitions
for pixel in hc_box_search.get_healpix_pixels():
assert pixel in box_search_catalog._ddf_pixel_map


def test_box_search_coarse_versus_fine(small_sky_order1_catalog):
ra = (280, 300)
dec = (-40, -30)
coarse_box_search = small_sky_order1_catalog.box(ra, dec, fine=False)
fine_box_search = small_sky_order1_catalog.box(ra, dec)
assert coarse_box_search.get_healpix_pixels() == fine_box_search.get_healpix_pixels()
assert coarse_box_search._ddf.npartitions == fine_box_search._ddf.npartitions
assert len(coarse_box_search.compute()) > len(fine_box_search.compute())


def test_box_search_invalid_args(small_sky_order1_catalog):
# Some declination values are out of the [-90,90] bounds
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_DEC):
small_sky_order1_catalog.box(ra=(0, 30), dec=(-100, -70))

# Declination values should be in ascending order
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(dec=(0, -10))

# One or more ranges are defined with more than 2 values
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(ra=(0, 30), dec=(-30, -40, 10))
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(ra=(0, 30, 40), dec=(-40, 10))

# The range values coincide (for ra, values are wrapped)
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(ra=(100, 100))
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(ra=(0, 360))
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(dec=(50, 50))

# No range values were provided
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box(ra=None, dec=None)
13 changes: 12 additions & 1 deletion tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_cone_search_filters_partitions(small_sky_order1_catalog):
dec = -80
radius = 20 * 3600
hc_conesearch = small_sky_order1_catalog.hc_structure.filter_by_cone(ra, dec, radius)
consearch_catalog = small_sky_order1_catalog.cone_search(ra, dec, radius)
consearch_catalog = small_sky_order1_catalog.cone_search(ra, dec, radius, fine=False)
assert len(hc_conesearch.get_healpix_pixels()) == len(consearch_catalog.get_healpix_pixels())
assert len(hc_conesearch.get_healpix_pixels()) == consearch_catalog._ddf.npartitions
for pixel in hc_conesearch.get_healpix_pixels():
Expand Down Expand Up @@ -100,6 +100,17 @@ def test_cone_search_wrapped_ra(small_sky_order1_catalog):
small_sky_order1_catalog.cone_search(-100.1, 0, 1.5)


def test_cone_search_coarse_versus_fine(small_sky_order1_catalog):
ra = 0
dec = -80
radius = 20 * 3600 # 20 degrees
coarse_cone_search = small_sky_order1_catalog.cone_search(ra, dec, radius, fine=False)
fine_cone_search = small_sky_order1_catalog.cone_search(ra, dec, radius)
assert coarse_cone_search.get_healpix_pixels() == fine_cone_search.get_healpix_pixels()
assert coarse_cone_search._ddf.npartitions == fine_cone_search._ddf.npartitions
assert len(coarse_cone_search.compute()) > len(fine_cone_search.compute())


def test_invalid_dec_and_negative_radius(small_sky_order1_catalog):
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_DEC):
small_sky_order1_catalog.cone_search(0, -100.3, 1.2)
Expand Down
9 changes: 9 additions & 0 deletions tests/lsdb/catalog/test_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, a
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 1
assert_divisions_are_correct(index_search_catalog)


def test_index_search_coarse_versus_fine(small_sky_order1_catalog, small_sky_order1_id_index_dir):
catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir)
coarse_index_search = small_sky_order1_catalog.index_search([700], catalog_index, fine=False)
fine_index_search = small_sky_order1_catalog.index_search([700], catalog_index)
assert coarse_index_search.get_healpix_pixels() == fine_index_search.get_healpix_pixels()
assert coarse_index_search._ddf.npartitions == fine_index_search._ddf.npartitions
assert len(coarse_index_search.compute()) > len(fine_index_search.compute())
11 changes: 10 additions & 1 deletion tests/lsdb/catalog/test_polygon_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_polygon_search_filters_partitions(small_sky_order1_catalog):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
_, vertices_xyz = get_cartesian_polygon(vertices)
hc_polygon_search = small_sky_order1_catalog.hc_structure.filter_by_polygon(vertices_xyz)
polygon_search_catalog = small_sky_order1_catalog.polygon_search(vertices)
polygon_search_catalog = small_sky_order1_catalog.polygon_search(vertices, fine=False)
assert len(hc_polygon_search.get_healpix_pixels()) == len(polygon_search_catalog.get_healpix_pixels())
assert len(hc_polygon_search.get_healpix_pixels()) == polygon_search_catalog._ddf.npartitions
for pixel in hc_polygon_search.get_healpix_pixels():
Expand All @@ -67,6 +67,15 @@ def test_polygon_search_empty(small_sky_order1_catalog):
assert len(polygon_search_catalog.hc_structure.pixel_tree) == 1


def test_polygon_search_coarse_versus_fine(small_sky_order1_catalog):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
coarse_polygon_search = small_sky_order1_catalog.polygon_search(vertices, fine=False)
fine_polygon_search = small_sky_order1_catalog.polygon_search(vertices)
assert coarse_polygon_search.get_healpix_pixels() == fine_polygon_search.get_healpix_pixels()
assert coarse_polygon_search._ddf.npartitions == fine_polygon_search._ddf.npartitions
assert len(coarse_polygon_search.compute()) > len(fine_polygon_search.compute())


def test_polygon_search_invalid_dec(small_sky_order1_catalog):
# Some declination values are out of the [-90,90] bounds
vertices = [(-20, 100), (-20, -1), (20, -1), (20, 100)]
Expand Down

0 comments on commit 49e53b5

Please sign in to comment.