Skip to content

Commit

Permalink
feat: add missing dunder methods in SparkLikeExpr and `SparkLikeNam…
Browse files Browse the repository at this point in the history
…espace.lit` (#1708)
  • Loading branch information
EdAbati authored Jan 9, 2025
1 parent deee14c commit 0f38521
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 8 deletions.
2 changes: 1 addition & 1 deletion narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __native_namespace__(self) -> Any: # pragma: no cover
def __narwhals_namespace__(self) -> SparkLikeNamespace:
from narwhals._spark_like.namespace import SparkLikeNamespace

return SparkLikeNamespace( # type: ignore[abstract]
return SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)

Expand Down
103 changes: 96 additions & 7 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
# Unused, just for compatibility with PandasLikeExpr
from narwhals._spark_like.namespace import SparkLikeNamespace

return SparkLikeNamespace( # type: ignore[abstract]
return SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)

Expand Down Expand Up @@ -123,32 +123,82 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:

def __add__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input + other,
lambda _input, other: _input.__add__(other),
"__add__",
other=other,
returns_scalar=False,
)

def __sub__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input - other,
lambda _input, other: _input.__sub__(other),
"__sub__",
other=other,
returns_scalar=False,
)

def __mul__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input * other,
lambda _input, other: _input.__mul__(other),
"__mul__",
other=other,
returns_scalar=False,
)

def __lt__(self, other: SparkLikeExpr) -> Self:
def __truediv__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input < other,
"__lt__",
lambda _input, other: _input.__truediv__(other),
"__truediv__",
other=other,
returns_scalar=False,
)

def __floordiv__(self, other: SparkLikeExpr) -> Self:
def _floordiv(_input: Column, other: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.floor(_input / other)

return self._from_call(
_floordiv, "__floordiv__", other=other, returns_scalar=False
)

def __pow__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__pow__(other),
"__pow__",
other=other,
returns_scalar=False,
)

def __mod__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__mod__(other),
"__mod__",
other=other,
returns_scalar=False,
)

def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__eq__(other),
"__eq__",
other=other,
returns_scalar=False,
)

def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__ne__(other),
"__ne__",
other=other,
returns_scalar=False,
)

def __ge__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__ge__(other),
"__ge__",
other=other,
returns_scalar=False,
)
Expand All @@ -161,6 +211,45 @@ def __gt__(self, other: SparkLikeExpr) -> Self:
returns_scalar=False,
)

def __le__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__le__(other),
"__le__",
other=other,
returns_scalar=False,
)

def __lt__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__lt__(other),
"__lt__",
other=other,
returns_scalar=False,
)

def __and__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__and__(other),
"__and__",
other=other,
returns_scalar=False,
)

def __or__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__or__(other),
"__or__",
other=other,
returns_scalar=False,
)

def __invert__(self) -> Self:
return self._from_call(
lambda _input: _input.__invert__(),
"__invert__",
returns_scalar=self._returns_scalar,
)

def abs(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down
23 changes: 23 additions & 0 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version


Expand Down Expand Up @@ -67,6 +68,28 @@ def col(self, *column_names: str) -> SparkLikeExpr:
*column_names, backend_version=self._backend_version, version=self._version
)

def lit(self, value: object, dtype: DType | None) -> SparkLikeExpr:
if dtype is not None:
msg = "todo"
raise NotImplementedError(msg)

def _lit(_: SparkLikeLazyFrame) -> list[Column]:
import pyspark.sql.functions as F # noqa: N812

return [F.lit(value).alias("literal")]

