Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement when/then/otherwise for DuckDB #1759

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DuckDBExpr(CompliantExpr["duckdb.Expression"]):

def __init__(
self,
call: Callable[[DuckDBLazyFrame], list[duckdb.Expression]],
call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]],
*,
depth: int,
function_name: str,
Expand Down
109 changes: 109 additions & 0 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any
from typing import Literal
from typing import Sequence
from typing import cast

from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.utils import narwhals_to_native_dtype
Expand Down Expand Up @@ -157,6 +158,16 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={"exprs": exprs},
)

def when(
self,
*predicates: IntoDuckDBExpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*predicates: IntoDuckDBExpr,
*predicates: IntoDuckDBExpr

trailing comma 😱 (just joking. I love the other PRs πŸ˜…)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this made me laugh out loud πŸ˜†

) -> DuckDBWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
condition = plx.all_horizontal(*predicates)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
condition = plx.all_horizontal(*predicates)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

The other backends have this check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i moved it up cause i was tired of rewriting it everywhere πŸ˜„ #1756

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

return DuckDBWhen(
condition, self._backend_version, returns_scalar=False, version=self._version
)

def col(self, *column_names: str) -> DuckDBExpr:
return DuckDBExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
Expand Down Expand Up @@ -203,3 +214,101 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]:
version=self._version,
kwargs={},
)


class DuckDBWhen:
def __init__(
self,
condition: DuckDBExpr,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
*,
returns_scalar: bool,
version: Version,
) -> None:
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value
self._returns_scalar = returns_scalar
self._version = version

def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
from duckdb import CaseExpression
from duckdb import ConstantExpression

from narwhals._expression_parsing import parse_into_expr

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("duckdb.Expression", condition)

try:
value = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value = ConstantExpression(self._then_value)
value = cast("duckdb.Expression", value)

if self._otherwise_value is None:
return [CaseExpression(condition=condition, value=value)]
try:
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
CaseExpression(condition=condition, value=value).otherwise(
ConstantExpression(self._otherwise_value)
)
]
otherwise = otherwise_expr(df)[0]
return [CaseExpression(condition=condition, value=value).otherwise(otherwise)]

def then(self, value: DuckDBExpr | Any) -> DuckDBThen:
self._then_value = value

return DuckDBThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
version=self._version,
kwargs={"value": value},
)


class DuckDBThen(DuckDBExpr):
def __init__(
self,
call: DuckDBWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
returns_scalar: bool,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
) -> None:
self._backend_version = backend_version
self._version = version
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names
self._returns_scalar = returns_scalar
self._kwargs = kwargs

def otherwise(self, value: DuckDBExpr | Any) -> DuckDBExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `DuckDBWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
3 changes: 2 additions & 1 deletion narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals.dtypes import DType
from narwhals.exceptions import InvalidIntoExprError
Expand Down Expand Up @@ -76,7 +77,7 @@ def parse_exprs_and_named_exprs(

def _columns_from_expr(
df: DuckDBLazyFrame, expr: IntoDuckDBExpr
) -> list[duckdb.Expression]:
) -> Sequence[duckdb.Expression]:
if isinstance(expr, str): # pragma: no cover
from duckdb import ColumnExpression

Expand Down
26 changes: 5 additions & 21 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
}


def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
Expand All @@ -28,9 +26,7 @@ def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
assert_equal_data(result, expected)


def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when_otherwise(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
expected = {
Expand All @@ -39,11 +35,7 @@ def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest
assert_equal_data(result, expected)


def test_multiple_conditions(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_multiple_conditions(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
Expand Down Expand Up @@ -85,11 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_value_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_value_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when"))
expected = {
Expand Down Expand Up @@ -122,11 +110,7 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_otherwise_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_otherwise_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when")
Expand Down
9 changes: 9 additions & 0 deletions tpch/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import dask.dataframe as dd
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
Expand All @@ -29,14 +30,18 @@
"pandas[pyarrow]": (pd, {"engine": "pyarrow", "dtype_backend": "pyarrow"}),
"polars[lazy]": (pl, {}),
"pyarrow": (pa, {}),
"duckdb": (duckdb, {}),
"dask": (dd, {"engine": "pyarrow", "dtype_backend": "pyarrow"}),
}

BACKEND_COLLECT_FUNC_MAP = {
"polars[lazy]": lambda x: x.collect(),
"duckdb": lambda x: x.pl(),
"dask": lambda x: x.compute(),
}

DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"]

QUERY_DATA_PATH_MAP = {
"q1": (LINEITEM_PATH,),
"q2": (REGION_PATH, NATION_PATH, SUPPLIER_PATH, PART_PATH, PARTSUPP_PATH),
Expand Down Expand Up @@ -90,6 +95,10 @@ def execute_query(query_id: str) -> None:
data_paths = QUERY_DATA_PATH_MAP[query_id]

for backend, (native_namespace, kwargs) in BACKEND_NAMESPACE_KWARGS_MAP.items():
if backend == "duckdb" and query_id in DUCKDB_XFAILS:
print(f"\nSkipping {query_id} for DuckDB") # noqa: T201
continue

print(f"\nRunning {query_id} with {backend=}") # noqa: T201
result = query_module.query(
*(
Expand Down
Loading