diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 676f64076..b251b2a50 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -4,6 +4,7 @@ handler: python options: members: + - __arrow_c_stream__ - __getitem__ - clone - collect_schema diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index 7b7f62b8a..f9cc2e6bb 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -4,6 +4,8 @@ handler: python options: members: + - __arrow_c_stream__ + - __getitem__ - abs - alias - all diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 039496744..dfcdce87b 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -15,6 +15,7 @@ from narwhals.dependencies import is_numpy_array from narwhals.schema import Schema from narwhals.utils import flatten +from narwhals.utils import parse_version if TYPE_CHECKING: from io import BytesIO @@ -249,6 +250,30 @@ def __repr__(self) -> str: # pragma: no cover + "┘" ) + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """ + Export a DataFrame via the Arrow PyCapsule Interface. + + - if the underlying dataframe implements the interface, it'll return that + - else, it'll call `to_arrow` and then defer to PyArrow's implementation + + See [PyCapsule Interface](https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html) + for more. + """ + native_frame = self._compliant_frame._native_frame + if hasattr(native_frame, "__arrow_c_stream__"): + return native_frame.__arrow_c_stream__(requested_schema=requested_schema) + try: + import pyarrow as pa # ignore-banned-import + except ModuleNotFoundError as exc: # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `DataFrame.__arrow_c_stream__` for object of type {type(native_frame)}" + raise ModuleNotFoundError(msg) from exc + if parse_version(pa.__version__) < (14, 0): # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `DataFrame.__arrow_c_stream__` for object of type {type(native_frame)}" + raise ModuleNotFoundError(msg) from None + pa_table = self.to_arrow() + return pa_table.__arrow_c_stream__(requested_schema=requested_schema) + def lazy(self) -> LazyFrame[Any]: """ Lazify the DataFrame (if possible). diff --git a/narwhals/series.py b/narwhals/series.py index a826e5dd4..3c79024c8 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -7,6 +7,8 @@ from typing import Sequence from typing import overload +from narwhals.utils import parse_version + if TYPE_CHECKING: import numpy as np import pandas as pd @@ -57,6 +59,32 @@ def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self: def __native_namespace__(self) -> Any: return self._compliant_series.__native_namespace__() + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """ + Export a Series via the Arrow PyCapsule Interface. + + Narwhals doesn't implement anything itself here: + + - if the underlying series implements the interface, it'll return that + - else, it'll call `to_arrow` and then defer to PyArrow's implementation + + See [PyCapsule Interface](https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html) + for more. + """ + native_series = self._compliant_series._native_series + if hasattr(native_series, "__arrow_c_stream__"): + return native_series.__arrow_c_stream__(requested_schema=requested_schema) + try: + import pyarrow as pa # ignore-banned-import + except ModuleNotFoundError as exc: # pragma: no cover + msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}" + raise ModuleNotFoundError(msg) from exc + if parse_version(pa.__version__) < (16, 0): # pragma: no cover + msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}" + raise ModuleNotFoundError(msg) + ca = pa.chunked_array([self.to_arrow()]) + return ca.__arrow_c_stream__(requested_schema=requested_schema) + @property def shape(self) -> tuple[int]: """ diff --git a/tests/frame/arrow_c_stream_test.py b/tests/frame/arrow_c_stream_test.py new file mode 100644 index 000000000..7a3403f69 --- /dev/null +++ b/tests/frame/arrow_c_stream_test.py @@ -0,0 +1,42 @@ +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import narwhals.stable.v1 as nw +from narwhals.utils import parse_version + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test() -> None: + df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) + result = pa.table(df) + expected = pa.table({"a": [1, 2, 3]}) + assert pc.all(pc.equal(result["a"], expected["a"])).as_py() + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + # "poison" the dunder method to make sure it actually got called above + monkeypatch.setattr( + "narwhals.dataframe.DataFrame.__arrow_c_stream__", lambda *_: 1 / 0 + ) + df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) + with pytest.raises(ZeroDivisionError, match="division by zero"): + pa.table(df) + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + # Check that fallback to PyArrow works + monkeypatch.delattr("polars.DataFrame.__arrow_c_stream__") + df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) + result = pa.table(df) + expected = pa.table({"a": [1, 2, 3]}) + assert pc.all(pc.equal(result["a"], expected["a"])).as_py() diff --git a/tests/series_only/arrow_c_stream_test.py b/tests/series_only/arrow_c_stream_test.py new file mode 100644 index 000000000..9964d7408 --- /dev/null +++ b/tests/series_only/arrow_c_stream_test.py @@ -0,0 +1,41 @@ +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import narwhals.stable.v1 as nw +from narwhals.utils import parse_version + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test() -> None: + s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) + result = pa.chunked_array(s) + expected = pa.chunked_array([[1, 2, 3]]) + assert pc.all(pc.equal(result, expected)).as_py() + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + # "poison" the dunder method to make sure it actually got called above + monkeypatch.setattr("narwhals.series.Series.__arrow_c_stream__", lambda *_: 1 / 0) + s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) + with pytest.raises(ZeroDivisionError, match="division by zero"): + pa.chunked_array(s) + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + # Check that fallback to PyArrow works + monkeypatch.delattr("polars.Series.__arrow_c_stream__") + s = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)["a"] + s.__arrow_c_stream__() + result = pa.chunked_array(s) + expected = pa.chunked_array([[1, 2, 3]]) + assert pc.all(pc.equal(result, expected)).as_py() diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 68c980086..f6e5303c4 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -45,13 +45,13 @@ documented = [ remove_prefix(i, " - ") for i in content.splitlines() - if i.startswith(" - ") + if i.startswith(" - ") and not i.startswith(" - _") ] if missing := set(top_level_functions).difference(documented): print("DataFrame: not documented") # noqa: T201 print(missing) # noqa: T201 ret = 1 -if extra := set(documented).difference(top_level_functions).difference({"__getitem__"}): +if extra := set(documented).difference(top_level_functions): print("DataFrame: outdated") # noqa: T201 print(extra) # noqa: T201 ret = 1 @@ -87,7 +87,7 @@ documented = [ remove_prefix(i, " - ") for i in content.splitlines() - if i.startswith(" - ") + if i.startswith(" - ") and not i.startswith(" - _") ] if ( missing := set(top_level_functions)