From 35316464ad605be4af903963e1f348848c028aff Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 29 Oct 2024 14:46:11 +0000 Subject: [PATCH] feat: support `columns` and `select` for InterchangeFrame if `_df` is present (#1283) --- narwhals/_interchange/dataframe.py | 52 ++++++++++++++++++-------- tests/frame/interchange_schema_test.py | 2 +- tests/frame/interchange_select_test.py | 38 +++++++++++++++---- 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index 4e8e542e7..dc2af3bad 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -75,7 +75,6 @@ def map_interchange_dtype_to_narwhals_dtype( class InterchangeFrame: def __init__(self, df: Any, dtypes: DTypes) -> None: - self._native_frame = df self._interchange_frame = df.__dataframe__() self._dtypes = dtypes @@ -97,21 +96,11 @@ def __getitem__(self, item: str) -> InterchangeSeries: self._interchange_frame.get_column_by_name(item), dtypes=self._dtypes ) - @property - def schema(self) -> dict[str, DType]: - return { - column_name: map_interchange_dtype_to_narwhals_dtype( - self._interchange_frame.get_column_by_name(column_name).dtype, - self._dtypes, - ) - for column_name in self._interchange_frame.column_names() - } - def to_pandas(self: Self) -> pd.DataFrame: import pandas as pd # ignore-banned-import() if parse_version(pd.__version__) >= parse_version("1.5.0"): - return pd.api.interchange.from_dataframe(self._native_frame) + return pd.api.interchange.from_dataframe(self._interchange_frame) else: # pragma: no cover msg = ( "Conversion to pandas is achieved via interchange protocol which requires" @@ -122,9 +111,19 @@ def to_pandas(self: Self) -> pd.DataFrame: def to_arrow(self: Self) -> pa.Table: from pyarrow.interchange import from_dataframe # ignore-banned-import() - return from_dataframe(self._native_frame) - - def __getattr__(self, attr: str) -> NoReturn: + return from_dataframe(self._interchange_frame) + + def __getattr__(self, attr: str) -> Any: + if attr == "schema": + return { + column_name: map_interchange_dtype_to_narwhals_dtype( + self._interchange_frame.get_column_by_name(column_name).dtype, + self._dtypes, + ) + for column_name in self._interchange_frame.column_names() + } + elif attr == "columns": + return list(self._interchange_frame.column_names()) msg = ( f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" "Hint: you probably called `nw.from_native` on an object which isn't fully " @@ -133,3 +132,26 @@ def __getattr__(self, attr: str) -> NoReturn: "at https://github.com/narwhals-dev/narwhals/issues." ) raise NotImplementedError(msg) + + def select( + self: Self, + *exprs: Any, + **named_exprs: Any, + ) -> Self: + if named_exprs or not all(isinstance(x, str) for x in exprs): # pragma: no cover + msg = ( + "`select`-ing not by name is not supported for interchange-only level.\n\n" + "If you would like to see this kind of object better supported in " + "Narwhals, please open a feature request " + "at https://github.com/narwhals-dev/narwhals/issues." + ) + raise NotImplementedError(msg) + + frame = self._interchange_frame.select_columns_by_name(exprs) + if not hasattr(frame, "_df"): # pragma: no cover + msg = ( + "Expected interchange object to implement `_df` property to allow for recovering original object.\n" + "See https://github.com/data-apis/dataframe-api/issues/360." + ) + raise NotImplementedError(frame) + return self.__class__(frame._df, dtypes=self._dtypes) diff --git a/tests/frame/interchange_schema_test.py b/tests/frame/interchange_schema_test.py index 35de7d74a..5a612db18 100644 --- a/tests/frame/interchange_schema_test.py +++ b/tests/frame/interchange_schema_test.py @@ -227,7 +227,7 @@ def test_invalid() -> None: with pytest.raises( NotImplementedError, match="is not supported for metadata-only dataframes" ): - nw.from_native(df, eager_or_interchange_only=True).select("a") + nw.from_native(df, eager_or_interchange_only=True).filter([True, False, True]) with pytest.raises(TypeError, match="Cannot only use `series_only=True`"): nw.from_native(df, eager_only=True) with pytest.raises(ValueError, match="Invalid parameter combination"): diff --git a/tests/frame/interchange_select_test.py b/tests/frame/interchange_select_test.py index e124735f7..b553af751 100644 --- a/tests/frame/interchange_select_test.py +++ b/tests/frame/interchange_select_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import duckdb import polars as pl import pytest @@ -9,14 +11,36 @@ data = {"a": [1, 2, 3], "b": [4.0, 5.0, 6.1], "z": ["x", "y", "z"]} +class InterchangeDataFrame: + def __init__(self, df: CustomDataFrame) -> None: + self._df = df + + def __dataframe__(self) -> InterchangeDataFrame: # pragma: no cover + return self + + def column_names(self) -> list[str]: + return list(self._df._data.keys()) + + def select_columns_by_name(self, columns: list[str]) -> InterchangeDataFrame: + return InterchangeDataFrame( + CustomDataFrame( + {key: value for key, value in self._df._data.items() if key in columns} + ) + ) + + +class CustomDataFrame: + def __init__(self, data: dict[str, Any]) -> None: + self._data = data + + def __dataframe__(self, *, allow_copy: bool = True) -> InterchangeDataFrame: + return InterchangeDataFrame(self) + + def test_interchange() -> None: - df_pl = pl.DataFrame(data) - df = nw.from_native(df_pl.__dataframe__(), eager_or_interchange_only=True) - with pytest.raises( - NotImplementedError, - match="Attribute select is not supported for metadata-only dataframes", - ): - df.select("a", "z") + df = CustomDataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "z": [1, 4, 2]}) + result = nw.from_native(df, eager_or_interchange_only=True).select("a", "z") + assert result.columns == ["a", "z"] def test_interchange_ibis(