Skip to content

Commit

Permalink
feat: add SparkLikeStrNamespace methods (#1781)
Browse files Browse the repository at this point in the history
* feat: SparkLikeStrNamespace

* pyproject

---------

Co-authored-by: Marco Edward Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jan 10, 2025
1 parent 07402c6 commit 50b3a40
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 66 deletions.
125 changes: 125 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,128 @@ def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)

@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)


class SparkLikeExprStringNamespace:
def __init__(self: Self, expr: SparkLikeExpr) -> None:
self._compliant_expr = expr

def len_chars(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.char_length,
"len",
returns_scalar=self._compliant_expr._returns_scalar,
)

def replace_all(
self: Self, pattern: str, value: str, *, literal: bool = False
) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, pattern: str, value: str, *, literal: bool) -> Column:
replace_all_func = F.replace if literal else F.regexp_replace
return replace_all_func(_input, F.lit(pattern), F.lit(value))

return self._compliant_expr._from_call(
func,
"replace",
pattern=pattern,
value=value,
literal=literal,
returns_scalar=self._compliant_expr._returns_scalar,
)

def strip_chars(self: Self, characters: str | None) -> SparkLikeExpr:
import string

from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, characters: str | None) -> Column:
to_remove = characters if characters is not None else string.whitespace
return F.btrim(_input, F.lit(to_remove))

return self._compliant_expr._from_call(
func,
"strip",
characters=characters,
returns_scalar=self._compliant_expr._returns_scalar,
)

def starts_with(self: Self, prefix: str) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
lambda _input, prefix: F.startswith(_input, F.lit(prefix)),
"starts_with",
prefix=prefix,
returns_scalar=self._compliant_expr._returns_scalar,
)

def ends_with(self: Self, suffix: str) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
lambda _input, suffix: F.endswith(_input, F.lit(suffix)),
"ends_with",
suffix=suffix,
returns_scalar=self._compliant_expr._returns_scalar,
)

def contains(self: Self, pattern: str, *, literal: bool) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, pattern: str, *, literal: bool) -> Column:
contains_func = F.contains if literal else F.regexp
return contains_func(_input, F.lit(pattern))

return self._compliant_expr._from_call(
func,
"contains",
pattern=pattern,
literal=literal,
returns_scalar=self._compliant_expr._returns_scalar,
)

def slice(self: Self, offset: int, length: int | None = None) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

# From the docs: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.substring.html
# The position is not zero based, but 1 based index.
def func(_input: Column, offset: int, length: int | None) -> Column:
col_length = F.char_length(_input)

_offset = col_length + F.lit(offset + 1) if offset < 0 else F.lit(offset + 1)
_length = F.lit(length) if length is not None else col_length
return _input.substr(_offset, _length)

return self._compliant_expr._from_call(
func,
"slice",
offset=offset,
length=length,
returns_scalar=self._compliant_expr._returns_scalar,
)

def to_uppercase(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.upper,
"to_uppercase",
returns_scalar=self._compliant_expr._returns_scalar,
)

def to_lowercase(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.lower,
"to_lowercase",
returns_scalar=self._compliant_expr._returns_scalar,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ filterwarnings = [
'ignore: unclosed <socket.socket',
'ignore:.*The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:pyspark',
'ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning:pyspark',
'ignore:.*is_datetime64tz_dtype is deprecated and will be removed in a future version.*:DeprecationWarning:pyspark',

]
xfail_strict = true
Expand Down
21 changes: 7 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,13 @@ def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cove
register(session.stop)

def _constructor(obj: Any) -> IntoFrame:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
r".*is_datetime64tz_dtype is deprecated and will be removed in a future version.*",
module="pyspark",
category=DeprecationWarning,
)
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)

return _constructor

Expand Down
16 changes: 3 additions & 13 deletions tests/expr_and_series/str/contains_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def test_contains_case_insensitive(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "cudf" in str(constructor) or "pyspark" in str(constructor):
if "cudf" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -40,12 +40,7 @@ def test_contains_series_case_insensitive(
assert_equal_data(result, expected)


def test_contains_case_sensitive(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_contains_case_sensitive(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("pets").str.contains("parrot|Dove").alias("default_match"))
expected = {
Expand All @@ -63,12 +58,7 @@ def test_contains_series_case_sensitive(constructor_eager: ConstructorEager) ->
assert_equal_data(result, expected)


def test_contains_literal(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_contains_literal(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.col("pets").str.contains("Parrot|dove").alias("default_match"),
Expand Down
7 changes: 1 addition & 6 deletions tests/expr_and_series/str/head_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -10,10 +8,7 @@
data = {"a": ["foo", "bars"]}


def test_str_head(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_str_head(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").str.head(3))
expected = {
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/str/len_chars_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_str_len_chars(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").str.len_chars())
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/str/replace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def test_str_replace_all_expr(
literal: bool, # noqa: FBT001
expected: dict[str, list[str]],
) -> None:
if ("pyspark" in str(constructor)) or (
"duckdb" in str(constructor) and literal is False
):
if "duckdb" in str(constructor) and literal is False:
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(
Expand Down
4 changes: 0 additions & 4 deletions tests/expr_and_series/str/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@
[(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})],
)
def test_str_slice(
request: pytest.FixtureRequest,
constructor: Constructor,
offset: int,
length: int | None,
expected: Any,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result_frame = df.select(nw.col("a").str.slice(offset, length))
assert_equal_data(result_frame, expected)
Expand Down
12 changes: 2 additions & 10 deletions tests/expr_and_series/str/starts_with_ends_with_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -13,10 +11,7 @@
data = {"a": ["fdas", "edfas"]}


def test_ends_with(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_ends_with(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").str.ends_with("das"))
expected = {
Expand All @@ -34,10 +29,7 @@ def test_ends_with_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_starts_with(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_starts_with(constructor: Constructor) -> None:
df = nw.from_native(constructor(data)).lazy()
result = df.select(nw.col("a").str.starts_with("fda"))
expected = {
Expand Down
3 changes: 0 additions & 3 deletions tests/expr_and_series/str/strip_chars_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
],
)
def test_str_strip_chars(
request: pytest.FixtureRequest,
constructor: Constructor,
characters: str | None,
expected: Any,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result_frame = df.select(nw.col("a").str.strip_chars(characters))
assert_equal_data(result_frame, expected)
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/str/tail_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -10,9 +8,7 @@
data = {"a": ["foo", "bars"]}


def test_str_tail(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_str_tail(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
expected = {"a": ["foo", "ars"]}

Expand Down
6 changes: 0 additions & 6 deletions tests/expr_and_series/str/to_uppercase_to_lowercase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def test_str_to_uppercase(
expected: dict[str, list[str]],
request: pytest.FixtureRequest,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

if any("ß" in s for value in data.values() for s in value) & (
constructor.__name__
in (
Expand Down Expand Up @@ -113,13 +110,10 @@ def test_str_to_uppercase_series(
],
)
def test_str_to_lowercase(
request: pytest.FixtureRequest,
constructor: Constructor,
data: dict[str, list[str]],
expected: dict[str, list[str]],
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result_frame = df.select(nw.col("a").str.to_lowercase())
assert_equal_data(result_frame, expected)
Expand Down
4 changes: 3 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _sort_dict_by_key(
data_dict: dict[str, list[Any]], key: str
) -> dict[str, list[Any]]: # pragma: no cover
sort_list = data_dict[key]
sorted_indices = sorted(range(len(sort_list)), key=lambda i: sort_list[i])
sorted_indices = sorted(
range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i])
)
return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()}


Expand Down

0 comments on commit 50b3a40

Please sign in to comment.