diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index b803c997..0c17ff7c 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -105,10 +105,10 @@ def _append_extra_columns(cls, dataframe: pd.DataFrame, extra_columns: pd.DataFr raise ValueError(f"Provided extra column '{col}' not found in definition") # Update columns according to crossmatch algorithm specification columns_to_update = [] - for col, col_type in cls.extra_columns.items(): + for col, col_type in cls.extra_columns.dtypes.items(): if col not in extra_columns: raise ValueError(f"Missing extra column '{col} of type {col_type}'") - if col_type.dtype != extra_columns[col].dtype: + if col_type != extra_columns[col].dtype: raise ValueError(f"Invalid type '{col_type}' for extra column '{col}'") columns_to_update.append(col) for col in columns_to_update: diff --git a/src/lsdb/core/crossmatch/kdtree_match.py b/src/lsdb/core/crossmatch/kdtree_match.py index 32b4329a..c25e2056 100644 --- a/src/lsdb/core/crossmatch/kdtree_match.py +++ b/src/lsdb/core/crossmatch/kdtree_match.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +import pyarrow as pa from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN from hipscat.pixel_math.validators import validate_radius @@ -15,7 +16,7 @@ class KdTreeCrossmatch(AbstractCrossmatchAlgorithm): """Nearest neighbor crossmatch using a 3D k-D tree""" - extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))}) + extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64()))}) def validate( self, @@ -102,6 +103,8 @@ def _create_crossmatch_df( axis=1, ) out.set_index(HIPSCAT_ID_COLUMN, inplace=True) - extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(arc_distances, index=out.index)}) + extra_columns = pd.DataFrame( + {"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()), index=out.index)} + ) self._append_extra_columns(out, extra_columns) return out diff --git a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py index a65278ae..4ec3ef54 100644 --- a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py +++ b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py @@ -36,6 +36,7 @@ def __init__( highest_order: int = 5, partition_size: int | None = None, threshold: int | None = None, + use_pyarrow_types: bool = True, **kwargs, ) -> None: """Initializes a DataframeCatalogLoader @@ -46,6 +47,8 @@ def __init__( highest_order (int): The highest partition order partition_size (int): The desired partition size, in number of rows threshold (int): The maximum number of data points per pixel + use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the + original data types. Defaults to True. **kwargs: Arguments to pass to the creation of the catalog info """ self.dataframe = dataframe @@ -53,6 +56,7 @@ def __init__( self.highest_order = highest_order self.threshold = self._calculate_threshold(partition_size, threshold) self.catalog_info = self._create_catalog_info(**kwargs) + self.use_pyarrow_types = use_pyarrow_types def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int: """Calculates the number of pixels per HEALPix pixel (threshold) for the @@ -169,9 +173,9 @@ def _generate_dask_df_and_map( # Obtain Dataframe for current HEALPix pixel pixel_dfs.append(self._get_dataframe_for_healpix(hp_pixel, pixels)) - # Generate Dask Dataframe with original schema + # Generate Dask Dataframe with the original schema and desired backend pixel_list = list(ddf_pixel_map.keys()) - ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list) + ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list, self.use_pyarrow_types) return ddf, ddf_pixel_map, total_rows def _get_dataframe_for_healpix(self, hp_pixel: HealpixPixel, pixels: List[int]) -> pd.DataFrame: diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 51075489..ae452ad8 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -15,6 +15,7 @@ def from_dataframe( threshold: int | None = None, margin_order: int | None = -1, margin_threshold: float = 5.0, + use_pyarrow_types: bool = True, **kwargs, ) -> Catalog: """Load a catalog from a Pandas Dataframe in CSV format. @@ -27,6 +28,8 @@ def from_dataframe( threshold (int): The maximum number of data points per pixel margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds + use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the + original data types. Defaults to True. **kwargs: Arguments to pass to the creation of the catalog info Returns: @@ -38,6 +41,7 @@ def from_dataframe( highest_order, partition_size, threshold, + use_pyarrow_types, **kwargs, ).load_catalog() if margin_threshold: @@ -45,5 +49,6 @@ def from_dataframe( catalog, margin_order, margin_threshold, + use_pyarrow_types, ).create_catalog() return catalog diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index b18b5350..c58a5733 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -3,6 +3,7 @@ import dask.dataframe as dd import numpy as np import pandas as pd +import pyarrow as pa from dask import delayed from hipscat.catalog import PartitionInfo from hipscat.pixel_math import HealpixPixel @@ -12,22 +13,44 @@ def _generate_dask_dataframe( - pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel] + pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], use_pyarrow_types: bool = True ) -> Tuple[dd.core.DataFrame, int]: """Create the Dask Dataframe from the list of HEALPix pixel Dataframes Args: pixel_dfs (List[pd.DataFrame]): The list of HEALPix pixel Dataframes pixels (List[HealpixPixel]): The list of HEALPix pixels in the catalog + use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True. Returns: The catalog's Dask Dataframe and its total number of rows. """ + pixel_dfs = [_convert_dtypes_to_pyarrow(df) for df in pixel_dfs] if use_pyarrow_types else pixel_dfs schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else [] - divisions = get_pixels_divisions(pixels) delayed_dfs = [delayed(df) for df in pixel_dfs] + divisions = get_pixels_divisions(pixels) ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions) - return ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame(), len(ddf) + ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame() + return ddf, len(ddf) + + +def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: + """Transform the columns (and index) of a Pandas DataFrame to pyarrow types. + + Args: + df (pd.DataFrame): A Pandas DataFrame + + Returns: + A new DataFrame, with columns of pyarrow types. The return value is a + shallow copy of the initial DataFrame to avoid copying the data. + """ + new_series = {} + df_index = df.index.astype(pd.ArrowDtype(pa.uint64())) + for column in df.columns: + pa_array = pa.array(df[column], from_pandas=True) + series = pd.Series(pa_array, dtype=pd.ArrowDtype(pa_array.type), copy=False, index=df_index) + new_series[column] = series + return pd.DataFrame(new_series, index=df_index, copy=False) def _append_partition_information_to_dataframe(dataframe: pd.DataFrame, pixel: HealpixPixel) -> pd.DataFrame: diff --git a/src/lsdb/loaders/dataframe/margin_catalog_generator.py b/src/lsdb/loaders/dataframe/margin_catalog_generator.py index 8abdb2c5..6a9a6af0 100644 --- a/src/lsdb/loaders/dataframe/margin_catalog_generator.py +++ b/src/lsdb/loaders/dataframe/margin_catalog_generator.py @@ -29,6 +29,7 @@ def __init__( catalog: Catalog, margin_order: int | None = -1, margin_threshold: float = 5.0, + use_pyarrow_types: bool = True, ) -> None: """Initialize a MarginCatalogGenerator @@ -36,11 +37,13 @@ def __init__( catalog (Catalog): The LSDB catalog to generate margins for margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds + use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True. """ self.dataframe = catalog.compute().copy() self.hc_structure = catalog.hc_structure self.margin_threshold = margin_threshold self.margin_order = self._set_margin_order(margin_order) + self.use_pyarrow_types = use_pyarrow_types def _set_margin_order(self, margin_order: int | None) -> int: """Calculate the order of the margin cache to be generated. If not provided @@ -102,7 +105,7 @@ def _generate_dask_df_and_map(self) -> Tuple[dd.core.DataFrame, Dict[HealpixPixe ddf_pixel_map = {pixel: index for index, pixel in enumerate(ordered_pixels)} # Generate the dask dataframe with the pixels and partitions - ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels) + ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels, self.use_pyarrow_types) return ddf, ddf_pixel_map, total_rows def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame: diff --git a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py index 1fa25bc2..78411e71 100644 --- a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py @@ -7,7 +7,7 @@ import dask.dataframe as dd import hipscat as hc import numpy as np -import pyarrow +import pandas as pd from hipscat.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset from hipscat.io.file_io import file_io from hipscat.pixel_math import HealpixPixel @@ -80,8 +80,7 @@ def _get_paths_from_pixels( def _load_df_from_paths( self, catalog: HCHealpixDataset, paths: List[hc.io.FilePointer], divisions: Tuple[int, ...] | None ) -> dd.core.DataFrame: - metadata_schema = self._load_parquet_metadata_schema(catalog) - dask_meta_schema = metadata_schema.empty_table().to_pandas() + dask_meta_schema = self._load_metadata_schema(catalog) if self.config.columns: dask_meta_schema = dask_meta_schema[self.config.columns] ddf = dd.io.from_map( @@ -91,11 +90,16 @@ def _load_df_from_paths( storage_options=self.storage_options, meta=dask_meta_schema, columns=self.config.columns, + dtype_backend=self.config.get_dtype_backend(), **self.config.get_kwargs_dict(), ) return ddf - def _load_parquet_metadata_schema(self, catalog: HCHealpixDataset) -> pyarrow.Schema: + def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame: metadata_pointer = hc.io.paths.get_common_metadata_pointer(catalog.catalog_base_dir) metadata = file_io.read_parquet_metadata(metadata_pointer, storage_options=self.storage_options) - return metadata.schema.to_arrow_schema() + return ( + metadata.schema.to_arrow_schema() + .empty_table() + .to_pandas(types_mapper=self.config.get_dtype_mapper()) + ) diff --git a/src/lsdb/loaders/hipscat/association_catalog_loader.py b/src/lsdb/loaders/hipscat/association_catalog_loader.py index e0a1bb87..13264859 100644 --- a/src/lsdb/loaders/hipscat/association_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/association_catalog_loader.py @@ -22,7 +22,6 @@ def load_catalog(self) -> AssociationCatalog: return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog) def _load_empty_dask_df_and_map(self, hc_catalog): - metadata_schema = self._load_parquet_metadata_schema(hc_catalog) - dask_meta_schema = metadata_schema.empty_table().to_pandas() + dask_meta_schema = self._load_metadata_schema(hc_catalog) ddf = dd.io.from_pandas(dask_meta_schema, npartitions=0) return ddf, {} diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index 944cc3f2..7a0653e8 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -1,7 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List +from typing import Callable, List + +import pandas as pd +from pandas._libs import lib +from pandas.io._util import _arrow_dtype_mapping from lsdb.catalog.margin_catalog import MarginCatalog from lsdb.core.search.abstract_search import AbstractSearch @@ -23,9 +27,31 @@ class HipscatLoadingConfig: margin_cache: MarginCatalog | None = None """Margin cache for the catalog. By default, it is None""" + dtype_backend: str | None = "pyarrow" + """The backend data type to apply to the catalog. It defaults to "pyarrow" and + if it is None no type conversion is performed.""" + kwargs: dict | None = None """Extra kwargs""" + def __post_init__(self): + if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]: + raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'") + def get_kwargs_dict(self) -> dict: """Returns a dictionary with the extra kwargs""" return self.kwargs if self.kwargs is not None else {} + + def get_dtype_backend(self) -> str: + """Returns the data type backend. It is either "pyarrow", "numpy_nullable", + or "", which allows us to keep the original data types.""" + return lib.no_default if self.dtype_backend is None else self.dtype_backend + + def get_dtype_mapper(self) -> Callable | None: + """Returns a mapper for pyarrow or numpy types, mirroring Pandas behaviour.""" + mapper = None + if self.dtype_backend == "pyarrow": + mapper = pd.ArrowDtype + elif self.dtype_backend == "numpy_nullable": + mapper = _arrow_dtype_mapping().get + return mapper diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index 20d9cb37..a4dc8249 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -28,9 +28,10 @@ def read_hipscat( path: str, catalog_type: Type[Dataset] | None = None, search_filter: AbstractSearch | None = None, - storage_options: dict | None = None, columns: List[str] | None = None, margin_cache: MarginCatalog | None = None, + dtype_backend: str | None = "pyarrow", + storage_options: dict | None = None, **kwargs, ) -> Dataset: """Load a catalog from a HiPSCat formatted catalog. @@ -55,9 +56,11 @@ def read_hipscat( type for type checking, the type of the catalog can be specified here. Use by specifying the lsdb class for that catalog. search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. - storage_options (dict): Dictionary that contains abstract filesystem credentials columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog): The margin cache for the main catalog + dtype_backend (str): Backend data type to apply to the catalog. + Defaults to "pyarrow". If None, no type conversion is performed. + storage_options (dict): Dictionary that contains abstract filesystem credentials **kwargs: Arguments to pass to the pandas parquet file reader Returns: diff --git a/tests/lsdb/catalog/test_association_catalog.py b/tests/lsdb/catalog/test_association_catalog.py index b5c284d1..b57ba826 100644 --- a/tests/lsdb/catalog/test_association_catalog.py +++ b/tests/lsdb/catalog/test_association_catalog.py @@ -19,7 +19,7 @@ def test_load_association(small_sky_to_xmatch_dir): pixel_number=hp_pixel, ) partition = small_sky_to_xmatch.get_partition(hp_order, hp_pixel) - data = pd.read_parquet(path) + data = pd.read_parquet(path, dtype_backend="pyarrow") pd.testing.assert_frame_equal(partition.compute(), data) diff --git a/tests/lsdb/catalog/test_cone_search.py b/tests/lsdb/catalog/test_cone_search.py index 38d57fc3..8351359d 100644 --- a/tests/lsdb/catalog/test_cone_search.py +++ b/tests/lsdb/catalog/test_cone_search.py @@ -36,9 +36,13 @@ def test_cone_search_filters_correct_points_margin( cone_search_catalog = small_sky_order1_source_with_margin.cone_search(ra, dec, radius) assert cone_search_catalog.margin is not None cone_search_df = cone_search_catalog.compute() - pd.testing.assert_frame_equal(cone_search_df, cone_search_expected, check_dtype=False) + pd.testing.assert_frame_equal( + cone_search_df, cone_search_expected, check_index_type=False, check_dtype=False + ) cone_search_margin_df = cone_search_catalog.margin.compute() - pd.testing.assert_frame_equal(cone_search_margin_df, cone_search_margin_expected, check_dtype=False) + pd.testing.assert_frame_equal( + cone_search_margin_df, cone_search_margin_expected, check_index_type=False, check_dtype=False + ) assert_divisions_are_correct(cone_search_catalog) assert_divisions_are_correct(cone_search_catalog.margin) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index d56dec4f..d652452d 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -216,7 +216,7 @@ def test_self_crossmatch(algo, small_sky_catalog, small_sky_dir): class MockCrossmatchAlgorithm(AbstractCrossmatchAlgorithm): """Mock class used to test a crossmatch algorithm""" - extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.dtype("float64"))}) + extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)}) # We must have the same signature as the crossmatch method def validate(self, mock_results: pd.DataFrame = None): # pylint: disable=unused-argument diff --git a/tests/lsdb/catalog/test_index_search.py b/tests/lsdb/catalog/test_index_search.py index efc8026a..1fddafe5 100644 --- a/tests/lsdb/catalog/test_index_search.py +++ b/tests/lsdb/catalog/test_index_search.py @@ -3,17 +3,12 @@ def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, assert_divisions_are_correct): catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir) - + # Searching for an object that does not exist index_search_catalog = small_sky_order1_catalog.index_search([900], catalog_index) index_search_df = index_search_catalog.compute() assert len(index_search_df) == 0 assert_divisions_are_correct(index_search_catalog) - - index_search_catalog = small_sky_order1_catalog.index_search(["700"], catalog_index) - index_search_df = index_search_catalog.compute() - assert len(index_search_df) == 0 - assert_divisions_are_correct(index_search_catalog) - + # Searching for an object that exists index_search_catalog = small_sky_order1_catalog.index_search([700], catalog_index) index_search_df = index_search_catalog.compute() assert len(index_search_df) == 1 diff --git a/tests/lsdb/catalog/test_margin_catalog.py b/tests/lsdb/catalog/test_margin_catalog.py index f0aa8b41..8b9716ca 100644 --- a/tests/lsdb/catalog/test_margin_catalog.py +++ b/tests/lsdb/catalog/test_margin_catalog.py @@ -30,7 +30,7 @@ def test_margin_catalog_partitions_correct(small_sky_xmatch_margin_dir): pixel_number=hp_pixel, ) partition = margin.get_partition(hp_order, hp_pixel) - data = pd.read_parquet(path) + data = pd.read_parquet(path, dtype_backend="pyarrow") pd.testing.assert_frame_equal(partition.compute(), data) diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index 8e65a675..6139ddc2 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -39,9 +39,7 @@ def test_from_dataframe(small_sky_order1_df, small_sky_order1_catalog, assert_di assert catalog._ddf.index.name == HIPSCAT_ID_COLUMN # Dataframes have the same data (column data types may differ) pd.testing.assert_frame_equal( - catalog.compute().sort_index(), - small_sky_order1_catalog.compute().sort_index(), - check_dtype=False, + catalog.compute().sort_index(), small_sky_order1_catalog.compute().sort_index() ) # Divisions belong to the respective HEALPix pixels assert_divisions_are_correct(catalog) @@ -213,3 +211,21 @@ def test_from_dataframe_margin_is_empty(small_sky_order1_df): threshold=100, ) assert catalog.margin is None + + +def test_from_dataframe_with_backend(small_sky_order1_df, small_sky_order1_dir): + """Tests that we can initialize a catalog from a Pandas Dataframe with the desired backend""" + # Read the catalog from hipscat format using pyarrow, import it from a CSV using + # the same backend and assert that we obtain the same catalog + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir) + kwargs = get_catalog_kwargs(expected_catalog) + catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs) + assert all(isinstance(col_type, pd.ArrowDtype) for col_type in catalog.dtypes) + pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) + + # Test that we can also keep the original types if desired + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, dtype_backend=None) + kwargs = get_catalog_kwargs(expected_catalog) + catalog = lsdb.from_dataframe(small_sky_order1_df, use_pyarrow_types=False, **kwargs) + assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) + pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index 24caa52b..a573fad9 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -4,6 +4,7 @@ import pandas as pd import pytest from hipscat.catalog.index.index_catalog import IndexCatalog +from pandas.core.dtypes.base import ExtensionDtype import lsdb from lsdb.core.search import BoxSearch, ConeSearch, IndexSearch, OrderSearch, PolygonSearch @@ -53,7 +54,7 @@ def test_parquet_data_in_partitions_match_files(small_sky_order1_dir, small_sky_ parquet_path = hc.io.paths.pixel_catalog_file( small_sky_order1_hipscat_catalog.catalog_base_dir, hp_order, hp_pixel ) - loaded_df = pd.read_parquet(parquet_path) + loaded_df = pd.read_parquet(parquet_path, dtype_backend="pyarrow") pd.testing.assert_frame_equal(partition_df, loaded_df) @@ -148,3 +149,23 @@ def test_read_hipscat_subset_no_partitions(small_sky_order1_dir, small_sky_order catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir) index_search = IndexSearch([900], catalog_index) lsdb.read_hipscat(small_sky_order1_dir, search_filter=index_search) + + +def test_read_hipscat_with_backend(small_sky_dir): + # By default, the schema is backed by pyarrow + default_catalog = lsdb.read_hipscat(small_sky_dir) + assert all(isinstance(col_type, pd.ArrowDtype) for col_type in default_catalog.dtypes) + # We can also pass it explicitly as an argument + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="pyarrow") + assert catalog.dtypes.equals(default_catalog.dtypes) + # Load data using a numpy-nullable types. + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy_nullable") + assert all(isinstance(col_type, ExtensionDtype) for col_type in catalog.dtypes) + # The other option is to keep the original types. In this case they are numpy-backed. + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend=None) + assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) + + +def test_read_hipscat_with_invalid_backend(small_sky_dir): + with pytest.raises(ValueError, match="data type backend must be either"): + lsdb.read_hipscat(small_sky_dir, dtype_backend="abc")