Skip to content

Commit

Permalink
perf
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 18, 2024
1 parent 2a39a86 commit ba0cf14
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 87 deletions.
3 changes: 1 addition & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from narwhals.group_by import GroupBy
from narwhals.series import Series
from narwhals.typing import IntoExpr
from narwhals.typing import T


class BaseFrame:
Expand Down Expand Up @@ -208,7 +207,7 @@ def to_dict(self, *, as_series: bool = True) -> dict[str, Any]:
class LazyFrame(BaseFrame):
def __init__(
self,
df: T,
df: Any,
*,
implementation: str | None = None,
) -> None:
Expand Down
29 changes: 15 additions & 14 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from narwhals.pandas_like.utils import evaluate_into_exprs
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import reset_index
from narwhals.pandas_like.utils import translate_dtype
from narwhals.pandas_like.utils import validate_dataframe_comparand
from narwhals.utils import flatten_str
Expand All @@ -33,17 +32,18 @@ def __init__(
implementation: str,
) -> None:
self._validate_columns(dataframe.columns)
self._dataframe = reset_index(dataframe)
self._dataframe = dataframe
self._implementation = implementation

def _validate_columns(self, columns: Sequence[str]) -> None:
counter = collections.Counter(columns)
for col, count in counter.items():
if count > 1:
msg = f"Expected unique column names, got {col!r} {count} time(s)"
raise ValueError(
msg,
)
if len(columns) != len(set(columns)):
counter = collections.Counter(columns)
for col, count in counter.items():
if count > 1:
msg = f"Expected unique column names, got {col!r} {count} time(s)"
raise ValueError(
msg,
)

