Skip to content

Commit

Permalink
Merge pull request #14 from raisadz/increase-coverage
Browse files Browse the repository at this point in the history
replace binary operations for pandas-like dataframes with their natives, add tests for binary expressions operations
  • Loading branch information
MarcoGorelli authored Mar 17, 2024
2 parents 76806b5 + f1dc9e9 commit d7af728
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 38 deletions.
3 changes: 0 additions & 3 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import collections
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Iterable
from typing import Literal

Expand All @@ -26,8 +25,6 @@


class PandasDataFrame:
_features: ClassVar[set[str]] = {"eager", "lazy"}

# --- not in the spec ---
def __init__(
self,
Expand Down
10 changes: 5 additions & 5 deletions narwhals/pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,31 @@ def __mul__(self, other: PandasExpr | Any) -> Self:
return register_expression_call(self, "__mul__", other)

def __rmul__(self, other: Any) -> Self:
return self.__mul__(other)
return register_expression_call(self, "__rmul__", other)

def __truediv__(self, other: PandasExpr | Any) -> Self:
return register_expression_call(self, "__truediv__", other)

def __rtruediv__(self, other: Any) -> Self:
raise NotImplementedError
return register_expression_call(self, "__rtruediv__", other)

def __floordiv__(self, other: PandasExpr | Any) -> Self:
return register_expression_call(self, "__floordiv__", other)

def __rfloordiv__(self, other: Any) -> Self:
raise NotImplementedError
return register_expression_call(self, "__rfloordiv__", other)

def __pow__(self, other: PandasExpr | Any) -> Self:
return register_expression_call(self, "__pow__", other)

def __rpow__(self, other: Any) -> Self: # pragma: no cover
raise NotImplementedError
return register_expression_call(self, "__rpow__", other)

def __mod__(self, other: PandasExpr | Any) -> Self:
return register_expression_call(self, "__mod__", other)

def __rmod__(self, other: Any) -> Self: # pragma: no cover
raise NotImplementedError
return register_expression_call(self, "__rmod__", other)

# Unary

Expand Down
78 changes: 48 additions & 30 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,106 +92,124 @@ def is_in(self, other: Any) -> PandasSeries:
# Binary comparisons

def __eq__(self, other: object) -> PandasSeries: # type: ignore[override]
other = validate_column_comparand(other)
ser = self._series
return self._from_series((ser == other).rename(ser.name, copy=False))
other = validate_column_comparand(other)
return self._from_series((ser.__eq__(other)).rename(ser.name, copy=False))

def __ne__(self, other: object) -> PandasSeries: # type: ignore[override]
other = validate_column_comparand(other)
ser = self._series
return self._from_series((ser != other).rename(ser.name, copy=False))
other = validate_column_comparand(other)
return self._from_series((ser.__ne__(other)).rename(ser.name, copy=False))

def __ge__(self, other: Any) -> PandasSeries:
other = validate_column_comparand(other)
ser = self._series
return self._from_series((ser >= other).rename(ser.name, copy=False))
other = validate_column_comparand(other)
return self._from_series((ser.__ge__(other)).rename(ser.name, copy=False))

def __gt__(self, other: Any) -> PandasSeries:
other = validate_column_comparand(other)
ser = self._series
return self._from_series((ser > other).rename(ser.name, copy=False))
other = validate_column_comparand(other)
return self._from_series((ser.__gt__(other)).rename(ser.name, copy=False))

def __le__(self, other: Any) -> PandasSeries:
other = validate_column_comparand(other)
ser = self._series
return self._from_series((ser <= other).rename(ser.name, copy=False))
other = validate_column_comparand(other)
return self._from_series((ser.__le__(other)).rename(ser.name, copy=False))

def __lt__(self, other: Any) -> PandasSeries:
other = validate_column_comparand(other)
ser = self._series
return self._from_series((ser < other).rename(ser.name, copy=False))
other = validate_column_comparand(other)
return self._from_series((ser.__lt__(other)).rename(ser.name, copy=False))

def __and__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser & other).rename(ser.name, copy=False))
return self._from_series((ser.__and__(other)).rename(ser.name, copy=False))

