-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: implement when/then/otherwise for DuckDB #1759
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,7 @@ | |||||||||||||
from typing import Any | ||||||||||||||
from typing import Literal | ||||||||||||||
from typing import Sequence | ||||||||||||||
from typing import cast | ||||||||||||||
|
||||||||||||||
from narwhals._duckdb.expr import DuckDBExpr | ||||||||||||||
from narwhals._duckdb.utils import narwhals_to_native_dtype | ||||||||||||||
|
@@ -157,6 +158,16 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: | |||||||||||||
kwargs={"exprs": exprs}, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def when( | ||||||||||||||
self, | ||||||||||||||
*predicates: IntoDuckDBExpr, | ||||||||||||||
) -> DuckDBWhen: | ||||||||||||||
plx = self.__class__(backend_version=self._backend_version, version=self._version) | ||||||||||||||
condition = plx.all_horizontal(*predicates) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The other backends have this check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah i moved it up cause i was tired of rewriting it everywhere π #1756 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||||||||||||||
return DuckDBWhen( | ||||||||||||||
condition, self._backend_version, returns_scalar=False, version=self._version | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def col(self, *column_names: str) -> DuckDBExpr: | ||||||||||||||
return DuckDBExpr.from_column_names( | ||||||||||||||
*column_names, backend_version=self._backend_version, version=self._version | ||||||||||||||
|
@@ -203,3 +214,101 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: | |||||||||||||
version=self._version, | ||||||||||||||
kwargs={}, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class DuckDBWhen: | ||||||||||||||
def __init__( | ||||||||||||||
self, | ||||||||||||||
condition: DuckDBExpr, | ||||||||||||||
backend_version: tuple[int, ...], | ||||||||||||||
then_value: Any = None, | ||||||||||||||
otherwise_value: Any = None, | ||||||||||||||
*, | ||||||||||||||
returns_scalar: bool, | ||||||||||||||
version: Version, | ||||||||||||||
) -> None: | ||||||||||||||
self._backend_version = backend_version | ||||||||||||||
self._condition = condition | ||||||||||||||
self._then_value = then_value | ||||||||||||||
self._otherwise_value = otherwise_value | ||||||||||||||
self._returns_scalar = returns_scalar | ||||||||||||||
self._version = version | ||||||||||||||
|
||||||||||||||
def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: | ||||||||||||||
from duckdb import CaseExpression | ||||||||||||||
from duckdb import ConstantExpression | ||||||||||||||
|
||||||||||||||
from narwhals._expression_parsing import parse_into_expr | ||||||||||||||
|
||||||||||||||
plx = df.__narwhals_namespace__() | ||||||||||||||
condition = parse_into_expr(self._condition, namespace=plx)(df)[0] | ||||||||||||||
condition = cast("duckdb.Expression", condition) | ||||||||||||||
|
||||||||||||||
try: | ||||||||||||||
value = parse_into_expr(self._then_value, namespace=plx)(df)[0] | ||||||||||||||
except TypeError: | ||||||||||||||
# `self._otherwise_value` is a scalar and can't be converted to an expression | ||||||||||||||
value = ConstantExpression(self._then_value) | ||||||||||||||
value = cast("duckdb.Expression", value) | ||||||||||||||
|
||||||||||||||
if self._otherwise_value is None: | ||||||||||||||
return [CaseExpression(condition=condition, value=value)] | ||||||||||||||
try: | ||||||||||||||
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) | ||||||||||||||
except TypeError: | ||||||||||||||
# `self._otherwise_value` is a scalar and can't be converted to an expression | ||||||||||||||
return [ | ||||||||||||||
CaseExpression(condition=condition, value=value).otherwise( | ||||||||||||||
ConstantExpression(self._otherwise_value) | ||||||||||||||
) | ||||||||||||||
] | ||||||||||||||
otherwise = otherwise_expr(df)[0] | ||||||||||||||
return [CaseExpression(condition=condition, value=value).otherwise(otherwise)] | ||||||||||||||
|
||||||||||||||
def then(self, value: DuckDBExpr | Any) -> DuckDBThen: | ||||||||||||||
self._then_value = value | ||||||||||||||
|
||||||||||||||
return DuckDBThen( | ||||||||||||||
self, | ||||||||||||||
depth=0, | ||||||||||||||
function_name="whenthen", | ||||||||||||||
root_names=None, | ||||||||||||||
output_names=None, | ||||||||||||||
returns_scalar=self._returns_scalar, | ||||||||||||||
backend_version=self._backend_version, | ||||||||||||||
version=self._version, | ||||||||||||||
kwargs={"value": value}, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class DuckDBThen(DuckDBExpr): | ||||||||||||||
def __init__( | ||||||||||||||
self, | ||||||||||||||
call: DuckDBWhen, | ||||||||||||||
*, | ||||||||||||||
depth: int, | ||||||||||||||
function_name: str, | ||||||||||||||
root_names: list[str] | None, | ||||||||||||||
output_names: list[str] | None, | ||||||||||||||
returns_scalar: bool, | ||||||||||||||
backend_version: tuple[int, ...], | ||||||||||||||
version: Version, | ||||||||||||||
kwargs: dict[str, Any], | ||||||||||||||
) -> None: | ||||||||||||||
self._backend_version = backend_version | ||||||||||||||
self._version = version | ||||||||||||||
self._call = call | ||||||||||||||
self._depth = depth | ||||||||||||||
self._function_name = function_name | ||||||||||||||
self._root_names = root_names | ||||||||||||||
self._output_names = output_names | ||||||||||||||
self._returns_scalar = returns_scalar | ||||||||||||||
self._kwargs = kwargs | ||||||||||||||
|
||||||||||||||
def otherwise(self, value: DuckDBExpr | Any) -> DuckDBExpr: | ||||||||||||||
# type ignore because we are setting the `_call` attribute to a | ||||||||||||||
# callable object of type `DuckDBWhen`, base class has the attribute as | ||||||||||||||
# only a `Callable` | ||||||||||||||
self._call._otherwise_value = value # type: ignore[attr-defined] | ||||||||||||||
self._function_name = "whenotherwise" | ||||||||||||||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trailing comma π± (just joking. I love the other PRs π )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok this made me laugh out loud π