Skip to content

Commit

Permalink
feat: add all, any and null_count Spark Expressions (#1724)
Browse files Browse the repository at this point in the history
* test: add logical tests, import ConstructorEager type

* feat: add any_horizontal method

* test: add any_horizontal test, update any_all reference

* dev: correct bool_any to bool_or

* feat: add null_count expr

* test: add tests for null_count expr

* tests: update constructor to pyspark_constructor

* tests: remove eager tests

* feat: initial draft of replace_strict method

* feat: initial draft of replace_strict method

* test: add lazy tests for replace_strict method

* Update expr.py

Co-authored-by: Edoardo Abati <[email protected]>

* Update expr.py

Co-authored-by: Edoardo Abati <[email protected]>

* remove replace_strict method

* remove replace_strict tests

* remove any_h references

* remove any_h references

* pyspark test

---------

Co-authored-by: lucas-nelson-uiuc <[email protected]>
Co-authored-by: Edoardo Abati <[email protected]>
Co-authored-by: FBruzzesi <[email protected]>
  • Loading branch information
4 people authored Jan 10, 2025
1 parent 20eb53b commit 0c98b60
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
18 changes: 18 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]:
kwargs={**self._kwargs, "name": name},
)

def all(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.bool_and, "all", returns_scalar=True)

def any(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.bool_or, "any", returns_scalar=True)

def count(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down Expand Up @@ -306,6 +316,14 @@ def min(self) -> Self:

return self._from_call(F.min, "min", returns_scalar=True)

def null_count(self) -> Self:
def _null_count(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.count_if(F.isnull(_input))

return self._from_call(_null_count, "null_count", returns_scalar=True)

def sum(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down
9 changes: 2 additions & 7 deletions tests/expr_and_series/any_all_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_any_all(constructor: Constructor) -> None:
df = nw.from_native(
constructor(
{
Expand All @@ -24,7 +19,7 @@ def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> No
result = df.select(nw.col("a", "b", "c").all())
expected = {"a": [False], "b": [True], "c": [False]}
assert_equal_data(result, expected)
result = df.select(nw.all().any())
result = df.select(nw.col("a", "b", "c").any())
expected = {"a": [True], "b": [True], "c": [False]}
assert_equal_data(result, expected)

Expand Down
4 changes: 2 additions & 2 deletions tests/expr_and_series/null_count_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
def test_null_count_expr(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.all().null_count())
result = df.select(nw.col("a", "b").null_count())
expected = {
"a": [2],
"b": [1],
Expand Down

0 comments on commit 0c98b60

Please sign in to comment.