Skip to content

Commit

Permalink
Load data using pyarrow types (#306)
Browse files Browse the repository at this point in the history
* Load the parquet metadata using a specified backend

* Convert types to pyarrow on from_dataframe

* Add tests for invalid backend types

---------

Co-authored-by: Melissa DeLucchi <[email protected]>
  • Loading branch information
camposandro and delucchi-cmu authored May 9, 2024
1 parent 5482cf5 commit 7c49a59
Show file tree
Hide file tree
Showing 17 changed files with 142 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def _append_extra_columns(cls, dataframe: pd.DataFrame, extra_columns: pd.DataFr
raise ValueError(f"Provided extra column '{col}' not found in definition")
# Update columns according to crossmatch algorithm specification
columns_to_update = []
for col, col_type in cls.extra_columns.items():
for col, col_type in cls.extra_columns.dtypes.items():
if col not in extra_columns:
raise ValueError(f"Missing extra column '{col} of type {col_type}'")
if col_type.dtype != extra_columns[col].dtype:
if col_type != extra_columns[col].dtype:
raise ValueError(f"Invalid type '{col_type}' for extra column '{col}'")
columns_to_update.append(col)
for col in columns_to_update:
Expand Down
7 changes: 5 additions & 2 deletions src/lsdb/core/crossmatch/kdtree_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN
from hipscat.pixel_math.validators import validate_radius

Expand All @@ -15,7 +16,7 @@
class KdTreeCrossmatch(AbstractCrossmatchAlgorithm):
"""Nearest neighbor crossmatch using a 3D k-D tree"""

extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))})
extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64()))})

def validate(
self,
Expand Down Expand Up @@ -102,6 +103,8 @@ def _create_crossmatch_df(
axis=1,
)
out.set_index(HIPSCAT_ID_COLUMN, inplace=True)
extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(arc_distances, index=out.index)})
extra_columns = pd.DataFrame(
{"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()), index=out.index)}
)
self._append_extra_columns(out, extra_columns)
return out
8 changes: 6 additions & 2 deletions src/lsdb/loaders/dataframe/dataframe_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
highest_order: int = 5,
partition_size: int | None = None,
threshold: int | None = None,
use_pyarrow_types: bool = True,
**kwargs,
) -> None:
"""Initializes a DataframeCatalogLoader
Expand All @@ -46,13 +47,16 @@ def __init__(
highest_order (int): The highest partition order
partition_size (int): The desired partition size, in number of rows
threshold (int): The maximum number of data points per pixel
use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the
original data types. Defaults to True.
**kwargs: Arguments to pass to the creation of the catalog info
"""
self.dataframe = dataframe
self.lowest_order = lowest_order
self.highest_order = highest_order
self.threshold = self._calculate_threshold(partition_size, threshold)
self.catalog_info = self._create_catalog_info(**kwargs)
self.use_pyarrow_types = use_pyarrow_types

