Skip to content

Commit

Permalink
feat: implement when-then-otherwise for duckdb
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 7, 2025
1 parent 761b178 commit 89c8b1e
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 37 deletions.
2 changes: 1 addition & 1 deletion narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DuckDBExpr(CompliantExpr["duckdb.Expression"]):

def __init__(
self,
call: Callable[[DuckDBLazyFrame], list[duckdb.Expression]],
call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]],
*,
depth: int,
function_name: str,
Expand Down
41 changes: 27 additions & 14 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import functools
import operator
from functools import reduce
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import Sequence
from typing import cast

from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.utils import narwhals_to_native_dtype
Expand Down Expand Up @@ -157,6 +158,16 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={"exprs": exprs},
)

def when(
self,
*predicates: IntoDuckDBExpr,
) -> DuckDBWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
condition = plx.all_horizontal(*predicates)
return DuckDBWhen(
condition, self._backend_version, returns_scalar=False, version=self._version
)

def col(self, *column_names: str) -> DuckDBExpr:
return DuckDBExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
Expand Down Expand Up @@ -204,6 +215,7 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={},
)


class DuckDBWhen:
def __init__(
self,
Expand All @@ -223,34 +235,35 @@ def __init__(
self._version = version

def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
from duckdb import CaseExpression
from duckdb import ConstantExpression

from narwhals._expression_parsing import parse_into_expr
from duckdb import FunctionExpression, ColumnExpression, ConstantExpression, CaseExpression

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("duckdb.Expression", condition)

breakpoint()

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
value = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = ConstantExpression(self._then_value)
value_series = cast("duckdb.Expression", value_series)
value = ConstantExpression(self._then_value)
value = cast("duckdb.Expression", value)

if self._otherwise_value is None:
return [value_series.where(condition)]
return [CaseExpression(condition=condition, value=value)]
try:
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [value_series.where(condition, self._otherwise_value)]
otherwise_series = otherwise_expr(df)[0]

if otherwise_expr._returns_scalar: # type: ignore[attr-defined]
return [value_series.where(condition, otherwise_series[0])]
return [value_series.where(condition, otherwise_series)]
return [
CaseExpression(condition=condition, value=value).otherwise(
ConstantExpression(self._otherwise_value)
)
]
otherwise = otherwise_expr(df)[0]
return [CaseExpression(condition=condition, value=value).otherwise(otherwise)]

def then(self, value: DuckDBExpr | Any) -> DuckDBThen:
self._then_value = value
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals.dtypes import DType
from narwhals.exceptions import InvalidIntoExprError
Expand Down Expand Up @@ -76,7 +77,7 @@ def parse_exprs_and_named_exprs(

def _columns_from_expr(
df: DuckDBLazyFrame, expr: IntoDuckDBExpr
) -> list[duckdb.Expression]:
) -> Sequence[duckdb.Expression]:
if isinstance(expr, str): # pragma: no cover
from duckdb import ColumnExpression

Expand Down
26 changes: 5 additions & 21 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
}


def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
Expand All @@ -28,9 +26,7 @@ def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
assert_equal_data(result, expected)


def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when_otherwise(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
expected = {
Expand All @@ -39,11 +35,7 @@ def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest
assert_equal_data(result, expected)


def test_multiple_conditions(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_multiple_conditions(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
Expand Down Expand Up @@ -85,11 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_value_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_value_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when"))
expected = {
Expand Down Expand Up @@ -122,11 +110,7 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_otherwise_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_otherwise_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when")
Expand Down

0 comments on commit 89c8b1e

Please sign in to comment.