diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 80d0740dc..9af2b37a3 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -68,6 +68,12 @@ def __narwhals_namespace__(self) -> ArrowNamespace: def cum_sum(self) -> Self: return reuse_series_implementation(self, "cum_sum") # type: ignore[type-var] + def any(self) -> Self: + return reuse_series_implementation(self, "any", returns_scalar=True) # type: ignore[type-var] + + def all(self) -> Self: + return reuse_series_implementation(self, "all", returns_scalar=True) # type: ignore[type-var] + @property def dt(self) -> ArrowExprDateTimeNamespace: return ArrowExprDateTimeNamespace(self) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index bc7182b77..e2bfa9daa 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,10 +1,12 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any from typing import Iterable from narwhals import dtypes from narwhals._arrow.expr import ArrowExpr +from narwhals._arrow.series import ArrowSeries from narwhals.utils import flatten if TYPE_CHECKING: @@ -12,7 +14,6 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.series import ArrowSeries class ArrowNamespace: @@ -66,6 +67,12 @@ def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr: output_names=None, ) + def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries: + return ArrowSeries.from_iterable( + [value], + name=series.name, + ) + # --- not in spec --- def __init__(self) -> None: ... diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 0044e3dae..4fdafb971 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -2,8 +2,10 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from narwhals._arrow.utils import translate_dtype +from narwhals._pandas_like.utils import native_series_from_iterable from narwhals.dependencies import get_pyarrow_compute if TYPE_CHECKING: @@ -29,6 +31,15 @@ def _from_series(self, series: Any) -> Self: name=self._name, ) + @classmethod + def from_iterable(cls: type[Self], data: Iterable[Any], name: str) -> Self: + return cls( + native_series_from_iterable( + data, name=name, index=None, implementation="arrow" + ), + name=name, + ) + def __len__(self) -> int: return len(self._series) @@ -62,6 +73,14 @@ def cum_sum(self) -> Self: pc = get_pyarrow_compute() return self._from_series(pc.cumulative_sum(self._series)) + def any(self) -> bool: + pc = get_pyarrow_compute() + return pc.any(self._series) # type: ignore[no-any-return] + + def all(self) -> bool: + pc = get_pyarrow_compute() + return pc.all(self._series) # type: ignore[no-any-return] + @property def shape(self) -> tuple[int]: return (len(self._series),) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 1ff880e57..dc1d68930 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -10,6 +10,7 @@ from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas +from narwhals.dependencies import get_pyarrow from narwhals.utils import flatten from narwhals.utils import isinstance_or_issubclass from narwhals.utils import parse_version @@ -365,6 +366,9 @@ def native_series_from_iterable( mpd = get_modin() return mpd.Series(data, name=name, index=index) + if implementation == "arrow": + pa = get_pyarrow() + return pa.chunked_array([data]) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover diff --git a/tests/expr/any_all_test.py b/tests/expr/any_all_test.py index 05aef9728..c5bb88cb5 100644 --- a/tests/expr/any_all_test.py +++ b/tests/expr/any_all_test.py @@ -1,12 +1,20 @@ from typing import Any +import pyarrow as pa +import pytest + import narwhals as nw +from narwhals.utils import parse_version from tests.utils import compare_dicts -def test_any_all(constructor: Any) -> None: +def test_any_all(constructor_with_pyarrow: Any, request: Any) -> None: + if "table" in str(constructor_with_pyarrow) and parse_version( + pa.__version__ + ) < parse_version("12.0.0"): # pragma: no cover + request.applymarker(pytest.mark.xfail) df = nw.from_native( - constructor( + constructor_with_pyarrow( { "a": [True, False, True], "b": [True, True, True], @@ -14,9 +22,9 @@ def test_any_all(constructor: Any) -> None: } ) ) - result = nw.to_native(df.select(nw.all().all())) + result = df.select(nw.all().all()) expected = {"a": [False], "b": [True], "c": [False]} compare_dicts(result, expected) - result = nw.to_native(df.select(nw.all().any())) + result = df.select(nw.all().any()) expected = {"a": [True], "b": [True], "c": [False]} compare_dicts(result, expected)