def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int:
"""Calculates the number of pixels per HEALPix pixel (threshold) for the
Expand Down Expand Up @@ -169,9 +173,9 @@ def _generate_dask_df_and_map(
# Obtain Dataframe for current HEALPix pixel
pixel_dfs.append(self._get_dataframe_for_healpix(hp_pixel, pixels))

# Generate Dask Dataframe with original schema
# Generate Dask Dataframe with the original schema and desired backend
pixel_list = list(ddf_pixel_map.keys())
ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list)
ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list, self.use_pyarrow_types)
return ddf, ddf_pixel_map, total_rows

def _get_dataframe_for_healpix(self, hp_pixel: HealpixPixel, pixels: List[int]) -> pd.DataFrame:
Expand Down
5 changes: 5 additions & 0 deletions src/lsdb/loaders/dataframe/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def from_dataframe(
threshold: int | None = None,
margin_order: int | None = -1,
margin_threshold: float = 5.0,
use_pyarrow_types: bool = True,
**kwargs,
) -> Catalog:
"""Load a catalog from a Pandas Dataframe in CSV format.
Expand All @@ -27,6 +28,8 @@ def from_dataframe(
threshold (int): The maximum number of data points per pixel
margin_order (int): The order at which to generate the margin cache
margin_threshold (float): The size of the margin cache boundary, in arcseconds
use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the
original data types. Defaults to True.
**kwargs: Arguments to pass to the creation of the catalog info
Returns:
Expand All @@ -38,12 +41,14 @@ def from_dataframe(
highest_order,
partition_size,
threshold,
use_pyarrow_types,
**kwargs,
).load_catalog()
if margin_threshold:
catalog.margin = MarginCatalogGenerator(
catalog,
margin_order,
margin_threshold,
use_pyarrow_types,
).create_catalog()
return catalog
29 changes: 26 additions & 3 deletions src/lsdb/loaders/dataframe/from_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pyarrow as pa
from dask import delayed
from hipscat.catalog import PartitionInfo
from hipscat.pixel_math import HealpixPixel
Expand All @@ -12,22 +13,44 @@


def _generate_dask_dataframe(
pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel]
pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], use_pyarrow_types: bool = True
) -> Tuple[dd.core.DataFrame, int]:
"""Create the Dask Dataframe from the list of HEALPix pixel Dataframes
Args:
pixel_dfs (List[pd.DataFrame]): The list of HEALPix pixel Dataframes
pixels (List[HealpixPixel]): The list of HEALPix pixels in the catalog
use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True.
Returns:
The catalog's Dask Dataframe and its total number of rows.
"""
pixel_dfs = [_convert_dtypes_to_pyarrow(df) for df in pixel_dfs] if use_pyarrow_types else pixel_dfs
schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else []
divisions = get_pixels_divisions(pixels)
delayed_dfs = [delayed(df) for df in pixel_dfs]
divisions = get_pixels_divisions(pixels)
ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions)
return ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame(), len(ddf)
ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame()
return ddf, len(ddf)


def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame:
"""Transform the columns (and index) of a Pandas DataFrame to pyarrow types.
Args:
df (pd.DataFrame): A Pandas DataFrame
Returns:
A new DataFrame, with columns of pyarrow types. The return value is a
shallow copy of the initial DataFrame to avoid copying the data.
"""
new_series = {}
df_index = df.index.astype(pd.ArrowDtype(pa.uint64()))
for column in df.columns:
pa_array = pa.array(df[column], from_pandas=True)
series = pd.Series(pa_array, dtype=pd.ArrowDtype(pa_array.type), copy=False, index=df_index)
new_series[column] = series
return pd.DataFrame(new_series, index=df_index, copy=False)


def _append_partition_information_to_dataframe(dataframe: pd.DataFrame, pixel: HealpixPixel) -> pd.DataFrame:
Expand Down
5 changes: 4 additions & 1 deletion src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ def __init__(
catalog: Catalog,
margin_order: int | None = -1,
margin_threshold: float = 5.0,
use_pyarrow_types: bool = True,
) -> None:
"""Initialize a MarginCatalogGenerator
Args:
catalog (Catalog): The LSDB catalog to generate margins for
margin_order (int): The order at which to generate the margin cache
margin_threshold (float): The size of the margin cache boundary, in arcseconds
use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True.
"""
self.dataframe = catalog.compute().copy()
self.hc_structure = catalog.hc_structure
self.margin_threshold = margin_threshold
self.margin_order = self._set_margin_order(margin_order)
self.use_pyarrow_types = use_pyarrow_types