return SparkLikeExpr( # type: ignore[abstract]
call=_lit,
depth=0,
function_name="lit",
root_names=None,
output_names=["literal"],
returns_scalar=True,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

Expand Down
110 changes: 110 additions & 0 deletions tests/spark_like_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if TYPE_CHECKING:
from pyspark.sql import SparkSession

from narwhals.dtypes import DType
from narwhals.typing import IntoFrame
from tests.utils import Constructor

Expand Down Expand Up @@ -954,6 +955,53 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None:
assert_equal_data(result, expected)


# copied from tests/expr_and_series/arithmetic_test.py
@pytest.mark.parametrize(
("attr", "rhs", "expected"),
[
("__add__", 1, [2, 3, 4]),
("__sub__", 1, [0, 1, 2]),
("__mul__", 2, [2, 4, 6]),
("__truediv__", 2.0, [0.5, 1.0, 1.5]),
("__truediv__", 1, [1, 2, 3]),
("__floordiv__", 2, [0, 1, 1]),
("__mod__", 2, [1, 0, 1]),
("__pow__", 2, [1, 4, 9]),
],
)
def test_arithmetic_expr(
attr: str, rhs: Any, expected: list[Any], pyspark_constructor: Constructor
) -> None:
data = {"a": [1.0, 2, 3]}
df = nw.from_native(pyspark_constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
assert_equal_data(result, {"a": expected})


@pytest.mark.parametrize(
("attr", "rhs", "expected"),
[
("__radd__", 1, [2, 3, 4]),
("__rsub__", 1, [0, -1, -2]),
("__rmul__", 2, [2, 4, 6]),
("__rtruediv__", 2.0, [2, 1, 2 / 3]),
("__rfloordiv__", 2, [2, 1, 0]),
("__rmod__", 2, [0, 0, 2]),
("__rpow__", 2, [2, 4, 8]),
],
)
def test_right_arithmetic_expr(
attr: str,
rhs: Any,
expected: list[Any],
pyspark_constructor: Constructor,
) -> None:
data = {"a": [1, 2, 3]}
df = nw.from_native(pyspark_constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
assert_equal_data(result, {"literal": expected})


# Copied from tests/expr_and_series/median_test.py
def test_median(pyspark_constructor: Constructor) -> None:
data = {"a": [3, 8, 2, None], "b": [5, 5, None, 7], "z": [7.0, 8, 9, None]}
Expand Down Expand Up @@ -1099,3 +1147,65 @@ def test_skew(
df = nw.from_native(pyspark_constructor({"a": data}))
result = df.select(skew=nw.col("a").skew())
assert_equal_data(result, {"skew": [expected]})


# copied from tests/expr_and_series/list_test.py
@pytest.mark.parametrize(
("dtype", "expected_lit"),
[(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])],
)
def test_lit(
pyspark_constructor: Constructor,
dtype: DType | None,
expected_lit: list[Any],
request: pytest.FixtureRequest,
) -> None:
if dtype is not None:
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = pyspark_constructor(data)
df = nw.from_native(df_raw).lazy()
result = df.with_columns(nw.lit(2, dtype).alias("lit"))
expected = {
"a": [1, 3, 2],
"b": [4, 4, 6],
"z": [7.0, 8.0, 9.0],
"lit": expected_lit,
}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("col_name", "expr", "expected_result"),
[
("left_lit", nw.lit(1) + nw.col("a"), [2, 4, 3]),
("right_lit", nw.col("a") + nw.lit(1), [2, 4, 3]),
("left_lit_with_agg", nw.lit(1) + nw.col("a").mean(), [3]),
("right_lit_with_agg", nw.col("a").mean() - nw.lit(1), [1]),
("left_scalar", 1 + nw.col("a"), [2, 4, 3]),
("right_scalar", nw.col("a") + 1, [2, 4, 3]),
("left_scalar_with_agg", 1 + nw.col("a").mean(), [3]),
("right_scalar_with_agg", nw.col("a").mean() - 1, [1]),
],
)
def test_lit_operation(
pyspark_constructor: Constructor,
col_name: str,
expr: nw.Expr,
expected_result: list[int],
request: pytest.FixtureRequest,
) -> None:
if col_name in (
"left_scalar_with_agg",
"left_lit_with_agg",
"right_lit",
"right_lit_with_agg",
):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2]}
df_raw = pyspark_constructor(data)
df = nw.from_native(df_raw).lazy()
result = df.select(expr.alias(col_name))
expected = {col_name: expected_result}
assert_equal_data(result, expected)

0 comments on commit 0f38521

Please sign in to comment.