diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 0f33ff846..4515cbba1 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -31,7 +31,7 @@ class DuckDBExpr(CompliantExpr["duckdb.Expression"]): def __init__( self, - call: Callable[[DuckDBLazyFrame], list[duckdb.Expression]], + call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], *, depth: int, function_name: str, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 27514b711..c91d11d3f 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -3,10 +3,11 @@ import functools import operator from functools import reduce -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from typing import Any from typing import Literal from typing import Sequence +from typing import cast from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.utils import narwhals_to_native_dtype @@ -157,6 +158,16 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: kwargs={"exprs": exprs}, ) + def when( + self, + *predicates: IntoDuckDBExpr, + ) -> DuckDBWhen: + plx = self.__class__(backend_version=self._backend_version, version=self._version) + condition = plx.all_horizontal(*predicates) + return DuckDBWhen( + condition, self._backend_version, returns_scalar=False, version=self._version + ) + def col(self, *column_names: str) -> DuckDBExpr: return DuckDBExpr.from_column_names( *column_names, backend_version=self._backend_version, version=self._version @@ -204,6 +215,7 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: kwargs={}, ) + class DuckDBWhen: def __init__( self, @@ -223,34 +235,35 @@ def __init__( self._version = version def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: + from duckdb import CaseExpression + from duckdb import ConstantExpression + from narwhals._expression_parsing import parse_into_expr - from duckdb import FunctionExpression, ColumnExpression, ConstantExpression, CaseExpression plx = df.__narwhals_namespace__() condition = parse_into_expr(self._condition, namespace=plx)(df)[0] condition = cast("duckdb.Expression", condition) - breakpoint() - try: - value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] + value = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression - value_series = ConstantExpression(self._then_value) - value_series = cast("duckdb.Expression", value_series) + value = ConstantExpression(self._then_value) + value = cast("duckdb.Expression", value) if self._otherwise_value is None: - return [value_series.where(condition)] + return [CaseExpression(condition=condition, value=value)] try: otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression - return [value_series.where(condition, self._otherwise_value)] - otherwise_series = otherwise_expr(df)[0] - - if otherwise_expr._returns_scalar: # type: ignore[attr-defined] - return [value_series.where(condition, otherwise_series[0])] - return [value_series.where(condition, otherwise_series)] + return [ + CaseExpression(condition=condition, value=value).otherwise( + ConstantExpression(self._otherwise_value) + ) + ] + otherwise = otherwise_expr(df)[0] + return [CaseExpression(condition=condition, value=value).otherwise(otherwise)] def then(self, value: DuckDBExpr | Any) -> DuckDBThen: self._then_value = value diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index abac2e158..62f126db9 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -4,6 +4,7 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from narwhals.dtypes import DType from narwhals.exceptions import InvalidIntoExprError @@ -76,7 +77,7 @@ def parse_exprs_and_named_exprs( def _columns_from_expr( df: DuckDBLazyFrame, expr: IntoDuckDBExpr -) -> list[duckdb.Expression]: +) -> Sequence[duckdb.Expression]: if isinstance(expr, str): # pragma: no cover from duckdb import ColumnExpression diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 739b00e2d..94e37aaa3 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -17,9 +17,7 @@ } -def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_when(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { @@ -28,9 +26,7 @@ def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None: assert_equal_data(result, expected) -def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_when_otherwise(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { @@ -39,11 +35,7 @@ def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest assert_equal_data(result, expected) -def test_multiple_conditions( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_multiple_conditions(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") @@ -85,11 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_value_expression( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_value_expression(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when")) expected = { @@ -122,11 +110,7 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_otherwise_expression( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_otherwise_expression(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when")