Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load data using pyarrow types #306

Merged
merged 15 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def _append_extra_columns(cls, dataframe: pd.DataFrame, extra_columns: pd.DataFr
raise ValueError(f"Provided extra column '{col}' not found in definition")
# Update columns according to crossmatch algorithm specification
columns_to_update = []
for col, col_type in cls.extra_columns.items():
for col, col_type in cls.extra_columns.dtypes.items():
if col not in extra_columns:
raise ValueError(f"Missing extra column '{col} of type {col_type}'")
if col_type.dtype != extra_columns[col].dtype:
if col_type != extra_columns[col].dtype:
raise ValueError(f"Invalid type '{col_type}' for extra column '{col}'")
columns_to_update.append(col)
for col in columns_to_update:
Expand Down
7 changes: 5 additions & 2 deletions src/lsdb/core/crossmatch/kdtree_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN
from hipscat.pixel_math.validators import validate_radius

Expand All @@ -15,7 +16,7 @@
class KdTreeCrossmatch(AbstractCrossmatchAlgorithm):
"""Nearest neighbor crossmatch using a 3D k-D tree"""

extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))})
extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64()))})

def validate(
self,
Expand Down Expand Up @@ -102,6 +103,8 @@ def _create_crossmatch_df(
axis=1,
)
out.set_index(HIPSCAT_ID_COLUMN, inplace=True)
extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(arc_distances, index=out.index)})
extra_columns = pd.DataFrame(
{"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()), index=out.index)}
)
self._append_extra_columns(out, extra_columns)
return out
8 changes: 6 additions & 2 deletions src/lsdb/loaders/dataframe/dataframe_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
highest_order: int = 5,
partition_size: int | None = None,
threshold: int | None = None,
use_pyarrow_types: bool = True,
**kwargs,
) -> None:
"""Initializes a DataframeCatalogLoader
Expand All @@ -46,13 +47,16 @@ def __init__(
highest_order (int): The highest partition order
partition_size (int): The desired partition size, in number of rows
threshold (int): The maximum number of data points per pixel
use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the
original data types. Defaults to True.
**kwargs: Arguments to pass to the creation of the catalog info
"""
self.dataframe = dataframe
self.lowest_order = lowest_order
self.highest_order = highest_order
self.threshold = self._calculate_threshold(partition_size, threshold)
self.catalog_info = self._create_catalog_info(**kwargs)
self.use_pyarrow_types = use_pyarrow_types

def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int:
"""Calculates the number of pixels per HEALPix pixel (threshold) for the
Expand Down Expand Up @@ -169,9 +173,9 @@ def _generate_dask_df_and_map(
# Obtain Dataframe for current HEALPix pixel
pixel_dfs.append(self._get_dataframe_for_healpix(hp_pixel, pixels))

# Generate Dask Dataframe with original schema
# Generate Dask Dataframe with the original schema and desired backend
pixel_list = list(ddf_pixel_map.keys())
ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list)
ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list, self.use_pyarrow_types)
return ddf, ddf_pixel_map, total_rows

def _get_dataframe_for_healpix(self, hp_pixel: HealpixPixel, pixels: List[int]) -> pd.DataFrame:
Expand Down
5 changes: 5 additions & 0 deletions src/lsdb/loaders/dataframe/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def from_dataframe(
threshold: int | None = None,
margin_order: int | None = -1,
margin_threshold: float = 5.0,
use_pyarrow_types: bool = True,
**kwargs,
) -> Catalog:
"""Load a catalog from a Pandas Dataframe in CSV format.
Expand All @@ -27,6 +28,8 @@ def from_dataframe(
threshold (int): The maximum number of data points per pixel
margin_order (int): The order at which to generate the margin cache
margin_threshold (float): The size of the margin cache boundary, in arcseconds
use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the
original data types. Defaults to True.
**kwargs: Arguments to pass to the creation of the catalog info

Returns:
Expand All @@ -38,12 +41,14 @@ def from_dataframe(
highest_order,
partition_size,
threshold,
use_pyarrow_types,
**kwargs,
).load_catalog()
if margin_threshold:
catalog.margin = MarginCatalogGenerator(
catalog,
margin_order,
margin_threshold,
use_pyarrow_types,
).create_catalog()
return catalog
29 changes: 26 additions & 3 deletions src/lsdb/loaders/dataframe/from_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pyarrow as pa
from dask import delayed
from hipscat.catalog import PartitionInfo
from hipscat.pixel_math import HealpixPixel
Expand All @@ -12,22 +13,44 @@


def _generate_dask_dataframe(
pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel]
pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], use_pyarrow_types: bool = True
) -> Tuple[dd.core.DataFrame, int]:
"""Create the Dask Dataframe from the list of HEALPix pixel Dataframes

Args:
pixel_dfs (List[pd.DataFrame]): The list of HEALPix pixel Dataframes
pixels (List[HealpixPixel]): The list of HEALPix pixels in the catalog
use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True.

Returns:
The catalog's Dask Dataframe and its total number of rows.
"""
pixel_dfs = [_convert_dtypes_to_pyarrow(df) for df in pixel_dfs] if use_pyarrow_types else pixel_dfs
schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else []
divisions = get_pixels_divisions(pixels)
delayed_dfs = [delayed(df) for df in pixel_dfs]
divisions = get_pixels_divisions(pixels)
ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions)
return ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame(), len(ddf)
ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame()
return ddf, len(ddf)


