Skip to content

Commit

Permalink
Merge branch 'tests/pyspark-to-main' of https://github.com/narwhals-d…
Browse files Browse the repository at this point in the history
…ev/narwhals into tests/pyspark-to-main
  • Loading branch information
FBruzzesi committed Jan 8, 2025
2 parents 2934687 + 16162af commit 582081e
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 244 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/downstream_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ jobs:
run: |
cd tea-tasting
pdm remove narwhals
pdm add ./..
pdm add ./..[dev]
- name: show-deps
run: |
cd tea-tasting
Expand Down
50 changes: 11 additions & 39 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,14 @@ def row(self: Self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self._native_frame)

@overload
def rows(
self: Self,
*,
named: Literal[True],
) -> list[dict[str, Any]]: ...
def rows(self: Self, *, named: Literal[True]) -> list[dict[str, Any]]: ...

@overload
def rows(
self: Self,
*,
named: Literal[False],
) -> list[tuple[Any, ...]]: ...
def rows(self: Self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...

@overload
def rows(
self: Self,
*,
named: bool,
self: Self, *, named: bool
) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...

def rows(self: Self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
Expand All @@ -126,10 +117,7 @@ def rows(self: Self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, A
return self._native_frame.to_pylist() # type: ignore[no-any-return]

def iter_rows(
self: Self,
*,
named: bool,
buffer_size: int,
self: Self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
df = self._native_frame
num_rows = df.num_rows
Expand Down Expand Up @@ -263,9 +251,7 @@ def __getitem__(
)
start = item.start or 0
stop = item.stop if item.stop is not None else len(self._native_frame)
return self._from_native_frame(
self._native_frame.slice(start, stop - start),
)
return self._from_native_frame(self._native_frame.slice(start, stop - start))

elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1):
if (
Expand Down Expand Up @@ -301,28 +287,19 @@ def estimated_size(self: Self, unit: SizeUnit) -> int | float:
def columns(self: Self) -> list[str]:
return self._native_frame.schema.names # type: ignore[no-any-return]

def select(
self: Self,
*exprs: IntoArrowExpr,
**named_exprs: IntoArrowExpr,
) -> Self:
def select(self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr) -> Self:
import pyarrow as pa

new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._from_native_frame(self._native_frame.__class__.from_arrays([]))
names = [s.name for s in new_series]
df = pa.Table.from_arrays(
broadcast_series(new_series),
names=names,
)
df = pa.Table.from_arrays(broadcast_series(new_series), names=names)
return self._from_native_frame(df)

def with_columns(
self: Self,
*exprs: IntoArrowExpr,
**named_exprs: IntoArrowExpr,
self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr
) -> Self:
native_frame = self._native_frame
new_columns = evaluate_into_exprs(self, *exprs, **named_exprs)
Expand All @@ -334,9 +311,7 @@ def with_columns(
col_name = col_value.name

column = validate_dataframe_comparand(
length=length,
other=col_value,
backend_version=self._backend_version,
length=length, other=col_value, backend_version=self._backend_version
)

native_frame = (
Expand Down Expand Up @@ -611,12 +586,9 @@ def is_duplicated(self: Self) -> ArrowSeries:
columns = self.columns
index_token = generate_temporary_column_name(n_bytes=8, columns=columns)
col_token = generate_temporary_column_name(
n_bytes=8,
columns=[*columns, index_token],
n_bytes=8, columns=[*columns, index_token]
)

df = self.with_row_index(index_token)._native_frame

row_count = (
df.append_column(col_token, pa.repeat(pa.scalar(1), len(self)))
.group_by(columns)
Expand Down
56 changes: 11 additions & 45 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
missing_columns=missing_columns, available_columns=df.columns
) from e

return cls(
Expand Down Expand Up @@ -564,9 +563,7 @@ def __init__(self: Self, expr: ArrowExpr) -> None:

def get_categories(self: Self) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"cat",
"get_categories",
self._compliant_expr, "cat", "get_categories"
)


Expand Down Expand Up @@ -676,12 +673,7 @@ def len_chars(self: Self) -> ArrowExpr:
)

def replace(
self: Self,
pattern: str,
value: str,
*,
literal: bool,
n: int,
self: Self, pattern: str, value: str, *, literal: bool, n: int
) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
Expand All @@ -693,13 +685,7 @@ def replace(
n=n,
)

def replace_all(
self: Self,
pattern: str,
value: str,
*,
literal: bool,
) -> ArrowExpr:
def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
Expand All @@ -711,26 +697,17 @@ def replace_all(

def strip_chars(self: Self, characters: str | None) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"strip_chars",
characters=characters,
self._compliant_expr, "str", "strip_chars", characters=characters
)

def starts_with(self: Self, prefix: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"starts_with",
prefix=prefix,
self._compliant_expr, "str", "starts_with", prefix=prefix
)

def ends_with(self: Self, suffix: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"ends_with",
suffix=suffix,
self._compliant_expr, "str", "ends_with", suffix=suffix
)

def contains(self, pattern: str, *, literal: bool) -> ArrowExpr:
Expand All @@ -745,24 +722,17 @@ def slice(self: Self, offset: int, length: int | None) -> ArrowExpr:

def to_datetime(self: Self, format: str | None) -> ArrowExpr: # noqa: A002
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"to_datetime",
format=format,
self._compliant_expr, "str", "to_datetime", format=format
)

def to_uppercase(self: Self) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"to_uppercase",
self._compliant_expr, "str", "to_uppercase"
)

def to_lowercase(self: Self) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"to_lowercase",
self._compliant_expr, "str", "to_lowercase"
)


Expand Down Expand Up @@ -931,8 +901,4 @@ def __init__(self: Self, expr: ArrowExpr) -> None:
self._expr = expr

def len(self: Self) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._expr,
"list",
"len",
)
return reuse_series_namespace_implementation(self._expr, "list", "len")
2 changes: 1 addition & 1 deletion narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def agg_arrow(
function_name = remove_prefix(expr._function_name, "col->")

if function_name in {"std", "var"}:
option = pc.VarianceOptions(ddof=expr._kwargs.get("ddof", 1))
option = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
Expand Down
7 changes: 1 addition & 6 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,7 @@ def when(
*predicates: IntoArrowExpr,
) -> ArrowWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return ArrowWhen(condition, self._backend_version, version=self._version)

def concat_str(
Expand Down
Loading

0 comments on commit 582081e

Please sign in to comment.