Skip to content

Commit

Permalink
stricter
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 14, 2025
1 parent 9e2143f commit 168f478
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 70 deletions.
8 changes: 8 additions & 0 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from narwhals.typing import CompliantNamespace
from narwhals.typing import CompliantSeries
from narwhals.typing import CompliantSeriesT_co
from narwhals.typing import IntoExpr

IntoCompliantExpr: TypeAlias = (
CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co
Expand Down Expand Up @@ -334,3 +335,10 @@ def extract_compliant(
if isinstance(other, Series):
return other._compliant_series
return other


def operation_is_order_dependent(*args: IntoExpr | Any) -> bool:
# If `rhs` is a Expr, we look at `_is_order_dependent`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a') + 1), and so we default
# to `False`.
return any(getattr(x, "_is_order_dependent", False) for x in args)
4 changes: 3 additions & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -64,6 +65,7 @@ def _from_compliant_dataframe(self, df: Any) -> Self:
level=self._level,
)

@abstractmethod
def _extract_compliant(self, arg: Any) -> Any:
raise NotImplementedError

Expand Down Expand Up @@ -3629,7 +3631,7 @@ def _extract_compliant(self, arg: Any) -> Any:

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series):
if isinstance(arg, Series): # pragma: no cover
msg = "Mixing Series with LazyFrame is not supported."
raise TypeError(msg)
if isinstance(arg, Expr):
Expand Down
71 changes: 35 additions & 36 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Sequence

from narwhals._expression_parsing import extract_compliant
from narwhals._expression_parsing import operation_is_order_dependent
from narwhals.dtypes import _validate_dtype
from narwhals.expr_cat import ExprCatNamespace
from narwhals.expr_dt import ExprDateTimeNamespace
Expand All @@ -27,13 +28,6 @@
from narwhals.typing import IntoExpr


def binary_operation_is_order_dependent(lhs: Expr, rhs: Expr | Any) -> bool:
# If `rhs` is a Expr, we look at `_is_order_dependent`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a') + 1), and so we default
# to `False`.
return lhs._is_order_dependent or getattr(rhs, "_is_order_dependent", False)


class Expr:
def __init__(
self,
Expand Down Expand Up @@ -242,23 +236,23 @@ def __eq__(self, other: object) -> Self: # type: ignore[override]
lambda plx: self._to_compliant_expr(plx).__eq__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __ne__(self, other: object) -> Self: # type: ignore[override]
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__ne__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __and__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__and__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rand__(self, other: Any) -> Self:
Expand All @@ -269,15 +263,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __or__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__or__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __ror__(self, other: Any) -> Self:
Expand All @@ -288,15 +282,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __add__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__add__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __radd__(self, other: Any) -> Self:
Expand All @@ -307,15 +301,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __sub__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__sub__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rsub__(self, other: Any) -> Self:
Expand All @@ -326,15 +320,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __truediv__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__truediv__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rtruediv__(self, other: Any) -> Self:
Expand All @@ -345,15 +339,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __mul__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__mul__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rmul__(self, other: Any) -> Self:
Expand All @@ -364,47 +358,47 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __le__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__le__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __lt__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__lt__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __gt__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__gt__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __ge__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__ge__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __pow__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__pow__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rpow__(self, other: Any) -> Self:
Expand All @@ -415,15 +409,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __floordiv__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__floordiv__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rfloordiv__(self, other: Any) -> Self:
Expand All @@ -434,15 +428,15 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __mod__(self, other: Any) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__mod__(
extract_compliant(plx, other)
),
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

def __rmod__(self, other: Any) -> Self:
Expand All @@ -453,7 +447,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:

return self.__class__(
func,
is_order_dependent=binary_operation_is_order_dependent(self, other),
is_order_dependent=operation_is_order_dependent(self, other),
)

# --- unary ---
Expand Down Expand Up @@ -1986,7 +1980,9 @@ def is_between(
extract_compliant(plx, upper_bound),
closed,
),
self._is_order_dependent,
is_order_dependent=operation_is_order_dependent(
self, lower_bound, upper_bound
),
)

def is_in(self, other: Any) -> Self:
Expand Down Expand Up @@ -2117,11 +2113,12 @@ def filter(self, *predicates: Any) -> Self:
a: [[5,6,7]]
b: [[10,11,12]]
"""
flat_predicates = flatten(predicates)
return self.__class__(
lambda plx: self._to_compliant_expr(plx).filter(
*[extract_compliant(plx, pred) for pred in flatten(predicates)],
*[extract_compliant(plx, pred) for pred in flat_predicates],
),
is_order_dependent=True,
is_order_dependent=operation_is_order_dependent(*flat_predicates),
)

def is_null(self) -> Self:
Expand Down Expand Up @@ -3512,7 +3509,9 @@ def clip(
extract_compliant(plx, lower_bound),
extract_compliant(plx, upper_bound),
),
self._is_order_dependent,
is_order_dependent=operation_is_order_dependent(
self, lower_bound, upper_bound
),
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -3568,7 +3567,7 @@ def mode(self: Self) -> Self:
a: [[1]]
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).mode(), is_order_dependent=True
lambda plx: self._to_compliant_expr(plx).mode(), is_order_dependent=False
)

def is_finite(self: Self) -> Self:
Expand Down
Loading

0 comments on commit 168f478

Please sign in to comment.