def _convert_dtypes_to_pyarrow(df: pd.DataFrame) -> pd.DataFrame:
"""Transform the columns (and index) of a Pandas DataFrame to pyarrow types.

Args:
df (pd.DataFrame): A Pandas DataFrame

Returns:
A new DataFrame, with columns of pyarrow types. The return value is a
shallow copy of the initial DataFrame to avoid copying the data.
"""
new_series = {}
df_index = df.index.astype(pd.ArrowDtype(pa.uint64()))
for column in df.columns:
pa_array = pa.array(df[column], from_pandas=True)
series = pd.Series(pa_array, dtype=pd.ArrowDtype(pa_array.type), copy=False, index=df_index)
new_series[column] = series
return pd.DataFrame(new_series, index=df_index, copy=False)


def _append_partition_information_to_dataframe(dataframe: pd.DataFrame, pixel: HealpixPixel) -> pd.DataFrame:
Expand Down
5 changes: 4 additions & 1 deletion src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ def __init__(
catalog: Catalog,
margin_order: int | None = -1,
margin_threshold: float = 5.0,
use_pyarrow_types: bool = True,
) -> None:
"""Initialize a MarginCatalogGenerator

Args:
catalog (Catalog): The LSDB catalog to generate margins for
margin_order (int): The order at which to generate the margin cache
margin_threshold (float): The size of the margin cache boundary, in arcseconds
use_pyarrow_types (bool): If True, use pyarrow types. Defaults to True.
"""
self.dataframe = catalog.compute().copy()
self.hc_structure = catalog.hc_structure
self.margin_threshold = margin_threshold
self.margin_order = self._set_margin_order(margin_order)
self.use_pyarrow_types = use_pyarrow_types

def _set_margin_order(self, margin_order: int | None) -> int:
"""Calculate the order of the margin cache to be generated. If not provided
Expand Down Expand Up @@ -102,7 +105,7 @@ def _generate_dask_df_and_map(self) -> Tuple[dd.core.DataFrame, Dict[HealpixPixe
ddf_pixel_map = {pixel: index for index, pixel in enumerate(ordered_pixels)}

# Generate the dask dataframe with the pixels and partitions
ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels)
ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels, self.use_pyarrow_types)
return ddf, ddf_pixel_map, total_rows

def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame:
Expand Down
14 changes: 9 additions & 5 deletions src/lsdb/loaders/hipscat/abstract_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dask.dataframe as dd
import hipscat as hc
import numpy as np
import pyarrow
import pandas as pd
from hipscat.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset
from hipscat.io.file_io import file_io
from hipscat.pixel_math import HealpixPixel
Expand Down Expand Up @@ -80,8 +80,7 @@ def _get_paths_from_pixels(
def _load_df_from_paths(
self, catalog: HCHealpixDataset, paths: List[hc.io.FilePointer], divisions: Tuple[int, ...] | None
) -> dd.core.DataFrame:
metadata_schema = self._load_parquet_metadata_schema(catalog)
dask_meta_schema = metadata_schema.empty_table().to_pandas()
dask_meta_schema = self._load_metadata_schema(catalog)
if self.config.columns:
dask_meta_schema = dask_meta_schema[self.config.columns]
ddf = dd.io.from_map(
Expand All @@ -91,11 +90,16 @@ def _load_df_from_paths(
storage_options=self.storage_options,
meta=dask_meta_schema,
columns=self.config.columns,
dtype_backend=self.config.get_dtype_backend(),
**self.config.get_kwargs_dict(),
)
return ddf

def _load_parquet_metadata_schema(self, catalog: HCHealpixDataset) -> pyarrow.Schema:
def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame:
metadata_pointer = hc.io.paths.get_common_metadata_pointer(catalog.catalog_base_dir)
metadata = file_io.read_parquet_metadata(metadata_pointer, storage_options=self.storage_options)
return metadata.schema.to_arrow_schema()
return (
metadata.schema.to_arrow_schema()
.empty_table()
.to_pandas(types_mapper=self.config.get_dtype_mapper())
)
3 changes: 1 addition & 2 deletions src/lsdb/loaders/hipscat/association_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def load_catalog(self) -> AssociationCatalog:
return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog)

def _load_empty_dask_df_and_map(self, hc_catalog):
metadata_schema = self._load_parquet_metadata_schema(hc_catalog)
dask_meta_schema = metadata_schema.empty_table().to_pandas()
dask_meta_schema = self._load_metadata_schema(hc_catalog)
ddf = dd.io.from_pandas(dask_meta_schema, npartitions=0)
return ddf, {}
28 changes: 27 additions & 1 deletion src/lsdb/loaders/hipscat/hipscat_loading_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List
from typing import Callable, List

import pandas as pd
from pandas._libs import lib
from pandas.io._util import _arrow_dtype_mapping

from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.core.search.abstract_search import AbstractSearch
Expand All @@ -23,9 +27,31 @@ class HipscatLoadingConfig:
margin_cache: MarginCatalog | None = None
"""Margin cache for the catalog. By default, it is None"""

dtype_backend: str | None = "pyarrow"
"""The backend data type to apply to the catalog. It defaults to "pyarrow" and
if it is None no type conversion is performed."""

kwargs: dict | None = None
"""Extra kwargs"""

def __post_init__(self):
if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]:
raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'")

def get_kwargs_dict(self) -> dict:
"""Returns a dictionary with the extra kwargs"""
return self.kwargs if self.kwargs is not None else {}

def get_dtype_backend(self) -> str:
"""Returns the data type backend. It is either "pyarrow", "numpy_nullable",
or "<no_default>", which allows us to keep the original data types."""
return lib.no_default if self.dtype_backend is None else self.dtype_backend

def get_dtype_mapper(self) -> Callable | None:
"""Returns a mapper for pyarrow or numpy types, mirroring Pandas behaviour."""
mapper = None
if self.dtype_backend == "pyarrow":
mapper = pd.ArrowDtype
elif self.dtype_backend == "numpy_nullable":
mapper = _arrow_dtype_mapping().get
return mapper
7 changes: 5 additions & 2 deletions src/lsdb/loaders/hipscat/read_hipscat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def read_hipscat(
path: str,
catalog_type: Type[Dataset] | None = None,
search_filter: AbstractSearch | None = None,
storage_options: dict | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
) -> Dataset:
"""Load a catalog from a HiPSCat formatted catalog.
Expand All @@ -55,9 +56,11 @@ def read_hipscat(
type for type checking, the type of the catalog can be specified here. Use by specifying
the lsdb class for that catalog.
search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied.
storage_options (dict): Dictionary that contains abstract filesystem credentials
columns (List[str]): Default `None`. The set of columns to filter the catalog on.
margin_cache (MarginCatalog): The margin cache for the main catalog
dtype_backend (str): Backend data type to apply to the catalog.
Defaults to "pyarrow". If None, no type conversion is performed.
storage_options (dict): Dictionary that contains abstract filesystem credentials
**kwargs: Arguments to pass to the pandas parquet file reader

Returns:
Expand Down
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_association_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_load_association(small_sky_to_xmatch_dir):
pixel_number=hp_pixel,
)
partition = small_sky_to_xmatch.get_partition(hp_order, hp_pixel)
data = pd.read_parquet(path)
data = pd.read_parquet(path, dtype_backend="pyarrow")
pd.testing.assert_frame_equal(partition.compute(), data)


