Skip to content

Commit

Permalink
feat: support columns and select for InterchangeFrame if _df is…
Browse files Browse the repository at this point in the history
… present (#1283)
  • Loading branch information
MarcoGorelli authored Oct 29, 2024
1 parent f349cb2 commit 3531646
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 23 deletions.
52 changes: 37 additions & 15 deletions narwhals/_interchange/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def map_interchange_dtype_to_narwhals_dtype(

class InterchangeFrame:
def __init__(self, df: Any, dtypes: DTypes) -> None:
self._native_frame = df
self._interchange_frame = df.__dataframe__()
self._dtypes = dtypes

Expand All @@ -97,21 +96,11 @@ def __getitem__(self, item: str) -> InterchangeSeries:
self._interchange_frame.get_column_by_name(item), dtypes=self._dtypes
)

@property
def schema(self) -> dict[str, DType]:
return {
column_name: map_interchange_dtype_to_narwhals_dtype(
self._interchange_frame.get_column_by_name(column_name).dtype,
self._dtypes,
)
for column_name in self._interchange_frame.column_names()
}

def to_pandas(self: Self) -> pd.DataFrame:
import pandas as pd # ignore-banned-import()

if parse_version(pd.__version__) >= parse_version("1.5.0"):
return pd.api.interchange.from_dataframe(self._native_frame)
return pd.api.interchange.from_dataframe(self._interchange_frame)
else: # pragma: no cover
msg = (
"Conversion to pandas is achieved via interchange protocol which requires"
Expand All @@ -122,9 +111,19 @@ def to_pandas(self: Self) -> pd.DataFrame:
def to_arrow(self: Self) -> pa.Table:
from pyarrow.interchange import from_dataframe # ignore-banned-import()

return from_dataframe(self._native_frame)

def __getattr__(self, attr: str) -> NoReturn:
return from_dataframe(self._interchange_frame)

def __getattr__(self, attr: str) -> Any:
if attr == "schema":
return {
column_name: map_interchange_dtype_to_narwhals_dtype(
self._interchange_frame.get_column_by_name(column_name).dtype,
self._dtypes,
)
for column_name in self._interchange_frame.column_names()
}
elif attr == "columns":
return list(self._interchange_frame.column_names())
msg = (
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"Hint: you probably called `nw.from_native` on an object which isn't fully "
Expand All @@ -133,3 +132,26 @@ def __getattr__(self, attr: str) -> NoReturn:
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)

def select(
self: Self,
*exprs: Any,
**named_exprs: Any,
) -> Self:
if named_exprs or not all(isinstance(x, str) for x in exprs): # pragma: no cover
msg = (
"`select`-ing not by name is not supported for interchange-only level.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)

frame = self._interchange_frame.select_columns_by_name(exprs)
if not hasattr(frame, "_df"): # pragma: no cover
msg = (
"Expected interchange object to implement `_df` property to allow for recovering original object.\n"
"See https://github.com/data-apis/dataframe-api/issues/360."
)
raise NotImplementedError(frame)
return self.__class__(frame._df, dtypes=self._dtypes)
2 changes: 1 addition & 1 deletion tests/frame/interchange_schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_invalid() -> None:
with pytest.raises(
NotImplementedError, match="is not supported for metadata-only dataframes"
):
nw.from_native(df, eager_or_interchange_only=True).select("a")
nw.from_native(df, eager_or_interchange_only=True).filter([True, False, True])
with pytest.raises(TypeError, match="Cannot only use `series_only=True`"):
nw.from_native(df, eager_only=True)
with pytest.raises(ValueError, match="Invalid parameter combination"):
Expand Down
38 changes: 31 additions & 7 deletions tests/frame/interchange_select_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import duckdb
import polars as pl
import pytest
Expand All @@ -9,14 +11,36 @@
data = {"a": [1, 2, 3], "b": [4.0, 5.0, 6.1], "z": ["x", "y", "z"]}


class InterchangeDataFrame:
def __init__(self, df: CustomDataFrame) -> None:
self._df = df

def __dataframe__(self) -> InterchangeDataFrame: # pragma: no cover
return self

def column_names(self) -> list[str]:
return list(self._df._data.keys())

def select_columns_by_name(self, columns: list[str]) -> InterchangeDataFrame:
return InterchangeDataFrame(
CustomDataFrame(
{key: value for key, value in self._df._data.items() if key in columns}
)
)


class CustomDataFrame:
def __init__(self, data: dict[str, Any]) -> None:
self._data = data

def __dataframe__(self, *, allow_copy: bool = True) -> InterchangeDataFrame:
return InterchangeDataFrame(self)


def test_interchange() -> None:
df_pl = pl.DataFrame(data)
df = nw.from_native(df_pl.__dataframe__(), eager_or_interchange_only=True)
with pytest.raises(
NotImplementedError,
match="Attribute select is not supported for metadata-only dataframes",
):
df.select("a", "z")
df = CustomDataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "z": [1, 4, 2]})
result = nw.from_native(df, eager_or_interchange_only=True).select("a", "z")
assert result.columns == ["a", "z"]


def test_interchange_ibis(
Expand Down

0 comments on commit 3531646

Please sign in to comment.