Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-48141: Add a simple way to represent a pandas index in astropy table parquet metadata. #1136

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-48141.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `lsst.daf.butler.formatters.parquet.add_pandas_index_to_astropy()` function which stores special metadata that will be used to create a pandas DataFrame index if the table is read as a DataFrame.
61 changes: 57 additions & 4 deletions python/lsst/daf/butler/formatters/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"pandas_to_astropy",
"astropy_to_arrow",
"astropy_to_pandas",
"add_pandas_index_to_astropy",
"numpy_to_arrow",
"numpy_to_astropy",
"numpy_dict_to_arrow",
Expand Down Expand Up @@ -78,6 +79,7 @@
AbstractFileSystem = type

TARGET_ROW_GROUP_BYTES = 1_000_000_000
ASTROPY_PANDAS_INDEX_KEY = "lsst::arrow::astropy_pandas_index"


class ParquetFormatter(FormatterV2):
Expand Down Expand Up @@ -226,7 +228,20 @@ def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
dataframe : `pandas.DataFrame`
Converted pandas dataframe.
"""
return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
dataframe = arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)

metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {}
if (key := ASTROPY_PANDAS_INDEX_KEY.encode()) in metadata:
pandas_index = metadata[key].decode("UTF8")
if pandas_index in arrow_table.schema.names:
dataframe.set_index(pandas_index, inplace=True)
else:
log.warning(
"Index column ``%s`` not available for arrow table conversion to DataFrame",
pandas_index,
)

return dataframe


def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
Expand All @@ -250,6 +265,10 @@ def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:

_apply_astropy_metadata(astropy_table, arrow_table.schema)

if (key := ASTROPY_PANDAS_INDEX_KEY) in astropy_table.meta:
if astropy_table.meta[key] not in astropy_table.columns:
astropy_table.meta.pop(key)

return astropy_table


Expand Down Expand Up @@ -487,6 +506,9 @@ def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
md = {}
md[b"lsst::arrow::rowcount"] = str(len(astropy_table))

if (key := ASTROPY_PANDAS_INDEX_KEY) in astropy_table.meta:
md[key.encode()] = astropy_table.meta[key]

for name in astropy_table.dtype.names:
_append_numpy_string_metadata(md, name, astropy_table.dtype[name])
_append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
Expand Down Expand Up @@ -543,16 +565,47 @@ def astropy_to_pandas(astropy_table: atable.Table, index: str | None = None) ->
dataframe : `pandas.DataFrame`
Output pandas dataframe.
"""
index_requested = False
if (key := ASTROPY_PANDAS_INDEX_KEY) in astropy_table.meta:
_index = astropy_table.meta[key]
if _index not in astropy_table.columns:
log.warning(
"Index column ``%s`` not available for astropy table conversion to DataFrame",
_index,
)
_index = None
else:
index_requested = True
_index = index

dataframe = arrow_to_pandas(astropy_to_arrow(astropy_table))

if isinstance(index, str):
dataframe = dataframe.set_index(index)
elif index:
# Set the index if we have a valid index name, and either the
# index was requested in the call to the function or the dataframe
# was not previously indexed with the call to arrow_to_pandas.
if isinstance(_index, str) and (index_requested or dataframe.index.name is None):
dataframe.set_index(_index, inplace=True)
elif _index and index_requested:
raise RuntimeError("index must be a string or None.")

return dataframe


def add_pandas_index_to_astropy(astropy_table: atable.Table, index: str) -> None:
"""Add special metadata to an astropy table to indicate a pandas index.

Parameters
----------
astropy_table : `astropy.table.Table`
Input astropy table.
index : `str`
Name of column for pandas to set as index, if read as DataFrame.
"""
if index not in astropy_table.columns:
raise ValueError("Column ``%s`` not in astropy table columns to use as pandas index.", index)
astropy_table.meta[ASTROPY_PANDAS_INDEX_KEY] = index


def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
"""Convert an astropy table to an arrow table.

Expand Down
44 changes: 44 additions & 0 deletions tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@

try:
from lsst.daf.butler.formatters.parquet import (
ASTROPY_PANDAS_INDEX_KEY,
ArrowAstropySchema,
ArrowNumpySchema,
DataFrameSchema,
Expand All @@ -105,6 +106,7 @@
_numpy_dtype_to_arrow_types,
_numpy_style_arrays_to_arrow_arrays,
_numpy_to_numpy_dict,
add_pandas_index_to_astropy,
arrow_to_astropy,
arrow_to_numpy,
arrow_to_numpy_dict,
Expand Down Expand Up @@ -1192,6 +1194,44 @@ def testBadAstropyColumnParquet(self):
with self.assertRaises(RuntimeError):
self.butler.put(bad_tab, self.datasetType, dataId={})

@unittest.skipUnless(pd is not None, "Cannot test ParquetFormatterDataFrame without pandas.")
def testWriteAstropyTableWithPandasIndexHint(self, testStrip=True):
tab1 = _makeSimpleAstropyTable()

add_pandas_index_to_astropy(tab1, "index")

self.butler.put(tab1, self.datasetType, dataId={})

# Read in as an astropy table and ensure index hint is still there.
tab2 = self.butler.get(self.datasetType, dataId={})

self.assertIn(ASTROPY_PANDAS_INDEX_KEY, tab2.meta)
self.assertEqual(tab2.meta[ASTROPY_PANDAS_INDEX_KEY], "index")

# Read as a dataframe and ensure index is set.
df3 = self.butler.get(self.datasetType, dataId={}, storageClass="DataFrame")

self.assertEqual(df3.index.name, "index")

# Read as a dataframe without naming the index column.
with self.assertLogs(level="WARNING") as cm:
_ = self.butler.get(
self.datasetType,
dataId={},
storageClass="DataFrame",
parameters={"columns": ["a", "b"]},
)
self.assertIn("Index column ``index``", cm.output[0])

if testStrip:
# Read as an astropy table without naming the index column.
tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "b"]})

self.assertNotIn(ASTROPY_PANDAS_INDEX_KEY, tab5.meta)

with self.assertRaises(ValueError):
add_pandas_index_to_astropy(tab1, "not_a_column")


@unittest.skipUnless(atable is not None, "Cannot test InMemoryDatastore with AstropyTable without astropy.")
class InMemoryArrowAstropyDelegateTestCase(ParquetFormatterArrowAstropyTestCase):
Expand Down Expand Up @@ -1222,6 +1262,10 @@ def testBadInput(self):
with self.assertRaises(AttributeError):
delegate.getComponent(composite=tab1, componentName="nothing")

@unittest.skipUnless(pd is not None, "Cannot test ParquetFormatterDataFrame without pandas.")
def testWriteAstropyTableWithPandasIndexHint(self):
super().testWriteAstropyTableWithPandasIndexHint(testStrip=False)


@unittest.skipUnless(np is not None, "Cannot test ParquetFormatterArrowNumpy without numpy.")
@unittest.skipUnless(pa is not None, "Cannot test ParquetFormatterArrowNumpy without pyarrow.")
Expand Down
Loading