From 0c98b60cf32e13e61613f78f18dd94e51e9d0a6d Mon Sep 17 00:00:00 2001
From: Lucas Nelson <lucas.nelson.contacts@gmail.com>
Date: Fri, 10 Jan 2025 01:45:22 -0600
Subject: [PATCH] feat: add `all`, `any` and `null_count` Spark Expressions
 (#1724)

* test: add logical tests, import ConstructorEager type

* feat: add any_horizontal method

* test: add any_horizontal test, update any_all reference

* dev: correct bool_any to bool_or

* feat: add null_count expr

* test: add tests for null_count expr

* tests: update constructor to pyspark_constructor

* tests: remove eager tests

* feat: initial draft of replace_strict method

* feat: initial draft of replace_strict method

* test: add lazy tests for replace_strict method

* Update expr.py

Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com>

* Update expr.py

Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com>

* remove replace_strict method

* remove replace_strict tests

* remove any_h references

* remove any_h references

* pyspark test

---------

Co-authored-by: lucas-nelson-uiuc <lucas.nelson.uiuc@gmail.com>
Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com>
Co-authored-by: FBruzzesi <francesco.bruzzesi.93@gmail.com>
---
 narwhals/_spark_like/expr.py             | 18 ++++++++++++++++++
 tests/expr_and_series/any_all_test.py    |  9 ++-------
 tests/expr_and_series/null_count_test.py |  4 ++--
 3 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py
index a8cafccfd..32139cf01 100644
--- a/narwhals/_spark_like/expr.py
+++ b/narwhals/_spark_like/expr.py
@@ -273,6 +273,16 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]:
             kwargs={**self._kwargs, "name": name},
         )
 
+    def all(self) -> Self:
+        from pyspark.sql import functions as F  # noqa: N812
+
+        return self._from_call(F.bool_and, "all", returns_scalar=True)
+
+    def any(self) -> Self:
+        from pyspark.sql import functions as F  # noqa: N812
+
+        return self._from_call(F.bool_or, "any", returns_scalar=True)
+
     def count(self) -> Self:
         from pyspark.sql import functions as F  # noqa: N812
 
@@ -306,6 +316,14 @@ def min(self) -> Self:
 
         return self._from_call(F.min, "min", returns_scalar=True)
 
+    def null_count(self) -> Self:
+        def _null_count(_input: Column) -> Column:
+            from pyspark.sql import functions as F  # noqa: N812
+
+            return F.count_if(F.isnull(_input))
+
+        return self._from_call(_null_count, "null_count", returns_scalar=True)
+
     def sum(self) -> Self:
         from pyspark.sql import functions as F  # noqa: N812
 
diff --git a/tests/expr_and_series/any_all_test.py b/tests/expr_and_series/any_all_test.py
index 7fd81f04d..e8554316e 100644
--- a/tests/expr_and_series/any_all_test.py
+++ b/tests/expr_and_series/any_all_test.py
@@ -1,17 +1,12 @@
 from __future__ import annotations
 
-import pytest
-
 import narwhals.stable.v1 as nw
 from tests.utils import Constructor
 from tests.utils import ConstructorEager
 from tests.utils import assert_equal_data
 
 
-def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
-    if "pyspark" in str(constructor):
-        request.applymarker(pytest.mark.xfail)
-
+def test_any_all(constructor: Constructor) -> None:
     df = nw.from_native(
         constructor(
             {
@@ -24,7 +19,7 @@ def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> No
     result = df.select(nw.col("a", "b", "c").all())
     expected = {"a": [False], "b": [True], "c": [False]}
     assert_equal_data(result, expected)
-    result = df.select(nw.all().any())
+    result = df.select(nw.col("a", "b", "c").any())
     expected = {"a": [True], "b": [True], "c": [False]}
     assert_equal_data(result, expected)
 
diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py
index 3bd15c66c..a49fd79c8 100644
--- a/tests/expr_and_series/null_count_test.py
+++ b/tests/expr_and_series/null_count_test.py
@@ -16,10 +16,10 @@
 def test_null_count_expr(
     constructor: Constructor, request: pytest.FixtureRequest
 ) -> None:
-    if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
+    if "duckdb" in str(constructor):
         request.applymarker(pytest.mark.xfail)
     df = nw.from_native(constructor(data))
-    result = df.select(nw.all().null_count())
+    result = df.select(nw.col("a", "b").null_count())
     expected = {
         "a": [2],
         "b": [1],