Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Disallow order-dependent expressions from being passed to nw.LazyFrame #1806

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
68 changes: 0 additions & 68 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,69 +350,6 @@ def var(self, ddof: int) -> Self:
def skew(self: Self) -> Self:
return self._from_call(lambda _input: _input.skew(), "skew", returns_scalar=True)

def shift(self, n: int) -> Self:
return self._from_call(
lambda _input, n: _input.shift(n),
"shift",
n=n,
returns_scalar=self._returns_scalar,
)

def cum_sum(self: Self, *, reverse: bool) -> Self:
if reverse:
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

return self._from_call(
lambda _input: _input.cumsum(),
"cum_sum",
returns_scalar=self._returns_scalar,
)

def cum_count(self: Self, *, reverse: bool) -> Self:
if reverse:
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

return self._from_call(
lambda _input: (~_input.isna()).astype(int).cumsum(),
"cum_count",
returns_scalar=self._returns_scalar,
)

def cum_min(self: Self, *, reverse: bool) -> Self:
if reverse:
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

return self._from_call(
lambda _input: _input.cummin(),
"cum_min",
returns_scalar=self._returns_scalar,
)

def cum_max(self: Self, *, reverse: bool) -> Self:
if reverse:
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

return self._from_call(
lambda _input: _input.cummax(),
"cum_max",
returns_scalar=self._returns_scalar,
)

def cum_prod(self: Self, *, reverse: bool) -> Self:
if reverse:
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

return self._from_call(
lambda _input: _input.cumprod(),
"cum_prod",
returns_scalar=self._returns_scalar,
)

def is_between(
self,
lower_bound: Self | Any,
Expand Down Expand Up @@ -554,11 +491,6 @@ def clip(
returns_scalar=self._returns_scalar,
)

def diff(self: Self) -> Self:
return self._from_call(
lambda _input: _input.diff(), "diff", returns_scalar=self._returns_scalar
)

def n_unique(self: Self) -> Self:
return self._from_call(
lambda _input: _input.nunique(dropna=False), "n_unique", returns_scalar=True
Expand Down
75 changes: 55 additions & 20 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,15 @@ def _from_compliant_dataframe(self, df: Any) -> Self:
level=self._level,
)

def _extract_compliant(self, arg: Any) -> Any:
raise NotImplementedError

def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any:
"""Process `args` and `kwargs`, extracting underlying objects as we go."""
args = [self._extract_compliant(v) for v in flatten(args)] # type: ignore[assignment]
kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()}
return args, kwargs

def _extract_compliant(self, arg: Any) -> Any:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of defining this in BaseFrame, we define it in DataFrame and LazyFrame so they can have different behaviours

from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series):
return arg._compliant_series
if isinstance(arg, Expr):
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)):
msg = (
f"Expected Narwhals object, got: {type(arg)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)
return arg

@property
def schema(self) -> Schema:
return Schema(self._compliant_frame.schema.items())
Expand Down Expand Up @@ -360,6 +343,26 @@ class DataFrame(BaseFrame[DataFrameT]):
```
"""

def _extract_compliant(self, arg: Any) -> Any:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series):
return arg._compliant_series
if isinstance(arg, Expr):
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)):
msg = (
f"Expected Narwhals object, got: {type(arg)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)
return arg

@property
def _series(self) -> type[Series[Any]]:
from narwhals.series import Series
Expand Down Expand Up @@ -3620,6 +3623,38 @@ class LazyFrame(BaseFrame[FrameT]):
```
"""

def _extract_compliant(self, arg: Any) -> Any:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series):
msg = "Mixing Series with LazyFrame is not supported."
raise TypeError(msg)
if isinstance(arg, Expr):
if arg._is_order_dependent:
msg = (
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
"Hints:\n"
"- Instead of `lf.select(nw.col('a').sort())`, use `lf.select('a').sort()\n"
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
"- `Expr.cum_sum`, and other such expressions, are not currently supported.\n"
" In a future version of Narwhals, a `order_by` argument will be added and \n"
" they will be supported."
)
raise TypeError(msg)
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)):
msg = (
f"Expected Narwhals object, got: {type(arg)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)
return arg

@property
def _dataframe(self) -> type[DataFrame[Any]]:
return DataFrame
Expand Down
Loading
Loading