From c322de85c2377f14f6f10cb31d89a881e1bf761a Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Thu, 23 May 2024 10:19:19 -0400 Subject: [PATCH] Allow margin path argument in `read_hipscat` (#328) * Accept margin path in read_hipscat * Use single argument for margin cache --- src/lsdb/__init__.py | 3 ++- .../hipscat/abstract_catalog_loader.py | 2 +- .../loaders/hipscat/hipscat_catalog_loader.py | 25 +++++++++++++++++-- .../loaders/hipscat/hipscat_loading_config.py | 17 ++++++------- src/lsdb/loaders/hipscat/read_hipscat.py | 12 +++++---- src/lsdb/loaders/hipscat/read_hipscat.pyi | 15 ++++++----- .../lsdb/loaders/hipscat/test_read_hipscat.py | 17 ++++++++++++- 7 files changed, 66 insertions(+), 25 deletions(-) diff --git a/src/lsdb/__init__.py b/src/lsdb/__init__.py index 27b64f99..2364508c 100644 --- a/src/lsdb/__init__.py +++ b/src/lsdb/__init__.py @@ -1,3 +1,4 @@ from ._version import __version__ from .catalog import Catalog, MarginCatalog -from .loaders import from_dataframe, read_hipscat +from .loaders.dataframe.from_dataframe import from_dataframe +from .loaders.hipscat.read_hipscat import read_hipscat diff --git a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py index 28756b22..c825e361 100644 --- a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py @@ -74,7 +74,7 @@ def _load_df_from_paths( dask_meta_schema = self._load_metadata_schema(catalog) if self.config.columns: dask_meta_schema = dask_meta_schema[self.config.columns] - kwargs = self.config.get_kwargs_dict() + kwargs = dict(self.config.kwargs) if self.config.dtype_backend is not None: kwargs["dtype_backend"] = self.config.dtype_backend return dd.io.from_map( diff --git a/src/lsdb/loaders/hipscat/hipscat_catalog_loader.py b/src/lsdb/loaders/hipscat/hipscat_catalog_loader.py index ae3734b0..9c4518d4 100644 --- a/src/lsdb/loaders/hipscat/hipscat_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/hipscat_catalog_loader.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import dataclasses import hipscat as hc -from lsdb.catalog.catalog import Catalog +import lsdb +from lsdb.catalog.catalog import Catalog, MarginCatalog from lsdb.loaders.hipscat.abstract_catalog_loader import AbstractCatalogLoader @@ -18,7 +21,7 @@ def load_catalog(self) -> Catalog: hc_catalog = self._load_hipscat_catalog(hc.catalog.Catalog) filtered_hc_catalog = self._filter_hipscat_catalog(hc_catalog) dask_df, dask_df_pixel_map = self._load_dask_df_and_map(filtered_hc_catalog) - return Catalog(dask_df, dask_df_pixel_map, filtered_hc_catalog, self.config.margin_cache) + return Catalog(dask_df, dask_df_pixel_map, filtered_hc_catalog, self._load_margin_catalog()) def _filter_hipscat_catalog(self, hc_catalog: hc.catalog.Catalog) -> hc.catalog.Catalog: """Filter the catalog pixels according to the spatial filter provided at loading time. @@ -34,3 +37,21 @@ def _filter_hipscat_catalog(self, hc_catalog: hc.catalog.Catalog) -> hc.catalog. return hc.catalog.Catalog( catalog_info, pixels_to_load, self.path, hc_catalog.moc, self.storage_options ) + + def _load_margin_catalog(self) -> MarginCatalog | None: + """Load the margin catalog. It can be provided using a margin catalog + instance or a path to the catalog on disk.""" + margin_catalog = None + if isinstance(self.config.margin_cache, MarginCatalog): + margin_catalog = self.config.margin_cache + elif isinstance(self.config.margin_cache, str): + margin_catalog = lsdb.read_hipscat( + path=self.config.margin_cache, + catalog_type=MarginCatalog, + search_filter=self.config.search_filter, + margin_cache=None, + dtype_backend=self.config.dtype_backend, + storage_options=self.storage_options, + **self.config.kwargs, + ) + return margin_catalog diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index a9da2a24..e74c4c5f 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable, List import pandas as pd @@ -23,24 +23,23 @@ class HipscatLoadingConfig: columns: List[str] | None = None """Columns to load from the catalog. If not specified, all columns are loaded""" - margin_cache: MarginCatalog | None = None - """Margin cache for the catalog. By default, it is None""" + margin_cache: MarginCatalog | str | None = None + """Margin cache for the catalog. It can be provided as a path for the margin on disk, + or as a margin object instance. By default, it is None.""" dtype_backend: str | None = "pyarrow" """The backend data type to apply to the catalog. It defaults to "pyarrow" and if it is None no type conversion is performed.""" - kwargs: dict | None = None - """Extra kwargs""" + kwargs: dict = field(default_factory=dict) + """Extra kwargs for the pandas parquet file reader""" def __post_init__(self): + if self.margin_cache is not None and not isinstance(self.margin_cache, (MarginCatalog, str)): + raise ValueError("`margin_cache` must be of type 'MarginCatalog' or 'str'") if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]: raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'") - def get_kwargs_dict(self) -> dict: - """Returns a dictionary with the extra kwargs""" - return self.kwargs if self.kwargs is not None else {} - def get_dtype_mapper(self) -> Callable | None: """Returns a mapper for pyarrow or numpy types, mirroring Pandas behaviour.""" mapper = None diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index 5ea98beb..9e4c6d69 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -12,6 +12,7 @@ from lsdb.catalog.dataset.dataset import Dataset from lsdb.catalog.margin_catalog import MarginCatalog from lsdb.core.search.abstract_search import AbstractSearch +from lsdb.loaders.hipscat.abstract_catalog_loader import CatalogTypeVar from lsdb.loaders.hipscat.hipscat_loader_factory import get_loader_for_type from lsdb.loaders.hipscat.hipscat_loading_config import HipscatLoadingConfig @@ -26,14 +27,14 @@ # pylint: disable=unused-argument def read_hipscat( path: str, - catalog_type: Type[Dataset] | None = None, + catalog_type: Type[CatalogTypeVar] | None = None, search_filter: AbstractSearch | None = None, columns: List[str] | None = None, - margin_cache: MarginCatalog | None = None, + margin_cache: MarginCatalog | str | None = None, dtype_backend: str | None = "pyarrow", storage_options: dict | None = None, **kwargs, -) -> Dataset | None: +) -> CatalogTypeVar | None: """Load a catalog from a HiPSCat formatted catalog. Typical usage example, where we load a catalog with a subset of columns: @@ -45,7 +46,7 @@ def read_hipscat( path="./my_catalog_dir", catalog_type=lsdb.Catalog, columns=["ra","dec"], - filter=lsdb.core.search.ConeSearch(ra, dec, radius_arcsec), + search_filter=lsdb.core.search.ConeSearch(ra, dec, radius_arcsec), ) Args: @@ -57,7 +58,8 @@ def read_hipscat( the lsdb class for that catalog. search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. columns (List[str]): Default `None`. The set of columns to filter the catalog on. - margin_cache (MarginCatalog): The margin cache for the main catalog + margin_cache (MarginCatalog | str): The margin cache for the main catalog, provided as a path + on disk or as an instance of the MarginCatalog object. Defaults to None. dtype_backend (str): Backend data type to apply to the catalog. Defaults to "pyarrow". If None, no type conversion is performed. storage_options (dict): Dictionary that contains abstract filesystem credentials diff --git a/src/lsdb/loaders/hipscat/read_hipscat.pyi b/src/lsdb/loaders/hipscat/read_hipscat.pyi index ad555ed6..6dd574c4 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.pyi +++ b/src/lsdb/loaders/hipscat/read_hipscat.pyi @@ -24,17 +24,20 @@ from lsdb.loaders.hipscat.abstract_catalog_loader import CatalogTypeVar def read_hipscat( path: str, search_filter: AbstractSearch | None = None, - storage_options: dict | None = None, columns: List[str] | None = None, - margin_cache: MarginCatalog | None = None, -) -> Dataset: ... + margin_cache: MarginCatalog | str | None = None, + dtype_backend: str | None = "pyarrow", + storage_options: dict | None = None, + **kwargs, +) -> Dataset | None: ... @overload def read_hipscat( path: str, catalog_type: Type[CatalogTypeVar], search_filter: AbstractSearch | None = None, - storage_options: dict | None = None, columns: List[str] | None = None, - margin_cache: MarginCatalog | None = None, + margin_cache: MarginCatalog | str | None = None, + dtype_backend: str | None = "pyarrow", + storage_options: dict | None = None, **kwargs, -) -> CatalogTypeVar: ... +) -> CatalogTypeVar | None: ... diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index b714251c..9ca3e953 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -72,10 +72,20 @@ def test_read_hipscat_specify_wrong_catalog_type(small_sky_dir): lsdb.read_hipscat(small_sky_dir, catalog_type=int) -def test_catalog_with_margin(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog): +def test_catalog_with_margin( + small_sky_xmatch_dir, small_sky_xmatch_margin_catalog, small_sky_xmatch_margin_dir +): + # Provide the margin cache catalog object catalog = lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog) assert isinstance(catalog, lsdb.Catalog) assert catalog.margin is small_sky_xmatch_margin_catalog + # Provide the margin cache catalog path + catalog_2 = lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir) + assert isinstance(catalog_2, lsdb.Catalog) + # The catalogs obtained are identical + assert catalog.margin.hc_structure.catalog_info == catalog_2.margin.hc_structure.catalog_info + assert catalog.margin.get_healpix_pixels() == catalog_2.margin.get_healpix_pixels() + pd.testing.assert_frame_equal(catalog.margin.compute(), catalog_2.margin.compute()) def test_catalog_without_margin_is_none(small_sky_xmatch_dir): @@ -84,6 +94,11 @@ def test_catalog_without_margin_is_none(small_sky_xmatch_dir): assert catalog.margin is None +def test_catalog_with_wrong_margin_args(small_sky_xmatch_dir): + with pytest.raises(ValueError, match="must be of type"): + lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=1) + + def test_read_hipscat_subset_with_cone_search(small_sky_order1_dir, small_sky_order1_catalog): cone_search = ConeSearch(ra=0, dec=-80, radius_arcsec=20 * 3600) # Filtering using catalog's cone_search