def __rand__(self, other: Any) -> PandasSeries:
return self.__and__(other)
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rand__(other)).rename(ser.name, copy=False))

def __or__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser | other).rename(ser.name, copy=False))
return self._from_series((ser.__or__(other)).rename(ser.name, copy=False))

def __ror__(self, other: Any) -> PandasSeries:
return self.__or__(other)
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__ror__(other)).rename(ser.name, copy=False))

def __add__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser + other).rename(ser.name, copy=False))
return self._from_series((ser.__add__(other)).rename(ser.name, copy=False))

def __radd__(self, other: Any) -> PandasSeries:
return self.__add__(other)
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__radd__(other)).rename(ser.name, copy=False))

def __sub__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser - other).rename(ser.name, copy=False))
return self._from_series((ser.__sub__(other)).rename(ser.name, copy=False))

def __rsub__(self, other: Any) -> PandasSeries:
return -1 * self.__sub__(other)
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rsub__(other)).rename(ser.name, copy=False))

def __mul__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser * other).rename(ser.name, copy=False))
return self._from_series((ser.__mul__(other)).rename(ser.name, copy=False))

def __rmul__(self, other: Any) -> PandasSeries:
return self.__mul__(other)
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rmul__(other)).rename(ser.name, copy=False))

def __truediv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser / other).rename(ser.name, copy=False))
return self._from_series((ser.__truediv__(other)).rename(ser.name, copy=False))

def __rtruediv__(self, other: Any) -> PandasSeries:
raise NotImplementedError
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rtruediv__(other)).rename(ser.name, copy=False))

def __floordiv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser // other).rename(ser.name, copy=False))
return self._from_series((ser.__floordiv__(other)).rename(ser.name, copy=False))

def __rfloordiv__(self, other: Any) -> PandasSeries:
raise NotImplementedError
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rfloordiv__(other)).rename(ser.name, copy=False))

def __pow__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser**other).rename(ser.name, copy=False))
return self._from_series((ser.__pow__(other)).rename(ser.name, copy=False))

def __rpow__(self, other: Any) -> PandasSeries: # pragma: no cover
raise NotImplementedError
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rpow__(other)).rename(ser.name, copy=False))

def __mod__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser % other).rename(ser.name, copy=False))
return self._from_series((ser.__mod__(other)).rename(ser.name, copy=False))

def __rmod__(self, other: Any) -> PandasSeries: # pragma: no cover
raise NotImplementedError
ser = self._series
other = validate_column_comparand(other)
return self._from_series((ser.__rmod__(other)).rename(ser.name, copy=False))

# Unary

Expand Down
27 changes: 27 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,30 @@ def test_shape(df_raw: Any) -> None:
result = nw.DataFrame(df_raw).shape
expected = (3, 3)
assert result == expected


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy])
def test_expr(df_raw: Any) -> None:
result = nw.LazyFrame(df_raw).with_columns(
a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")),
b=nw.col("z") / (2 - nw.col("b")),
c=nw.col("a") + nw.col("b") / 2,
d=nw.col("a") - nw.col("b"),
e=((nw.col("a") > nw.col("b")) & (nw.col("a") >= nw.col("z"))).cast(nw.Int64),
f=(
(nw.col("a") < nw.col("b"))
| (nw.col("a") <= nw.col("z"))
| (nw.col("a") == 1)
).cast(nw.Int64),
)
result_native = nw.to_native(result)
expected = {
"a": [4, 3.333333, 3.5],
"b": [-3.5, -4.0, -2.25],
"z": [7.0, 8.0, 9.0],
"c": [3, 5, 5],
"d": [-3, -1, -4],
"e": [0, 0, 0],
"f": [1, 1, 1],
}
compare_dicts(result_native, expected)

0 comments on commit d7af728

Please sign in to comment.