Skip to content

Commit

Permalink
Make search method public and move fine argument to search object (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
camposandro authored Jun 5, 2024
1 parent dd4b472 commit 2c44fbe
Show file tree
Hide file tree
Showing 17 changed files with 67 additions and 60 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/margins.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@
},
"outputs": [],
"source": [
"small_sky_box_filter = ztf_object.box(ra=[179.9, 180], dec=[9.5, 9.7])\n",
"small_sky_box_filter = ztf_object.box_search(ra=[179.9, 180], dec=[9.5, 9.7])\n",
"\n",
"# Plot the points from the filtered ztf pixel in green, and from the pixel's margin cache in red\n",
"plot_points(\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/pre_executed/des-gaia.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4209,9 +4209,9 @@
"ra_range = [0.0, 0.1]\n",
"dec_range = [2.45, 2.55]\n",
"\n",
"des_box = des_catalog.box(ra=ra_range, dec=dec_range).compute()\n",
"gaia_box = gaia_catalog.box(ra=ra_range, dec=dec_range).compute()\n",
"xmatch_box = xmatched.box(ra=ra_range, dec=dec_range).compute()\n",
"des_box = des_catalog.box_search(ra=ra_range, dec=dec_range).compute()\n",
"gaia_box = gaia_catalog.box_search(ra=ra_range, dec=dec_range).compute()\n",
"xmatch_box = xmatched.box_search(ra=ra_range, dec=dec_range).compute()\n",
"\n",
"ra_des = np.where(des_box[\"RA\"] > 180, des_box[\"RA\"] - 360, des_box[\"RA\"])\n",
"ra_gaia = np.where(gaia_box[\"ra\"] > 180, gaia_box[\"ra\"] - 360, gaia_box[\"ra\"])\n",
Expand Down
10 changes: 5 additions & 5 deletions docs/tutorials/working_with_large_catalogs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
"The `fine` parameter allows us to specify whether or not we desire to run the fine stage, for each search. It brings some overhead, so if your intention is to get a rough estimate of the data points for a region, you may disable it. It is always executed by default.\n",
"\n",
"```\n",
"catalog.box(..., fine=False)\n",
"catalog.box_search(..., fine=False)\n",
"catalog.cone_search(..., fine=False)\n",
"catalog.polygon_search(..., fine=False)\n",
"```"
Expand Down Expand Up @@ -259,7 +259,7 @@
"metadata": {},
"outputs": [],
"source": [
"ztf_object_box_ra = ztf_object.box(ra=[-65, -60])\n",
"ztf_object_box_ra = ztf_object.box_search(ra=[-65, -60])\n",
"ztf_object_box_ra"
]
},
Expand Down Expand Up @@ -288,7 +288,7 @@
"metadata": {},
"outputs": [],
"source": [
"ztf_object_box_dec = ztf_object.box(dec=[12, 15])\n",
"ztf_object_box_dec = ztf_object.box_search(dec=[12, 15])\n",
"ztf_object_box_dec"
]
},
Expand Down Expand Up @@ -319,7 +319,7 @@
"metadata": {},
"outputs": [],
"source": [
"ztf_object_box = ztf_object.box(ra=[-65, -60], dec=[12, 15])\n",
"ztf_object_box = ztf_object.box_search(ra=[-65, -60], dec=[12, 15])\n",
"ztf_object_box"
]
},
Expand All @@ -338,7 +338,7 @@
"id": "9a887b31",
"metadata": {},
"source": [
"We can stack a several number of filters, which are applied in sequence. For example, `catalog.box().polygon_search()` should result in a perfectly valid HiPSCat catalog containing the objects that match both filters."
"We can stack a several number of filters, which are applied in sequence. For example, `catalog.box_search().polygon_search()` should result in a perfectly valid HiPSCat catalog containing the objects that match both filters."
]
},
{
Expand Down
21 changes: 10 additions & 11 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def cone_search(self, ra: float, dec: float, radius_arcsec: float, fine: bool =
A new Catalog containing the points filtered to those within the cone, and the partitions that
overlap the cone.
"""
return self._search(ConeSearch(ra, dec, radius_arcsec), fine)
return self.search(ConeSearch(ra, dec, radius_arcsec, fine))

def box(
def box_search(
self,
ra: Tuple[float, float] | None = None,
dec: Tuple[float, float] | None = None,
Expand All @@ -235,7 +235,7 @@ def box(
A new catalog containing the points filtered to those within the region, and the
partitions that have some overlap with it.
"""
return self._search(BoxSearch(ra=ra, dec=dec), fine)
return self.search(BoxSearch(ra, dec, fine))

def polygon_search(self, vertices: List[SphericalCoordinates], fine: bool = True) -> Catalog:
"""Perform a polygonal search to filter the catalog.
Expand All @@ -252,7 +252,7 @@ def polygon_search(self, vertices: List[SphericalCoordinates], fine: bool = True
A new catalog containing the points filtered to those within the
polygonal region, and the partitions that have some overlap with it.
"""
return self._search(PolygonSearch(vertices), fine)
return self.search(PolygonSearch(vertices, fine))

def index_search(self, ids, catalog_index: HCIndexCatalog, fine: bool = True) -> Catalog:
"""Find rows by ids (or other value indexed by a catalog index).
Expand All @@ -270,7 +270,7 @@ def index_search(self, ids, catalog_index: HCIndexCatalog, fine: bool = True) ->
Returns:
A new Catalog containing the points filtered to those matching the ids.
"""
return self._search(IndexSearch(ids, catalog_index), fine)
return self.search(IndexSearch(ids, catalog_index, fine))

def order_search(self, min_order: int = 0, max_order: int | None = None) -> Catalog:
"""Filter catalog by order of HEALPix.
Expand All @@ -282,7 +282,7 @@ def order_search(self, min_order: int = 0, max_order: int | None = None) -> Cata
Returns:
A new Catalog containing only the pixels of orders specified (inclusive).
"""
return self._search(OrderSearch(min_order, max_order), fine=False)
return self.search(OrderSearch(min_order, max_order))

def pixel_search(self, pixels: List[Tuple[int, int]]) -> Catalog:
"""Finds all catalog pixels that overlap with the requested pixel set.
Expand All @@ -294,26 +294,25 @@ def pixel_search(self, pixels: List[Tuple[int, int]]) -> Catalog:
Returns:
A new Catalog containing only the pixels that overlap with the requested pixel set.
"""
return self._search(PixelSearch(pixels), fine=False)
return self.search(PixelSearch(pixels))

def _search(self, search: AbstractSearch, fine: bool = True):
def search(self, search: AbstractSearch):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria.
Filters to points that match some finer criteria.
Args:
search (AbstractSearch): Instance of AbstractSearch.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_hc_structure = search.filter_hc_catalog(self.hc_structure)
ddf_partition_map, search_ddf = self._perform_search(
filtered_hc_structure, filtered_hc_structure.get_healpix_pixels(), search, fine
filtered_hc_structure, filtered_hc_structure.get_healpix_pixels(), search
)
margin = self.margin._search(filtered_hc_structure, search, fine) if self.margin is not None else None
margin = self.margin.search(filtered_hc_structure, search) if self.margin is not None else None
return Catalog(search_ddf, ddf_partition_map, filtered_hc_structure, margin=margin)

def merge(
Expand Down
4 changes: 1 addition & 3 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,13 @@ def _perform_search(
metadata: hc.catalog.Catalog,
filtered_pixels: List[HealpixPixel],
search: AbstractSearch,
fine: bool = True,
):
"""Performs a search on the catalog from a list of pixels to search in
Args:
metadata (hc.catalog.Catalog): The metadata of the hipscat catalog.
filtered_pixels (List[HealpixPixel]): List of pixels in the catalog to be searched.
search (AbstractSearch): Instance of AbstractSearch.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A tuple containing a dictionary mapping pixel to partition index and a dask dataframe
Expand All @@ -151,7 +149,7 @@ def _perform_search(
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = (
[search.search_points(partition, metadata.catalog_info) for partition in targeted_partitions]
if fine
if search.fine
else targeted_partitions
)
return self._construct_search_ddf(filtered_pixels, filtered_partitions)
Expand Down
8 changes: 3 additions & 5 deletions src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,23 @@ def __init__(
):
super().__init__(ddf, ddf_pixel_map, hc_structure)

def _search(self, metadata: hc.catalog.Catalog, search: AbstractSearch, fine: bool = True):
def search(self, metadata: hc.catalog.Catalog, search: AbstractSearch):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria and their neighbors.
Filters to points that match some finer criteria.
Args:
metadata (hc.catalog.Catalog): The metadata of the hipscat catalog corresponding to the margin.
search (AbstractSearch): Instance of AbstractSearch.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""

# if the margin size is greater than the size of a pixel, this is an invalid search
margin_search_moc = metadata.pixel_tree.to_moc()

filtered_hc_structure = self.hc_structure.filter_by_moc(margin_search_moc)
ddf_partition_map, search_ddf = self._perform_search(
metadata, filtered_hc_structure.get_healpix_pixels(), search, fine
metadata, filtered_hc_structure.get_healpix_pixels(), search
)
return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure)
3 changes: 3 additions & 0 deletions src/lsdb/core/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class AbstractSearch(ABC):
individual rows matching the query terms.
"""

def __init__(self, fine: bool = True):
self.fine = fine

def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar:
"""Filters the hispcat catalog object to the partitions included in the search"""
max_order = hc_structure.get_max_coverage_order()
Expand Down
8 changes: 7 additions & 1 deletion src/lsdb/core/search/box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ class BoxSearch(AbstractSearch):
Filters partitions in the catalog to those that have some overlap with the region.
"""

def __init__(self, ra: Tuple[float, float] | None = None, dec: Tuple[float, float] | None = None):
def __init__(
self,
ra: Tuple[float, float] | None = None,
dec: Tuple[float, float] | None = None,
fine: bool = True,
):
super().__init__(fine)
ra = tuple(wrap_ra_angles(ra)) if ra else None
validate_box_search(ra, dec)
self.ra, self.dec = ra, dec
Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class ConeSearch(AbstractSearch):
Filters partitions in the catalog to those that have some overlap with the cone.
"""

def __init__(self, ra, dec, radius_arcsec):
def __init__(self, ra: float, dec: float, radius_arcsec: float, fine: bool = True):
super().__init__(fine)
validate_radius(radius_arcsec)
validate_declination_values(dec)

self.ra = ra
self.dec = dec
self.radius_arcsec = radius_arcsec
Expand Down
3 changes: 2 additions & 1 deletion src/lsdb/core/search/index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class IndexSearch(AbstractSearch):
NB: This requires a previously-computed catalog index table.
"""

def __init__(self, ids, catalog_index: IndexCatalog):
def __init__(self, ids, catalog_index: IndexCatalog, fine: bool = True):
super().__init__(fine)
self.ids = ids
self.catalog_index = catalog_index

Expand Down
1 change: 1 addition & 0 deletions src/lsdb/core/search/order_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class OrderSearch(AbstractSearch):
"""

def __init__(self, min_order: int = 0, max_order: int | None = None):
super().__init__(fine=False)
if max_order and min_order > max_order:
raise ValueError("The minimum order should be lower than or equal to the maximum order")
self.min_order = min_order
Expand Down
1 change: 1 addition & 0 deletions src/lsdb/core/search/pixel_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class PixelSearch(AbstractSearch):
"""

def __init__(self, pixels: List[Tuple[int, int]]):
super().__init__(fine=False)
self.pixels = [HealpixPixel(o, p) for o, p in set(pixels)]

def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar:
Expand Down
3 changes: 2 additions & 1 deletion src/lsdb/core/search/polygon_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class PolygonSearch(AbstractSearch):
Filters partitions in the catalog to those that have some overlap with the region.
"""

def __init__(self, vertices: List[SphericalCoordinates]):
def __init__(self, vertices: List[SphericalCoordinates], fine: bool = True):
super().__init__(fine)
_, dec = np.array(vertices).T
validate_declination_values(dec)
self.vertices = np.array(vertices)
Expand Down
9 changes: 4 additions & 5 deletions src/lsdb/dask/partition_indexer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from typing import List
from typing import List, Tuple

import numpy as np
from hipscat.pixel_math import HealpixPixel

from lsdb.core.search.pixel_search import PixelSearch

Expand All @@ -18,7 +17,7 @@ def __init__(self, catalog):
def __getitem__(self, item):
indices = self._parse_partition_indices(item)
pixels = self._get_pixels_from_partition_indices(indices)
return self.catalog._search(PixelSearch(pixels), fine=False)
return self.catalog.search(PixelSearch(pixels))

def _parse_partition_indices(self, item: int | List[int]) -> List[int]:
"""Parses the partition indices provided in the square brackets accessor.
Expand All @@ -28,9 +27,9 @@ def _parse_partition_indices(self, item: int | List[int]) -> List[int]:
indices = np.arange(len(self.catalog._ddf_pixel_map), dtype=object)[item].tolist()
return indices

def _get_pixels_from_partition_indices(self, indices: List[int]) -> List[HealpixPixel]:
def _get_pixels_from_partition_indices(self, indices: List[int]) -> List[Tuple[int, int]]:
"""Performs a reverse-lookup in the catalog pixel-to-partition map and returns the
pixels for the specified partition `indices`."""
inverted_pixel_map = {i: pixel for pixel, i in self.catalog._ddf_pixel_map.items()}
filtered_pixels = [inverted_pixel_map[key] for key in indices]
return filtered_pixels
return [(p.order, p.pixel) for p in filtered_pixels]
2 changes: 1 addition & 1 deletion src/lsdb/loaders/hipscat/hipscat_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _load_margin_catalog(self, metadata: hc.catalog.Catalog) -> MarginCatalog |
margin_catalog = self.config.margin_cache
if self.config.search_filter is not None:
# pylint: disable=protected-access
margin_catalog = margin_catalog._search(metadata, self.config.search_filter, fine=False)
margin_catalog = margin_catalog.search(metadata, self.config.search_filter)
elif isinstance(self.config.margin_cache, str):
margin_catalog = lsdb.read_hipscat(
path=self.config.margin_cache,
Expand Down
Loading

0 comments on commit 2c44fbe

Please sign in to comment.