Skip to content

Commit

Permalink
Allow margin path argument in read_hipscat (#328)
Browse files Browse the repository at this point in the history
* Accept margin path in read_hipscat

* Use single argument for margin cache
  • Loading branch information
camposandro authored May 23, 2024
1 parent d8096e0 commit c322de8
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 25 deletions.
3 changes: 2 additions & 1 deletion src/lsdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._version import __version__
from .catalog import Catalog, MarginCatalog
from .loaders import from_dataframe, read_hipscat
from .loaders.dataframe.from_dataframe import from_dataframe
from .loaders.hipscat.read_hipscat import read_hipscat
2 changes: 1 addition & 1 deletion src/lsdb/loaders/hipscat/abstract_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _load_df_from_paths(
dask_meta_schema = self._load_metadata_schema(catalog)
if self.config.columns:
dask_meta_schema = dask_meta_schema[self.config.columns]
kwargs = self.config.get_kwargs_dict()
kwargs = dict(self.config.kwargs)
if self.config.dtype_backend is not None:
kwargs["dtype_backend"] = self.config.dtype_backend
return dd.io.from_map(
Expand Down
25 changes: 23 additions & 2 deletions src/lsdb/loaders/hipscat/hipscat_catalog_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import dataclasses

import hipscat as hc

from lsdb.catalog.catalog import Catalog
import lsdb
from lsdb.catalog.catalog import Catalog, MarginCatalog
from lsdb.loaders.hipscat.abstract_catalog_loader import AbstractCatalogLoader


Expand All @@ -18,7 +21,7 @@ def load_catalog(self) -> Catalog:
hc_catalog = self._load_hipscat_catalog(hc.catalog.Catalog)
filtered_hc_catalog = self._filter_hipscat_catalog(hc_catalog)
dask_df, dask_df_pixel_map = self._load_dask_df_and_map(filtered_hc_catalog)
return Catalog(dask_df, dask_df_pixel_map, filtered_hc_catalog, self.config.margin_cache)
return Catalog(dask_df, dask_df_pixel_map, filtered_hc_catalog, self._load_margin_catalog())

def _filter_hipscat_catalog(self, hc_catalog: hc.catalog.Catalog) -> hc.catalog.Catalog:
"""Filter the catalog pixels according to the spatial filter provided at loading time.
Expand All @@ -34,3 +37,21 @@ def _filter_hipscat_catalog(self, hc_catalog: hc.catalog.Catalog) -> hc.catalog.
return hc.catalog.Catalog(
catalog_info, pixels_to_load, self.path, hc_catalog.moc, self.storage_options
)

def _load_margin_catalog(self) -> MarginCatalog | None:
"""Load the margin catalog. It can be provided using a margin catalog
instance or a path to the catalog on disk."""
margin_catalog = None
if isinstance(self.config.margin_cache, MarginCatalog):
margin_catalog = self.config.margin_cache
elif isinstance(self.config.margin_cache, str):
margin_catalog = lsdb.read_hipscat(
path=self.config.margin_cache,
catalog_type=MarginCatalog,
search_filter=self.config.search_filter,
margin_cache=None,
dtype_backend=self.config.dtype_backend,
storage_options=self.storage_options,
**self.config.kwargs,
)
return margin_catalog
17 changes: 8 additions & 9 deletions src/lsdb/loaders/hipscat/hipscat_loading_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Callable, List

import pandas as pd
Expand All @@ -23,24 +23,23 @@ class HipscatLoadingConfig:
columns: List[str] | None = None
"""Columns to load from the catalog. If not specified, all columns are loaded"""

margin_cache: MarginCatalog | None = None
"""Margin cache for the catalog. By default, it is None"""
margin_cache: MarginCatalog | str | None = None
"""Margin cache for the catalog. It can be provided as a path for the margin on disk,
or as a margin object instance. By default, it is None."""

dtype_backend: str | None = "pyarrow"
"""The backend data type to apply to the catalog. It defaults to "pyarrow" and
if it is None no type conversion is performed."""

kwargs: dict | None = None
"""Extra kwargs"""
kwargs: dict = field(default_factory=dict)
"""Extra kwargs for the pandas parquet file reader"""

def __post_init__(self):
if self.margin_cache is not None and not isinstance(self.margin_cache, (MarginCatalog, str)):
raise ValueError("`margin_cache` must be of type 'MarginCatalog' or 'str'")
if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]:
raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'")

def get_kwargs_dict(self) -> dict:
"""Returns a dictionary with the extra kwargs"""
return self.kwargs if self.kwargs is not None else {}

def get_dtype_mapper(self) -> Callable | None:
"""Returns a mapper for pyarrow or numpy types, mirroring Pandas behaviour."""
mapper = None
Expand Down
12 changes: 7 additions & 5 deletions src/lsdb/loaders/hipscat/read_hipscat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lsdb.catalog.dataset.dataset import Dataset
from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.loaders.hipscat.abstract_catalog_loader import CatalogTypeVar
from lsdb.loaders.hipscat.hipscat_loader_factory import get_loader_for_type
from lsdb.loaders.hipscat.hipscat_loading_config import HipscatLoadingConfig

Expand All @@ -26,14 +27,14 @@
# pylint: disable=unused-argument
def read_hipscat(
path: str,
catalog_type: Type[Dataset] | None = None,
catalog_type: Type[CatalogTypeVar] | None = None,
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | None = None,
margin_cache: MarginCatalog | str | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
) -> Dataset | None:
) -> CatalogTypeVar | None:
"""Load a catalog from a HiPSCat formatted catalog.
Typical usage example, where we load a catalog with a subset of columns:
Expand All @@ -45,7 +46,7 @@ def read_hipscat(
path="./my_catalog_dir",
catalog_type=lsdb.Catalog,
columns=["ra","dec"],
filter=lsdb.core.search.ConeSearch(ra, dec, radius_arcsec),
search_filter=lsdb.core.search.ConeSearch(ra, dec, radius_arcsec),
)
Args:
Expand All @@ -57,7 +58,8 @@ def read_hipscat(
the lsdb class for that catalog.
search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied.
columns (List[str]): Default `None`. The set of columns to filter the catalog on.
margin_cache (MarginCatalog): The margin cache for the main catalog
margin_cache (MarginCatalog | str): The margin cache for the main catalog, provided as a path
on disk or as an instance of the MarginCatalog object. Defaults to None.
dtype_backend (str): Backend data type to apply to the catalog.
Defaults to "pyarrow". If None, no type conversion is performed.
storage_options (dict): Dictionary that contains abstract filesystem credentials
Expand Down
15 changes: 9 additions & 6 deletions src/lsdb/loaders/hipscat/read_hipscat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@ from lsdb.loaders.hipscat.abstract_catalog_loader import CatalogTypeVar
def read_hipscat(
path: str,
search_filter: AbstractSearch | None = None,
storage_options: dict | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | None = None,
) -> Dataset: ...
margin_cache: MarginCatalog | str | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
) -> Dataset | None: ...
@overload
def read_hipscat(
path: str,
catalog_type: Type[CatalogTypeVar],
search_filter: AbstractSearch | None = None,
storage_options: dict | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | None = None,
margin_cache: MarginCatalog | str | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
) -> CatalogTypeVar: ...
) -> CatalogTypeVar | None: ...
17 changes: 16 additions & 1 deletion tests/lsdb/loaders/hipscat/test_read_hipscat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,20 @@ def test_read_hipscat_specify_wrong_catalog_type(small_sky_dir):
lsdb.read_hipscat(small_sky_dir, catalog_type=int)


def test_catalog_with_margin(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog):
def test_catalog_with_margin(
small_sky_xmatch_dir, small_sky_xmatch_margin_catalog, small_sky_xmatch_margin_dir
):
# Provide the margin cache catalog object
catalog = lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog)
assert isinstance(catalog, lsdb.Catalog)
assert catalog.margin is small_sky_xmatch_margin_catalog
# Provide the margin cache catalog path
catalog_2 = lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir)
assert isinstance(catalog_2, lsdb.Catalog)
# The catalogs obtained are identical
assert catalog.margin.hc_structure.catalog_info == catalog_2.margin.hc_structure.catalog_info
assert catalog.margin.get_healpix_pixels() == catalog_2.margin.get_healpix_pixels()
pd.testing.assert_frame_equal(catalog.margin.compute(), catalog_2.margin.compute())


def test_catalog_without_margin_is_none(small_sky_xmatch_dir):
Expand All @@ -84,6 +94,11 @@ def test_catalog_without_margin_is_none(small_sky_xmatch_dir):
assert catalog.margin is None


def test_catalog_with_wrong_margin_args(small_sky_xmatch_dir):
with pytest.raises(ValueError, match="must be of type"):
lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=1)


def test_read_hipscat_subset_with_cone_search(small_sky_order1_dir, small_sky_order1_catalog):
cone_search = ConeSearch(ra=0, dec=-80, radius_arcsec=20 * 3600)
# Filtering using catalog's cone_search
Expand Down

0 comments on commit c322de8

Please sign in to comment.