From 1f0c7183048c8568cb714b39d8efe8684b5c66f4 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Wed, 8 Jan 2025 13:08:46 +0000 Subject: [PATCH] feat: implement when/then/otherwise for DuckDB (#1759) --- narwhals/_duckdb/expr.py | 2 +- narwhals/_duckdb/namespace.py | 109 +++++++++++++++++++++++++++++ narwhals/_duckdb/utils.py | 3 +- tests/expr_and_series/when_test.py | 26 ++----- tpch/execute.py | 9 +++ 5 files changed, 126 insertions(+), 23 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 0f33ff846..4515cbba1 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -31,7 +31,7 @@ class DuckDBExpr(CompliantExpr["duckdb.Expression"]): def __init__( self, - call: Callable[[DuckDBLazyFrame], list[duckdb.Expression]], + call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], *, depth: int, function_name: str, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index bcd7eff6d..c91d11d3f 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -7,6 +7,7 @@ from typing import Any from typing import Literal from typing import Sequence +from typing import cast from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.utils import narwhals_to_native_dtype @@ -157,6 +158,16 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: kwargs={"exprs": exprs}, ) + def when( + self, + *predicates: IntoDuckDBExpr, + ) -> DuckDBWhen: + plx = self.__class__(backend_version=self._backend_version, version=self._version) + condition = plx.all_horizontal(*predicates) + return DuckDBWhen( + condition, self._backend_version, returns_scalar=False, version=self._version + ) + def col(self, *column_names: str) -> DuckDBExpr: return DuckDBExpr.from_column_names( *column_names, backend_version=self._backend_version, version=self._version @@ -203,3 +214,101 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: version=self._version, kwargs={}, ) + + +class DuckDBWhen: + def __init__( + self, + condition: DuckDBExpr, + backend_version: tuple[int, ...], + then_value: Any = None, + otherwise_value: Any = None, + *, + returns_scalar: bool, + version: Version, + ) -> None: + self._backend_version = backend_version + self._condition = condition + self._then_value = then_value + self._otherwise_value = otherwise_value + self._returns_scalar = returns_scalar + self._version = version + + def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: + from duckdb import CaseExpression + from duckdb import ConstantExpression + + from narwhals._expression_parsing import parse_into_expr + + plx = df.__narwhals_namespace__() + condition = parse_into_expr(self._condition, namespace=plx)(df)[0] + condition = cast("duckdb.Expression", condition) + + try: + value = parse_into_expr(self._then_value, namespace=plx)(df)[0] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + value = ConstantExpression(self._then_value) + value = cast("duckdb.Expression", value) + + if self._otherwise_value is None: + return [CaseExpression(condition=condition, value=value)] + try: + otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + return [ + CaseExpression(condition=condition, value=value).otherwise( + ConstantExpression(self._otherwise_value) + ) + ] + otherwise = otherwise_expr(df)[0] + return [CaseExpression(condition=condition, value=value).otherwise(otherwise)] + + def then(self, value: DuckDBExpr | Any) -> DuckDBThen: + self._then_value = value + + return DuckDBThen( + self, + depth=0, + function_name="whenthen", + root_names=None, + output_names=None, + returns_scalar=self._returns_scalar, + backend_version=self._backend_version, + version=self._version, + kwargs={"value": value}, + ) + + +class DuckDBThen(DuckDBExpr): + def __init__( + self, + call: DuckDBWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + returns_scalar: bool, + backend_version: tuple[int, ...], + version: Version, + kwargs: dict[str, Any], + ) -> None: + self._backend_version = backend_version + self._version = version + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + self._returns_scalar = returns_scalar + self._kwargs = kwargs + + def otherwise(self, value: DuckDBExpr | Any) -> DuckDBExpr: + # type ignore because we are setting the `_call` attribute to a + # callable object of type `DuckDBWhen`, base class has the attribute as + # only a `Callable` + self._call._otherwise_value = value # type: ignore[attr-defined] + self._function_name = "whenotherwise" + return self diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index abac2e158..62f126db9 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -4,6 +4,7 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from narwhals.dtypes import DType from narwhals.exceptions import InvalidIntoExprError @@ -76,7 +77,7 @@ def parse_exprs_and_named_exprs( def _columns_from_expr( df: DuckDBLazyFrame, expr: IntoDuckDBExpr -) -> list[duckdb.Expression]: +) -> Sequence[duckdb.Expression]: if isinstance(expr, str): # pragma: no cover from duckdb import ColumnExpression diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 739b00e2d..94e37aaa3 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -17,9 +17,7 @@ } -def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_when(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { @@ -28,9 +26,7 @@ def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None: assert_equal_data(result, expected) -def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_when_otherwise(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { @@ -39,11 +35,7 @@ def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest assert_equal_data(result, expected) -def test_multiple_conditions( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_multiple_conditions(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") @@ -85,11 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_value_expression( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_value_expression(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when")) expected = { @@ -122,11 +110,7 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_otherwise_expression( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_otherwise_expression(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when") diff --git a/tpch/execute.py b/tpch/execute.py index e19b51dfb..1f3823ced 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -5,6 +5,7 @@ from pathlib import Path import dask.dataframe as dd +import duckdb import pandas as pd import polars as pl import pyarrow as pa @@ -29,14 +30,18 @@ "pandas[pyarrow]": (pd, {"engine": "pyarrow", "dtype_backend": "pyarrow"}), "polars[lazy]": (pl, {}), "pyarrow": (pa, {}), + "duckdb": (duckdb, {}), "dask": (dd, {"engine": "pyarrow", "dtype_backend": "pyarrow"}), } BACKEND_COLLECT_FUNC_MAP = { "polars[lazy]": lambda x: x.collect(), + "duckdb": lambda x: x.pl(), "dask": lambda x: x.compute(), } +DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"] + QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,), "q2": (REGION_PATH, NATION_PATH, SUPPLIER_PATH, PART_PATH, PARTSUPP_PATH), @@ -90,6 +95,10 @@ def execute_query(query_id: str) -> None: data_paths = QUERY_DATA_PATH_MAP[query_id] for backend, (native_namespace, kwargs) in BACKEND_NAMESPACE_KWARGS_MAP.items(): + if backend == "duckdb" and query_id in DUCKDB_XFAILS: + print(f"\nSkipping {query_id} for DuckDB") # noqa: T201 + continue + print(f"\nRunning {query_id} with {backend=}") # noqa: T201 result = query_module.query( *(