Skip to content

Commit

Permalink
Provide arrow schema on HiPSCat catalog creation (#383)
Browse files Browse the repository at this point in the history
* Include the arrow schema in catalog operations

* Resolve circular import issue

* Allow schema to be provided in from_dataframe

* Remove filtering of columns from arrow schema
  • Loading branch information
camposandro authored Jul 31, 2024
1 parent 28375b7 commit efdd685
Show file tree
Hide file tree
Showing 15 changed files with 168 additions and 32 deletions.
9 changes: 6 additions & 3 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data
from lsdb.dask.join_catalog_data import join_catalog_data_on, join_catalog_data_through
from lsdb.dask.partition_indexer import PartitionIndexer
from lsdb.io.schema import get_arrow_schema
from lsdb.types import DaskDFPixelMap


Expand Down Expand Up @@ -199,7 +200,7 @@ def crossmatch(
ra_column=self.hc_structure.catalog_info.ra_column + suffixes[0],
dec_column=self.hc_structure.catalog_info.dec_column + suffixes[0],
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf))
return Catalog(ddf, ddf_map, hc_catalog)

def cone_search(self, ra: float, dec: float, radius_arcsec: float, fine: bool = True) -> Catalog:
Expand Down Expand Up @@ -418,7 +419,9 @@ def join(
ra_column=self.hc_structure.catalog_info.ra_column + suffixes[0],
dec_column=self.hc_structure.catalog_info.dec_column + suffixes[0],
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
hc_catalog = hc.catalog.Catalog(
new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf)
)
return Catalog(ddf, ddf_map, hc_catalog)
if left_on is None or right_on is None:
raise ValueError("Either both of left_on and right_on, or through must be set")
Expand All @@ -439,5 +442,5 @@ def join(
ra_column=self.hc_structure.catalog_info.ra_column + suffixes[0],
dec_column=self.hc_structure.catalog_info.dec_column + suffixes[0],
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf))
return Catalog(ddf, ddf_map, hc_catalog)
17 changes: 17 additions & 0 deletions src/lsdb/io/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import dask.dataframe as dd
import pyarrow as pa


def get_arrow_schema(ddf: dd.DataFrame) -> pa.Schema:
"""Constructs the pyarrow schema from the meta of a Dask DataFrame.
Args:
ddf (dd.DataFrame): A Dask DataFrame.
Returns:
The arrow schema for the provided Dask DataFrame.
"""
# pylint: disable=protected-access
return pa.Schema.from_pandas(ddf._meta)
10 changes: 9 additions & 1 deletion src/lsdb/loaders/dataframe/dataframe_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hipscat as hc
import numpy as np
import pandas as pd
import pyarrow as pa
from hipscat.catalog import CatalogType
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.pixel_math import HealpixPixel, generate_histogram
Expand All @@ -18,6 +19,7 @@
from mocpy import MOC