def _set_margin_order(self, margin_order: int | None) -> int:
"""Calculate the order of the margin cache to be generated. If not provided
Expand Down Expand Up @@ -102,7 +105,7 @@ def _generate_dask_df_and_map(self) -> Tuple[dd.core.DataFrame, Dict[HealpixPixe
ddf_pixel_map = {pixel: index for index, pixel in enumerate(ordered_pixels)}

# Generate the dask dataframe with the pixels and partitions
ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels)
ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels, self.use_pyarrow_types)
return ddf, ddf_pixel_map, total_rows

def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame:
Expand Down
14 changes: 9 additions & 5 deletions src/lsdb/loaders/hipscat/abstract_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dask.dataframe as dd
import hipscat as hc
import numpy as np
import pyarrow
import pandas as pd
from hipscat.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset
from hipscat.io.file_io import file_io
from hipscat.pixel_math import HealpixPixel
Expand Down Expand Up @@ -80,8 +80,7 @@ def _get_paths_from_pixels(
def _load_df_from_paths(
self, catalog: HCHealpixDataset, paths: List[hc.io.FilePointer], divisions: Tuple[int, ...] | None
) -> dd.core.DataFrame:
metadata_schema = self._load_parquet_metadata_schema(catalog)
dask_meta_schema = metadata_schema.empty_table().to_pandas()
dask_meta_schema = self._load_metadata_schema(catalog)
if self.config.columns:
dask_meta_schema = dask_meta_schema[self.config.columns]
ddf = dd.io.from_map(
Expand All @@ -91,11 +90,16 @@ def _load_df_from_paths(
storage_options=self.storage_options,
meta=dask_meta_schema,
columns=self.config.columns,
dtype_backend=self.config.get_dtype_backend(),
**self.config.get_kwargs_dict(),
)
return ddf

def _load_parquet_metadata_schema(self, catalog: HCHealpixDataset) -> pyarrow.Schema:
def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame:
metadata_pointer = hc.io.paths.get_common_metadata_pointer(catalog.catalog_base_dir)
metadata = file_io.read_parquet_metadata(metadata_pointer, storage_options=self.storage_options)
return metadata.schema.to_arrow_schema()
return (
metadata.schema.to_arrow_schema()
.empty_table()
.to_pandas(types_mapper=self.config.get_dtype_mapper())
)
3 changes: 1 addition & 2 deletions src/lsdb/loaders/hipscat/association_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def load_catalog(self) -> AssociationCatalog:
return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog)

def _load_empty_dask_df_and_map(self, hc_catalog):
metadata_schema = self._load_parquet_metadata_schema(hc_catalog)
dask_meta_schema = metadata_schema.empty_table().to_pandas()
dask_meta_schema = self._load_metadata_schema(hc_catalog)
ddf = dd.io.from_pandas(dask_meta_schema, npartitions=0)
return ddf, {}
28 changes: 27 additions & 1 deletion src/lsdb/loaders/hipscat/hipscat_loading_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List
from typing import Callable, List

import pandas as pd
from pandas._libs import lib
from pandas.io._util import _arrow_dtype_mapping

from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.core.search.abstract_search import AbstractSearch
Expand All @@ -23,9 +27,31 @@ class HipscatLoadingConfig:
margin_cache: MarginCatalog | None = None
"""Margin cache for the catalog. By default, it is None"""

dtype_backend: str | None = "pyarrow"
"""The backend data type to apply to the catalog. It defaults to "pyarrow" and
if it is None no type conversion is performed."""

kwargs: dict | None = None
"""Extra kwargs"""

def __post_init__(self):
if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]:
raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'")

def get_kwargs_dict(self) -> dict:
"""Returns a dictionary with the extra kwargs"""
return self.kwargs if self.kwargs is not None else {}

def get_dtype_backend(self) -> str:
"""Returns the data type backend. It is either "pyarrow", "numpy_nullable",
or "<no_default>", which allows us to keep the original data types."""
return lib.no_default if self.dtype_backend is None else self.dtype_backend

def get_dtype_mapper(self) -> Callable | None:
"""Returns a mapper for pyarrow or numpy types, mirroring Pandas behaviour."""
mapper = None
if self.dtype_backend == "pyarrow":
mapper = pd.ArrowDtype
elif self.dtype_backend == "numpy_nullable":
mapper = _arrow_dtype_mapping().get
return mapper
7 changes: 5 additions & 2 deletions src/lsdb/loaders/hipscat/read_hipscat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def read_hipscat(
path: str,
catalog_type: Type[Dataset] | None = None,
search_filter: AbstractSearch | None = None,
storage_options: dict | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
) -> Dataset:
"""Load a catalog from a HiPSCat formatted catalog.
Expand All @@ -55,9 +56,11 @@ def read_hipscat(
type for type checking, the type of the catalog can be specified here. Use by specifying
the lsdb class for that catalog.
search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied.
storage_options (dict): Dictionary that contains abstract filesystem credentials
columns (List[str]): Default `None`. The set of columns to filter the catalog on.
margin_cache (MarginCatalog): The margin cache for the main catalog
dtype_backend (str): Backend data type to apply to the catalog.
Defaults to "pyarrow". If None, no type conversion is performed.
storage_options (dict): Dictionary that contains abstract filesystem credentials
**kwargs: Arguments to pass to the pandas parquet file reader
Returns:
Expand Down
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_association_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_load_association(small_sky_to_xmatch_dir):
pixel_number=hp_pixel,
)
partition = small_sky_to_xmatch.get_partition(hp_order, hp_pixel)
data = pd.read_parquet(path)
data = pd.read_parquet(path, dtype_backend="pyarrow")
pd.testing.assert_frame_equal(partition.compute(), data)


