From 72e809adcf11558c0b6e2bc4ddf6e27453ad09b1 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 19 Nov 2024 13:11:21 -0500 Subject: [PATCH 1/8] Remove the `lonlat` argument from `ang2vec` (#504) * Remove lonlat argument from ang2vec * Install hats from margin branch * Run tests and pre-commit in margin branch --- .github/workflows/pre-commit-ci.yml | 2 +- .github/workflows/testing-and-coverage.yml | 2 +- docs/requirements.txt | 4 ++-- requirements.txt | 2 +- src/lsdb/core/search/polygon_search.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pre-commit-ci.yml b/.github/workflows/pre-commit-ci.yml index f31b3ee6..642e7b9d 100644 --- a/.github/workflows/pre-commit-ci.yml +++ b/.github/workflows/pre-commit-ci.yml @@ -7,7 +7,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, margin ] jobs: pre-commit-ci: diff --git a/.github/workflows/testing-and-coverage.yml b/.github/workflows/testing-and-coverage.yml index 11c13d05..b318df50 100644 --- a/.github/workflows/testing-and-coverage.yml +++ b/.github/workflows/testing-and-coverage.yml @@ -7,7 +7,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, margin ] jobs: build: diff --git a/docs/requirements.txt b/docs/requirements.txt index 6a734805..8cd6da74 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,5 +10,5 @@ sphinx-autoapi sphinx-copybutton sphinx-book-theme sphinx-design -git+https://github.com/astronomy-commons/hats.git@main -git+https://github.com/astronomy-commons/hats-import.git@main +git+https://github.com/astronomy-commons/hats.git@margin +git+https://github.com/astronomy-commons/hats-import.git@margin diff --git a/requirements.txt b/requirements.txt index 54865347..f2b394e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -git+https://github.com/astronomy-commons/hats.git@main +git+https://github.com/astronomy-commons/hats.git@margin git+https://github.com/lincc-frameworks/nested-pandas.git@main git+https://github.com/lincc-frameworks/nested-dask.git@main \ No newline at end of file diff --git a/src/lsdb/core/search/polygon_search.py b/src/lsdb/core/search/polygon_search.py index 1f29d8c6..fe7e3207 100644 --- a/src/lsdb/core/search/polygon_search.py +++ b/src/lsdb/core/search/polygon_search.py @@ -63,6 +63,6 @@ def get_cartesian_polygon(vertices: list[tuple[float, float]]) -> ConvexPolygon: Returns: The convex polygon object. """ - vertices_xyz = hp.ang2vec(*np.array(vertices).T, lonlat=True) + vertices_xyz = hp.ang2vec(*np.array(vertices).T) edge_vectors = [UnitVector3d(x, y, z) for x, y, z in vertices_xyz] return ConvexPolygon(edge_vectors) From 60c9421ff15694b32ade9c052935b5eec94deaac Mon Sep 17 00:00:00 2001 From: Sean McGuire <123987820+smcguire-cmu@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:03:38 -0500 Subject: [PATCH 2/8] Add MOC Filter (#503) * add moc_filter * add search to healpix_dataset * add moc search to catalog * add unit tests * fix isort * refactor search --- src/lsdb/catalog/catalog.py | 21 ++++++++++--- src/lsdb/catalog/dataset/healpix_dataset.py | 16 ++++++++++ src/lsdb/catalog/margin_catalog.py | 17 ----------- src/lsdb/core/search/moc_search.py | 33 +++++++++++++++++++++ tests/lsdb/catalog/test_moc_search.py | 31 +++++++++++++++++++ 5 files changed, 97 insertions(+), 21 deletions(-) create mode 100644 src/lsdb/core/search/moc_search.py create mode 100644 tests/lsdb/catalog/test_moc_search.py diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 585b32f0..d407481a 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -7,6 +7,7 @@ import nested_pandas as npd import pandas as pd from hats.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog +from mocpy import MOC from pandas._libs import lib from pandas._typing import AnyAll, Axis, IndexLabel from pandas.api.extensions import no_default @@ -18,6 +19,7 @@ from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm from lsdb.core.search import BoxSearch, ConeSearch, IndexSearch, OrderSearch, PolygonSearch from lsdb.core.search.abstract_search import AbstractSearch +from lsdb.core.search.moc_search import MOCSearch from lsdb.core.search.pixel_search import PixelSearch from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data from lsdb.dask.join_catalog_data import ( @@ -324,6 +326,18 @@ def pixel_search(self, pixels: List[Tuple[int, int]]) -> Catalog: """ return self.search(PixelSearch(pixels)) + def moc_search(self, moc: MOC, fine: bool = True) -> Catalog: + """Finds all catalog points that are contained within a moc. + + Args: + moc (mocpy.MOC): The moc that defines the region for the search. + fine (bool): True if points are to be filtered, False if only partitions. Defaults to True. + + Returns: + A new Catalog containing only the points that are within the moc. + """ + return self.search(MOCSearch(moc, fine=fine)) + def search(self, search: AbstractSearch): """Find rows by reusable search algorithm. @@ -336,10 +350,9 @@ def search(self, search: AbstractSearch): Returns: A new Catalog containing the points filtered to those matching the search parameters. """ - filtered_hc_structure = search.filter_hc_catalog(self.hc_structure) - ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search) - margin = self.margin.search(search) if self.margin is not None else None - return Catalog(search_ndf, ddf_partition_map, filtered_hc_structure, margin=margin) + cat = super().search(search) + cat.margin = self.margin.search(search) if self.margin is not None else None + return cat def merge( self, diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index 9aa5f15e..0f9ea771 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -194,6 +194,22 @@ def _perform_search( ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)} return ddf_partition_map, filtered_partitions_ddf + def search(self, search: AbstractSearch): + """Find rows by reusable search algorithm. + + Filters partitions in the catalog to those that match some rough criteria. + Filters to points that match some finer criteria. + + Args: + search (AbstractSearch): Instance of AbstractSearch. + + Returns: + A new Catalog containing the points filtered to those matching the search parameters. + """ + filtered_hc_structure = search.filter_hc_catalog(self.hc_structure) + ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search) + return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure) + def map_partitions( self, func: Callable[..., npd.NestedFrame], diff --git a/src/lsdb/catalog/margin_catalog.py b/src/lsdb/catalog/margin_catalog.py index f7332a50..db3abe30 100644 --- a/src/lsdb/catalog/margin_catalog.py +++ b/src/lsdb/catalog/margin_catalog.py @@ -2,7 +2,6 @@ import nested_dask as nd from lsdb.catalog.dataset.healpix_dataset import HealpixDataset -from lsdb.core.search.abstract_search import AbstractSearch from lsdb.types import DaskDFPixelMap @@ -24,19 +23,3 @@ def __init__( hc_structure: hc.catalog.MarginCatalog, ): super().__init__(ddf, ddf_pixel_map, hc_structure) - - def search(self, search: AbstractSearch): - """Find rows by reusable search algorithm. - - Filters partitions in the catalog to those that match some rough criteria and their neighbors. - Filters to points that match some finer criteria. - - Args: - search (AbstractSearch): Instance of AbstractSearch. - - Returns: - A new Catalog containing the points filtered to those matching the search parameters. - """ - filtered_hc_structure = search.filter_hc_catalog(self.hc_structure) - ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search) - return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure) diff --git a/src/lsdb/core/search/moc_search.py b/src/lsdb/core/search/moc_search.py new file mode 100644 index 00000000..ccaec1f2 --- /dev/null +++ b/src/lsdb/core/search/moc_search.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import astropy.units as u +import nested_pandas as npd +from hats.catalog import TableProperties +from mocpy import MOC + +from lsdb.core.search.abstract_search import AbstractSearch + +if TYPE_CHECKING: + from lsdb.types import HCCatalogTypeVar + + +class MOCSearch(AbstractSearch): + """Filter the catalog by a MOC. + + Filters partitions in the catalog to those that are in a specified moc. + """ + + def __init__(self, moc: MOC, fine: bool = True): + super().__init__(fine) + self.moc = moc + + def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar: + return hc_structure.filter_by_moc(self.moc) + + def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame: + df_ras = frame[metadata.ra_column].to_numpy() + df_decs = frame[metadata.dec_column].to_numpy() + mask = self.moc.contains_lonlat(df_ras * u.deg, df_decs * u.deg) + return frame.iloc[mask] diff --git a/tests/lsdb/catalog/test_moc_search.py b/tests/lsdb/catalog/test_moc_search.py new file mode 100644 index 00000000..16aa397d --- /dev/null +++ b/tests/lsdb/catalog/test_moc_search.py @@ -0,0 +1,31 @@ +import astropy.units as u +import numpy as np +import pandas as pd +from hats.pixel_math import HealpixPixel +from mocpy import MOC + + +def test_moc_search_filters_correct_points(small_sky_order1_catalog): + search_moc = MOC.from_healpix_cells(ipix=np.array([176, 177]), depth=np.array([2, 2]), max_depth=2) + filtered_cat = small_sky_order1_catalog.moc_search(search_moc) + assert filtered_cat.get_healpix_pixels() == [HealpixPixel(1, 44)] + filtered_cat_comp = filtered_cat.compute() + cat_comp = small_sky_order1_catalog.compute() + assert np.all( + search_moc.contains_lonlat( + filtered_cat_comp["ra"].to_numpy() * u.deg, filtered_cat_comp["dec"].to_numpy() * u.deg + ) + ) + assert np.sum( + search_moc.contains_lonlat(cat_comp["ra"].to_numpy() * u.deg, cat_comp["dec"].to_numpy() * u.deg) + ) == len(filtered_cat_comp) + + +def test_moc_search_non_fine(small_sky_order1_catalog): + search_moc = MOC.from_healpix_cells(ipix=np.array([176, 180]), depth=np.array([2, 2]), max_depth=2) + filtered_cat = small_sky_order1_catalog.moc_search(search_moc, fine=False) + assert filtered_cat.get_healpix_pixels() == [HealpixPixel(1, 44), HealpixPixel(1, 45)] + pd.testing.assert_frame_equal( + filtered_cat.compute(), + small_sky_order1_catalog.pixel_search([HealpixPixel(1, 44), HealpixPixel(1, 45)]).compute(), + ) From 5e4ded520920d13854f6c138d1897679234fafba Mon Sep 17 00:00:00 2001 From: Sean McGuire <123987820+smcguire-cmu@users.noreply.github.com> Date: Fri, 22 Nov 2024 15:16:22 -0500 Subject: [PATCH 3/8] Update to hats healpix math (#509) * update to use latest healpix math functions * update fits reading * fix mypy --- src/lsdb/dask/merge_catalog_functions.py | 2 +- src/lsdb/loaders/dataframe/from_dataframe.py | 2 +- .../dataframe/margin_catalog_generator.py | 8 ++-- tests/lsdb/catalog/test_catalog.py | 37 +++++++------------ .../loaders/dataframe/test_from_dataframe.py | 2 +- 5 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 774af97c..8ae79242 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -86,7 +86,7 @@ def align_catalogs(left: Catalog, right: Catalog, add_right_margin: bool = True) else right.hc_structure.pixel_tree.to_moc() ) if right_added_radius is not None: - right_moc_depth_resol = hp.nside2resol(hp.order2nside(right_moc.max_order), arcmin=True) * 60 + right_moc_depth_resol = hp.order2resol(right_moc.max_order, arcmin=True) * 60 if right_added_radius < right_moc_depth_resol: right_moc = copy_moc(right_moc).add_neighbours() else: diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 1601c95a..48fd10e0 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -19,7 +19,7 @@ def from_dataframe( drop_empty_siblings: bool = False, partition_size: int | None = None, threshold: int | None = None, - margin_order: int | None = -1, + margin_order: int = -1, margin_threshold: float | None = 5.0, should_generate_moc: bool = True, moc_max_order: int = 10, diff --git a/src/lsdb/loaders/dataframe/margin_catalog_generator.py b/src/lsdb/loaders/dataframe/margin_catalog_generator.py index a2d2ba74..c9e88af0 100644 --- a/src/lsdb/loaders/dataframe/margin_catalog_generator.py +++ b/src/lsdb/loaders/dataframe/margin_catalog_generator.py @@ -27,7 +27,7 @@ class MarginCatalogGenerator: def __init__( self, catalog: Catalog, - margin_order: int | None = -1, + margin_order: int = -1, margin_threshold: float = 5.0, use_pyarrow_types: bool = True, **kwargs, @@ -169,12 +169,10 @@ def _create_margins(self, margin_pairs_df: pd.DataFrame) -> Dict[HealpixPixel, p A dictionary mapping each margin pixel to the respective DataFrame. """ margin_pixel_df_map: Dict[HealpixPixel, npd.NestedFrame] = {} - self.dataframe["margin_pixel"] = hp.ang2pix( - 2**self.margin_order, + self.dataframe["margin_pixel"] = hp.radec2pix( + self.margin_order, self.dataframe[self.hc_structure.catalog_info.ra_column].to_numpy(), self.dataframe[self.hc_structure.catalog_info.dec_column].to_numpy(), - lonlat=True, - nest=True, ) constrained_data = self.dataframe.reset_index().merge(margin_pairs_df, on="margin_pixel") if len(constrained_data): diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index c2648a2e..c72527df 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -10,6 +10,7 @@ import numpy.testing as npt import pandas as pd import pytest +from hats.io.file_io import read_fits_image from hats.pixel_math import HealpixPixel, spatial_index_to_healpix import lsdb @@ -208,21 +209,13 @@ def test_save_catalog_point_map(small_sky_order1_catalog, tmp_path): point_map_path = base_catalog_path / "point_map.fits" assert hc.io.file_io.does_file_or_directory_exist(point_map_path) - map_fits_image = hp.read_map(point_map_path, nest=True, h=True) - histogram, header_dict = map_fits_image[0], dict(map_fits_image[1]) + histogram = read_fits_image(point_map_path) # The histogram and the sky map histogram match assert len(small_sky_order1_catalog) == np.sum(histogram) expected_histogram = small_sky_order1_catalog.skymap_histogram(lambda df, _: len(df), order=8) npt.assert_array_equal(expected_histogram, histogram) - # Check the fits file metadata - assert header_dict["PIXTYPE"] == "HEALPIX" - assert header_dict["ORDERING"] == "NESTED" - assert header_dict["INDXSCHM"] == "IMPLICIT" - assert header_dict["OBJECT"] == "FULLSKY" - assert header_dict["NSIDE"] == 256 - def test_save_catalog_overwrite(small_sky_catalog, tmp_path): base_catalog_path = tmp_path / "small_sky" @@ -324,7 +317,7 @@ def test_prune_empty_partitions_all_are_removed(small_sky_order1_catalog): def test_skymap_data(small_sky_order1_catalog): def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) skymap = small_sky_order1_catalog.skymap_data(func) for pixel in skymap.keys(): @@ -335,7 +328,7 @@ def func(df, healpix): def test_skymap_data_order(small_sky_order1_catalog): def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) order = 3 @@ -357,7 +350,7 @@ def func(df, healpix): def test_skymap_data_wrong_order(small_sky_order1_catalog): def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) order = 0 @@ -367,7 +360,7 @@ def func(df, healpix): def test_skymap_histogram(small_sky_order1_catalog): def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) pixel_map = small_sky_order1_catalog.skymap_data(func) pixel_map = {pixel: value.compute() for pixel, value in pixel_map.items()} @@ -384,7 +377,7 @@ def func(df, healpix): def test_skymap_histogram_empty(small_sky_order1_catalog): def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) expected_img = np.full(12, 1) img = small_sky_order1_catalog.cone_search(0, 0, 1).skymap_histogram(func, default_value=1) @@ -396,7 +389,7 @@ def test_skymap_histogram_order_default(small_sky_order1_catalog): default = -1.0 def func(df, _): - return len(df) / hp.nside2pixarea(hp.order2nside(order), degrees=True) + return len(df) / hp.order2pixarea(order, degrees=True) computed_catalog = small_sky_order1_catalog.compute() order_3_pixels = spatial_index_to_healpix(computed_catalog.index.to_numpy(), order) @@ -412,7 +405,7 @@ def test_skymap_histogram_null_values_order_default(small_sky_order1_catalog): default = -1.0 def func(df, healpix): - density = len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + density = len(df) / hp.order2pixarea(healpix.order, degrees=True) return density if healpix.pixel % 2 == 0 else None pixels = list(small_sky_order1_catalog._ddf_pixel_map.keys()) @@ -438,7 +431,7 @@ def test_skymap_histogram_null_values_order(small_sky_order1_catalog): default = -1.0 def func(df, healpix): - density = len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + density = len(df) / hp.order2pixarea(healpix.order, degrees=True) return density if healpix.pixel % 2 == 0 else None pixels = list(small_sky_order1_catalog._ddf_pixel_map.keys()) @@ -462,7 +455,7 @@ def test_skymap_histogram_order_empty(small_sky_order1_catalog): order = 3 def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) catalog = small_sky_order1_catalog.cone_search(0, 0, 1) _, non_empty_partitions = catalog._get_non_empty_partitions() @@ -477,7 +470,7 @@ def test_skymap_histogram_order_some_partitions_empty(small_sky_order1_catalog): order = 3 def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) catalog = small_sky_order1_catalog.query("ra > 350 and dec < -50") _, non_empty_partitions = catalog._get_non_empty_partitions() @@ -502,7 +495,7 @@ def test_skymap_plot(small_sky_order1_catalog, mocker): mocker.patch("lsdb.catalog.dataset.healpix_dataset.plot_healpix_map") def func(df, healpix): - return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + return len(df) / hp.order2pixarea(healpix.order, degrees=True) small_sky_order1_catalog.skymap(func) pixel_map = small_sky_order1_catalog.skymap_data(func) @@ -600,9 +593,7 @@ def add_col(df, pixel): assert isinstance(mapped, Catalog) assert "pix" in mapped.columns mapcomp = mapped.compute() - pix_col = hp.ang2pix( - hp.order2nside(1), mapcomp["ra"].to_numpy(), mapcomp["dec"].to_numpy(), lonlat=True, nest=True - ) + pix_col = hp.radec2pix(1, mapcomp["ra"].to_numpy(), mapcomp["dec"].to_numpy()) assert np.all(mapcomp["pix"] == pix_col) diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index 33777e30..eefd63f5 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -95,7 +95,7 @@ def test_partitions_on_map_equal_partitions_in_df(small_sky_order1_df, small_sky partition_df = catalog._ddf.partitions[partition_index].compute() assert isinstance(partition_df, pd.DataFrame) for _, row in partition_df.iterrows(): - ipix = hp.ang2pix(2**hp_pixel.order, row["ra"], row["dec"], nest=True, lonlat=True) + ipix = hp.radec2pix(hp_pixel.order, row["ra"], row["dec"]) assert ipix == hp_pixel.pixel From d007a6dc71314c95a4c1264b959c1c798d0d6e8a Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Fri, 22 Nov 2024 15:52:11 -0500 Subject: [PATCH 4/8] Update box search for migration (#507) * Adapt box filtering for migration to mocpy * Apply changes to docs notebook * Allow for coinciding values for ra range * Test box edge case with dec of 90 deg * Comment hats_import installation inside notebook --- docs/tutorials/filtering_large_catalogs.ipynb | 78 +----------------- docs/tutorials/import_catalogs.ipynb | 8 +- src/lsdb/catalog/catalog.py | 10 +-- src/lsdb/core/search/box_search.py | 42 ++++------ tests/lsdb/catalog/test_box_search.py | 79 +++++++++---------- tests/lsdb/catalog/test_catalog.py | 2 +- tests/lsdb/catalog/test_polygon_search.py | 11 +-- tests/lsdb/loaders/hats/test_read_hats.py | 4 +- 8 files changed, 70 insertions(+), 164 deletions(-) diff --git a/docs/tutorials/filtering_large_catalogs.ipynb b/docs/tutorials/filtering_large_catalogs.ipynb index 10e6ab19..2d1cc34f 100644 --- a/docs/tutorials/filtering_large_catalogs.ipynb +++ b/docs/tutorials/filtering_large_catalogs.ipynb @@ -217,79 +217,7 @@ "source": [ "### Box search\n", "\n", - "A box search can be defined by:\n", - "\n", - "- Right ascension band `(ra1, ra2)`\n", - "- Declination band `(dec1, dec2)`\n", - "- Both right ascension and declination bands `[(ra1, ra2), (dec1, dec2)]`" - ] - }, - { - "cell_type": "markdown", - "id": "ca8c3815-0165-4123-8173-ccb6dfaa7eb5", - "metadata": {}, - "source": [ - "#### Right ascension band" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5e6ce24-34a1-42f5-99d6-be8ca0772113", - "metadata": {}, - "outputs": [], - "source": [ - "ztf_object_box_ra = ztf_object.box_search(ra=[-65, -60])\n", - "ztf_object_box_ra" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12edf29c-da39-47e5-9507-ebe4258ba789", - "metadata": {}, - "outputs": [], - "source": [ - "ztf_object_box_ra.plot_pixels(plot_title=\"ZTF_DR14 - RA band pixel map\")" - ] - }, - { - "cell_type": "markdown", - "id": "83c465ba-4548-4108-b9ec-5396ac760f64", - "metadata": {}, - "source": [ - "#### Declination band" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03966efa-4e65-4aa8-ad1f-3e54bbf6a281", - "metadata": {}, - "outputs": [], - "source": [ - "ztf_object_box_dec = ztf_object.box_search(dec=[12, 15])\n", - "ztf_object_box_dec" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5de5b4a-78f6-4295-8c38-fc80a9a6f2a9", - "metadata": {}, - "outputs": [], - "source": [ - "ztf_object_box_dec.plot_pixels(plot_title=\"ZTF_DR14 - DEC band pixel map\")" - ] - }, - { - "cell_type": "markdown", - "id": "87952874-26fb-433f-adb7-28ac13c7cee3", - "metadata": { - "tags": [] - }, - "source": [ - "#### Right ascension and declination bands" + "A box search can be defined by right ascension and declination bands `[(ra1, ra2), (dec1, dec2)]`." ] }, { @@ -384,7 +312,7 @@ ], "metadata": { "kernelspec": { - "display_name": "demo", + "display_name": "lsdb", "language": "python", "name": "python3" }, @@ -398,7 +326,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/docs/tutorials/import_catalogs.ipynb b/docs/tutorials/import_catalogs.ipynb index 1c666ddf..0cb4d3f7 100644 --- a/docs/tutorials/import_catalogs.ipynb +++ b/docs/tutorials/import_catalogs.ipynb @@ -117,7 +117,7 @@ "id": "3842520c", "metadata": {}, "source": [ - "Let's install the latest release of hats-import:" + "Please uncomment the next line to install the latest release of hats-import:" ] }, { @@ -132,7 +132,7 @@ }, "outputs": [], "source": [ - "!pip install git+https://github.com/astronomy-commons/hats-import.git@main --quiet" + "#!pip install git+https://github.com/astronomy-commons/hats-import.git@main --quiet" ] }, { @@ -267,7 +267,7 @@ ], "metadata": { "kernelspec": { - "display_name": "demo", + "display_name": "lsdb", "language": "python", "name": "python3" }, @@ -281,7 +281,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index d407481a..fa6f71bc 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -245,13 +245,9 @@ def cone_search(self, ra: float, dec: float, radius_arcsec: float, fine: bool = """ return self.search(ConeSearch(ra, dec, radius_arcsec, fine)) - def box_search( - self, - ra: Tuple[float, float] | None = None, - dec: Tuple[float, float] | None = None, - fine: bool = True, - ) -> Catalog: - """Performs filtering according to right ascension and declination ranges. + def box_search(self, ra: Tuple[float, float], dec: Tuple[float, float], fine: bool = True) -> Catalog: + """Performs filtering according to right ascension and declination ranges. The right ascension + edges follow great arc circles and the declination edges follow small arc circles. Filters to points within the region specified in degrees. Filters partitions in the catalog to those that have some overlap with the region. diff --git a/src/lsdb/core/search/box_search.py b/src/lsdb/core/search/box_search.py index 747873ed..2a82ec89 100644 --- a/src/lsdb/core/search/box_search.py +++ b/src/lsdb/core/search/box_search.py @@ -13,18 +13,14 @@ class BoxSearch(AbstractSearch): """Perform a box search to filter the catalog. This type of search is used for a - range of ra or dec (one or the other). If both, a polygonal search should be used. + range of right ascension or declination, where the right ascension edges follow + great arc circles and the declination edges follow small arc circles. Filters to points within the ra / dec region, specified in degrees. Filters partitions in the catalog to those that have some overlap with the region. """ - def __init__( - self, - ra: tuple[float, float] | None = None, - dec: tuple[float, float] | None = None, - fine: bool = True, - ): + def __init__(self, ra: tuple[float, float], dec: tuple[float, float], fine: bool = True): super().__init__(fine) ra = tuple(wrap_ra_angles(ra)) if ra else None validate_box(ra, dec) @@ -41,8 +37,8 @@ def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> np def box_filter( data_frame: npd.NestedFrame, - ra: tuple[float, float] | None, - dec: tuple[float, float] | None, + ra: tuple[float, float], + dec: tuple[float, float], metadata: TableProperties, ) -> npd.NestedFrame: """Filters a dataframe to only include points within the specified box region. @@ -56,28 +52,22 @@ def box_filter( Returns: A new DataFrame with the rows from `data_frame` filtered to only the points inside the box region. """ - mask = np.ones(len(data_frame), dtype=bool) - if ra is not None: - ra_values = data_frame[metadata.ra_column] - wrapped_ra = np.asarray(wrap_ra_angles(ra_values)) - mask_ra = _create_ra_mask(ra, wrapped_ra) - mask = np.logical_and(mask, mask_ra) - if dec is not None: - dec_values = data_frame[metadata.dec_column].to_numpy() - mask_dec = np.logical_and(dec[0] <= dec_values, dec_values <= dec[1]) - mask = np.logical_and(mask, mask_dec) - data_frame = data_frame.iloc[mask] + ra_values = data_frame[metadata.ra_column].to_numpy() + dec_values = data_frame[metadata.dec_column].to_numpy() + wrapped_ra = wrap_ra_angles(ra_values) + mask_ra = _create_ra_mask(ra, wrapped_ra) + mask_dec = (dec[0] <= dec_values) & (dec_values <= dec[1]) + data_frame = data_frame.iloc[mask_ra & mask_dec] return data_frame def _create_ra_mask(ra: tuple[float, float], values: np.ndarray) -> np.ndarray: """Creates the mask to filter right ascension values. If this range crosses the discontinuity line (0 degrees), we have a branched logical operation.""" - if ra[0] <= ra[1]: - mask = np.logical_and(ra[0] <= values, values <= ra[1]) + if ra[0] == ra[1]: + return np.ones(len(values), dtype=bool) + if ra[0] < ra[1]: + mask = (values >= ra[0]) & (values <= ra[1]) else: - mask = np.logical_or( - np.logical_and(ra[0] <= values, values <= 360), - np.logical_and(0 <= values, values <= ra[1]), - ) + mask = ((values >= ra[0]) & (values <= 360)) | ((values >= 0) & (values <= ra[1])) return mask diff --git a/tests/lsdb/catalog/test_box_search.py b/tests/lsdb/catalog/test_box_search.py index 5dee4241..bfcea3e8 100644 --- a/tests/lsdb/catalog/test_box_search.py +++ b/tests/lsdb/catalog/test_box_search.py @@ -1,48 +1,51 @@ -import nested_dask as nd -import nested_pandas as npd import numpy as np import pytest from hats.pixel_math.validators import ValidatorsErrors -def test_box_search_ra_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct): - ra_search_catalog = small_sky_order1_catalog.box_search(ra=(280, 300)) - assert isinstance(ra_search_catalog._ddf, nd.NestedFrame) - ra_search_df = ra_search_catalog.compute() - assert isinstance(ra_search_df, npd.NestedFrame) - ra_values = ra_search_df[small_sky_order1_catalog.hc_structure.catalog_info.ra_column] - assert len(ra_search_df) < len(small_sky_order1_catalog.compute()) +def test_box_search_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct): + search_catalog = small_sky_order1_catalog.box_search(ra=(280, 300), dec=(-40, -30)) + search_df = search_catalog.compute() + ra_values = search_df[small_sky_order1_catalog.hc_structure.catalog_info.ra_column] + dec_values = search_df[small_sky_order1_catalog.hc_structure.catalog_info.dec_column] + assert len(search_df) < len(small_sky_order1_catalog.compute()) assert all(280 <= ra <= 300 for ra in ra_values) - assert_divisions_are_correct(ra_search_catalog) + assert all(-40 <= dec <= -30 for dec in dec_values) + assert_divisions_are_correct(search_catalog) -def test_box_search_ra_filters_correct_points_margin( +def test_box_search_filters_correct_points_margin( small_sky_order1_source_with_margin, assert_divisions_are_correct ): - ra_search_catalog = small_sky_order1_source_with_margin.box_search(ra=(280, 300)) + ra_search_catalog = small_sky_order1_source_with_margin.box_search(ra=(280, 300), dec=(-90, 30)) ra_search_df = ra_search_catalog.compute() ra_values = ra_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column] + dec_values = ra_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.dec_column] assert len(ra_search_df) < len(small_sky_order1_source_with_margin.compute()) assert all(280 <= ra <= 300 for ra in ra_values) + assert all(-90 <= dec <= 30 for dec in dec_values) assert_divisions_are_correct(ra_search_catalog) assert ra_search_catalog.margin is not None ra_margin_search_df = ra_search_catalog.margin.compute() ra_values = ra_margin_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column] + dec_values = ra_margin_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.dec_column] assert len(ra_margin_search_df) < len(small_sky_order1_source_with_margin.margin.compute()) assert all(280 <= ra <= 300 for ra in ra_values) + assert all(-90 <= dec <= 30 for dec in dec_values) assert_divisions_are_correct(ra_search_catalog.margin) def test_box_search_ra_complement(small_sky_order1_catalog): + dec = (-90, 90) ra_column = small_sky_order1_catalog.hc_structure.catalog_info.ra_column - ra_search_catalog = small_sky_order1_catalog.box_search(ra=(280, 300)) + ra_search_catalog = small_sky_order1_catalog.box_search(ra=(280, 300), dec=dec) filtered_ra_values = ra_search_catalog.compute()[ra_column] assert len(filtered_ra_values) == 34 # The complement search contains the remaining catalog points - complement_search_catalog = small_sky_order1_catalog.box_search(ra=(300, 280)) + complement_search_catalog = small_sky_order1_catalog.box_search(ra=(300, 280), dec=dec) complement_search_ra_values = complement_search_catalog.compute()[ra_column] assert len(complement_search_ra_values) == 97 @@ -52,35 +55,36 @@ def test_box_search_ra_complement(small_sky_order1_catalog): def test_box_search_ra_wrapped_filters_correct_points(small_sky_order1_catalog): + dec = (-90, 90) ra_column = small_sky_order1_catalog.hc_structure.catalog_info.ra_column - ra_search_catalog = small_sky_order1_catalog.box_search(ra=(330, 30)) + ra_search_catalog = small_sky_order1_catalog.box_search(ra=(330, 30), dec=dec) filtered_ra_values = ra_search_catalog.compute()[ra_column] # Some other options with values that need to be wrapped for ra_range in [(-30, 30), (330, 390), (330, -330)]: - catalog = small_sky_order1_catalog.box_search(ra=ra_range) + catalog = small_sky_order1_catalog.box_search(ra=ra_range, dec=dec) ra_values = catalog.compute()[ra_column] assert all((0 <= ra <= 30 or 330 <= ra <= 360) for ra in ra_values) assert np.array_equal(ra_values, filtered_ra_values) -def test_box_search_dec_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct): - dec_search_catalog = small_sky_order1_catalog.box_search(dec=(0, 30)) - dec_search_df = dec_search_catalog.compute() - dec_values = dec_search_df[small_sky_order1_catalog.hc_structure.catalog_info.dec_column] - assert len(dec_search_df) < len(small_sky_order1_catalog.compute()) - assert all(0 <= dec <= 30 for dec in dec_values) - assert_divisions_are_correct(dec_search_catalog) +def test_box_search_ra_boundary(small_sky_order1_catalog): + dec = (-90, 90) + ra_column = small_sky_order1_catalog.hc_structure.catalog_info.ra_column + dec_column = small_sky_order1_catalog.hc_structure.catalog_info.dec_column + + ra_search_catalog = small_sky_order1_catalog.box_search(ra=(0, 0), dec=dec) + ra_search_df = ra_search_catalog.compute() + ra_values = ra_search_df[ra_column] + dec_values = ra_search_df[dec_column] + assert len(ra_search_df) > 0 + assert all((0 <= ra <= 360) for ra in ra_values) + assert all((-90 <= dec <= 90) for dec in dec_values) -def test_box_search_ra_and_dec_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct): - search_catalog = small_sky_order1_catalog.box_search(ra=(280, 300), dec=(-40, -30)) - search_df = search_catalog.compute() - ra_values = search_df[small_sky_order1_catalog.hc_structure.catalog_info.ra_column] - dec_values = search_df[small_sky_order1_catalog.hc_structure.catalog_info.dec_column] - assert len(search_df) < len(small_sky_order1_catalog.compute()) - assert all(280 <= ra <= 300 for ra in ra_values) - assert all(-40 <= dec <= -30 for dec in dec_values) - assert_divisions_are_correct(search_catalog) + for ra_range in [(0, 360), (360, 0)]: + catalog_df = small_sky_order1_catalog.box_search(ra=ra_range, dec=dec).compute() + assert np.array_equal(catalog_df[ra_column], ra_values) + assert np.array_equal(catalog_df[dec_column], dec_values) def test_box_search_filters_partitions(small_sky_order1_catalog): @@ -110,19 +114,14 @@ def test_box_search_invalid_args(small_sky_order1_catalog): small_sky_order1_catalog.box_search(ra=(0, 30), dec=(-100, -70)) # Declination values should be in ascending order with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): - small_sky_order1_catalog.box_search(dec=(0, -10)) + small_sky_order1_catalog.box_search(ra=(0, 30), dec=(0, -10)) # One or more ranges are defined with more than 2 values with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): - small_sky_order1_catalog.box_search(ra=(0, 30), dec=(-30, -40, 10)) + small_sky_order1_catalog.box_search(ra=(0, 30), dec=(-40, -30, 10)) with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): small_sky_order1_catalog.box_search(ra=(0, 30, 40), dec=(-40, 10)) - # The range values coincide (for ra, values are wrapped) - with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): - small_sky_order1_catalog.box_search(ra=(100, 100)) - with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): - small_sky_order1_catalog.box_search(ra=(0, 360)) with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): - small_sky_order1_catalog.box_search(dec=(50, 50)) + small_sky_order1_catalog.box_search(ra=(0, 50), dec=(50, 50)) # No range values were provided with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE): small_sky_order1_catalog.box_search(ra=None, dec=None) diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index c72527df..2c0cbae5 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -685,7 +685,7 @@ def test_filtered_catalog_has_undetermined_len(small_sky_order1_catalog, small_s vertices = [(300, -50), (300, -55), (272, -55), (272, -50)] len(small_sky_order1_catalog.polygon_search(vertices)) with pytest.raises(ValueError, match="undetermined"): - len(small_sky_order1_catalog.box_search(ra=(280, 300))) + len(small_sky_order1_catalog.box_search(ra=(280, 300), dec=(0, 30))) with pytest.raises(ValueError, match="undetermined"): len(small_sky_order1_catalog.order_search(max_order=2)) with pytest.raises(ValueError, match="undetermined"): diff --git a/tests/lsdb/catalog/test_polygon_search.py b/tests/lsdb/catalog/test_polygon_search.py index cd64444d..18e14fec 100644 --- a/tests/lsdb/catalog/test_polygon_search.py +++ b/tests/lsdb/catalog/test_polygon_search.py @@ -73,16 +73,9 @@ def test_polygon_search_coarse_versus_fine(small_sky_order1_catalog): def test_polygon_search_invalid_dec(small_sky_order1_catalog): - # Some declination values are out of the [-90,90] bounds - vertices = [(-20, 100), (-20, -1), (20, -1), (20, 100)] + # Some declination values are out of the [-90,90[ bounds with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_DEC): - small_sky_order1_catalog.polygon_search(vertices) - - -def test_polygon_search_invalid_shape(small_sky_order1_catalog): - """The polygon is not convex, so the shape is invalid""" - with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_CONCAVE_SHAPE): - vertices = [(45, 30), (60, 60), (90, 45), (60, 50)] + vertices = [(-20, 100), (-20, -1), (20, -1), (20, 100)] small_sky_order1_catalog.polygon_search(vertices) diff --git a/tests/lsdb/loaders/hats/test_read_hats.py b/tests/lsdb/loaders/hats/test_read_hats.py index 4a979b68..a8fda97d 100644 --- a/tests/lsdb/loaders/hats/test_read_hats.py +++ b/tests/lsdb/loaders/hats/test_read_hats.py @@ -174,9 +174,9 @@ def test_read_hats_subset_with_cone_search(small_sky_order1_dir, small_sky_order def test_read_hats_subset_with_box_search(small_sky_order1_dir, small_sky_order1_catalog): - box_search = BoxSearch(ra=(0, 10), dec=(-20, 10)) + box_search = BoxSearch(ra=(300, 320), dec=(-40, -10)) # Filtering using catalog's box_search - box_search_catalog = small_sky_order1_catalog.box_search(ra=(0, 10), dec=(-20, 10)) + box_search_catalog = small_sky_order1_catalog.box_search(ra=(300, 320), dec=(-40, -10)) # Filtering when calling `read_hats` box_search_catalog_2 = lsdb.read_hats(small_sky_order1_dir, search_filter=box_search) assert isinstance(box_search_catalog_2, lsdb.Catalog) From d899327bd7001ce4072579d70a27d4f3e5311947 Mon Sep 17 00:00:00 2001 From: Sean McGuire <123987820+smcguire-cmu@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:04:57 -0500 Subject: [PATCH 5/8] Add plot points method (#510) * add plot points method * label colorbar * unit test plot_points * docstrings * types * update pylint * Apply suggestions from code review Co-authored-by: Sandro Campos --------- Co-authored-by: Sandro Campos --- src/.pylintrc | 2 +- src/lsdb/catalog/dataset/healpix_dataset.py | 112 +++++++++++++++++++- src/lsdb/core/plotting/plot_points.py | 99 +++++++++++++++++ tests/lsdb/catalog/test_catalog.py | 66 ++++++++++++ 4 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 src/lsdb/core/plotting/plot_points.py diff --git a/src/.pylintrc b/src/.pylintrc index 62ab5a7b..ee4884ee 100644 --- a/src/.pylintrc +++ b/src/.pylintrc @@ -278,7 +278,7 @@ exclude-too-few-public-methods= ignored-parents= # Maximum number of arguments for function / method. -max-args=10 +max-args=12 # Maximum number of positional arguments. max-positional-arguments=15 diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index 0f9ea771..678db29f 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -2,8 +2,9 @@ import warnings from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, cast +from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, cast +import astropy import dask import dask.dataframe as dd import hats as hc @@ -11,10 +12,13 @@ import nested_pandas as npd import numpy as np import pandas as pd +from astropy.coordinates import SkyCoord +from astropy.units import Quantity from astropy.visualization.wcsaxes import WCSAxes +from astropy.visualization.wcsaxes.frame import BaseFrame from dask.delayed import Delayed, delayed from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset -from hats.inspection.visualize_catalog import plot_healpix_map +from hats.inspection.visualize_catalog import get_fov_moc_from_wcs, initialize_wcs_axes, plot_healpix_map from hats.pixel_math import HealpixPixel from hats.pixel_math.healpix_pixel_function import get_pixel_argsort from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN @@ -27,8 +31,10 @@ from lsdb import io from lsdb.catalog.dataset.dataset import Dataset +from lsdb.core.plotting.plot_points import plot_points from lsdb.core.plotting.skymap import compute_skymap, perform_inner_skymap from lsdb.core.search.abstract_search import AbstractSearch +from lsdb.core.search.moc_search import MOCSearch from lsdb.dask.merge_catalog_functions import concat_metas from lsdb.io.schema import get_arrow_schema from lsdb.types import DaskDFPixelMap @@ -691,3 +697,105 @@ def reduce_part(df): hc_catalog = self._create_modified_hc_structure(**hc_updates) hc_catalog.schema = get_arrow_schema(ndf) return self.__class__(ndf, self._ddf_pixel_map, hc_catalog) + + def plot_points( + self, + *, + ra_column: str | None = None, + dec_column: str | None = None, + color_col: str | None = None, + projection: str = "MOL", + title: str | None = None, + fov: Quantity | Tuple[Quantity, Quantity] | None = None, + center: SkyCoord | None = None, + wcs: astropy.wcs.WCS | None = None, + frame_class: Type[BaseFrame] | None = None, + ax: WCSAxes | None = None, + fig: Figure | None = None, + **kwargs, + ): + """Plots the points in the catalog as a scatter plot + + Performs a scatter plot on a WCSAxes after computing the points of the catalog. + This will perform compute on the catalog, and so may be slow/resource intensive. + If the fov or wcs args are set, only the partitions in the catalog visible to the plot will be + computed. + The scatter points can be colored by a column of the catalog by using the `color_col` kwarg + + Args: + ra_column (str | None): The column to use as the RA of the points to plot. Defaults to the + catalog's default RA column. Useful for plotting joined or cross-matched points + dec_column (str | None): The column to use as the Declination of the points to plot. Defaults to + the catalog's default Declination column. Useful for plotting joined or cross-matched points + color_col (str | None): The column to use as the color array for the scatter plot. Allows coloring + of the points by the values of a given column. + projection (str): The projection to use in the WCS. Available projections listed at + https://docs.astropy.org/en/stable/wcs/supported_projections.html + title (str): The title of the plot + fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an + astropy Quantity with an angular unit, or a tuple of quantities for different longitude and \ + latitude FOVs (Default covers the full sky) + center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0)) + wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters + are ignored and the parameters from the WCS object is used. + frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized + with. if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for + full sky projection. If FOV is set, RectangularFrame is used) + ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be + used. If specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set + with the WCS object used in the axes. (Default: None) + fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created, + unless ax is specified (Default: None) + **kwargs: Additional kwargs to pass to creating the matplotlib `scatter` function. These include + `c` for color, `s` for the size of hte points, `marker` for the maker type, `cmap` and `norm` + if `color_col` is used + + Returns: + Tuple[Figure, WCSAxes] - The figure and axes used for the plot + """ + fig, ax, wcs = initialize_wcs_axes( + projection=projection, + fov=fov, + center=center, + wcs=wcs, + frame_class=frame_class, + ax=ax, + fig=fig, + figsize=(9, 5), + ) + + fov_moc = get_fov_moc_from_wcs(wcs) + + computed_catalog = ( + self.search(MOCSearch(fov_moc)).compute() if fov_moc is not None else self.compute() + ) + + if ra_column is None: + ra_column = self.hc_structure.catalog_info.ra_column + if dec_column is None: + dec_column = self.hc_structure.catalog_info.dec_column + + if ra_column is None: + raise ValueError("Catalog has no RA Column") + + if dec_column is None: + raise ValueError("Catalog has no DEC Column") + + if title is None: + title = f"Points in the {self.name} catalog" + + return plot_points( + computed_catalog, + ra_column, + dec_column, + color_col=color_col, + projection=projection, + title=title, + fov=fov, + center=center, + wcs=wcs, + frame_class=frame_class, + ax=ax, + fig=fig, + **kwargs, + ) diff --git a/src/lsdb/core/plotting/plot_points.py b/src/lsdb/core/plotting/plot_points.py new file mode 100644 index 00000000..964ee91c --- /dev/null +++ b/src/lsdb/core/plotting/plot_points.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import Tuple, Type + +import astropy +import matplotlib.pyplot as plt +import pandas as pd +from astropy.coordinates import SkyCoord +from astropy.units import Quantity +from astropy.visualization.wcsaxes import WCSAxes +from astropy.visualization.wcsaxes.frame import BaseFrame +from hats.inspection.visualize_catalog import initialize_wcs_axes +from matplotlib.figure import Figure +from mocpy.moc.plot.utils import _set_wcs + + +def plot_points( + df: pd.DataFrame, + ra_column: str, + dec_column: str, + *, + color_col: str | None = None, + projection: str = "MOL", + title: str = "", + fov: Quantity | Tuple[Quantity, Quantity] | None = None, + center: SkyCoord | None = None, + wcs: astropy.wcs.WCS | None = None, + frame_class: Type[BaseFrame] | None = None, + ax: WCSAxes | None = None, + fig: Figure | None = None, + **kwargs, +): + """Plots the points in a given dataframe as a scatter plot + + Performs a scatter plot on a WCSAxes with the points in a dataframe. + The scatter points can be colored by a column of the catalog by using the `color_col` kwarg + + Args: + ra_column (str | None): The column to use as the RA of the points to plot. Defaults to the + catalog's default RA column. Useful for plotting joined or cross-matched points + dec_column (str | None): The column to use as the Declination of the points to plot. Defaults to + the catalog's default Declination column. Useful for plotting joined or cross-matched points + color_col (str | None): The column to use as the color array for the scatter plot. Allows coloring + of the points by the values of a given column. + projection (str): The projection to use in the WCS. Available projections listed at + https://docs.astropy.org/en/stable/wcs/supported_projections.html + title (str): The title of the plot + fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an + astropy Quantity with an angular unit, or a tuple of quantities for different longitude and \ + latitude FOVs (Default covers the full sky) + center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0)) + wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters + are ignored and the parameters from the WCS object is used. + frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized with. + if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for full + sky projection. If FOV is set, RectangularFrame is used) + ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be used. If + specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set with the WCS + object used in the axes. (Default: None) + fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created, unless + ax is specified (Default: None) + **kwargs: Additional kwargs to pass to creating the matplotlib `scatter` function. These include + `c` for color, `s` for the size of hte points, `marker` for the maker type, `cmap` and `norm` + if `color_col` is used + + Returns: + Tuple[Figure, WCSAxes] - The figure and axes used for the plot + """ + fig, ax, wcs = initialize_wcs_axes( + projection=projection, + fov=fov, + center=center, + wcs=wcs, + frame_class=frame_class, + ax=ax, + fig=fig, + figsize=(9, 5), + ) + + ra = df[ra_column].to_numpy() + dec = df[dec_column].to_numpy() + if color_col is not None: + kwargs["c"] = df[color_col].to_numpy() + collection = None + if len(ra) > 0: + collection = ax.scatter(ra, dec, transform=ax.get_transform("icrs"), **kwargs) + + # Set projection + _set_wcs(ax, wcs) + + ax.coords[0].set_format_unit("deg") + + plt.grid() + plt.ylabel("Dec") + plt.xlabel("RA") + plt.title(title) + if color_col is not None and collection is not None: + plt.colorbar(collection, label=color_col) + return fig, ax diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index 2c0cbae5..be1ef43a 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -1,23 +1,36 @@ from pathlib import Path +import astropy.units as u import dask.array as da import dask.dataframe as dd import hats as hc import hats.pixel_math.healpix_shim as hp +import matplotlib.pyplot as plt import nested_dask as nd import nested_pandas as npd import numpy as np import numpy.testing as npt import pandas as pd import pytest +from astropy.coordinates import SkyCoord +from astropy.visualization.wcsaxes import WCSAxes +from hats.inspection.visualize_catalog import get_fov_moc_from_wcs from hats.io.file_io import read_fits_image from hats.pixel_math import HealpixPixel, spatial_index_to_healpix +from mocpy import WCS import lsdb from lsdb import Catalog +from lsdb.core.search.moc_search import MOCSearch from lsdb.dask.merge_catalog_functions import filter_by_spatial_index_to_pixel +@pytest.fixture(autouse=True) +def reset_matplotlib(): + yield + plt.close("all") + + def test_catalog_pixels_equals_hc_catalog_pixels(small_sky_order1_catalog, small_sky_order1_hats_catalog): assert small_sky_order1_catalog.get_healpix_pixels() == small_sky_order1_hats_catalog.get_healpix_pixels() @@ -745,3 +758,56 @@ def test_modified_hc_structure_is_a_deep_copy(small_sky_order1_catalog): # The rows of the new structure are invalidated assert modified_hc_structure.catalog_info.total_rows == 0 + + +def test_plot_points(small_sky_order1_catalog, mocker): + mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter") + _, ax = small_sky_order1_catalog.plot_points() + comp_cat = small_sky_order1_catalog.compute() + WCSAxes.scatter.assert_called_once() + npt.assert_array_equal(WCSAxes.scatter.call_args[0][0], comp_cat["ra"]) + npt.assert_array_equal(WCSAxes.scatter.call_args[0][1], comp_cat["dec"]) + assert WCSAxes.scatter.call_args.kwargs["transform"] == ax.get_transform("icrs") + + +def test_plot_points_fov(small_sky_order1_catalog, mocker): + mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter") + fig = plt.figure(figsize=(10, 6)) + center = SkyCoord(350, -80, unit="deg") + fov = 10 * u.deg + wcs = WCS(fig=fig, fov=fov, center=center, projection="MOL").w + wcs_moc = get_fov_moc_from_wcs(wcs) + _, ax = small_sky_order1_catalog.plot_points(fov=fov, center=center) + comp_cat = small_sky_order1_catalog.search(MOCSearch(wcs_moc)).compute() + WCSAxes.scatter.assert_called_once() + npt.assert_array_equal(WCSAxes.scatter.call_args[0][0], comp_cat["ra"]) + npt.assert_array_equal(WCSAxes.scatter.call_args[0][1], comp_cat["dec"]) + assert WCSAxes.scatter.call_args.kwargs["transform"] == ax.get_transform("icrs") + + +def test_plot_points_wcs(small_sky_order1_catalog, mocker): + mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter") + fig = plt.figure(figsize=(10, 6)) + center = SkyCoord(350, -80, unit="deg") + fov = 10 * u.deg + wcs = WCS(fig=fig, fov=fov, center=center).w + wcs_moc = get_fov_moc_from_wcs(wcs) + _, ax = small_sky_order1_catalog.plot_points(wcs=wcs) + comp_cat = small_sky_order1_catalog.search(MOCSearch(wcs_moc)).compute() + WCSAxes.scatter.assert_called_once() + npt.assert_array_equal(WCSAxes.scatter.call_args[0][0], comp_cat["ra"]) + npt.assert_array_equal(WCSAxes.scatter.call_args[0][1], comp_cat["dec"]) + assert WCSAxes.scatter.call_args.kwargs["transform"] == ax.get_transform("icrs") + + +def test_plot_points_colorcol(small_sky_order1_catalog, mocker): + mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter") + mocker.patch("matplotlib.pyplot.colorbar") + _, ax = small_sky_order1_catalog.plot_points(color_col="id") + comp_cat = small_sky_order1_catalog.compute() + WCSAxes.scatter.assert_called_once() + npt.assert_array_equal(WCSAxes.scatter.call_args[0][0], comp_cat["ra"]) + npt.assert_array_equal(WCSAxes.scatter.call_args[0][1], comp_cat["dec"]) + npt.assert_array_equal(WCSAxes.scatter.call_args.kwargs["c"], comp_cat["id"]) + assert WCSAxes.scatter.call_args.kwargs["transform"] == ax.get_transform("icrs") + plt.colorbar.assert_called_once() From 8269182f5bf7564740cc052d00a7447106f57ca2 Mon Sep 17 00:00:00 2001 From: Sean McGuire <123987820+smcguire-cmu@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:24:41 -0500 Subject: [PATCH 6/8] Update margin docs notebook to use new plotting functions (#513) * use updated plotting and remove healpy * formatting * remove wrong import * alpha * update fov * remove import * add legend and labels --- docs/tutorials/margins.ipynb | 200 +++++++++++------------------------ 1 file changed, 61 insertions(+), 139 deletions(-) diff --git a/docs/tutorials/margins.ipynb b/docs/tutorials/margins.ipynb index 1c465c87..2dc4ec2e 100644 --- a/docs/tutorials/margins.ipynb +++ b/docs/tutorials/margins.ipynb @@ -37,12 +37,17 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:18:45.846049Z", - "start_time": "2024-10-22T18:18:45.844102Z" + "end_time": "2024-11-25T20:08:43.247545Z", + "start_time": "2024-11-25T20:08:43.119690Z" } }, "outputs": [], "source": [ + "from astropy.coordinates import SkyCoord\n", + "import astropy.units as u\n", + "from hats.pixel_math import HealpixPixel\n", + "import matplotlib.pyplot as plt\n", + "\n", "import lsdb\n", "\n", "surveys_path = \"https://data.lsdb.io/hats/\"" @@ -53,15 +58,15 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:18:56.393295Z", - "start_time": "2024-10-22T18:18:45.865119Z" + "end_time": "2024-11-25T20:08:48.953611Z", + "start_time": "2024-11-25T20:08:43.999437Z" } }, "outputs": [], "source": [ "from lsdb import BoxSearch\n", "\n", - "box = BoxSearch(ra=(179.5, 180.1), dec=(9.4, 10))\n", + "box = BoxSearch(ra=(179.5, 180.001), dec=(9.4, 10))\n", "\n", "ztf_object_path = f\"{surveys_path}/ztf_dr14/ztf_object\"\n", "ztf_margin_path = f\"{surveys_path}/ztf_dr14/ztf_object_10arcs\"\n", @@ -88,71 +93,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:18:56.553194Z", - "start_time": "2024-10-22T18:18:56.548744Z" - } - }, - "outputs": [], - "source": [ - "# Defining a function to plot the points in a pixel and the pixel boundary\n", - "\n", - "import numpy as np\n", - "from matplotlib.patches import Polygon\n", - "from matplotlib import pyplot as plt\n", - "import healpy as hp\n", - "\n", - "\n", - "def plot_points(\n", - " pixel_dfs, order, pixel, colors, ra_columns, dec_columns, xlim=None, ylim=None, markers=None, alpha=1\n", - "):\n", - " ax = plt.subplot()\n", - "\n", - " # Plot hp pixel bounds\n", - " nsides = hp.order2nside(order)\n", - " pix0_bounds = hp.vec2dir(hp.boundaries(nsides, pixel, step=100, nest=True), lonlat=True)\n", - " lon = pix0_bounds[0]\n", - " lat = pix0_bounds[1]\n", - " vertices = np.vstack([lon.ravel(), lat.ravel()]).transpose()\n", - " p = Polygon(vertices, closed=True, edgecolor=\"#3b81db\", facecolor=\"none\")\n", - " ax.add_patch(p)\n", - "\n", - " if markers is None:\n", - " markers = [\"+\"] * len(pixel_dfs)\n", - "\n", - " # plot the points\n", - " for pixel_df, color, ra_column, dec_column, marker in zip(\n", - " pixel_dfs, colors, ra_columns, dec_columns, markers\n", - " ):\n", - " ax.scatter(\n", - " pixel_df[ra_column].to_numpy(),\n", - " pixel_df[dec_column].to_numpy(),\n", - " c=color,\n", - " marker=marker,\n", - " linewidths=1,\n", - " alpha=alpha,\n", - " )\n", - "\n", - " # plotting configuration\n", - " VIEW_MARGIN = 2\n", - " xlim_low = np.min(lon) - VIEW_MARGIN if xlim is None else xlim[0]\n", - " xlim_high = np.max(lon) + VIEW_MARGIN if xlim is None else xlim[1]\n", - " ylim_low = np.min(lat) - VIEW_MARGIN if ylim is None else ylim[0]\n", - " ylim_high = np.max(lat) + VIEW_MARGIN if ylim is None else ylim[1]\n", - "\n", - " plt.xlim(xlim_low, xlim_high)\n", - " plt.ylim(ylim_low, ylim_high)\n", - " plt.xlabel(\"ra\")\n", - " plt.ylabel(\"dec\")\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-22T18:19:17.288078Z", - "start_time": "2024-10-22T18:18:56.559558Z" + "end_time": "2024-11-25T20:09:21.409413Z", + "start_time": "2024-11-25T20:09:18.049993Z" } }, "outputs": [], @@ -161,20 +103,19 @@ "order = 3\n", "pixel = 434\n", "\n", + "ztf_pixel = ztf_object.pixel_search([HealpixPixel(order, pixel)])\n", + "\n", "# Plot the points from the specified ztf pixel in green, and from the pixel's margin cache in red\n", - "plot_points(\n", - " [\n", - " ztf_object.get_partition(order, pixel).compute(),\n", - " ztf_object.margin.get_partition(order, pixel).compute(),\n", - " ],\n", - " order,\n", - " pixel,\n", - " [\"green\", \"red\"],\n", - " [\"ra\", \"ra\"],\n", - " [\"dec\", \"dec\"],\n", - " xlim=[179.5, 180.1],\n", - " ylim=[9.4, 10.0],\n", - ")" + "ztf_pixel.plot_pixels(\n", + " color_by_order=False,\n", + " fc=\"#00000000\",\n", + " ec=\"grey\",\n", + " center=SkyCoord(179.8, 9.7, unit=\"deg\"),\n", + " fov=(0.5 * u.deg, 0.5 * u.deg),\n", + ")\n", + "ztf_pixel.plot_points(c=\"green\", marker=\"+\", label=\"partition points\")\n", + "ztf_pixel.margin.plot_points(c=\"red\", marker=\"+\", label=\"margin points\")\n", + "plt.legend()" ] }, { @@ -195,13 +136,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:19:27.786625Z", - "start_time": "2024-10-22T18:19:17.304719Z" + "end_time": "2024-11-25T18:43:54.231870Z", + "start_time": "2024-11-25T18:43:50.753170Z" } }, "outputs": [], "source": [ - "gaia = lsdb.read_hats(f\"{surveys_path}/gaia_dr3/gaia/\", columns=[\"ra\", \"dec\"])\n", + "gaia = lsdb.read_hats(f\"{surveys_path}/gaia_dr3/gaia/\", columns=[\"ra\", \"dec\"], search_filter=box)\n", "gaia" ] }, @@ -219,22 +160,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:19:27.902933Z", - "start_time": "2024-10-22T18:19:27.899123Z" - } - }, - "outputs": [], - "source": [ - "gaia._ddf.index.name" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-22T18:19:28.516785Z", - "start_time": "2024-10-22T18:19:27.915282Z" + "end_time": "2024-11-25T18:19:09.709583Z", + "start_time": "2024-11-25T18:19:08.039861Z" } }, "outputs": [], @@ -256,8 +183,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:19:29.033307Z", - "start_time": "2024-10-22T18:19:28.589757Z" + "end_time": "2024-11-25T18:19:09.800325Z", + "start_time": "2024-11-25T18:19:09.779567Z" } }, "outputs": [], @@ -282,8 +209,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:21:39.092341Z", - "start_time": "2024-10-22T18:21:16.067095Z" + "end_time": "2024-11-25T20:10:28.243086Z", + "start_time": "2024-11-25T20:10:15.749569Z" } }, "outputs": [], @@ -291,23 +218,20 @@ "order = 3\n", "pixel = 434\n", "\n", - "crossmatch_result = gaia.crossmatch(ztf_object).get_partition(order, pixel).compute()\n", + "crossmatch_cat = gaia.crossmatch(ztf_pixel)\n", "\n", - "plot_points(\n", - " [\n", - " crossmatch_result,\n", - " crossmatch_result,\n", - " ],\n", - " order,\n", - " pixel,\n", - " [\"green\", \"red\"],\n", - " [\"ra_gaia\", \"ra_ztf_dr14\"],\n", - " [\"dec_gaia\", \"dec_ztf_dr14\"],\n", - " xlim=[179.5, 180.1],\n", - " ylim=[9.4, 10.0],\n", - " markers=[\"+\", \"x\"],\n", - " alpha=0.8,\n", - ")" + "ztf_pixel.plot_pixels(\n", + " color_by_order=False,\n", + " fc=\"#00000000\",\n", + " ec=\"grey\",\n", + " center=SkyCoord(179.8, 9.7, unit=\"deg\"),\n", + " fov=(0.5 * u.deg, 0.5 * u.deg),\n", + ")\n", + "crossmatch_cat.plot_points(c=\"green\", marker=\"+\", alpha=0.8, label=\"Gaia points\")\n", + "crossmatch_cat.plot_points(\n", + " ra_column=\"ra_ztf_dr14\", dec_column=\"dec_ztf_dr14\", c=\"red\", marker=\"x\", alpha=0.8, label=\"ZTF points\"\n", + ")\n", + "plt.legend()" ] }, { @@ -326,28 +250,26 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T18:21:59.564234Z", - "start_time": "2024-10-22T18:21:46.665531Z" + "end_time": "2024-11-25T20:10:42.272733Z", + "start_time": "2024-11-25T20:10:37.683953Z" } }, "outputs": [], "source": [ - "small_sky_box_filter = ztf_object.box_search(ra=[179.9, 180], dec=[9.5, 9.7])\n", + "small_sky_box_filter = ztf_pixel.box_search(ra=[179.9, 180], dec=[9.5, 9.7])\n", + "\n", "\n", - "# Plot the points from the filtered ztf pixel in green, and from the pixel's margin cache in red\n", - "plot_points(\n", - " [\n", - " small_sky_box_filter.get_partition(order, pixel).compute(),\n", - " small_sky_box_filter.margin.get_partition(order, pixel).compute(),\n", - " ],\n", - " order,\n", - " pixel,\n", - " [\"green\", \"red\"],\n", - " [\"ra\", \"ra\"],\n", - " [\"dec\", \"dec\"],\n", - " xlim=[179.5, 180.1],\n", - " ylim=[9.4, 10.0],\n", - ")" + "# Plot the points from the specified ztf pixel in green, and from the pixel's margin cache in red\n", + "small_sky_box_filter.plot_pixels(\n", + " color_by_order=False,\n", + " fc=\"#00000000\",\n", + " ec=\"grey\",\n", + " center=SkyCoord(179.8, 9.7, unit=\"deg\"),\n", + " fov=(0.5 * u.deg, 0.5 * u.deg),\n", + ")\n", + "small_sky_box_filter.plot_points(c=\"green\", marker=\"+\", label=\"partition points\")\n", + "small_sky_box_filter.margin.plot_points(c=\"red\", marker=\"+\", label=\"margin points\")\n", + "plt.legend()" ] }, { From 2a36166b7c713f770e799a57319c8ad696cc8e28 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 26 Nov 2024 14:59:25 -0500 Subject: [PATCH 7/8] Fix empty margin catalogs in `from_dataframe` (#508) * Create empty catalog for margins * Fix margin catalog validation * Calculate threshold when order is specified --- src/lsdb/catalog/margin_catalog.py | 21 +++++++ src/lsdb/loaders/dataframe/from_dataframe.py | 19 +++--- .../dataframe/margin_catalog_generator.py | 49 +++++++++++---- src/lsdb/loaders/hats/read_hats.py | 14 +---- .../loaders/dataframe/test_from_dataframe.py | 60 +++++++++++++++++-- 5 files changed, 123 insertions(+), 40 deletions(-) diff --git a/src/lsdb/catalog/margin_catalog.py b/src/lsdb/catalog/margin_catalog.py index db3abe30..5987eb96 100644 --- a/src/lsdb/catalog/margin_catalog.py +++ b/src/lsdb/catalog/margin_catalog.py @@ -1,5 +1,7 @@ import hats as hc import nested_dask as nd +import pyarrow as pa +from hats.io import paths from lsdb.catalog.dataset.healpix_dataset import HealpixDataset from lsdb.types import DaskDFPixelMap @@ -23,3 +25,22 @@ def __init__( hc_structure: hc.catalog.MarginCatalog, ): super().__init__(ddf, ddf_pixel_map, hc_structure) + + +def _validate_margin_catalog(margin_hc_catalog, hc_catalog): + """Validate that the margin and main catalogs have compatible schemas. The order of + the pyarrow fields should not matter.""" + expected_margin_schema = _create_margin_schema(hc_catalog.schema) + # Compare the fields for the schemas (allowing duplicates). They should match. + margin_catalog_fields = sorted((f.name, f.type) for f in margin_hc_catalog.schema) + expected_margin_fields = sorted((f.name, f.type) for f in expected_margin_schema) + if margin_catalog_fields != expected_margin_fields: + raise ValueError("The margin catalog and the main catalog must have the same schema.") + + +def _create_margin_schema(main_catalog_schema: pa.Schema) -> pa.Schema: + """Create a pyarrow schema for the margin catalog from the main catalog schema.""" + order_field = pa.field(f"margin_{paths.PARTITION_ORDER}", pa.uint8()) + dir_field = pa.field(f"margin_{paths.PARTITION_DIR}", pa.uint64()) + pixel_field = pa.field(f"margin_{paths.PARTITION_PIXEL}", pa.uint64()) + return main_catalog_schema.append(order_field).append(dir_field).append(pixel_field) diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 48fd10e0..ea7c1b27 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -44,8 +44,8 @@ def from_dataframe( partition_size (int): The desired partition size, in number of bytes in-memory. threshold (int): The maximum number of data points per pixel. margin_order (int): The order at which to generate the margin cache. - margin_threshold (float): The size of the margin cache boundary, in arcseconds. If None, - the margin cache is not generated. Defaults to 5 arcseconds. + margin_threshold (float): The size of the margin cache boundary, in arcseconds. If None, and + margin order is not specified, the margin cache is not generated. Defaults to 5 arcseconds. should_generate_moc (bool): should we generate a MOC (multi-order coverage map) of the data. can improve performance when joining/crossmatching to other hats-sharded datasets. @@ -74,12 +74,11 @@ def from_dataframe( schema=schema, **kwargs, ).load_catalog() - if margin_threshold: - catalog.margin = MarginCatalogGenerator( - catalog, - margin_order, - margin_threshold, - use_pyarrow_types, - **kwargs, - ).create_catalog() + catalog.margin = MarginCatalogGenerator( + catalog, + margin_order, + margin_threshold, + use_pyarrow_types, + **kwargs, + ).create_catalog() return catalog diff --git a/src/lsdb/loaders/dataframe/margin_catalog_generator.py b/src/lsdb/loaders/dataframe/margin_catalog_generator.py index c9e88af0..32fbcd19 100644 --- a/src/lsdb/loaders/dataframe/margin_catalog_generator.py +++ b/src/lsdb/loaders/dataframe/margin_catalog_generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Dict, List, Tuple import hats as hc @@ -13,7 +14,7 @@ from hats.pixel_math.healpix_pixel_function import get_pixel_argsort from lsdb import Catalog -from lsdb.catalog.margin_catalog import MarginCatalog +from lsdb.catalog.margin_catalog import MarginCatalog, _create_margin_schema from lsdb.loaders.dataframe.from_dataframe_utils import ( _extra_property_dict, _format_margin_partition_dataframe, @@ -28,7 +29,7 @@ def __init__( self, catalog: Catalog, margin_order: int = -1, - margin_threshold: float = 5.0, + margin_threshold: float | None = 5.0, use_pyarrow_types: bool = True, **kwargs, ) -> None: @@ -45,9 +46,9 @@ def __init__( self.hc_structure = catalog.hc_structure self.margin_threshold = margin_threshold self.margin_order = margin_order - self._resolve_margin_order() self.use_pyarrow_types = use_pyarrow_types - self.catalog_info = self._create_catalog_info(**kwargs) + self.catalog_info_kwargs = kwargs + self.margin_schema = _create_margin_schema(catalog.hc_structure.schema) def _resolve_margin_order(self): """Calculate the order of the margin cache to be generated. If not provided @@ -61,6 +62,9 @@ def _resolve_margin_order(self): if self.margin_order < 0: self.margin_order = hp.margin2order(margin_thr_arcmin=self.margin_threshold / 60.0) + else: + self.margin_threshold = hp.order2mindist(self.margin_order) * 60.0 + warnings.warn("Ignoring margin_threshold because margin_order was specified.", RuntimeWarning) if self.margin_order < highest_order + 1: raise ValueError( @@ -72,22 +76,44 @@ def _resolve_margin_order(self): raise ValueError("margin pixels must be larger than margin_threshold") def create_catalog(self) -> MarginCatalog | None: - """Create a margin catalog for another pre-computed catalog + """Create a margin catalog for another pre-computed catalog. + + Only one of margin order / threshold can be specified. If the margin order + is not specified: if the threshold is zero the margin is an empty catalog; + if the threshold is None, the margin is not generated (it is None). Returns: - Margin catalog object, or None if the margin is empty. + Margin catalog object or None if the margin is not generated. """ + if self.margin_order < 0: + if self.margin_threshold is None: + return None + if self.margin_threshold < 0: + raise ValueError("margin_threshold must be positive.") + if self.margin_threshold == 0: + return self._create_empty_catalog() + return self._create_catalog() + + def _create_catalog(self) -> MarginCatalog: + """Create a non-empty margin catalog""" + self._resolve_margin_order() pixels, partitions = self._get_margins() if len(pixels) == 0: - return None + return self._create_empty_catalog() ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map(pixels, partitions) - self.catalog_info.total_rows = total_rows + catalog_info = self._create_catalog_info(**self.catalog_info_kwargs, total_rows=total_rows) margin_pixels = list(ddf_pixel_map.keys()) - margin_structure = hc.catalog.MarginCatalog( - self.catalog_info, margin_pixels, schema=self.hc_structure.schema - ) + margin_structure = hc.catalog.MarginCatalog(catalog_info, margin_pixels, schema=self.margin_schema) return MarginCatalog(ddf, ddf_pixel_map, margin_structure) + def _create_empty_catalog(self) -> MarginCatalog: + """Create an empty margin catalog""" + dask_meta_schema = self.margin_schema.empty_table().to_pandas() + ddf = nd.NestedFrame.from_pandas(dask_meta_schema, npartitions=1) + catalog_info = self._create_catalog_info(**self.catalog_info_kwargs, total_rows=0) + margin_structure = hc.catalog.MarginCatalog(catalog_info, [], schema=self.margin_schema) + return MarginCatalog(ddf, {}, margin_structure) + def _get_margins(self) -> Tuple[List[HealpixPixel], List[npd.NestedFrame]]: """Generates the list of pixels that have margin data, and the dataframes with the margin data for each partition @@ -206,7 +232,6 @@ def _create_catalog_info(self, catalog_name: str | None = None, **kwargs) -> Tab catalog_type=CatalogType.MARGIN, ra_column=self.hc_structure.catalog_info.ra_column, dec_column=self.hc_structure.catalog_info.dec_column, - total_rows=self.hc_structure.catalog_info.total_rows, primary_catalog=catalog_name, margin_threshold=self.margin_threshold, **kwargs, diff --git a/src/lsdb/loaders/hats/read_hats.py b/src/lsdb/loaders/hats/read_hats.py index 2bf64b6a..d3d5d904 100644 --- a/src/lsdb/loaders/hats/read_hats.py +++ b/src/lsdb/loaders/hats/read_hats.py @@ -10,7 +10,6 @@ import pyarrow as pa from hats.catalog import CatalogType from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset -from hats.io import paths from hats.io.file_io import file_io from hats.pixel_math import HealpixPixel from hats.pixel_math.healpix_pixel_function import get_pixel_argsort @@ -19,6 +18,7 @@ from lsdb.catalog.association_catalog import AssociationCatalog from lsdb.catalog.catalog import Catalog, DaskDFPixelMap, MarginCatalog +from lsdb.catalog.margin_catalog import _validate_margin_catalog from lsdb.core.search.abstract_search import AbstractSearch from lsdb.dask.divisions import get_pixels_divisions from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig @@ -154,18 +154,6 @@ def _load_object_catalog(hc_catalog, config): return catalog -def _validate_margin_catalog(margin_hc_catalog, hc_catalog): - """Validate that the margin catalog and the main catalog are compatible""" - pixel_columns = [paths.PARTITION_ORDER, paths.PARTITION_DIR, paths.PARTITION_PIXEL] - margin_pixel_columns = pixel_columns + ["margin_" + column for column in pixel_columns] - catalog_schema = pa.schema([field for field in hc_catalog.schema if field.name not in pixel_columns]) - margin_schema = pa.schema( - [field for field in margin_hc_catalog.schema if field.name not in margin_pixel_columns] - ) - if not catalog_schema.equals(margin_schema): - raise ValueError("The margin catalog and the main catalog must have the same schema") - - def _create_dask_meta_schema(schema: pa.Schema, config) -> npd.NestedFrame: """Creates the Dask meta DataFrame from the HATS catalog schema.""" dask_meta_schema = schema.empty_table().to_pandas(types_mapper=config.get_dtype_mapper()) diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index eefd63f5..eea9ec15 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -15,7 +15,7 @@ from mocpy import MOC import lsdb -from lsdb.catalog.margin_catalog import MarginCatalog +from lsdb.catalog.margin_catalog import MarginCatalog, _validate_margin_catalog def get_catalog_kwargs(catalog, **kwargs): @@ -258,11 +258,41 @@ def test_from_dataframe_small_sky_source_with_margins(small_sky_source_df, small assert catalog.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb") assert margin.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb") - # The margin and main catalog's schemas are the same - assert margin.hc_structure.schema is catalog.hc_structure.schema + # The margin and main catalog's schemas are valid + _validate_margin_catalog(margin.hc_structure, catalog.hc_structure) -def test_from_dataframe_invalid_margin_order(small_sky_source_df): + +def test_from_dataframe_margin_threshold_from_order(small_sky_source_df): + # By default, the threshold is set to 5 arcsec, triggering a warning + with pytest.warns(RuntimeWarning, match="Ignoring margin_threshold"): + catalog = lsdb.from_dataframe( + small_sky_source_df, + ra_column="source_ra", + dec_column="source_dec", + lowest_order=0, + highest_order=2, + threshold=3000, + margin_order=3, + ) + assert len(catalog.margin.get_healpix_pixels()) == 17 + margin_threshold_order3 = hp.order2mindist(3) * 60.0 + assert catalog.margin.hc_structure.catalog_info.margin_threshold == margin_threshold_order3 + assert catalog.margin._ddf.index.name == catalog._ddf.index.name + _validate_margin_catalog(catalog.margin.hc_structure, catalog.hc_structure) + + +def test_from_dataframe_invalid_margin_args(small_sky_source_df): + # The provided margin threshold is negative + with pytest.raises(ValueError, match="positive"): + lsdb.from_dataframe( + small_sky_source_df, + ra_column="source_ra", + dec_column="source_dec", + lowest_order=2, + margin_threshold=-1, + ) + # Margin order is inferior to the main catalog's highest order with pytest.raises(ValueError, match="margin_order"): lsdb.from_dataframe( small_sky_source_df, @@ -281,7 +311,27 @@ def test_from_dataframe_margin_is_empty(small_sky_order1_df): highest_order=5, threshold=100, ) - assert catalog.margin is None + assert len(catalog.margin.get_healpix_pixels()) == 0 + assert catalog.margin._ddf_pixel_map == {} + assert catalog.margin._ddf.index.name == catalog._ddf.index.name + assert catalog.margin.hc_structure.catalog_info.margin_threshold == 5.0 + _validate_margin_catalog(catalog.margin.hc_structure, catalog.hc_structure) + + +def test_from_dataframe_margin_threshold_zero(small_sky_order1_df): + catalog = lsdb.from_dataframe( + small_sky_order1_df, + catalog_name="small_sky_order1", + catalog_type="object", + highest_order=5, + threshold=100, + margin_threshold=0, + ) + assert len(catalog.margin.get_healpix_pixels()) == 0 + assert catalog.margin._ddf_pixel_map == {} + assert catalog.margin._ddf.index.name == catalog._ddf.index.name + assert catalog.margin.hc_structure.catalog_info.margin_threshold == 0 + _validate_margin_catalog(catalog.margin.hc_structure, catalog.hc_structure) def test_from_dataframe_moc(small_sky_order1_catalog): From 0032f1572cd97865fee3a1c9dcdc8153734d9d46 Mon Sep 17 00:00:00 2001 From: Melissa DeLucchi Date: Tue, 26 Nov 2024 16:02:02 -0500 Subject: [PATCH 8/8] Update dev dependencies. --- .github/workflows/pre-commit-ci.yml | 2 +- .github/workflows/testing-and-coverage.yml | 2 +- docs/requirements.txt | 4 ++-- requirements.txt | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pre-commit-ci.yml b/.github/workflows/pre-commit-ci.yml index 642e7b9d..f31b3ee6 100644 --- a/.github/workflows/pre-commit-ci.yml +++ b/.github/workflows/pre-commit-ci.yml @@ -7,7 +7,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main, margin ] + branches: [ main ] jobs: pre-commit-ci: diff --git a/.github/workflows/testing-and-coverage.yml b/.github/workflows/testing-and-coverage.yml index b318df50..11c13d05 100644 --- a/.github/workflows/testing-and-coverage.yml +++ b/.github/workflows/testing-and-coverage.yml @@ -7,7 +7,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main, margin ] + branches: [ main ] jobs: build: diff --git a/docs/requirements.txt b/docs/requirements.txt index 8cd6da74..6a734805 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,5 +10,5 @@ sphinx-autoapi sphinx-copybutton sphinx-book-theme sphinx-design -git+https://github.com/astronomy-commons/hats.git@margin -git+https://github.com/astronomy-commons/hats-import.git@margin +git+https://github.com/astronomy-commons/hats.git@main +git+https://github.com/astronomy-commons/hats-import.git@main diff --git a/requirements.txt b/requirements.txt index f2b394e4..54865347 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -git+https://github.com/astronomy-commons/hats.git@margin +git+https://github.com/astronomy-commons/hats.git@main git+https://github.com/lincc-frameworks/nested-pandas.git@main git+https://github.com/lincc-frameworks/nested-dask.git@main \ No newline at end of file