from lsdb.catalog.catalog import Catalog
from lsdb.io.schema import get_arrow_schema
from lsdb.loaders.dataframe.from_dataframe_utils import (
_append_partition_information_to_dataframe,
_generate_dask_dataframe,
Expand All @@ -30,6 +32,7 @@
class DataframeCatalogLoader:
"""Creates a HiPSCat formatted Catalog from a Pandas Dataframe"""

# pylint: disable=too-many-arguments
def __init__(
self,
dataframe: pd.DataFrame,
Expand All @@ -41,6 +44,7 @@ def __init__(
should_generate_moc: bool = True,
moc_max_order: int = 10,
use_pyarrow_types: bool = True,
schema: pa.Schema | None = None,
**kwargs,
) -> None:
"""Initializes a DataframeCatalogLoader
Expand All @@ -59,6 +63,8 @@ def __init__(
moc_max_order (int): if generating a MOC, what to use as the max order. Defaults to 10.
use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the
original data types. Defaults to True.
schema (pa.Schema): the arrow schema to create the catalog with. If None, the schema is
automatically inferred from the provided DataFrame using `pa.Schema.from_pandas`.
**kwargs: Arguments to pass to the creation of the catalog info.
"""
self.dataframe = dataframe
Expand All @@ -70,6 +76,7 @@ def __init__(
self.should_generate_moc = should_generate_moc
self.moc_max_order = moc_max_order
self.use_pyarrow_types = use_pyarrow_types
self.schema = schema

def _calculate_threshold(self, partition_size: int | None = None, threshold: int | None = None) -> int:
"""Calculates the number of pixels per HEALPix pixel (threshold) for the
Expand Down Expand Up @@ -130,7 +137,8 @@ def load_catalog(self) -> Catalog:
ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map(pixel_list)
self.catalog_info = dataclasses.replace(self.catalog_info, total_rows=total_rows)
moc = self._generate_moc() if self.should_generate_moc else None
hc_structure = hc.catalog.Catalog(self.catalog_info, pixel_list, moc=moc)
schema = self.schema if self.schema is not None else get_arrow_schema(ddf)
hc_structure = hc.catalog.Catalog(self.catalog_info, pixel_list, moc=moc, schema=schema)
return Catalog(ddf, ddf_pixel_map, hc_structure)

def _set_hipscat_index(self):
Expand Down
5 changes: 5 additions & 0 deletions src/lsdb/loaders/dataframe/from_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pandas as pd
import pyarrow as pa

from lsdb.catalog import Catalog
from lsdb.loaders.dataframe.dataframe_catalog_loader import DataframeCatalogLoader
Expand All @@ -21,6 +22,7 @@ def from_dataframe(
should_generate_moc: bool = True,
moc_max_order: int = 10,
use_pyarrow_types: bool = True,
schema: pa.Schema | None = None,
**kwargs,
) -> Catalog:
"""Load a catalog from a Pandas Dataframe in CSV format.
Expand All @@ -46,6 +48,8 @@ def from_dataframe(
moc_max_order (int): if generating a MOC, what to use as the max order. Defaults to 10.
use_pyarrow_types (bool): If True, the data is backed by pyarrow, otherwise we keep the
original data types. Defaults to True.
schema (pa.Schema): the arrow schema to create the catalog with. If None, the schema is
automatically inferred from the provided DataFrame using `pa.Schema.from_pandas`.
**kwargs: Arguments to pass to the creation of the catalog info.
Returns:
Expand All @@ -61,6 +65,7 @@ def from_dataframe(
should_generate_moc=should_generate_moc,
moc_max_order=moc_max_order,
use_pyarrow_types=use_pyarrow_types,
schema=schema,
**kwargs,
).load_catalog()
if margin_threshold:
Expand Down
3 changes: 2 additions & 1 deletion src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
self.margin_threshold = margin_threshold
self.margin_order = self._set_margin_order(margin_order)
self.use_pyarrow_types = use_pyarrow_types
self.schema = catalog.hc_structure.schema

def _set_margin_order(self, margin_order: int | None) -> int:
"""Calculate the order of the margin cache to be generated. If not provided
Expand Down Expand Up @@ -79,7 +80,7 @@ def create_catalog(self) -> MarginCatalog | None:
ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map(pixels, partitions)
margin_pixels = list(ddf_pixel_map.keys())
margin_catalog_info = self._create_catalog_info(total_rows)
margin_structure = hc.catalog.MarginCatalog(margin_catalog_info, margin_pixels)
margin_structure = hc.catalog.MarginCatalog(margin_catalog_info, margin_pixels, schema=self.schema)
return MarginCatalog(ddf, ddf_pixel_map, margin_structure)

def _get_margins(self) -> Tuple[List[HealpixPixel], List[pd.DataFrame]]:
Expand Down
40 changes: 24 additions & 16 deletions src/lsdb/loaders/hipscat/abstract_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import hipscat as hc
import numpy as np
import pandas as pd
import pyarrow as pa
from hipscat.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset
from hipscat.io.file_io import file_io
from hipscat.pixel_math import HealpixPixel
Expand Down Expand Up @@ -45,7 +46,13 @@ def load_catalog(self) -> CatalogTypeVar | None:

def _load_hipscat_catalog(self, catalog_type: Type[HCCatalogTypeVar]) -> HCCatalogTypeVar:
"""Load `hipscat` library catalog object with catalog metadata and partition data"""
return catalog_type.read_from_hipscat(self.path, storage_options=self.storage_options)
hc_catalog = catalog_type.read_from_hipscat(self.path, storage_options=self.storage_options)
if hc_catalog.schema is None:
raise ValueError(
"The catalog schema could not be loaded from metadata."
" Ensure your catalog has _common_metadata or _metadata files"
)
return hc_catalog

def _load_dask_df_and_map(self, catalog: HCHealpixDataset) -> Tuple[dd.DataFrame, DaskDFPixelMap]:
"""Load Dask DF from parquet files and make dict of HEALPix pixel to partition index"""
Expand All @@ -71,29 +78,30 @@ def _get_paths_from_pixels(
def _load_df_from_paths(
self, catalog: HCHealpixDataset, paths: List[hc.io.FilePointer], divisions: Tuple[int, ...] | None
) -> dd.DataFrame:
dask_meta_schema = self._load_metadata_schema(catalog)
if self.config.columns:
dask_meta_schema = dask_meta_schema[self.config.columns]
kwargs = dict(self.config.kwargs)
if self.config.dtype_backend is not None:
kwargs["dtype_backend"] = self.config.dtype_backend
dask_meta_schema = self._create_dask_meta_schema(catalog.schema)
if len(paths) > 0:
return dd.from_map(
file_io.read_parquet_file_to_pandas,
paths,
columns=self.config.columns,
divisions=divisions,
meta=dask_meta_schema,
schema=catalog.schema,
storage_options=self.storage_options,
**kwargs,
**self._get_kwargs(),
)
return dd.from_pandas(dask_meta_schema, npartitions=1)

def _load_metadata_schema(self, catalog: HCHealpixDataset) -> pd.DataFrame:
metadata_pointer = hc.io.paths.get_common_metadata_pointer(catalog.catalog_base_dir)
metadata = file_io.read_parquet_metadata(metadata_pointer, storage_options=self.storage_options)
return (
metadata.schema.to_arrow_schema()
.empty_table()
.to_pandas(types_mapper=self.config.get_dtype_mapper())
)
def _create_dask_meta_schema(self, schema: pa.Schema) -> pd.DataFrame:
"""Creates the Dask meta DataFrame from the HiPSCat catalog schema."""
dask_meta_schema = schema.empty_table().to_pandas(types_mapper=self.config.get_dtype_mapper())
if self.config.columns is not None:
dask_meta_schema = dask_meta_schema[self.config.columns]
return dask_meta_schema

def _get_kwargs(self) -> dict:
"""Constructs additional arguments for the `read_parquet` call"""
kwargs = dict(self.config.kwargs)
if self.config.dtype_backend is not None:
kwargs["dtype_backend"] = self.config.dtype_backend
return kwargs
2 changes: 1 addition & 1 deletion src/lsdb/loaders/hipscat/association_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ def load_catalog(self) -> AssociationCatalog:
return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog)

def _load_empty_dask_df_and_map(self, hc_catalog):
dask_meta_schema = self._load_metadata_schema(hc_catalog)
dask_meta_schema = self._create_dask_meta_schema(hc_catalog.schema)
ddf = dd.from_pandas(dask_meta_schema, npartitions=1)
return ddf, {}
22 changes: 13 additions & 9 deletions src/lsdb/loaders/hipscat/hipscat_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import hipscat as hc

import lsdb
from lsdb.catalog.catalog import Catalog, MarginCatalog
from lsdb.loaders.hipscat.abstract_catalog_loader import AbstractCatalogLoader
from lsdb.loaders.hipscat.hipscat_loading_config import HipscatLoadingConfig
from lsdb.loaders.hipscat.margin_catalog_loader import MarginCatalogLoader


class HipscatCatalogLoader(AbstractCatalogLoader[Catalog]):
Expand Down Expand Up @@ -40,6 +41,7 @@ def _filter_hipscat_catalog(self, hc_catalog: hc.catalog.Catalog) -> hc.catalog.
filtered_catalog.pixel_tree,
catalog_path=hc_catalog.catalog_path,
moc=filtered_catalog.moc,
schema=filtered_catalog.schema,
storage_options=hc_catalog.storage_options,
)

Expand All @@ -53,13 +55,15 @@ def _load_margin_catalog(self) -> MarginCatalog | None:
# pylint: disable=protected-access
margin_catalog = margin_catalog.search(self.config.search_filter)
elif self.config.margin_cache is not None:
margin_catalog = lsdb.read_hipscat(
path=self.config.margin_cache,
catalog_type=MarginCatalog,
search_filter=self.config.search_filter,
margin_cache=None,
dtype_backend=self.config.dtype_backend,
margin_catalog = MarginCatalogLoader(
str(self.config.margin_cache),
HipscatLoadingConfig(
search_filter=self.config.search_filter,
columns=self.config.columns,
margin_cache=None,
dtype_backend=self.config.dtype_backend,
**self.config.kwargs,
),
storage_options=self.storage_options,
**self.config.kwargs,
)
).load_catalog()
return margin_catalog
1 change: 1 addition & 0 deletions src/lsdb/loaders/hipscat/margin_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ def _filter_hipscat_catalog(self, hc_catalog: hc.catalog.MarginCatalog) -> hc.ca
filtered_catalog.catalog_info,
filtered_catalog.pixel_tree,
catalog_path=hc_catalog.catalog_path,
schema=filtered_catalog.schema,
storage_options=hc_catalog.storage_options,
)
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SMALL_SKY_TO_ORDER1_SOURCE_NAME = "small_sky_to_o1source"
SMALL_SKY_TO_ORDER1_SOURCE_SOFT_NAME = "small_sky_to_o1source_soft"
SMALL_SKY_ORDER1_CSV = "small_sky_order1.csv"
SMALL_SKY_NO_METADATA = "small_sky_no_metadata"
XMATCH_CORRECT_FILE = "xmatch_correct.csv"
XMATCH_CORRECT_005_FILE = "xmatch_correct_0_005.csv"
XMATCH_CORRECT_002_005_FILE = "xmatch_correct_002_005.csv"
Expand Down Expand Up @@ -200,6 +201,11 @@ def small_sky_order3_source_margin_catalog(test_data_dir):
return lsdb.read_hipscat(test_data_dir / SMALL_SKY_ORDER3_SOURCE_MARGIN_NAME)


@pytest.fixture
def small_sky_no_metadata_dir(test_data_dir):
return test_data_dir / "raw" / SMALL_SKY_NO_METADATA


@pytest.fixture
def xmatch_expected_dir(test_data_dir):
return test_data_dir / "raw" / "xmatch_expected"
Expand Down
8 changes: 8 additions & 0 deletions tests/data/raw/small_sky_no_metadata/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "small_sky",
"catalog_type": "object",
"total_rows": 131,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
2 changes: 2 additions & 0 deletions tests/data/raw/small_sky_no_metadata/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Norder,Npix,Dir
0,11,0
53 changes: 53 additions & 0 deletions tests/data/raw/small_sky_no_metadata/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"catalog_name": "small_sky",
"catalog_type": "object",
"total_rows": 131,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.2.7.dev15+g85ec4a0",
"generation_date": "2024.03.06",
"tool_args": {
"tool_name": "hipscat_import",
"version": "0.2.5.dev5+g0733afb",
"runtime_args": {
"catalog_name": "small_sky",
"output_path": ".",
"output_artifact_name": "small_sky",
"tmp_dir": "/tmp/user/11115/tmphywoxno9",
"overwrite": true,
"dask_tmp": "",
"dask_n_workers": 1,
"dask_threads_per_worker": 1,
"catalog_path": "./small_sky",
"tmp_path": "/tmp/user/11115/tmphywoxno9/small_sky/intermediate",
"epoch": "J2000",
"catalog_type": "object",
"input_path": null,
"input_paths": [
"small_sky_order1/small_sky_order1.csv"
],
"input_file_list": [
"small_sky_order1/small_sky_order1.csv"
],
"ra_column": "ra",
"dec_column": "dec",
"use_hipscat_index": false,
"sort_columns": null,
"constant_healpix_order": -1,
"highest_healpix_order": 7,
"pixel_threshold": 1000000,
"mapping_healpix_order": 7,
"debug_stats_only": false,
"file_reader_info": {
"input_reader_type": "CsvReader",
"chunksize": 500000,
"header": "infer",
"schema_file": null,
"separator": ",",
"column_names": null,
"type_map": {}
}
}
}
}
Loading

0 comments on commit efdd685

Please sign in to comment.