diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 99f043ebd..b02ad32ee 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -359,12 +359,7 @@ def when( *predicates: IntoArrowExpr, ) -> ArrowWhen: plx = self.__class__(backend_version=self._backend_version, version=self._version) - if predicates: - condition = plx.all_horizontal(*predicates) - else: - msg = "at least one predicate needs to be provided" - raise TypeError(msg) - + condition = plx.all_horizontal(*predicates) return ArrowWhen(condition, self._backend_version, version=self._version) def concat_str( diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index d9a1a8ac6..9a16d7f13 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -310,12 +310,7 @@ def when( *predicates: IntoDaskExpr, ) -> DaskWhen: plx = self.__class__(backend_version=self._backend_version, version=self._version) - if predicates: - condition = plx.all_horizontal(*predicates) - else: - msg = "at least one predicate needs to be provided" - raise TypeError(msg) - + condition = plx.all_horizontal(*predicates) return DaskWhen( condition, self._backend_version, returns_scalar=False, version=self._version ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 7885d7de0..212c9c938 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -371,12 +371,7 @@ def when( plx = self.__class__( self._implementation, self._backend_version, version=self._version ) - if predicates: - condition = plx.all_horizontal(*predicates) - else: - msg = "at least one predicate needs to be provided" - raise TypeError(msg) - + condition = plx.all_horizontal(*predicates) return PandasWhen( condition, self._implementation, self._backend_version, version=self._version ) diff --git a/narwhals/expr.py b/narwhals/expr.py index 809f76e77..653300da8 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -7643,6 +7643,9 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicates = flatten([predicates]) + if not self._predicates: + msg = "At least one predicate needs to be provided to `narwhals.when`." + raise TypeError(msg) def _extract_predicates(self, plx: Any) -> Any: return [extract_compliant(plx, v) for v in self._predicates]