Skip to content

Commit

Permalink
feat: SparkLikeNamespace methods (#1779)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Marco Edward Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jan 10, 2025
1 parent 50b3a40 commit 8229282
Show file tree
Hide file tree
Showing 16 changed files with 280 additions and 87 deletions.
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def concat_str(
self,
exprs: Iterable[IntoDaskExpr],
*more_exprs: IntoDaskExpr,
separator: str = "",
ignore_nulls: bool = False,
separator: str,
ignore_nulls: bool,
) -> DaskExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def concat_str(
self,
exprs: Iterable[IntoPandasLikeExpr],
*more_exprs: IntoPandasLikeExpr,
separator: str = "",
ignore_nulls: bool = False,
separator: str,
ignore_nulls: bool,
) -> PandasLikeExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
Expand Down
14 changes: 14 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,20 @@ def skew(self) -> Self:

return self._from_call(F.skewness, "skew", returns_scalar=True)

def n_unique(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return self._from_call(_n_unique, "n_unique", returns_scalar=True)

def is_null(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar)

@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)
Expand Down
6 changes: 1 addition & 5 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,7 @@ def agg_pyspark(
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(function_name, **expr._kwargs)
agg_func = get_spark_function(expr._function_name, **expr._kwargs)
simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
Expand Down
260 changes: 239 additions & 21 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
import operator
from functools import reduce
from typing import TYPE_CHECKING
from typing import Iterable
from typing import Literal

from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_exprs
from narwhals._expression_parsing import reduce_output_names
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.utils import get_column_name
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import DataFrame

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
Expand Down Expand Up @@ -43,26 +46,6 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def col(self, *column_names: str) -> SparkLikeExpr:
return SparkLikeExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
Expand Down Expand Up @@ -90,6 +73,64 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

def len(self) -> SparkLikeExpr:
def func(_: SparkLikeLazyFrame) -> list[Column]:
import pyspark.sql.functions as F # noqa: N812

return [F.count("*").alias("len")]

return SparkLikeExpr( # type: ignore[abstract]
func,
depth=0,
function_name="len",
root_names=None,
output_names=["len"],
returns_scalar=True,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.or_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

Expand All @@ -116,3 +157,180 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
version=self._version,
kwargs={"exprs": exprs},
)

def mean_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [
(
reduce(operator.add, (F.coalesce(col, F.lit(0)) for col in cols))
/ reduce(
operator.add,
(col.isNotNull().cast(IntegerType()) for col in cols),
)
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def max_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [F.greatest(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="max_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def min_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [F.least(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="min_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def concat(
self,
items: Iterable[SparkLikeLazyFrame],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> SparkLikeLazyFrame:
dfs: list[DataFrame] = [item._native_frame for item in items]
if how == "horizontal":
msg = (
"Horizontal concatenation is not supported for LazyFrame backed by "
"a PySpark DataFrame."
)
raise NotImplementedError(msg)

if how == "vertical":
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)

return SparkLikeLazyFrame(
native_dataframe=reduce(lambda x, y: x.union(y), dfs),
backend_version=self._backend_version,
version=self._version,
)

if how == "diagonal":
return SparkLikeLazyFrame(
native_dataframe=reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
),
backend_version=self._backend_version,
version=self._version,
)
raise NotImplementedError

def concat_str(
self,
exprs: Iterable[IntoSparkLikeExpr],
*more_exprs: IntoSparkLikeExpr,
separator: str,
ignore_nulls: bool,
) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import StringType

parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = (s.cast(StringType()) for _expr in parsed_exprs for s in _expr(df))
null_mask = [F.isnull(s) for _expr in parsed_exprs for s in _expr(df)]

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = F.when(
~null_mask_result,
reduce(lambda x, y: F.format_string(f"%s{separator}%s", x, y), cols),
).otherwise(F.lit(None))
else:
init_value, *values = [
F.when(~nm, col).otherwise(F.lit(""))
for col, nm in zip(cols, null_mask)
]

separators = (
F.when(nm, F.lit("")).otherwise(F.lit(separator))
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: F.format_string("%s%s", x, y),
(F.format_string("%s%s", s, v) for s, v in zip(separators, values)),
init_value,
)

return [result]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={
"exprs": exprs,
"more_exprs": more_exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)
10 changes: 2 additions & 8 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(
request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand All @@ -27,9 +23,7 @@ def test_anyh(
assert_equal_data(result, expected)


def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_anyh_all(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand Down
Loading

0 comments on commit 8229282

Please sign in to comment.