Expand Down
8 changes: 6 additions & 2 deletions tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ def test_cone_search_filters_correct_points_margin(
cone_search_catalog = small_sky_order1_source_with_margin.cone_search(ra, dec, radius)
assert cone_search_catalog.margin is not None
cone_search_df = cone_search_catalog.compute()
pd.testing.assert_frame_equal(cone_search_df, cone_search_expected, check_dtype=False)
pd.testing.assert_frame_equal(
cone_search_df, cone_search_expected, check_index_type=False, check_dtype=False
)
cone_search_margin_df = cone_search_catalog.margin.compute()
pd.testing.assert_frame_equal(cone_search_margin_df, cone_search_margin_expected, check_dtype=False)
pd.testing.assert_frame_equal(
cone_search_margin_df, cone_search_margin_expected, check_index_type=False, check_dtype=False
)
assert_divisions_are_correct(cone_search_catalog)
assert_divisions_are_correct(cone_search_catalog.margin)

Expand Down
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_crossmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_self_crossmatch(algo, small_sky_catalog, small_sky_dir):
class MockCrossmatchAlgorithm(AbstractCrossmatchAlgorithm):
"""Mock class used to test a crossmatch algorithm"""

extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.dtype("float64"))})
extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)})

# We must have the same signature as the crossmatch method
def validate(self, mock_results: pd.DataFrame = None): # pylint: disable=unused-argument
Expand Down
9 changes: 2 additions & 7 deletions tests/lsdb/catalog/test_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@

def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, assert_divisions_are_correct):
catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir)

# Searching for an object that does not exist
index_search_catalog = small_sky_order1_catalog.index_search([900], catalog_index)
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 0
assert_divisions_are_correct(index_search_catalog)

index_search_catalog = small_sky_order1_catalog.index_search(["700"], catalog_index)
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 0
assert_divisions_are_correct(index_search_catalog)

# Searching for an object that exists
index_search_catalog = small_sky_order1_catalog.index_search([700], catalog_index)
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_margin_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_margin_catalog_partitions_correct(small_sky_xmatch_margin_dir):
pixel_number=hp_pixel,
)
partition = margin.get_partition(hp_order, hp_pixel)
data = pd.read_parquet(path)
data = pd.read_parquet(path, dtype_backend="pyarrow")
pd.testing.assert_frame_equal(partition.compute(), data)


Expand Down
Loading

0 comments on commit 7c49a59

Please sign in to comment.