Skip to content

Commit

Permalink
feat: add support for SparkLikeNamespace.when (#1805)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 13, 2025
1 parent e25e5e6 commit f769897
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 23 deletions.
102 changes: 102 additions & 0 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import operator
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal

from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_expr
from narwhals._expression_parsing import parse_into_exprs
from narwhals._expression_parsing import reduce_output_names
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
Expand Down Expand Up @@ -334,3 +336,103 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
"ignore_nulls": ignore_nulls,
},
)

def when(self, *predicates: IntoSparkLikeExpr) -> SparkLikeWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
condition = plx.all_horizontal(*predicates)
return SparkLikeWhen(
condition, self._backend_version, returns_scalar=False, version=self._version
)


class SparkLikeWhen:
def __init__(
self,
condition: SparkLikeExpr,
backend_version: tuple[int, ...],
then_value: Any | None = None,
otherwise_value: Any | None = None,
*,
returns_scalar: bool,
version: Version,
) -> None:
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value
self._returns_scalar = returns_scalar
self._version = version

def __call__(self, df: SparkLikeLazyFrame) -> list[Column]:
from pyspark.sql import functions as F # noqa: N812

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]

try:
value_ = parse_into_expr(self._then_value, namespace=plx)(df)[0]
col_name = get_column_name(df, value_)
except TypeError:
# `self._then_value` is a scalar and can't be converted to an expression
value_ = F.lit(self._then_value)
col_name = "literal"

try:
other_ = parse_into_expr(self._otherwise_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
other_ = F.lit(self._otherwise_value)

return [
F.when(condition=condition, value=value_)
.otherwise(value=other_)
.alias(col_name)
]

def then(self, value: SparkLikeExpr | Any) -> SparkLikeThen:
self._then_value = value

return SparkLikeThen( # type: ignore[abstract]
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
version=self._version,
kwargs={"value": value},
)


class SparkLikeThen(SparkLikeExpr):
def __init__(
self,
call: SparkLikeWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
returns_scalar: bool,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
) -> None:
self._backend_version = backend_version
self._version = version
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names
self._returns_scalar = returns_scalar
self._kwargs = kwargs

def otherwise(self, value: SparkLikeExpr | Any) -> SparkLikeExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `SparkLikeWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
30 changes: 7 additions & 23 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
}


def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
Expand All @@ -28,9 +26,7 @@ def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
assert_equal_data(result, expected)


def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when_otherwise(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
expected = {
Expand All @@ -39,11 +35,7 @@ def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest
assert_equal_data(result, expected)


def test_multiple_conditions(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_multiple_conditions(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
Expand Down Expand Up @@ -85,11 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_value_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_value_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when"))
expected = {
Expand Down Expand Up @@ -122,11 +110,7 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_otherwise_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_otherwise_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when")
Expand All @@ -140,7 +124,7 @@ def test_otherwise_expression(
def test_when_then_otherwise_into_expr(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e"))
Expand All @@ -151,7 +135,7 @@ def test_when_then_otherwise_into_expr(
def test_when_then_otherwise_lit_str(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z")))
Expand Down

0 comments on commit f769897

Please sign in to comment.