def _validate_booleanness(self) -> None:
if not (
Expand Down Expand Up @@ -102,7 +102,7 @@ def filter(
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
_mask = validate_dataframe_comparand(mask)
_mask = validate_dataframe_comparand(self._dataframe.index, mask)
return self._from_dataframe(self._dataframe.loc[_mask])

def with_columns(
Expand All @@ -112,7 +112,10 @@ def with_columns(
) -> Self:
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
df = self._dataframe.assign(
**{series.name: validate_dataframe_comparand(series) for series in new_series}
**{
series.name: validate_dataframe_comparand(self._dataframe.index, series)
for series in new_series
}
)
return self._from_dataframe(df)

Expand All @@ -137,9 +140,7 @@ def sort(
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
return self._from_dataframe(
df.sort_values(flat_keys, ascending=ascending),
)
return self._from_dataframe(df.sort_values(flat_keys, ascending=ascending))

# --- convert ---
def collect(self) -> PandasDataFrame:
Expand Down
59 changes: 30 additions & 29 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pandas.api.types import is_extension_array_dtype

from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import reset_index
from narwhals.pandas_like.utils import reverse_translate_dtype
from narwhals.pandas_like.utils import translate_dtype
from narwhals.pandas_like.utils import validate_column_comparand
Expand All @@ -32,7 +31,7 @@ def __init__(
"""

self._name = str(series.name) if series.name is not None else ""
self._series = reset_index(series)
self._series = series
self._implementation = implementation

def _from_series(self, series: Any) -> Self:
Expand Down Expand Up @@ -70,7 +69,9 @@ def cast(

def filter(self, mask: Self) -> Self:
ser = self._series
return self._from_series(ser.loc[validate_column_comparand(mask)])
return self._from_series(
ser.loc[validate_column_comparand(self._series.index, mask)]
)

def item(self) -> Any:
return item(self._series)
Expand All @@ -93,122 +94,122 @@ def is_in(self, other: Any) -> PandasSeries:

def __eq__(self, other: object) -> PandasSeries: # type: ignore[override]
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__eq__(other)).rename(ser.name, copy=False))

def __ne__(self, other: object) -> PandasSeries: # type: ignore[override]
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__ne__(other)).rename(ser.name, copy=False))

def __ge__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__ge__(other)).rename(ser.name, copy=False))

def __gt__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__gt__(other)).rename(ser.name, copy=False))

def __le__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__le__(other)).rename(ser.name, copy=False))

def __lt__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__lt__(other)).rename(ser.name, copy=False))

def __and__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__and__(other)).rename(ser.name, copy=False))

def __rand__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rand__(other)).rename(ser.name, copy=False))

def __or__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__or__(other)).rename(ser.name, copy=False))

def __ror__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__ror__(other)).rename(ser.name, copy=False))

def __add__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__add__(other)).rename(ser.name, copy=False))

def __radd__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__radd__(other)).rename(ser.name, copy=False))

def __sub__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__sub__(other)).rename(ser.name, copy=False))

def __rsub__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rsub__(other)).rename(ser.name, copy=False))

def __mul__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__mul__(other)).rename(ser.name, copy=False))

def __rmul__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rmul__(other)).rename(ser.name, copy=False))

def __truediv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__truediv__(other)).rename(ser.name, copy=False))

def __rtruediv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rtruediv__(other)).rename(ser.name, copy=False))

def __floordiv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__floordiv__(other)).rename(ser.name, copy=False))

def __rfloordiv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rfloordiv__(other)).rename(ser.name, copy=False))

def __pow__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__pow__(other)).rename(ser.name, copy=False))

def __rpow__(self, other: Any) -> PandasSeries: # pragma: no cover
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rpow__(other)).rename(ser.name, copy=False))

def __mod__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__mod__(other)).rename(ser.name, copy=False))

def __rmod__(self, other: Any) -> PandasSeries: # pragma: no cover
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rmod__(other)).rename(ser.name, copy=False))

# Unary
Expand Down Expand Up @@ -285,8 +286,8 @@ def n_unique(self) -> int:
return ser.nunique() # type: ignore[no-any-return]

def zip_with(self, mask: PandasSeries, other: PandasSeries) -> PandasSeries:
mask = validate_column_comparand(mask)
other = validate_column_comparand(other)
mask = validate_column_comparand(self._series.index, mask)
other = validate_column_comparand(self._series.index, other)
ser = self._series
return self._from_series(ser.where(mask, other))

Expand Down
42 changes: 20 additions & 22 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from narwhals.pandas_like.typing import IntoPandasExpr


def validate_column_comparand(other: Any) -> Any:
def validate_column_comparand(index: Any, other: Any) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -47,11 +47,17 @@ def validate_column_comparand(other: Any) -> Any:
if other.len() == 1:
# broadcast
return other.item()
if other._series.index is not index and not (other._series.index == index).all():
msg = (
"Narwhals does not support automated index alignment. "
"Please reset the index of the Series or DataFrame."
)
raise ValueError(msg)
return other._series
return other


def validate_dataframe_comparand(other: Any) -> Any:
def validate_dataframe_comparand(index: Any, other: Any) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -60,19 +66,25 @@ def validate_dataframe_comparand(other: Any) -> Any:
from narwhals.pandas_like.dataframe import PandasDataFrame
from narwhals.pandas_like.series import PandasSeries

if isinstance(other, list) and len(other) > 1:
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions are not supported in this context"
raise ValueError(msg)
if isinstance(other, list):
other = other[0]
if isinstance(other, PandasDataFrame):
return NotImplemented
if isinstance(other, PandasSeries):
if other.len() == 1:
# broadcast
return item(other._series)
if other._series.index is not index and not (other._series.index == index).all():
msg = (
"Narwhals does not support automated index alignment. "
"Please reset the index of the Series or DataFrame."
)
raise ValueError(msg)
return other._series
if isinstance(other, list) and len(other) > 1:
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions are not supported in this context"
raise ValueError(msg)
if isinstance(other, list):
other = other[0]
return other


Expand Down Expand Up @@ -368,17 +380,3 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
return "bool"
msg = f"Unknown dtype: {dtype}"
raise TypeError(msg)


def reset_index(obj: Any) -> Any:
index = obj.index
if (
hasattr(index, "start")
and hasattr(index, "stop")
and hasattr(index, "step")
and index.start == 0
and index.stop == len(obj)
and index.step == 1
):
return obj
return obj.reset_index(drop=True)
1 change: 1 addition & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_accepted_dataframes() -> None:


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd])
@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning")
def test_convert_pandas(df_raw: Any) -> None:
result = nw.DataFrame(df_raw).to_pandas()
expected = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
Expand Down
Loading

0 comments on commit ba0cf14

Please sign in to comment.