From 0c98b60cf32e13e61613f78f18dd94e51e9d0a6d Mon Sep 17 00:00:00 2001 From: Lucas Nelson <lucas.nelson.contacts@gmail.com> Date: Fri, 10 Jan 2025 01:45:22 -0600 Subject: [PATCH] feat: add `all`, `any` and `null_count` Spark Expressions (#1724) * test: add logical tests, import ConstructorEager type * feat: add any_horizontal method * test: add any_horizontal test, update any_all reference * dev: correct bool_any to bool_or * feat: add null_count expr * test: add tests for null_count expr * tests: update constructor to pyspark_constructor * tests: remove eager tests * feat: initial draft of replace_strict method * feat: initial draft of replace_strict method * test: add lazy tests for replace_strict method * Update expr.py Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> * Update expr.py Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> * remove replace_strict method * remove replace_strict tests * remove any_h references * remove any_h references * pyspark test --------- Co-authored-by: lucas-nelson-uiuc <lucas.nelson.uiuc@gmail.com> Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Co-authored-by: FBruzzesi <francesco.bruzzesi.93@gmail.com> --- narwhals/_spark_like/expr.py | 18 ++++++++++++++++++ tests/expr_and_series/any_all_test.py | 9 ++------- tests/expr_and_series/null_count_test.py | 4 ++-- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index a8cafccfd..32139cf01 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -273,6 +273,16 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: kwargs={**self._kwargs, "name": name}, ) + def all(self) -> Self: + from pyspark.sql import functions as F # noqa: N812 + + return self._from_call(F.bool_and, "all", returns_scalar=True) + + def any(self) -> Self: + from pyspark.sql import functions as F # noqa: N812 + + return self._from_call(F.bool_or, "any", returns_scalar=True) + def count(self) -> Self: from pyspark.sql import functions as F # noqa: N812 @@ -306,6 +316,14 @@ def min(self) -> Self: return self._from_call(F.min, "min", returns_scalar=True) + def null_count(self) -> Self: + def _null_count(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.count_if(F.isnull(_input)) + + return self._from_call(_null_count, "null_count", returns_scalar=True) + def sum(self) -> Self: from pyspark.sql import functions as F # noqa: N812 diff --git a/tests/expr_and_series/any_all_test.py b/tests/expr_and_series/any_all_test.py index 7fd81f04d..e8554316e 100644 --- a/tests/expr_and_series/any_all_test.py +++ b/tests/expr_and_series/any_all_test.py @@ -1,17 +1,12 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_any_all(constructor: Constructor) -> None: df = nw.from_native( constructor( { @@ -24,7 +19,7 @@ def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> No result = df.select(nw.col("a", "b", "c").all()) expected = {"a": [False], "b": [True], "c": [False]} assert_equal_data(result, expected) - result = df.select(nw.all().any()) + result = df.select(nw.col("a", "b", "c").any()) expected = {"a": [True], "b": [True], "c": [False]} assert_equal_data(result, expected) diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index 3bd15c66c..a49fd79c8 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -16,10 +16,10 @@ def test_null_count_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(nw.all().null_count()) + result = df.select(nw.col("a", "b").null_count()) expected = { "a": [2], "b": [1],