From 0f385212dc43aad770b3c75740a0dceaaf67d38f Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Thu, 9 Jan 2025 17:11:53 +0100 Subject: [PATCH] feat: add missing dunder methods in `SparkLikeExpr` and `SparkLikeNamespace.lit` (#1708) --- narwhals/_spark_like/dataframe.py | 2 +- narwhals/_spark_like/expr.py | 103 ++++++++++++++++++++++++++-- narwhals/_spark_like/namespace.py | 23 +++++++ tests/spark_like_test.py | 110 ++++++++++++++++++++++++++++++ 4 files changed, 230 insertions(+), 8 deletions(-) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index e04da7f57..e54a05997 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -50,7 +50,7 @@ def __native_namespace__(self) -> Any: # pragma: no cover def __narwhals_namespace__(self) -> SparkLikeNamespace: from narwhals._spark_like.namespace import SparkLikeNamespace - return SparkLikeNamespace( # type: ignore[abstract] + return SparkLikeNamespace( backend_version=self._backend_version, version=self._version ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 03529ca96..10fb76227 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -59,7 +59,7 @@ def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover # Unused, just for compatibility with PandasLikeExpr from narwhals._spark_like.namespace import SparkLikeNamespace - return SparkLikeNamespace( # type: ignore[abstract] + return SparkLikeNamespace( backend_version=self._backend_version, version=self._version ) @@ -123,7 +123,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def __add__(self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: _input + other, + lambda _input, other: _input.__add__(other), "__add__", other=other, returns_scalar=False, @@ -131,7 +131,7 @@ def __add__(self, other: SparkLikeExpr) -> Self: def __sub__(self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: _input - other, + lambda _input, other: _input.__sub__(other), "__sub__", other=other, returns_scalar=False, @@ -139,16 +139,66 @@ def __sub__(self, other: SparkLikeExpr) -> Self: def __mul__(self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: _input * other, + lambda _input, other: _input.__mul__(other), "__mul__", other=other, returns_scalar=False, ) - def __lt__(self, other: SparkLikeExpr) -> Self: + def __truediv__(self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: _input < other, - "__lt__", + lambda _input, other: _input.__truediv__(other), + "__truediv__", + other=other, + returns_scalar=False, + ) + + def __floordiv__(self, other: SparkLikeExpr) -> Self: + def _floordiv(_input: Column, other: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.floor(_input / other) + + return self._from_call( + _floordiv, "__floordiv__", other=other, returns_scalar=False + ) + + def __pow__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__pow__(other), + "__pow__", + other=other, + returns_scalar=False, + ) + + def __mod__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__mod__(other), + "__mod__", + other=other, + returns_scalar=False, + ) + + def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] + return self._from_call( + lambda _input, other: _input.__eq__(other), + "__eq__", + other=other, + returns_scalar=False, + ) + + def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] + return self._from_call( + lambda _input, other: _input.__ne__(other), + "__ne__", + other=other, + returns_scalar=False, + ) + + def __ge__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__ge__(other), + "__ge__", other=other, returns_scalar=False, ) @@ -161,6 +211,45 @@ def __gt__(self, other: SparkLikeExpr) -> Self: returns_scalar=False, ) + def __le__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__le__(other), + "__le__", + other=other, + returns_scalar=False, + ) + + def __lt__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__lt__(other), + "__lt__", + other=other, + returns_scalar=False, + ) + + def __and__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__and__(other), + "__and__", + other=other, + returns_scalar=False, + ) + + def __or__(self, other: SparkLikeExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__or__(other), + "__or__", + other=other, + returns_scalar=False, + ) + + def __invert__(self) -> Self: + return self._from_call( + lambda _input: _input.__invert__(), + "__invert__", + returns_scalar=self._returns_scalar, + ) + def abs(self) -> Self: from pyspark.sql import functions as F # noqa: N812 diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index d34867b00..56cc4d271 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -16,6 +16,7 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.typing import IntoSparkLikeExpr + from narwhals.dtypes import DType from narwhals.utils import Version @@ -67,6 +68,28 @@ def col(self, *column_names: str) -> SparkLikeExpr: *column_names, backend_version=self._backend_version, version=self._version ) + def lit(self, value: object, dtype: DType | None) -> SparkLikeExpr: + if dtype is not None: + msg = "todo" + raise NotImplementedError(msg) + + def _lit(_: SparkLikeLazyFrame) -> list[Column]: + import pyspark.sql.functions as F # noqa: N812 + + return [F.lit(value).alias("literal")] + + return SparkLikeExpr( # type: ignore[abstract] + call=_lit, + depth=0, + function_name="lit", + root_names=None, + output_names=["literal"], + returns_scalar=True, + backend_version=self._backend_version, + version=self._version, + kwargs={}, + ) + def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 30610be45..f7cd9e6a9 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from pyspark.sql import SparkSession + from narwhals.dtypes import DType from narwhals.typing import IntoFrame from tests.utils import Constructor @@ -954,6 +955,53 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# copied from tests/expr_and_series/arithmetic_test.py +@pytest.mark.parametrize( + ("attr", "rhs", "expected"), + [ + ("__add__", 1, [2, 3, 4]), + ("__sub__", 1, [0, 1, 2]), + ("__mul__", 2, [2, 4, 6]), + ("__truediv__", 2.0, [0.5, 1.0, 1.5]), + ("__truediv__", 1, [1, 2, 3]), + ("__floordiv__", 2, [0, 1, 1]), + ("__mod__", 2, [1, 0, 1]), + ("__pow__", 2, [1, 4, 9]), + ], +) +def test_arithmetic_expr( + attr: str, rhs: Any, expected: list[Any], pyspark_constructor: Constructor +) -> None: + data = {"a": [1.0, 2, 3]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(getattr(nw.col("a"), attr)(rhs)) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("attr", "rhs", "expected"), + [ + ("__radd__", 1, [2, 3, 4]), + ("__rsub__", 1, [0, -1, -2]), + ("__rmul__", 2, [2, 4, 6]), + ("__rtruediv__", 2.0, [2, 1, 2 / 3]), + ("__rfloordiv__", 2, [2, 1, 0]), + ("__rmod__", 2, [0, 0, 2]), + ("__rpow__", 2, [2, 4, 8]), + ], +) +def test_right_arithmetic_expr( + attr: str, + rhs: Any, + expected: list[Any], + pyspark_constructor: Constructor, +) -> None: + data = {"a": [1, 2, 3]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(getattr(nw.col("a"), attr)(rhs)) + assert_equal_data(result, {"literal": expected}) + + # Copied from tests/expr_and_series/median_test.py def test_median(pyspark_constructor: Constructor) -> None: data = {"a": [3, 8, 2, None], "b": [5, 5, None, 7], "z": [7.0, 8, 9, None]} @@ -1099,3 +1147,65 @@ def test_skew( df = nw.from_native(pyspark_constructor({"a": data})) result = df.select(skew=nw.col("a").skew()) assert_equal_data(result, {"skew": [expected]}) + + +# copied from tests/expr_and_series/list_test.py +@pytest.mark.parametrize( + ("dtype", "expected_lit"), + [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], +) +def test_lit( + pyspark_constructor: Constructor, + dtype: DType | None, + expected_lit: list[Any], + request: pytest.FixtureRequest, +) -> None: + if dtype is not None: + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df_raw = pyspark_constructor(data) + df = nw.from_native(df_raw).lazy() + result = df.with_columns(nw.lit(2, dtype).alias("lit")) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "lit": expected_lit, + } + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("col_name", "expr", "expected_result"), + [ + ("left_lit", nw.lit(1) + nw.col("a"), [2, 4, 3]), + ("right_lit", nw.col("a") + nw.lit(1), [2, 4, 3]), + ("left_lit_with_agg", nw.lit(1) + nw.col("a").mean(), [3]), + ("right_lit_with_agg", nw.col("a").mean() - nw.lit(1), [1]), + ("left_scalar", 1 + nw.col("a"), [2, 4, 3]), + ("right_scalar", nw.col("a") + 1, [2, 4, 3]), + ("left_scalar_with_agg", 1 + nw.col("a").mean(), [3]), + ("right_scalar_with_agg", nw.col("a").mean() - 1, [1]), + ], +) +def test_lit_operation( + pyspark_constructor: Constructor, + col_name: str, + expr: nw.Expr, + expected_result: list[int], + request: pytest.FixtureRequest, +) -> None: + if col_name in ( + "left_scalar_with_agg", + "left_lit_with_agg", + "right_lit", + "right_lit_with_agg", + ): + request.applymarker(pytest.mark.xfail) + + data = {"a": [1, 3, 2]} + df_raw = pyspark_constructor(data) + df = nw.from_native(df_raw).lazy() + result = df.select(expr.alias(col_name)) + expected = {col_name: expected_result} + assert_equal_data(result, expected)