Skip to content

Commit

Permalink
chore: validate predicates in nw.when one level higher
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 7, 2025
1 parent a6d76e1 commit f57f894
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 18 deletions.
7 changes: 1 addition & 6 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,7 @@ def when(
*predicates: IntoArrowExpr,
) -> ArrowWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return ArrowWhen(condition, self._backend_version, version=self._version)

def concat_str(
Expand Down
7 changes: 1 addition & 6 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,7 @@ def when(
*predicates: IntoDaskExpr,
) -> DaskWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return DaskWhen(
condition, self._backend_version, returns_scalar=False, version=self._version
)
Expand Down
7 changes: 1 addition & 6 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,7 @@ def when(
plx = self.__class__(
self._implementation, self._backend_version, version=self._version
)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return PandasWhen(
condition, self._implementation, self._backend_version, version=self._version
)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7643,6 +7643,9 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
class When:
def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
self._predicates = flatten([predicates])
if not self._predicates:
msg = "At least one predicate needs to be provided to `narwhals.when`."
raise TypeError(msg)

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]
Expand Down

0 comments on commit f57f894

Please sign in to comment.