Expand Down
8 changes: 6 additions & 2 deletions tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ def test_cone_search_filters_correct_points_margin(
cone_search_catalog = small_sky_order1_source_with_margin.cone_search(ra, dec, radius)
assert cone_search_catalog.margin is not None
cone_search_df = cone_search_catalog.compute()
pd.testing.assert_frame_equal(cone_search_df, cone_search_expected, check_dtype=False)
pd.testing.assert_frame_equal(
cone_search_df, cone_search_expected, check_index_type=False, check_dtype=False
)
cone_search_margin_df = cone_search_catalog.margin.compute()
pd.testing.assert_frame_equal(cone_search_margin_df, cone_search_margin_expected, check_dtype=False)
pd.testing.assert_frame_equal(
cone_search_margin_df, cone_search_margin_expected, check_index_type=False, check_dtype=False
)
assert_divisions_are_correct(cone_search_catalog)
assert_divisions_are_correct(cone_search_catalog.margin)

Expand Down
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_crossmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_self_crossmatch(algo, small_sky_catalog, small_sky_dir):
class MockCrossmatchAlgorithm(AbstractCrossmatchAlgorithm):
"""Mock class used to test a crossmatch algorithm"""

extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.dtype("float64"))})
extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)})

# We must have the same signature as the crossmatch method
def validate(self, mock_results: pd.DataFrame = None): # pylint: disable=unused-argument
Expand Down
9 changes: 2 additions & 7 deletions tests/lsdb/catalog/test_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@

def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, assert_divisions_are_correct):
catalog_index = IndexCatalog.read_from_hipscat(small_sky_order1_id_index_dir)

# Searching for an object that does not exist
index_search_catalog = small_sky_order1_catalog.index_search([900], catalog_index)
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 0
assert_divisions_are_correct(index_search_catalog)

index_search_catalog = small_sky_order1_catalog.index_search(["700"], catalog_index)
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 0
assert_divisions_are_correct(index_search_catalog)

# Searching for an object that exists
index_search_catalog = small_sky_order1_catalog.index_search([700], catalog_index)
index_search_df = index_search_catalog.compute()
assert len(index_search_df) == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_margin_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_margin_catalog_partitions_correct(small_sky_xmatch_margin_dir):
pixel_number=hp_pixel,
)
partition = margin.get_partition(hp_order, hp_pixel)
data = pd.read_parquet(path)
data = pd.read_parquet(path, dtype_backend="pyarrow")
pd.testing.assert_frame_equal(partition.compute(), data)


Expand Down
Loading