From f16577464b5a4a039e6ad67e5dee4afe36684cfd Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Fri, 3 May 2024 11:41:20 -0400 Subject: [PATCH 01/13] Load the parquet metadata using a specified backend --- .../hipscat/abstract_catalog_loader.py | 11 +++--- .../hipscat/association_catalog_loader.py | 3 +- .../loaders/hipscat/hipscat_loading_config.py | 3 ++ src/lsdb/loaders/hipscat/read_hipscat.py | 5 ++- tests/conftest.py | 34 ++++++++++++------- .../lsdb/catalog/test_association_catalog.py | 2 +- tests/lsdb/catalog/test_catalog.py | 6 ++-- tests/lsdb/catalog/test_margin_catalog.py | 2 +- .../lsdb/loaders/hipscat/test_read_hipscat.py | 14 +++++++- 9 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py index 1fa25bc2..3b582007 100644 --- a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py @@ -7,7 +7,7 @@ import dask.dataframe as dd import hipscat as hc import numpy as np -import pyarrow +import pandas as pd from hipscat.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset from hipscat.io.file_io import file_io from hipscat.pixel_math import HealpixPixel @@ -80,8 +80,7 @@ def _get_paths_from_pixels( def _load_df_from_paths( self, catalog: HCHealpixDataset, paths: List[hc.io.FilePointer], divisions: Tuple[int, ...] | None ) -> dd.core.DataFrame: - metadata_schema = self._load_parquet_metadata_schema(catalog) - dask_meta_schema = metadata_schema.empty_table().to_pandas() + dask_meta_schema = self._load_metadata_schema(catalog) if self.config.columns: dask_meta_schema = dask_meta_schema[self.config.columns] ddf = dd.io.from_map( @@ -91,11 +90,13 @@ def _load_df_from_paths( storage_options=self.storage_options, meta=dask_meta_schema, columns=self.config.columns, + dtype_backend=self.config.dtype_backend, **self.config.get_kwargs_dict(), ) return ddf - def _load_parquet_metadata_schema(self, catalog: HCHealpixDataset) -> pyarrow.Schema: + def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame: metadata_pointer = hc.io.paths.get_common_metadata_pointer(catalog.catalog_base_dir) metadata = file_io.read_parquet_metadata(metadata_pointer, storage_options=self.storage_options) - return metadata.schema.to_arrow_schema() + types_mapper = pd.ArrowDtype if self.config.dtype_backend == "pyarrow" else None + return metadata.schema.to_arrow_schema().empty_table().to_pandas(types_mapper=types_mapper) diff --git a/src/lsdb/loaders/hipscat/association_catalog_loader.py b/src/lsdb/loaders/hipscat/association_catalog_loader.py index e0a1bb87..13264859 100644 --- a/src/lsdb/loaders/hipscat/association_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/association_catalog_loader.py @@ -22,7 +22,6 @@ def load_catalog(self) -> AssociationCatalog: return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog) def _load_empty_dask_df_and_map(self, hc_catalog): - metadata_schema = self._load_parquet_metadata_schema(hc_catalog) - dask_meta_schema = metadata_schema.empty_table().to_pandas() + dask_meta_schema = self._load_metadata_schema(hc_catalog) ddf = dd.io.from_pandas(dask_meta_schema, npartitions=0) return ddf, {} diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index 944cc3f2..00de5a09 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -23,6 +23,9 @@ class HipscatLoadingConfig: margin_cache: MarginCatalog | None = None """Margin cache for the catalog. By default, it is None""" + dtype_backend: str = "pyarrow" + """Whether the data should be backed by numpy or pyarrow. It is either 'numpy_nullable' or 'pyarrow'""" + kwargs: dict | None = None """Extra kwargs""" diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index 20d9cb37..8aa24c92 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -28,9 +28,10 @@ def read_hipscat( path: str, catalog_type: Type[Dataset] | None = None, search_filter: AbstractSearch | None = None, - storage_options: dict | None = None, columns: List[str] | None = None, margin_cache: MarginCatalog | None = None, + dtype_backend: str = "pyarrow", + storage_options: dict | None = None, **kwargs, ) -> Dataset: """Load a catalog from a HiPSCat formatted catalog. @@ -58,6 +59,8 @@ def read_hipscat( storage_options (dict): Dictionary that contains abstract filesystem credentials columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog): The margin cache for the main catalog + dtype_backend (str): Whether the data should be backed by numpy or pyarrow. It is either + "numpy_nullable" or "pyarrow". Defaults to "pyarrow". **kwargs: Arguments to pass to the pandas parquet file reader Returns: diff --git a/tests/conftest.py b/tests/conftest.py index c8694d21..cc63e5b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,10 @@ TEST_DIR = os.path.dirname(__file__) +def read_csv_with_pyarrow(*args, **kwargs): + return pd.read_csv(*args, **kwargs, dtype_backend="pyarrow") + + @pytest.fixture def test_data_dir(): return os.path.join(TEST_DIR, DATA_DIR_NAME) @@ -183,12 +187,14 @@ def small_sky_to_o1source_soft_catalog(small_sky_to_order1_source_soft_dir): @pytest.fixture def small_sky_order1_df(small_sky_order1_dir): - return pd.read_csv(os.path.join(small_sky_order1_dir, SMALL_SKY_ORDER1_CSV)) + return read_csv_with_pyarrow(os.path.join(small_sky_order1_dir, SMALL_SKY_ORDER1_CSV)) @pytest.fixture def small_sky_source_df(test_data_dir): - return pd.read_csv(os.path.join(test_data_dir, "raw", "small_sky_source", "small_sky_source.csv")) + return read_csv_with_pyarrow( + os.path.join(test_data_dir, "raw", "small_sky_source", "small_sky_source.csv") + ) @pytest.fixture @@ -208,42 +214,42 @@ def xmatch_expected_dir(test_data_dir): @pytest.fixture def xmatch_correct(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_FILE)) @pytest.fixture def xmatch_correct_005(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_005_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_005_FILE)) @pytest.fixture def xmatch_correct_002_005(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_002_005_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_002_005_FILE)) @pytest.fixture def xmatch_correct_05_2_3n_margin(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_05_2_3N_MARGIN_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_05_2_3N_MARGIN_FILE)) @pytest.fixture def xmatch_correct_3n_2t(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_FILE)) @pytest.fixture def xmatch_correct_3n_2t_no_margin(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NO_MARGIN_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NO_MARGIN_FILE)) @pytest.fixture def xmatch_correct_3n_2t_negative(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NEGATIVE_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NEGATIVE_FILE)) @pytest.fixture def xmatch_mock(xmatch_expected_dir): - return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_MOCK_FILE)) + return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_MOCK_FILE)) @pytest.fixture @@ -253,12 +259,16 @@ def cone_search_expected_dir(test_data_dir): @pytest.fixture def cone_search_expected(cone_search_expected_dir): - return pd.read_csv(os.path.join(cone_search_expected_dir, "catalog.csv"), index_col=HIPSCAT_ID_COLUMN) + return read_csv_with_pyarrow( + os.path.join(cone_search_expected_dir, "catalog.csv"), index_col=HIPSCAT_ID_COLUMN + ) @pytest.fixture def cone_search_margin_expected(cone_search_expected_dir): - return pd.read_csv(os.path.join(cone_search_expected_dir, "margin.csv"), index_col=HIPSCAT_ID_COLUMN) + return read_csv_with_pyarrow( + os.path.join(cone_search_expected_dir, "margin.csv"), index_col=HIPSCAT_ID_COLUMN + ) @pytest.fixture diff --git a/tests/lsdb/catalog/test_association_catalog.py b/tests/lsdb/catalog/test_association_catalog.py index b5c284d1..b57ba826 100644 --- a/tests/lsdb/catalog/test_association_catalog.py +++ b/tests/lsdb/catalog/test_association_catalog.py @@ -19,7 +19,7 @@ def test_load_association(small_sky_to_xmatch_dir): pixel_number=hp_pixel, ) partition = small_sky_to_xmatch.get_partition(hp_order, hp_pixel) - data = pd.read_parquet(path) + data = pd.read_parquet(path, dtype_backend="pyarrow") pd.testing.assert_frame_equal(partition.compute(), data) diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index 0837f7a3..c20df898 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -411,14 +411,14 @@ def test_square_bracket_columns(small_sky_order1_catalog): assert all(column_subset.columns == columns) assert isinstance(column_subset, Catalog) pd.testing.assert_frame_equal(column_subset.compute(), small_sky_order1_catalog.compute()[columns]) - assert np.all(column_subset.compute().index.values == small_sky_order1_catalog.compute().index.values) + assert column_subset.compute().index.values.equals(small_sky_order1_catalog.compute().index.values) def test_square_bracket_column(small_sky_order1_catalog): column_name = "ra" column = small_sky_order1_catalog[column_name] pd.testing.assert_series_equal(column.compute(), small_sky_order1_catalog.compute()[column_name]) - assert np.all(column.compute().index.values == small_sky_order1_catalog.compute().index.values) + assert column.compute().index.values.equals(small_sky_order1_catalog.compute().index.values) assert isinstance(column, dd.core.Series) @@ -427,7 +427,7 @@ def test_square_bracket_filter(small_sky_order1_catalog): assert isinstance(filtered_id, Catalog) ss_computed = small_sky_order1_catalog.compute() pd.testing.assert_frame_equal(filtered_id.compute(), ss_computed[ss_computed["id"] > 750]) - assert np.all(filtered_id.compute().index.values == ss_computed[ss_computed["id"] > 750].index.values) + assert filtered_id.compute().index.values.equals(ss_computed[ss_computed["id"] > 750].index.values) def test_map_partitions(small_sky_order1_catalog): diff --git a/tests/lsdb/catalog/test_margin_catalog.py b/tests/lsdb/catalog/test_margin_catalog.py index f0aa8b41..8b9716ca 100644 --- a/tests/lsdb/catalog/test_margin_catalog.py +++ b/tests/lsdb/catalog/test_margin_catalog.py @@ -30,7 +30,7 @@ def test_margin_catalog_partitions_correct(small_sky_xmatch_margin_dir): pixel_number=hp_pixel, ) partition = margin.get_partition(hp_order, hp_pixel) - data = pd.read_parquet(path) + data = pd.read_parquet(path, dtype_backend="pyarrow") pd.testing.assert_frame_equal(partition.compute(), data) diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index a1ae7f82..65317692 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -53,7 +53,7 @@ def test_parquet_data_in_partitions_match_files(small_sky_order1_dir, small_sky_ parquet_path = hc.io.paths.pixel_catalog_file( small_sky_order1_hipscat_catalog.catalog_base_dir, hp_order, hp_pixel ) - loaded_df = pd.read_parquet(parquet_path) + loaded_df = pd.read_parquet(parquet_path, dtype_backend="pyarrow") pd.testing.assert_frame_equal(partition_df, loaded_df) @@ -148,3 +148,15 @@ def test_read_hipscat_subset_no_partitions(small_sky_order1_dir, small_sky_order catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir) index_search = IndexSearch([900], catalog_index) lsdb.read_hipscat(small_sky_order1_dir, search_filter=index_search) + + +def test_read_hipscat_with_backend(small_sky_dir): + # By default, the schema is backed by pyarrow + default_catalog = lsdb.read_hipscat(small_sky_dir) + assert all(isinstance(col_type, pd.ArrowDtype) for col_type in default_catalog.dtypes) + # We can also pass it explicitly as an argument + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="pyarrow") + assert catalog.dtypes.equals(default_catalog.dtypes) + # The other option is to use numpy-backed data types + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy_nullable") + assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) From 34eb303fdfdbb6bdd2a17ca4b02a783515207f84 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Fri, 3 May 2024 18:43:52 -0400 Subject: [PATCH 02/13] Convert types to pyarrow on from_dataframe --- .../abstract_crossmatch_algorithm.py | 4 +- src/lsdb/core/crossmatch/kdtree_match.py | 7 ++- .../dataframe/dataframe_catalog_loader.py | 8 ++- src/lsdb/loaders/dataframe/from_dataframe.py | 4 ++ .../loaders/dataframe/from_dataframe_utils.py | 35 ++++++++++-- src/lsdb/loaders/hipscat/read_hipscat.py | 2 +- tests/conftest.py | 34 +++++------- tests/lsdb/catalog/test_cone_search.py | 8 ++- tests/lsdb/catalog/test_crossmatch.py | 54 +++++++++---------- tests/lsdb/catalog/test_index_search.py | 9 +--- .../loaders/dataframe/test_from_dataframe.py | 4 +- 11 files changed, 98 insertions(+), 71 deletions(-) diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index 4119c46a..1306097e 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -105,10 +105,10 @@ def _append_extra_columns(cls, dataframe: pd.DataFrame, extra_columns: pd.DataFr raise ValueError(f"Provided extra column '{col}' not found in definition") # Update columns according to crossmatch algorithm specification columns_to_update = [] - for col, col_type in cls.extra_columns.items(): + for col, col_type in cls.extra_columns.dtypes.items(): if col not in extra_columns: raise ValueError(f"Missing extra column '{col} of type {col_type}'") - if col_type.dtype != extra_columns[col].dtype: + if col_type != extra_columns[col].dtype: raise ValueError(f"Invalid type '{col_type}' for extra column '{col}'") columns_to_update.append(col) for col in columns_to_update: diff --git a/src/lsdb/core/crossmatch/kdtree_match.py b/src/lsdb/core/crossmatch/kdtree_match.py index f32d7ea2..fbf24f5e 100644 --- a/src/lsdb/core/crossmatch/kdtree_match.py +++ b/src/lsdb/core/crossmatch/kdtree_match.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +import pyarrow as pa from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN from hipscat.pixel_math.validators import validate_radius @@ -15,7 +16,7 @@ class KdTreeCrossmatch(AbstractCrossmatchAlgorithm): """Nearest neighbor crossmatch using a 3D k-D tree""" - extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))}) + extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64()))}) def validate( self, @@ -102,6 +103,8 @@ def _create_crossmatch_df( axis=1, ) out.set_index(HIPSCAT_ID_COLUMN, inplace=True) - extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(arc_distances, index=out.index)}) + extra_columns = pd.DataFrame( + {"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()), index=out.index)} + ) self._append_extra_columns(out, extra_columns) return out diff --git a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py index b74ccdab..89b7b27c 100644 --- a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py +++ b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py @@ -36,6 +36,7 @@ def __init__( highest_order: int = 5, partition_size: int | None = None, threshold: int | None = None, + dtype_backend: str = "pyarrow", **kwargs, ) -> None: """Initializes a DataframeCatalogLoader @@ -46,12 +47,15 @@ def __init__( highest_order (int): The highest partition order partition_size (int): The desired partition size, in number of rows threshold (int): The maximum number of data points per pixel + dtype_backend (str): Whether the data should be backed by numpy or pyarrow. + It is either "numpy_nullable" or "pyarrow". Defaults to "pyarrow". **kwargs: Arguments to pass to the creation of the catalog info """ self.dataframe = dataframe self.lowest_order = lowest_order self.highest_order = highest_order self.threshold = self._calculate_threshold(partition_size, threshold) + self.dtype_backend = dtype_backend self.catalog_info = self._create_catalog_info(**kwargs) def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int: @@ -169,9 +173,9 @@ def _generate_dask_df_and_map( # Obtain Dataframe for current HEALPix pixel pixel_dfs.append(self._get_dataframe_for_healpix(hp_pixel, pixels)) - # Generate Dask Dataframe with original schema + # Generate Dask Dataframe with the original schema and desired backend pixel_list = list(ddf_pixel_map.keys()) - ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list) + ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list, dtype_backend=self.dtype_backend) return ddf, ddf_pixel_map, total_rows def _get_dataframe_for_healpix(self, hp_pixel: HealpixPixel, pixels: List[int]) -> pd.DataFrame: diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 51075489..11ec5294 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -15,6 +15,7 @@ def from_dataframe( threshold: int | None = None, margin_order: int | None = -1, margin_threshold: float = 5.0, + dtype_backend: str = "pyarrow", **kwargs, ) -> Catalog: """Load a catalog from a Pandas Dataframe in CSV format. @@ -27,6 +28,8 @@ def from_dataframe( threshold (int): The maximum number of data points per pixel margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds + dtype_backend (str): Whether the data should be backed by numpy or pyarrow. + It is either "numpy_nullable" or "pyarrow". Defaults to "pyarrow". **kwargs: Arguments to pass to the creation of the catalog info Returns: @@ -38,6 +41,7 @@ def from_dataframe( highest_order, partition_size, threshold, + dtype_backend, **kwargs, ).load_catalog() if margin_threshold: diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index d0a935ad..3412a6f7 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -3,6 +3,7 @@ import dask.dataframe as dd import numpy as np import pandas as pd +import pyarrow as pa from dask import delayed from hipscat.catalog import PartitionInfo from hipscat.pixel_math import HealpixPixel @@ -12,22 +13,50 @@ def _generate_dask_dataframe( - pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel] + pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], dtype_backend: str = "pyarrow" ) -> Tuple[dd.core.DataFrame, int]: """Create the Dask Dataframe from the list of HEALPix pixel Dataframes Args: pixel_dfs (List[pd.DataFrame]): The list of HEALPix pixel Dataframes pixels (List[HealpixPixel]): The list of HEALPix pixels in the catalog + dtype_backend (str): The backend that handles data types Returns: The catalog's Dask Dataframe and its total number of rows. """ schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else [] - divisions = get_pixels_divisions(pixels) delayed_dfs = [delayed(df) for df in pixel_dfs] + divisions = get_pixels_divisions(pixels) ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions) - return ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame(), len(ddf) + ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame() + ddf = _convert_ddf_types_to_pyarrow(ddf) if dtype_backend == "pyarrow" else ddf + return ddf, len(ddf) + + +# pylint: disable=protected-access +def _convert_ddf_types_to_pyarrow(ddf: dd.DataFrame) -> dd.DataFrame: + # Convert schema types according to the backend + pyarrow_meta = _convert_dtypes_to_pyarrow(ddf._meta) + # Apply the new schema to the dask dataframe + ddf = ddf.astype(pyarrow_meta.dtypes) + # Update index data type as well, which is not handled automatically + ddf.index = ddf.index.astype(pd.ArrowDtype(pa.uint64())) + # Finally, set the new schema as the dataframe meta + ddf._meta = pyarrow_meta + return ddf + + +def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: + new_series = {} + for column in df.columns: + try: + pa_array = pa.array(df[column], from_pandas=True) + except Exception as exc: + raise ValueError(f"Could not convert column {column} to a pyarrow type") from exc + series = pd.Series(pa_array, dtype=pd.ArrowDtype(pa_array.type), copy=False, index=df.index) + new_series[column] = series + return pd.DataFrame(new_series, index=df.index, copy=False) def _append_partition_information_to_dataframe(dataframe: pd.DataFrame, pixel: HealpixPixel) -> pd.DataFrame: diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index 8aa24c92..f6736af8 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -56,11 +56,11 @@ def read_hipscat( type for type checking, the type of the catalog can be specified here. Use by specifying the lsdb class for that catalog. search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. - storage_options (dict): Dictionary that contains abstract filesystem credentials columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog): The margin cache for the main catalog dtype_backend (str): Whether the data should be backed by numpy or pyarrow. It is either "numpy_nullable" or "pyarrow". Defaults to "pyarrow". + storage_options (dict): Dictionary that contains abstract filesystem credentials **kwargs: Arguments to pass to the pandas parquet file reader Returns: diff --git a/tests/conftest.py b/tests/conftest.py index cc63e5b1..c8694d21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,10 +36,6 @@ TEST_DIR = os.path.dirname(__file__) -def read_csv_with_pyarrow(*args, **kwargs): - return pd.read_csv(*args, **kwargs, dtype_backend="pyarrow") - - @pytest.fixture def test_data_dir(): return os.path.join(TEST_DIR, DATA_DIR_NAME) @@ -187,14 +183,12 @@ def small_sky_to_o1source_soft_catalog(small_sky_to_order1_source_soft_dir): @pytest.fixture def small_sky_order1_df(small_sky_order1_dir): - return read_csv_with_pyarrow(os.path.join(small_sky_order1_dir, SMALL_SKY_ORDER1_CSV)) + return pd.read_csv(os.path.join(small_sky_order1_dir, SMALL_SKY_ORDER1_CSV)) @pytest.fixture def small_sky_source_df(test_data_dir): - return read_csv_with_pyarrow( - os.path.join(test_data_dir, "raw", "small_sky_source", "small_sky_source.csv") - ) + return pd.read_csv(os.path.join(test_data_dir, "raw", "small_sky_source", "small_sky_source.csv")) @pytest.fixture @@ -214,42 +208,42 @@ def xmatch_expected_dir(test_data_dir): @pytest.fixture def xmatch_correct(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_FILE)) @pytest.fixture def xmatch_correct_005(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_005_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_005_FILE)) @pytest.fixture def xmatch_correct_002_005(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_002_005_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_002_005_FILE)) @pytest.fixture def xmatch_correct_05_2_3n_margin(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_05_2_3N_MARGIN_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_05_2_3N_MARGIN_FILE)) @pytest.fixture def xmatch_correct_3n_2t(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_FILE)) @pytest.fixture def xmatch_correct_3n_2t_no_margin(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NO_MARGIN_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NO_MARGIN_FILE)) @pytest.fixture def xmatch_correct_3n_2t_negative(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NEGATIVE_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_CORRECT_3N_2T_NEGATIVE_FILE)) @pytest.fixture def xmatch_mock(xmatch_expected_dir): - return read_csv_with_pyarrow(os.path.join(xmatch_expected_dir, XMATCH_MOCK_FILE)) + return pd.read_csv(os.path.join(xmatch_expected_dir, XMATCH_MOCK_FILE)) @pytest.fixture @@ -259,16 +253,12 @@ def cone_search_expected_dir(test_data_dir): @pytest.fixture def cone_search_expected(cone_search_expected_dir): - return read_csv_with_pyarrow( - os.path.join(cone_search_expected_dir, "catalog.csv"), index_col=HIPSCAT_ID_COLUMN - ) + return pd.read_csv(os.path.join(cone_search_expected_dir, "catalog.csv"), index_col=HIPSCAT_ID_COLUMN) @pytest.fixture def cone_search_margin_expected(cone_search_expected_dir): - return read_csv_with_pyarrow( - os.path.join(cone_search_expected_dir, "margin.csv"), index_col=HIPSCAT_ID_COLUMN - ) + return pd.read_csv(os.path.join(cone_search_expected_dir, "margin.csv"), index_col=HIPSCAT_ID_COLUMN) @pytest.fixture diff --git a/tests/lsdb/catalog/test_cone_search.py b/tests/lsdb/catalog/test_cone_search.py index 38d57fc3..8351359d 100644 --- a/tests/lsdb/catalog/test_cone_search.py +++ b/tests/lsdb/catalog/test_cone_search.py @@ -36,9 +36,13 @@ def test_cone_search_filters_correct_points_margin( cone_search_catalog = small_sky_order1_source_with_margin.cone_search(ra, dec, radius) assert cone_search_catalog.margin is not None cone_search_df = cone_search_catalog.compute() - pd.testing.assert_frame_equal(cone_search_df, cone_search_expected, check_dtype=False) + pd.testing.assert_frame_equal( + cone_search_df, cone_search_expected, check_index_type=False, check_dtype=False + ) cone_search_margin_df = cone_search_catalog.margin.compute() - pd.testing.assert_frame_equal(cone_search_margin_df, cone_search_margin_expected, check_dtype=False) + pd.testing.assert_frame_equal( + cone_search_margin_df, cone_search_margin_expected, check_index_type=False, check_dtype=False + ) assert_divisions_are_correct(cone_search_catalog) assert_divisions_are_correct(cone_search_catalog.margin) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index ca60f7f1..45e5222a 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -22,10 +22,10 @@ def test_kdtree_crossmatch(algo, small_sky_catalog, small_sky_xmatch_catalog, xm ).compute() assert len(xmatched) == len(xmatch_correct) for _, correct_row in xmatch_correct.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_kdtree_crossmatch_thresh(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_005): @@ -38,10 +38,10 @@ def test_kdtree_crossmatch_thresh(algo, small_sky_catalog, small_sky_xmatch_cata ).compute() assert len(xmatched) == len(xmatch_correct_005) for _, correct_row in xmatch_correct_005.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_kdtree_crossmatch_multiple_neighbors( @@ -57,13 +57,13 @@ def test_kdtree_crossmatch_multiple_neighbors( ).compute() assert len(xmatched) == len(xmatch_correct_3n_2t_no_margin) for _, correct_row in xmatch_correct_3n_2t_no_margin.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[ (xmatched["id_small_sky"] == correct_row["ss_id"]) & (xmatched["id_small_sky_xmatch"] == correct_row["xmatch_id"]) ] assert len(xmatch_row) == 1 - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_kdtree_crossmatch_multiple_neighbors_margin( @@ -77,13 +77,13 @@ def test_kdtree_crossmatch_multiple_neighbors_margin( ).compute() assert len(xmatched) == len(xmatch_correct_3n_2t) for _, correct_row in xmatch_correct_3n_2t.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[ (xmatched["id_small_sky"] == correct_row["ss_id"]) & (xmatched["id_small_sky_xmatch"] == correct_row["xmatch_id"]) ] assert len(xmatch_row) == 1 - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_crossmatch_negative_margin( @@ -101,13 +101,13 @@ def test_crossmatch_negative_margin( ).compute() assert len(xmatched) == len(xmatch_correct_3n_2t_negative) for _, correct_row in xmatch_correct_3n_2t_negative.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky_left_xmatch"].values + assert correct_row["ss_id"] in xmatched["id_small_sky_left_xmatch"].to_numpy() xmatch_row = xmatched[ (xmatched["id_small_sky_left_xmatch"] == correct_row["ss_id"]) & (xmatched["id_small_sky_xmatch"] == correct_row["xmatch_id"]) ] assert len(xmatch_row) == 1 - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog): @@ -131,10 +131,10 @@ def test_kdtree_crossmatch_min_thresh( ).compute() assert len(xmatched) == len(xmatch_correct_002_005) for _, correct_row in xmatch_correct_002_005.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_kdtree_crossmatch_min_thresh_multiple_neighbors_margin( @@ -157,13 +157,13 @@ def test_kdtree_crossmatch_min_thresh_multiple_neighbors_margin( ).compute() assert len(xmatched) == len(xmatch_correct_05_2_3n_margin) for _, correct_row in xmatch_correct_05_2_3n_margin.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[ (xmatched["id_small_sky"] == correct_row["ss_id"]) & (xmatched["id_small_sky_xmatch"] == correct_row["xmatch_id"]) ] assert len(xmatch_row) == 1 - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_kdtree_crossmatch_no_close_neighbors( @@ -181,10 +181,10 @@ def test_kdtree_crossmatch_no_close_neighbors( ).compute() assert len(xmatched) == len(xmatch_correct_005) for _, correct_row in xmatch_correct_005.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_dist_arcsec"].values == pytest.approx(correct_row["dist"] * 3600) + assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] + assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod def test_crossmatch_more_neighbors_than_points_available( @@ -224,7 +224,7 @@ def test_self_crossmatch(algo, small_sky_catalog, small_sky_dir): class MockCrossmatchAlgorithm(AbstractCrossmatchAlgorithm): """Mock class used to test a crossmatch algorithm""" - extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.dtype("float64"))}) + extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)}) # We must have the same signature as the crossmatch method def validate(self, mock_results: pd.DataFrame = None): # pylint: disable=unused-argument @@ -235,15 +235,15 @@ def crossmatch(self, mock_results: pd.DataFrame = None): right_reset = self.right.reset_index(drop=True) self._rename_columns_with_suffix(self.left, self.suffixes[0]) self._rename_columns_with_suffix(self.right, self.suffixes[1]) - mock_results = mock_results[mock_results["ss_id"].isin(left_reset["id"].values)] + mock_results = mock_results[mock_results["ss_id"].isin(left_reset["id"].to_numpy())] left_indexes = mock_results.apply( lambda row: left_reset[left_reset["id"] == row["ss_id"]].index[0], axis=1 ) right_indexes = mock_results.apply( lambda row: right_reset[right_reset["id"] == row["xmatch_id"]].index[0], axis=1 ) - left_join_part = self.left.iloc[left_indexes.values].reset_index() - right_join_part = self.right.iloc[right_indexes.values].reset_index(drop=True) + left_join_part = self.left.iloc[left_indexes.to_numpy()].reset_index() + right_join_part = self.right.iloc[right_indexes.to_numpy()].reset_index(drop=True) out = pd.concat( [ left_join_part, # select the rows of the left table @@ -265,10 +265,10 @@ def test_custom_crossmatch_algorithm(small_sky_catalog, small_sky_xmatch_catalog ).compute() assert len(xmatched) == len(xmatch_mock) for _, correct_row in xmatch_mock.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values + assert correct_row["ss_id"] in xmatched["id_small_sky"].to_numpy() xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) + assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] + assert xmatch_row["_DIST"].to_numpy() == pytest.approx(correct_row["dist"]) def test_append_extra_columns(small_sky_xmatch_catalog): diff --git a/tests/lsdb/catalog/test_index_search.py b/tests/lsdb/catalog/test_index_search.py index efc8026a..1fddafe5 100644 --- a/tests/lsdb/catalog/test_index_search.py +++ b/tests/lsdb/catalog/test_index_search.py @@ -3,17 +3,12 @@ def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, assert_divisions_are_correct): catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir) - + # Searching for an object that does not exist index_search_catalog = small_sky_order1_catalog.index_search([900], catalog_index) index_search_df = index_search_catalog.compute() assert len(index_search_df) == 0 assert_divisions_are_correct(index_search_catalog) - - index_search_catalog = small_sky_order1_catalog.index_search(["700"], catalog_index) - index_search_df = index_search_catalog.compute() - assert len(index_search_df) == 0 - assert_divisions_are_correct(index_search_catalog) - + # Searching for an object that exists index_search_catalog = small_sky_order1_catalog.index_search([700], catalog_index) index_search_df = index_search_catalog.compute() assert len(index_search_df) == 1 diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index 8e65a675..13c11467 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -39,9 +39,7 @@ def test_from_dataframe(small_sky_order1_df, small_sky_order1_catalog, assert_di assert catalog._ddf.index.name == HIPSCAT_ID_COLUMN # Dataframes have the same data (column data types may differ) pd.testing.assert_frame_equal( - catalog.compute().sort_index(), - small_sky_order1_catalog.compute().sort_index(), - check_dtype=False, + catalog.compute().sort_index(), small_sky_order1_catalog.compute().sort_index() ) # Divisions belong to the respective HEALPix pixels assert_divisions_are_correct(catalog) From 1ab9cf78638d5c60b89c0865475bd3596e977877 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Mon, 6 May 2024 12:09:57 -0400 Subject: [PATCH 03/13] Perform some cleaning and add more tests --- .../dataframe/dataframe_catalog_loader.py | 8 +++--- src/lsdb/loaders/dataframe/from_dataframe.py | 4 +-- .../loaders/dataframe/from_dataframe_utils.py | 27 ++++++++++++++++--- .../dataframe/margin_catalog_generator.py | 5 +++- .../hipscat/abstract_catalog_loader.py | 9 ++++--- .../loaders/hipscat/hipscat_loading_config.py | 18 ++++++++++++- src/lsdb/loaders/hipscat/read_hipscat.py | 4 +-- .../loaders/dataframe/test_from_dataframe.py | 22 +++++++++++++++ .../lsdb/loaders/hipscat/test_read_hipscat.py | 2 +- 9 files changed, 81 insertions(+), 18 deletions(-) diff --git a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py index 89b7b27c..c3e0b122 100644 --- a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py +++ b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py @@ -47,15 +47,15 @@ def __init__( highest_order (int): The highest partition order partition_size (int): The desired partition size, in number of rows threshold (int): The maximum number of data points per pixel - dtype_backend (str): Whether the data should be backed by numpy or pyarrow. - It is either "numpy_nullable" or "pyarrow". Defaults to "pyarrow". + dtype_backend (str): Whether the data should be backed by pyarrow or numpy. + It is either "pyarrow" or "numpy". Defaults to "pyarrow". **kwargs: Arguments to pass to the creation of the catalog info """ self.dataframe = dataframe self.lowest_order = lowest_order self.highest_order = highest_order self.threshold = self._calculate_threshold(partition_size, threshold) - self.dtype_backend = dtype_backend + self.use_pyarrow_types = dtype_backend == "pyarrow" self.catalog_info = self._create_catalog_info(**kwargs) def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int: @@ -175,7 +175,7 @@ def _generate_dask_df_and_map( # Generate Dask Dataframe with the original schema and desired backend pixel_list = list(ddf_pixel_map.keys()) - ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list, dtype_backend=self.dtype_backend) + ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list, self.use_pyarrow_types) return ddf, ddf_pixel_map, total_rows def _get_dataframe_for_healpix(self, hp_pixel: HealpixPixel, pixels: List[int]) -> pd.DataFrame: diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 11ec5294..02a25452 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -28,8 +28,8 @@ def from_dataframe( threshold (int): The maximum number of data points per pixel margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds - dtype_backend (str): Whether the data should be backed by numpy or pyarrow. - It is either "numpy_nullable" or "pyarrow". Defaults to "pyarrow". + dtype_backend (str): Whether the data should be backed by pyarrow or numpy. + It is either "pyarrow" or "numpy". Defaults to "pyarrow". **kwargs: Arguments to pass to the creation of the catalog info Returns: diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index 3412a6f7..fc0131eb 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -13,14 +13,14 @@ def _generate_dask_dataframe( - pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], dtype_backend: str = "pyarrow" + pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], use_pyarrow_types: bool = True ) -> Tuple[dd.core.DataFrame, int]: """Create the Dask Dataframe from the list of HEALPix pixel Dataframes Args: pixel_dfs (List[pd.DataFrame]): The list of HEALPix pixel Dataframes pixels (List[HealpixPixel]): The list of HEALPix pixels in the catalog - dtype_backend (str): The backend that handles data types + use_pyarrow_types (bool): If True, use pyarrow types. Defaults to "True". Returns: The catalog's Dask Dataframe and its total number of rows. @@ -30,17 +30,26 @@ def _generate_dask_dataframe( divisions = get_pixels_divisions(pixels) ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions) ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame() - ddf = _convert_ddf_types_to_pyarrow(ddf) if dtype_backend == "pyarrow" else ddf + ddf = _convert_ddf_types_to_pyarrow(ddf) if use_pyarrow_types else ddf return ddf, len(ddf) # pylint: disable=protected-access def _convert_ddf_types_to_pyarrow(ddf: dd.DataFrame) -> dd.DataFrame: + """Convert a Dask DataFrame to pyarrow types. + + Args: + ddf (dd.DataFrame): A Dask DataFrame + + Returns: + A new dask DataFrame, where columns, index and schema have been + converted to use pyarrow types. + """ # Convert schema types according to the backend pyarrow_meta = _convert_dtypes_to_pyarrow(ddf._meta) # Apply the new schema to the dask dataframe ddf = ddf.astype(pyarrow_meta.dtypes) - # Update index data type as well, which is not handled automatically + # Update the hipscat index data type, which is not handled automatically ddf.index = ddf.index.astype(pd.ArrowDtype(pa.uint64())) # Finally, set the new schema as the dataframe meta ddf._meta = pyarrow_meta @@ -48,6 +57,16 @@ def _convert_ddf_types_to_pyarrow(ddf: dd.DataFrame) -> dd.DataFrame: def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: + """Transform the columns of a Pandas DataFrame to pyarrow types. It + does not update the type of the DataFrame index. + + Args: + df (pd.DataFrame): A Pandas DataFrame + + Returns: + A new DataFrame, with columns of pyarrow types. The return value is a + shallow copy of the initial DataFrame to avoid copying the data. + """ new_series = {} for column in df.columns: try: diff --git a/src/lsdb/loaders/dataframe/margin_catalog_generator.py b/src/lsdb/loaders/dataframe/margin_catalog_generator.py index 90d512a1..71c4351f 100644 --- a/src/lsdb/loaders/dataframe/margin_catalog_generator.py +++ b/src/lsdb/loaders/dataframe/margin_catalog_generator.py @@ -29,6 +29,7 @@ def __init__( catalog: Catalog, margin_order: int | None = -1, margin_threshold: float = 5.0, + use_pyarrow_types: bool = True, ) -> None: """Initialize a MarginCatalogGenerator @@ -36,11 +37,13 @@ def __init__( catalog (Catalog): The LSDB catalog to generate margins for margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds + use_pyarrow_types (bool): If True, use pyarrow types. Defaults to "True". """ self.dataframe = catalog.compute().copy() self.hc_structure = catalog.hc_structure self.margin_threshold = margin_threshold self.margin_order = self._set_margin_order(margin_order) + self.use_pyarrow_types = use_pyarrow_types def _set_margin_order(self, margin_order: int | None) -> int: """Calculate the order of the margin cache to be generated. If not provided @@ -102,7 +105,7 @@ def _generate_dask_df_and_map(self) -> Tuple[dd.core.DataFrame, Dict[HealpixPixe ddf_pixel_map = {pixel: index for index, pixel in enumerate(ordered_pixels)} # Generate the dask dataframe with the pixels and partitions - ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels) + ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels, self.use_pyarrow_types) return ddf, ddf_pixel_map, total_rows def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame: diff --git a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py index 3b582007..42277f72 100644 --- a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py @@ -90,7 +90,7 @@ def _load_df_from_paths( storage_options=self.storage_options, meta=dask_meta_schema, columns=self.config.columns, - dtype_backend=self.config.dtype_backend, + dtype_backend=self.config.get_dtype_backend(), **self.config.get_kwargs_dict(), ) return ddf @@ -98,5 +98,8 @@ def _load_df_from_paths( def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame: metadata_pointer = hc.io.paths.get_common_metadata_pointer(catalog.catalog_base_dir) metadata = file_io.read_parquet_metadata(metadata_pointer, storage_options=self.storage_options) - types_mapper = pd.ArrowDtype if self.config.dtype_backend == "pyarrow" else None - return metadata.schema.to_arrow_schema().empty_table().to_pandas(types_mapper=types_mapper) + return ( + metadata.schema.to_arrow_schema() + .empty_table() + .to_pandas(types_mapper=self.config.get_pyarrow_dtype_mapper()) + ) diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index 00de5a09..9921d25d 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -3,6 +3,9 @@ from dataclasses import dataclass from typing import List +import pandas as pd +from pandas._libs import lib + from lsdb.catalog.margin_catalog import MarginCatalog from lsdb.core.search.abstract_search import AbstractSearch @@ -24,11 +27,24 @@ class HipscatLoadingConfig: """Margin cache for the catalog. By default, it is None""" dtype_backend: str = "pyarrow" - """Whether the data should be backed by numpy or pyarrow. It is either 'numpy_nullable' or 'pyarrow'""" + """Whether the data should be backed by pyarrow or numpy. It is either 'pyarrow' or 'numpy'""" kwargs: dict | None = None """Extra kwargs""" + def __post_init__(self): + if self.dtype_backend not in ["pyarrow", "numpy"]: + raise ValueError("The data type must be either 'pyarrow' or 'numpy'") + def get_kwargs_dict(self) -> dict: """Returns a dictionary with the extra kwargs""" return self.kwargs if self.kwargs is not None else {} + + def get_dtype_backend(self) -> str: + """Returns the data type backend. It is either "pyarrow" or , + in case we want to keep numpy-backed types.""" + return self.dtype_backend if self.dtype_backend == "pyarrow" else lib.no_default + + def get_pyarrow_dtype_mapper(self) -> pd.ArrowDtype | None: + """Returns a types mapper for pyarrow""" + return pd.ArrowDtype if self.dtype_backend == "pyarrow" else None diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index f6736af8..d341ad43 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -58,8 +58,8 @@ def read_hipscat( search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog): The margin cache for the main catalog - dtype_backend (str): Whether the data should be backed by numpy or pyarrow. It is either - "numpy_nullable" or "pyarrow". Defaults to "pyarrow". + dtype_backend (str): Whether the data should be backed by pyarrow or numpy. + It is either "pyarrow" or "numpy". Defaults to "pyarrow". storage_options (dict): Dictionary that contains abstract filesystem credentials **kwargs: Arguments to pass to the pandas parquet file reader diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index 13c11467..baf6fe8c 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -211,3 +211,25 @@ def test_from_dataframe_margin_is_empty(small_sky_order1_df): threshold=100, ) assert catalog.margin is None + + +def test_from_dataframe_with_backend(small_sky_order1_df, small_sky_order1_dir): + """Tests that we can initialize a catalog from a Pandas Dataframe with the desired backend""" + # Read the catalog from hipscat format using pyarrow, import it from a CSV using + # the same backend and assert that we obtain the same catalog + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, dtype_backend="pyarrow") + kwargs = get_catalog_kwargs(expected_catalog) + catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs, dtype_backend="pyarrow") + assert all(isinstance(col_type, pd.ArrowDtype) for col_type in catalog.dtypes) + pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) + + # By default, from_dataframe uses the pyarrow backend + default_catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs) + pd.testing.assert_frame_equal(catalog.compute().sort_index(), default_catalog.compute().sort_index()) + + # Test that we can also specify a numpy backend + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, dtype_backend="numpy") + kwargs = get_catalog_kwargs(expected_catalog) + catalog = lsdb.from_dataframe(small_sky_order1_df, dtype_backend="numpy", **kwargs) + assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) + pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index 65317692..edbf7375 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -158,5 +158,5 @@ def test_read_hipscat_with_backend(small_sky_dir): catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="pyarrow") assert catalog.dtypes.equals(default_catalog.dtypes) # The other option is to use numpy-backed data types - catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy_nullable") + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy") assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) From 6bfe75f498df484c63fef2ad14c4f923874d77d4 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Mon, 6 May 2024 12:25:34 -0400 Subject: [PATCH 04/13] Add tests for invalid backend types --- src/lsdb/loaders/dataframe/dataframe_catalog_loader.py | 6 +++++- src/lsdb/loaders/hipscat/hipscat_loading_config.py | 2 +- tests/lsdb/loaders/dataframe/test_from_dataframe.py | 5 +++++ tests/lsdb/loaders/hipscat/test_read_hipscat.py | 5 +++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py index c3e0b122..c2590d53 100644 --- a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py +++ b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py @@ -55,9 +55,13 @@ def __init__( self.lowest_order = lowest_order self.highest_order = highest_order self.threshold = self._calculate_threshold(partition_size, threshold) - self.use_pyarrow_types = dtype_backend == "pyarrow" self.catalog_info = self._create_catalog_info(**kwargs) + if dtype_backend not in ["pyarrow", "numpy"]: + raise ValueError("The data type backend must be either 'pyarrow' or 'numpy'") + + self.use_pyarrow_types = dtype_backend == "pyarrow" + def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int: """Calculates the number of pixels per HEALPix pixel (threshold) for the desired partition size. diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index 9921d25d..80203ef7 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -34,7 +34,7 @@ class HipscatLoadingConfig: def __post_init__(self): if self.dtype_backend not in ["pyarrow", "numpy"]: - raise ValueError("The data type must be either 'pyarrow' or 'numpy'") + raise ValueError("The data type backend must be either 'pyarrow' or 'numpy'") def get_kwargs_dict(self) -> dict: """Returns a dictionary with the extra kwargs""" diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index baf6fe8c..10119b6e 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -233,3 +233,8 @@ def test_from_dataframe_with_backend(small_sky_order1_df, small_sky_order1_dir): catalog = lsdb.from_dataframe(small_sky_order1_df, dtype_backend="numpy", **kwargs) assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) + + +def test_from_dataframe_with_invalid_backend(small_sky_order1_dir): + with pytest.raises(ValueError, match="data type backend must be either"): + lsdb.from_dataframe(small_sky_order1_dir, dtype_backend="abc") diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index edbf7375..4d7c6f67 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -160,3 +160,8 @@ def test_read_hipscat_with_backend(small_sky_dir): # The other option is to use numpy-backed data types catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy") assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) + + +def test_read_hipscat_with_invalid_backend(small_sky_dir): + with pytest.raises(ValueError, match="data type backend must be either"): + lsdb.read_hipscat(small_sky_dir, dtype_backend="abc") From a4d9e4c0e63edcb8b8fd5300326e5ef898549c78 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Mon, 6 May 2024 13:50:23 -0400 Subject: [PATCH 05/13] Use bool flag to use pyarrow dtypes instead of str --- .../dataframe/dataframe_catalog_loader.py | 12 ++++-------- src/lsdb/loaders/dataframe/from_dataframe.py | 8 ++++---- .../loaders/hipscat/hipscat_loading_config.py | 16 ++++++---------- src/lsdb/loaders/hipscat/read_hipscat.py | 6 +++--- .../loaders/dataframe/test_from_dataframe.py | 13 ++++--------- tests/lsdb/loaders/hipscat/test_read_hipscat.py | 11 +++-------- 6 files changed, 24 insertions(+), 42 deletions(-) diff --git a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py index c2590d53..ce762b68 100644 --- a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py +++ b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py @@ -36,7 +36,7 @@ def __init__( highest_order: int = 5, partition_size: int | None = None, threshold: int | None = None, - dtype_backend: str = "pyarrow", + use_pyarrow_types: bool = True, **kwargs, ) -> None: """Initializes a DataframeCatalogLoader @@ -47,8 +47,8 @@ def __init__( highest_order (int): The highest partition order partition_size (int): The desired partition size, in number of rows threshold (int): The maximum number of data points per pixel - dtype_backend (str): Whether the data should be backed by pyarrow or numpy. - It is either "pyarrow" or "numpy". Defaults to "pyarrow". + use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the + original data types. Defaults to "pyarrow". **kwargs: Arguments to pass to the creation of the catalog info """ self.dataframe = dataframe @@ -56,11 +56,7 @@ def __init__( self.highest_order = highest_order self.threshold = self._calculate_threshold(partition_size, threshold) self.catalog_info = self._create_catalog_info(**kwargs) - - if dtype_backend not in ["pyarrow", "numpy"]: - raise ValueError("The data type backend must be either 'pyarrow' or 'numpy'") - - self.use_pyarrow_types = dtype_backend == "pyarrow" + self.use_pyarrow_types = use_pyarrow_types def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int: """Calculates the number of pixels per HEALPix pixel (threshold) for the diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 02a25452..ff0457cc 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -15,7 +15,7 @@ def from_dataframe( threshold: int | None = None, margin_order: int | None = -1, margin_threshold: float = 5.0, - dtype_backend: str = "pyarrow", + use_pyarrow_types: bool = True, **kwargs, ) -> Catalog: """Load a catalog from a Pandas Dataframe in CSV format. @@ -28,8 +28,8 @@ def from_dataframe( threshold (int): The maximum number of data points per pixel margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds - dtype_backend (str): Whether the data should be backed by pyarrow or numpy. - It is either "pyarrow" or "numpy". Defaults to "pyarrow". + use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the + original data types. Defaults to "pyarrow". **kwargs: Arguments to pass to the creation of the catalog info Returns: @@ -41,7 +41,7 @@ def from_dataframe( highest_order, partition_size, threshold, - dtype_backend, + use_pyarrow_types, **kwargs, ).load_catalog() if margin_threshold: diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index 80203ef7..03dee9f0 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -26,25 +26,21 @@ class HipscatLoadingConfig: margin_cache: MarginCatalog | None = None """Margin cache for the catalog. By default, it is None""" - dtype_backend: str = "pyarrow" - """Whether the data should be backed by pyarrow or numpy. It is either 'pyarrow' or 'numpy'""" + use_pyarrow_types: bool = True + """Whether the data should be backed by pyarrow or not. Defaults to "pyarrow".""" kwargs: dict | None = None """Extra kwargs""" - def __post_init__(self): - if self.dtype_backend not in ["pyarrow", "numpy"]: - raise ValueError("The data type backend must be either 'pyarrow' or 'numpy'") - def get_kwargs_dict(self) -> dict: """Returns a dictionary with the extra kwargs""" return self.kwargs if self.kwargs is not None else {} def get_dtype_backend(self) -> str: """Returns the data type backend. It is either "pyarrow" or , - in case we want to keep numpy-backed types.""" - return self.dtype_backend if self.dtype_backend == "pyarrow" else lib.no_default + in case we want to keep the original types.""" + return "pyarrow" if self.use_pyarrow_types else lib.no_default def get_pyarrow_dtype_mapper(self) -> pd.ArrowDtype | None: - """Returns a types mapper for pyarrow""" - return pd.ArrowDtype if self.dtype_backend == "pyarrow" else None + """Returns a mapper for pyarrow types""" + return pd.ArrowDtype if self.use_pyarrow_types else None diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index d341ad43..5f425a31 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -30,7 +30,7 @@ def read_hipscat( search_filter: AbstractSearch | None = None, columns: List[str] | None = None, margin_cache: MarginCatalog | None = None, - dtype_backend: str = "pyarrow", + use_pyarrow_types: bool = True, storage_options: dict | None = None, **kwargs, ) -> Dataset: @@ -58,8 +58,8 @@ def read_hipscat( search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog): The margin cache for the main catalog - dtype_backend (str): Whether the data should be backed by pyarrow or numpy. - It is either "pyarrow" or "numpy". Defaults to "pyarrow". + use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the + original data types. Defaults to "pyarrow". storage_options (dict): Dictionary that contains abstract filesystem credentials **kwargs: Arguments to pass to the pandas parquet file reader diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index 10119b6e..5611adba 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -217,9 +217,9 @@ def test_from_dataframe_with_backend(small_sky_order1_df, small_sky_order1_dir): """Tests that we can initialize a catalog from a Pandas Dataframe with the desired backend""" # Read the catalog from hipscat format using pyarrow, import it from a CSV using # the same backend and assert that we obtain the same catalog - expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, dtype_backend="pyarrow") + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, use_pyarrow_types=True) kwargs = get_catalog_kwargs(expected_catalog) - catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs, dtype_backend="pyarrow") + catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs, use_pyarrow_types=True) assert all(isinstance(col_type, pd.ArrowDtype) for col_type in catalog.dtypes) pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) @@ -228,13 +228,8 @@ def test_from_dataframe_with_backend(small_sky_order1_df, small_sky_order1_dir): pd.testing.assert_frame_equal(catalog.compute().sort_index(), default_catalog.compute().sort_index()) # Test that we can also specify a numpy backend - expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, dtype_backend="numpy") + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, use_pyarrow_types=False) kwargs = get_catalog_kwargs(expected_catalog) - catalog = lsdb.from_dataframe(small_sky_order1_df, dtype_backend="numpy", **kwargs) + catalog = lsdb.from_dataframe(small_sky_order1_df, use_pyarrow_types=False, **kwargs) assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) - - -def test_from_dataframe_with_invalid_backend(small_sky_order1_dir): - with pytest.raises(ValueError, match="data type backend must be either"): - lsdb.from_dataframe(small_sky_order1_dir, dtype_backend="abc") diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index 4d7c6f67..156faf49 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -155,13 +155,8 @@ def test_read_hipscat_with_backend(small_sky_dir): default_catalog = lsdb.read_hipscat(small_sky_dir) assert all(isinstance(col_type, pd.ArrowDtype) for col_type in default_catalog.dtypes) # We can also pass it explicitly as an argument - catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="pyarrow") + catalog = lsdb.read_hipscat(small_sky_dir, use_pyarrow_types=True) assert catalog.dtypes.equals(default_catalog.dtypes) - # The other option is to use numpy-backed data types - catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy") + # The other option is to keep the original types. In this case they are numpy-backed. + catalog = lsdb.read_hipscat(small_sky_dir, use_pyarrow_types=False) assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) - - -def test_read_hipscat_with_invalid_backend(small_sky_dir): - with pytest.raises(ValueError, match="data type backend must be either"): - lsdb.read_hipscat(small_sky_dir, dtype_backend="abc") From 213ceb625b3e0dca1341fa812b75f653580bc45e Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Mon, 6 May 2024 17:43:09 -0400 Subject: [PATCH 06/13] Fix band column being cast to LargeString --- src/lsdb/loaders/dataframe/from_dataframe.py | 1 + src/lsdb/loaders/dataframe/from_dataframe_utils.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index ff0457cc..9bbad919 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -49,5 +49,6 @@ def from_dataframe( catalog, margin_order, margin_threshold, + use_pyarrow_types, ).create_catalog() return catalog diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index fc0131eb..c34ebb8b 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -73,7 +73,9 @@ def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: pa_array = pa.array(df[column], from_pandas=True) except Exception as exc: raise ValueError(f"Could not convert column {column} to a pyarrow type") from exc - series = pd.Series(pa_array, dtype=pd.ArrowDtype(pa_array.type), copy=False, index=df.index) + # The LargeString type is not recommended. Prevent any strings from being cast to this type. + pyarrow_dtype = pa.string() if isinstance(pa_array, pa.LargeStringArray) else pa_array.type + series = pd.Series(pa_array, dtype=pd.ArrowDtype(pyarrow_dtype), copy=False, index=df.index) new_series[column] = series return pd.DataFrame(new_series, index=df.index, copy=False) From 407e6026248e4c02a2836cb474ac4c5f1c96b585 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 7 May 2024 09:43:26 -0400 Subject: [PATCH 07/13] Remove piece of untestable code --- src/lsdb/loaders/dataframe/from_dataframe_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index c34ebb8b..46c48f93 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -69,10 +69,7 @@ def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: """ new_series = {} for column in df.columns: - try: - pa_array = pa.array(df[column], from_pandas=True) - except Exception as exc: - raise ValueError(f"Could not convert column {column} to a pyarrow type") from exc + pa_array = pa.array(df[column], from_pandas=True) # The LargeString type is not recommended. Prevent any strings from being cast to this type. pyarrow_dtype = pa.string() if isinstance(pa_array, pa.LargeStringArray) else pa_array.type series = pd.Series(pa_array, dtype=pd.ArrowDtype(pyarrow_dtype), copy=False, index=df.index) From eef3fc80b8fcc0d857fee17139d63249a1e557bd Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Wed, 8 May 2024 14:21:29 -0400 Subject: [PATCH 08/13] Revert to using a dtype_backend flag in read_hipscat --- .../hipscat/abstract_catalog_loader.py | 2 +- .../loaders/hipscat/hipscat_loading_config.py | 29 +++++++++++++------ src/lsdb/loaders/hipscat/read_hipscat.py | 6 ++-- .../loaders/dataframe/test_from_dataframe.py | 12 +++----- .../lsdb/loaders/hipscat/test_read_hipscat.py | 8 +++-- 5 files changed, 34 insertions(+), 23 deletions(-) diff --git a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py index 42277f72..78411e71 100644 --- a/src/lsdb/loaders/hipscat/abstract_catalog_loader.py +++ b/src/lsdb/loaders/hipscat/abstract_catalog_loader.py @@ -101,5 +101,5 @@ def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame: return ( metadata.schema.to_arrow_schema() .empty_table() - .to_pandas(types_mapper=self.config.get_pyarrow_dtype_mapper()) + .to_pandas(types_mapper=self.config.get_dtype_mapper()) ) diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index 03dee9f0..fc01cf7b 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -5,6 +5,7 @@ import pandas as pd from pandas._libs import lib +from pandas.io._util import _arrow_dtype_mapping from lsdb.catalog.margin_catalog import MarginCatalog from lsdb.core.search.abstract_search import AbstractSearch @@ -26,21 +27,31 @@ class HipscatLoadingConfig: margin_cache: MarginCatalog | None = None """Margin cache for the catalog. By default, it is None""" - use_pyarrow_types: bool = True - """Whether the data should be backed by pyarrow or not. Defaults to "pyarrow".""" + dtype_backend: str | None = "pyarrow" + """The backend data type to apply to the catalog. It defaults to "pyarrow" and + if it is None no type conversion is performed.""" kwargs: dict | None = None """Extra kwargs""" + def __post_init__(self): + if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]: + raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'") + def get_kwargs_dict(self) -> dict: """Returns a dictionary with the extra kwargs""" return self.kwargs if self.kwargs is not None else {} def get_dtype_backend(self) -> str: - """Returns the data type backend. It is either "pyarrow" or , - in case we want to keep the original types.""" - return "pyarrow" if self.use_pyarrow_types else lib.no_default - - def get_pyarrow_dtype_mapper(self) -> pd.ArrowDtype | None: - """Returns a mapper for pyarrow types""" - return pd.ArrowDtype if self.use_pyarrow_types else None + """Returns the data type backend. It is either "pyarrow", "numpy_nullable", + or , in case we want to keep the original types.""" + return lib.no_default if self.dtype_backend is None else self.dtype_backend + + def get_dtype_mapper(self) -> pd.ArrowDtype | None: + """Returns a mapper for pyarrow or numpy extension types""" + mapper = None + if self.dtype_backend == "pyarrow": + mapper = pd.ArrowDtype + elif self.dtype_backend == "numpy_nullable": + mapper = _arrow_dtype_mapping().get + return mapper diff --git a/src/lsdb/loaders/hipscat/read_hipscat.py b/src/lsdb/loaders/hipscat/read_hipscat.py index 5f425a31..a4dc8249 100644 --- a/src/lsdb/loaders/hipscat/read_hipscat.py +++ b/src/lsdb/loaders/hipscat/read_hipscat.py @@ -30,7 +30,7 @@ def read_hipscat( search_filter: AbstractSearch | None = None, columns: List[str] | None = None, margin_cache: MarginCatalog | None = None, - use_pyarrow_types: bool = True, + dtype_backend: str | None = "pyarrow", storage_options: dict | None = None, **kwargs, ) -> Dataset: @@ -58,8 +58,8 @@ def read_hipscat( search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog): The margin cache for the main catalog - use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the - original data types. Defaults to "pyarrow". + dtype_backend (str): Backend data type to apply to the catalog. + Defaults to "pyarrow". If None, no type conversion is performed. storage_options (dict): Dictionary that contains abstract filesystem credentials **kwargs: Arguments to pass to the pandas parquet file reader diff --git a/tests/lsdb/loaders/dataframe/test_from_dataframe.py b/tests/lsdb/loaders/dataframe/test_from_dataframe.py index 5611adba..6139ddc2 100644 --- a/tests/lsdb/loaders/dataframe/test_from_dataframe.py +++ b/tests/lsdb/loaders/dataframe/test_from_dataframe.py @@ -217,18 +217,14 @@ def test_from_dataframe_with_backend(small_sky_order1_df, small_sky_order1_dir): """Tests that we can initialize a catalog from a Pandas Dataframe with the desired backend""" # Read the catalog from hipscat format using pyarrow, import it from a CSV using # the same backend and assert that we obtain the same catalog - expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, use_pyarrow_types=True) + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir) kwargs = get_catalog_kwargs(expected_catalog) - catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs, use_pyarrow_types=True) + catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs) assert all(isinstance(col_type, pd.ArrowDtype) for col_type in catalog.dtypes) pd.testing.assert_frame_equal(catalog.compute().sort_index(), expected_catalog.compute().sort_index()) - # By default, from_dataframe uses the pyarrow backend - default_catalog = lsdb.from_dataframe(small_sky_order1_df, **kwargs) - pd.testing.assert_frame_equal(catalog.compute().sort_index(), default_catalog.compute().sort_index()) - - # Test that we can also specify a numpy backend - expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, use_pyarrow_types=False) + # Test that we can also keep the original types if desired + expected_catalog = lsdb.read_hipscat(small_sky_order1_dir, dtype_backend=None) kwargs = get_catalog_kwargs(expected_catalog) catalog = lsdb.from_dataframe(small_sky_order1_df, use_pyarrow_types=False, **kwargs) assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index 156faf49..921ecd3d 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -4,6 +4,7 @@ import pandas as pd import pytest from hipscat.catalog.index.index_catalog import IndexCatalog +from pandas.core.dtypes.base import ExtensionDtype import lsdb from lsdb.core.search import BoxSearch, ConeSearch, IndexSearch, OrderSearch, PolygonSearch @@ -155,8 +156,11 @@ def test_read_hipscat_with_backend(small_sky_dir): default_catalog = lsdb.read_hipscat(small_sky_dir) assert all(isinstance(col_type, pd.ArrowDtype) for col_type in default_catalog.dtypes) # We can also pass it explicitly as an argument - catalog = lsdb.read_hipscat(small_sky_dir, use_pyarrow_types=True) + catalog = lsdb.read_hipscat(small_sky_dir) assert catalog.dtypes.equals(default_catalog.dtypes) + # Load data using a numpy-nullable types. + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy_nullable") + assert all(isinstance(col_type, ExtensionDtype) for col_type in catalog.dtypes) # The other option is to keep the original types. In this case they are numpy-backed. - catalog = lsdb.read_hipscat(small_sky_dir, use_pyarrow_types=False) + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend=None) assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) From 513e5df3055016b932a2a4972d9fc76569b90e7f Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Wed, 8 May 2024 15:41:05 -0400 Subject: [PATCH 09/13] Improve code coverage --- src/lsdb/loaders/hipscat/hipscat_loading_config.py | 8 ++++---- tests/lsdb/loaders/hipscat/test_read_hipscat.py | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/lsdb/loaders/hipscat/hipscat_loading_config.py b/src/lsdb/loaders/hipscat/hipscat_loading_config.py index fc01cf7b..7a0653e8 100644 --- a/src/lsdb/loaders/hipscat/hipscat_loading_config.py +++ b/src/lsdb/loaders/hipscat/hipscat_loading_config.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List +from typing import Callable, List import pandas as pd from pandas._libs import lib @@ -44,11 +44,11 @@ def get_kwargs_dict(self) -> dict: def get_dtype_backend(self) -> str: """Returns the data type backend. It is either "pyarrow", "numpy_nullable", - or , in case we want to keep the original types.""" + or "", which allows us to keep the original data types.""" return lib.no_default if self.dtype_backend is None else self.dtype_backend - def get_dtype_mapper(self) -> pd.ArrowDtype | None: - """Returns a mapper for pyarrow or numpy extension types""" + def get_dtype_mapper(self) -> Callable | None: + """Returns a mapper for pyarrow or numpy types, mirroring Pandas behaviour.""" mapper = None if self.dtype_backend == "pyarrow": mapper = pd.ArrowDtype diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index 2d64673c..d0a42cac 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -164,3 +164,8 @@ def test_read_hipscat_with_backend(small_sky_dir): # The other option is to keep the original types. In this case they are numpy-backed. catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend=None) assert all(isinstance(col_type, np.dtype) for col_type in catalog.dtypes) + + +def test_read_hipscat_with_invalid_backend(small_sky_dir): + with pytest.raises(ValueError, match="data type backend must be either"): + lsdb.read_hipscat(small_sky_dir, dtype_backend="abc") From af305e6cc0869d9995934315744c75a17ed3debc Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Wed, 8 May 2024 17:18:34 -0400 Subject: [PATCH 10/13] Make type conversion before creating ddf --- .../loaders/dataframe/from_dataframe_utils.py | 31 +++---------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index 05470049..014079c8 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -25,37 +25,15 @@ def _generate_dask_dataframe( Returns: The catalog's Dask Dataframe and its total number of rows. """ + pixel_dfs = [_convert_dtypes_to_pyarrow(df) for df in pixel_dfs] if use_pyarrow_types else pixel_dfs schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else [] delayed_dfs = [delayed(df) for df in pixel_dfs] divisions = get_pixels_divisions(pixels) ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions) ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame() - ddf = _convert_ddf_types_to_pyarrow(ddf) if use_pyarrow_types else ddf return ddf, len(ddf) -# pylint: disable=protected-access -def _convert_ddf_types_to_pyarrow(ddf: dd.DataFrame) -> dd.DataFrame: - """Convert a Dask DataFrame to pyarrow types. - - Args: - ddf (dd.DataFrame): A Dask DataFrame - - Returns: - A new dask DataFrame, where columns, index and schema have been - converted to use pyarrow types. - """ - # Convert schema types according to the backend - pyarrow_meta = _convert_dtypes_to_pyarrow(ddf._meta) - # Apply the new schema to the dask dataframe - ddf = ddf.astype(pyarrow_meta.dtypes) - # Update the hipscat index data type, which is not handled automatically - ddf.index = ddf.index.astype(pd.ArrowDtype(pa.uint64())) - # Finally, set the new schema as the dataframe meta - ddf._meta = pyarrow_meta - return ddf - - def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: """Transform the columns of a Pandas DataFrame to pyarrow types. It does not update the type of the DataFrame index. @@ -68,13 +46,12 @@ def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: shallow copy of the initial DataFrame to avoid copying the data. """ new_series = {} + df_index = df.index.astype(pd.ArrowDtype(pa.uint64())) for column in df.columns: pa_array = pa.array(df[column], from_pandas=True) - # The LargeString type is not recommended. Prevent any strings from being cast to this type. - pyarrow_dtype = pa.string() if isinstance(pa_array, pa.LargeStringArray) else pa_array.type - series = pd.Series(pa_array, dtype=pd.ArrowDtype(pyarrow_dtype), copy=False, index=df.index) + series = pd.Series(pa_array, dtype=pd.ArrowDtype(pa_array.type), copy=False, index=df_index) new_series[column] = series - return pd.DataFrame(new_series, index=df.index, copy=False) + return pd.DataFrame(new_series, index=df_index, copy=False) def _append_partition_information_to_dataframe(dataframe: pd.DataFrame, pixel: HealpixPixel) -> pd.DataFrame: From 106e8c3f7d3668771f9642076aff46436c4a3cf0 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Wed, 8 May 2024 17:22:05 -0400 Subject: [PATCH 11/13] Update docstring --- src/lsdb/loaders/dataframe/from_dataframe_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index 014079c8..0bbf5b8c 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -35,8 +35,7 @@ def _generate_dask_dataframe( def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame: - """Transform the columns of a Pandas DataFrame to pyarrow types. It - does not update the type of the DataFrame index. + """Transform the columns (and index) of a Pandas DataFrame to pyarrow types. Args: df (pd.DataFrame): A Pandas DataFrame From 906ec292667891b3b854c0cb83f9d7e3cbaf5260 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Thu, 9 May 2024 13:30:14 -0400 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Melissa DeLucchi <113376043+delucchi-cmu@users.noreply.github.com> --- src/lsdb/loaders/dataframe/dataframe_catalog_loader.py | 2 +- src/lsdb/loaders/dataframe/from_dataframe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py index f8f33207..4ec3ef54 100644 --- a/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py +++ b/src/lsdb/loaders/dataframe/dataframe_catalog_loader.py @@ -48,7 +48,7 @@ def __init__( partition_size (int): The desired partition size, in number of rows threshold (int): The maximum number of data points per pixel use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the - original data types. Defaults to "pyarrow". + original data types. Defaults to True. **kwargs: Arguments to pass to the creation of the catalog info """ self.dataframe = dataframe diff --git a/src/lsdb/loaders/dataframe/from_dataframe.py b/src/lsdb/loaders/dataframe/from_dataframe.py index 9bbad919..ae452ad8 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe.py +++ b/src/lsdb/loaders/dataframe/from_dataframe.py @@ -29,7 +29,7 @@ def from_dataframe( margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the - original data types. Defaults to "pyarrow". + original data types. Defaults to True. **kwargs: Arguments to pass to the creation of the catalog info Returns: From 5cd240f343775849d322c21ebbddf89e3dc4b180 Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Thu, 9 May 2024 14:46:34 -0400 Subject: [PATCH 13/13] Add missing argument in test --- src/lsdb/loaders/dataframe/from_dataframe_utils.py | 2 +- src/lsdb/loaders/dataframe/margin_catalog_generator.py | 2 +- tests/lsdb/loaders/hipscat/test_read_hipscat.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lsdb/loaders/dataframe/from_dataframe_utils.py b/src/lsdb/loaders/dataframe/from_dataframe_utils.py index 0bbf5b8c..c58a5733 100644 --- a/src/lsdb/loaders/dataframe/from_dataframe_utils.py +++ b/src/lsdb/loaders/dataframe/from_dataframe_utils.py @@ -20,7 +20,7 @@ def _generate_dask_dataframe( Args: pixel_dfs (List[pd.DataFrame]): The list of HEALPix pixel Dataframes pixels (List[HealpixPixel]): The list of HEALPix pixels in the catalog - use_pyarrow_types (bool): If True, use pyarrow types. Defaults to "True". + use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True. Returns: The catalog's Dask Dataframe and its total number of rows. diff --git a/src/lsdb/loaders/dataframe/margin_catalog_generator.py b/src/lsdb/loaders/dataframe/margin_catalog_generator.py index 77f5299e..6a9a6af0 100644 --- a/src/lsdb/loaders/dataframe/margin_catalog_generator.py +++ b/src/lsdb/loaders/dataframe/margin_catalog_generator.py @@ -37,7 +37,7 @@ def __init__( catalog (Catalog): The LSDB catalog to generate margins for margin_order (int): The order at which to generate the margin cache margin_threshold (float): The size of the margin cache boundary, in arcseconds - use_pyarrow_types (bool): If True, use pyarrow types. Defaults to "True". + use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True. """ self.dataframe = catalog.compute().copy() self.hc_structure = catalog.hc_structure diff --git a/tests/lsdb/loaders/hipscat/test_read_hipscat.py b/tests/lsdb/loaders/hipscat/test_read_hipscat.py index d0a42cac..a573fad9 100644 --- a/tests/lsdb/loaders/hipscat/test_read_hipscat.py +++ b/tests/lsdb/loaders/hipscat/test_read_hipscat.py @@ -156,7 +156,7 @@ def test_read_hipscat_with_backend(small_sky_dir): default_catalog = lsdb.read_hipscat(small_sky_dir) assert all(isinstance(col_type, pd.ArrowDtype) for col_type in default_catalog.dtypes) # We can also pass it explicitly as an argument - catalog = lsdb.read_hipscat(small_sky_dir) + catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="pyarrow") assert catalog.dtypes.equals(default_catalog.dtypes) # Load data using a numpy-nullable types. catalog = lsdb.read_hipscat(small_sky_dir, dtype_backend="numpy_nullable")