Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add when chaining #669

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 75 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
d7446f4
add simple when
aivanoved Jul 17, 2024
6ebc78b
delete unnecessary file
aivanoved Jul 17, 2024
a3fdcc5
lint with ruff
aivanoved Jul 17, 2024
1ad1c94
use lambda expression
aivanoved Jul 17, 2024
55f394f
Merge branch 'main' into add-where-expression
aivanoved Jul 18, 2024
93e7121
remove deleted file
aivanoved Jul 18, 2024
f3770b7
Fix errors from the migration
aivanoved Jul 22, 2024
cf92f80
Merge branch 'main' into add-where-expression
aivanoved Jul 22, 2024
a7f442a
remove unnecessary changes
aivanoved Jul 22, 2024
7f23f05
add back the change in version
aivanoved Jul 22, 2024
7cc3aad
fix rename change
aivanoved Jul 22, 2024
ab85e40
rename test file
aivanoved Jul 22, 2024
4a8ac56
fix forgotten memeber change
aivanoved Jul 22, 2024
8283f24
make api identical
aivanoved Jul 22, 2024
f1c667e
remove unnecessary diff
aivanoved Jul 22, 2024
74937ea
add when documentation
aivanoved Jul 23, 2024
5b030d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
7390e1a
address mypy issues
aivanoved Jul 23, 2024
279e3ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
63048ee
address ruff type-ignore blanket issue
aivanoved Jul 23, 2024
e96af89
support `Iterable[Expr]` in the pandas api
aivanoved Jul 23, 2024
d4f0e9c
move when test file to a better location
aivanoved Jul 23, 2024
99d9899
make when test filename similar to other tests
aivanoved Jul 23, 2024
71e542d
add simple when
aivanoved Jul 17, 2024
8b1355a
lint with ruff
aivanoved Jul 17, 2024
eb36164
use lambda expression
aivanoved Jul 17, 2024
c9b09bf
Fix errors from the migration
aivanoved Jul 22, 2024
2b1eabc
remove unnecessary changes
aivanoved Jul 22, 2024
2ef564e
remove unnecessary diff
aivanoved Jul 22, 2024
0e4773d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
151fe14
fix rebase error
aivanoved Jul 23, 2024
add7b89
remove files left from wrong rebase
aivanoved Jul 23, 2024
504c4ea
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 23, 2024
fd21c78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
2f03bd0
chore: remove all wrong rebase leftover code
aivanoved Jul 23, 2024
43e2670
feat: add chaining for polars
aivanoved Jul 23, 2024
071ec9f
chore: remove unused fields
aivanoved Jul 23, 2024
9a49db6
bug: fix bug in chaining
aivanoved Jul 25, 2024
47a8cdd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
b71847b
bug: add chaing from chained then
aivanoved Jul 25, 2024
b5d7df8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
2ddd993
docs: add when to api reference
aivanoved Jul 25, 2024
489c463
bug: allow constraints to be passed to pandas implementation
aivanoved Jul 25, 2024
bb3847e
misc: fix typo
aivanoved Jul 25, 2024
37cc634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
0454ac4
misc: keep api the same
aivanoved Jul 25, 2024
4ad28b7
test: add test for multiple predicates
aivanoved Jul 25, 2024
0ded393
misc: make when stable
aivanoved Jul 29, 2024
3280a3c
bug: make stable v1 `Then` a stable expr `Expr`
aivanoved Jul 29, 2024
5c6deed
bug: fix when constraints pandas implementation
aivanoved Jul 29, 2024
27d17d9
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 29, 2024
8688491
test: stabalise all paths and test error on no arg
aivanoved Jul 29, 2024
81039bf
misc: add when to main api
aivanoved Jul 29, 2024
c5eac26
Merge branch 'add-where-expression' into add-when-chaining
aivanoved Jul 29, 2024
af83c50
fix: fix when chain
aivanoved Jul 29, 2024
1196fab
misc: remove constraints
aivanoved Jul 30, 2024
ba702b2
Merge branch 'add-where-expression' into add-when-chaining
aivanoved Jul 30, 2024
8991103
docs: remove wrong import
aivanoved Jul 30, 2024
beba175
docs: remove wrong import in stable
aivanoved Jul 30, 2024
b07fc9f
Merge branch 'add-where-expression' into add-when-chaining
aivanoved Jul 30, 2024
606feed
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 30, 2024
45684d4
docs: remove wrong import in main docstring
aivanoved Jul 30, 2024
c0d2f7a
Merge branch 'add-where-expression' into add-when-chaining
aivanoved Jul 30, 2024
10245c4
misc: make when the chaining stable
aivanoved Jul 30, 2024
c2bbc66
Merge branch 'main' into add-when-chaining
aivanoved Aug 26, 2024
ad5a50a
Merge branch 'main' into add-when-chaining
aivanoved Sep 11, 2024
09bea00
feat: add when then chaining back
aivanoved Sep 11, 2024
2c37729
misc: fix typo
aivanoved Sep 11, 2024
64cde0b
misc: remove unnecessary file
aivanoved Sep 11, 2024
53de101
misc: remove unused function
aivanoved Sep 11, 2024
05d2c5f
misc: add some stability
aivanoved Sep 11, 2024
fbb6c26
misc: stabilify when-then
aivanoved Sep 11, 2024
4c872b9
Merge branch 'main' into add-when-chaining
aivanoved Sep 16, 2024
e0819ce
tests: increase the test coverage for when-the-otherwise
aivanoved Sep 16, 2024
8b32e5d
tests: actually increase the thest coverage of when-the-otherwise
aivanoved Sep 16, 2024
9427b33
Update narwhals/_pandas_like/namespace.py
aivanoved Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def __init__(
self._then_value = then_value
self._otherwise_value = otherwise_value

self._already_set = self._condition

def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace
Expand Down Expand Up @@ -361,10 +363,166 @@ def __init__(
self._root_names = root_names
self._output_names = output_names

def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr:
def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `PandasWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self

def when(self, *predicates: IntoPandasLikeExpr) -> PandasChainedWhen:
plx = PandasLikeNamespace(self._implementation, self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)
return PandasChainedWhen(
self, condition, self._depth + 1, self._implementation, self._backend_version
)


class PandasChainedWhen:
def __init__(
self,
above_then: PandasThen | PandasChainedThen,
condition: PandasLikeExpr,
depth: int,
implementation: Implementation,
backend_version: tuple[int, ...],
then_value: Any = None,
otherise_value: Any = None,
) -> None:
self._implementation = implementation
self._depth = depth
self._backend_version = backend_version
self._condition = condition
self._above_then = above_then
self._then_value = then_value
self._otherwise_value = otherise_value

# TODO @aivanoved: this is way slow as during computation time this takes
# quadratic time need to improve this to linear time
self._above_already_set = self._above_then._call._already_set # type: ignore[attr-defined]
self._already_set = self._above_already_set | self._condition

def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace

plx = PandasLikeNamespace(
implementation=self._implementation, backend_version=self._backend_version
)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
aivanoved marked this conversation as resolved.
Show resolved Hide resolved
value_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
)
value_series = cast(PandasLikeSeries, value_series)

set_then = condition
set_then_native = set_then._native_series
above_already_set = parse_into_expr(self._above_already_set, namespace=plx)._call(
df # type: ignore[arg-type]
)[0]

value_series_native = value_series._native_series

above_result = self._above_then._call(df)[0]
above_result_native = above_result._native_series
set_then_native = set_then._native_series
above_already_set_native = above_already_set._native_series
if self._otherwise_value is None:
return [
above_result._from_native_series(
value_series_native.where(
~above_already_set_native & set_then_native, above_result_native
)
)
]

try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
otherwise_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._otherwise_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
)
otherwise_series = cast(PandasLikeSeries, otherwise_series)
return [
above_result.zip_with(
above_already_set, value_series.zip_with(set_then, otherwise_series)
)
]

def then(self, value: Any) -> PandasChainedThen:
self._then_value = value
return PandasChainedThen(
self,
depth=self._depth,
implementation=self._implementation,
function_name="chainedwhen",
root_names=None,
output_names=None,
backend_version=self._backend_version,
)


class PandasChainedThen(PandasLikeExpr):
def __init__(
self,
call: PandasChainedWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> None:
self._implementation = implementation
self._backend_version = backend_version

self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names

def when(
self,
*predicates: IntoPandasLikeExpr,
) -> PandasChainedWhen:
plx = PandasLikeNamespace(self._implementation, self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)
return PandasChainedWhen(
self,
condition,
depth=self._depth + 1,
implementation=self._implementation,
backend_version=self._backend_version,
)

def otherwise(self, value: Any) -> PandasChainedThen:
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "chainedwhenotherwise"
return self
43 changes: 39 additions & 4 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3994,16 +3994,17 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


def _extract_predicates(plx: Any, predicates: IntoExpr | Iterable[IntoExpr]) -> Any:
return [extract_compliant(plx, v) for v in flatten([predicates])]


class When:
def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
self._predicates = flatten([predicates])

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]

def then(self, value: Any) -> Then:
return Then(
lambda plx: plx.when(*self._extract_predicates(plx)).then(
lambda plx: plx.when(*_extract_predicates(plx, self._predicates)).then(
extract_compliant(plx, value)
)
)
Expand All @@ -4016,6 +4017,40 @@ def __init__(self, call: Callable[[Any], Any]) -> None:
def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value)))

def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen:
return ChainedWhen(self, *predicates)


class ChainedWhen:
def __init__(
self,
above_then: Then | ChainedThen,
*predicates: IntoExpr | Iterable[IntoExpr],
) -> None:
self._above_then = above_then
self._predicates = flatten([predicates])

def then(self, value: Any) -> ChainedThen:
return ChainedThen(
lambda plx: self._above_then._call(plx)
.when(*_extract_predicates(plx, self._predicates))
.then(value)
)


class ChainedThen(Expr):
def __init__(self, call: Callable[[Any], Any]) -> None:
self._call = call

def when(
self,
*predicates: IntoExpr | Iterable[IntoExpr],
) -> ChainedWhen:
return ChainedWhen(self, *predicates)

def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(value))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
"""
Expand Down
68 changes: 62 additions & 6 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from narwhals.dtypes import UInt32
from narwhals.dtypes import UInt64
from narwhals.dtypes import Unknown
from narwhals.expr import ChainedThen as NwChainedThen
from narwhals.expr import ChainedWhen as NwChainedWhen
from narwhals.expr import Expr as NwExpr
from narwhals.expr import Then as NwThen
from narwhals.expr import When as NwWhen
Expand Down Expand Up @@ -491,12 +493,34 @@ def _stableify(obj: NwSeries) -> Series: ...
@overload
def _stableify(obj: NwExpr) -> Expr: ...
@overload
def _stableify(obj: NwWhen) -> When: ...
@overload
def _stableify(obj: NwChainedWhen) -> ChainedWhen: ...
@overload
def _stableify(obj: Any) -> Any: ...


def _stableify(
obj: NwDataFrame[IntoFrameT] | NwLazyFrame[IntoFrameT] | NwSeries | NwExpr | Any,
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series | Expr | Any:
obj: NwDataFrame[IntoFrameT]
| NwLazyFrame[IntoFrameT]
| NwSeries
| NwExpr
| NwWhen
| NwChainedWhen
| NwThen
| NwChainedThen
| Any,
) -> (
DataFrame[IntoFrameT]
| LazyFrame[IntoFrameT]
| Series
| Expr
| When
| ChainedWhen
| Then
| ChainedThen
| Any
):
if isinstance(obj, NwDataFrame):
return DataFrame(
obj._compliant_frame,
Expand All @@ -512,6 +536,14 @@ def _stableify(
obj._compliant_series,
level=obj._level,
)
elif isinstance(obj, NwChainedWhen):
return ChainedWhen.from_base(obj)
if isinstance(obj, NwWhen):
return When.from_base(obj)
elif isinstance(obj, NwChainedThen):
return ChainedThen.from_base(obj)
elif isinstance(obj, NwThen):
return Then.from_base(obj)
if isinstance(obj, NwExpr):
return Expr(obj._call)
return obj
Expand Down Expand Up @@ -1692,21 +1724,45 @@ def get_level(

class When(NwWhen):
@classmethod
def from_when(cls, when: NwWhen) -> Self:
def from_base(cls, when: NwWhen) -> Self:
return cls(*when._predicates)

def then(self, value: Any) -> Then:
return Then.from_then(super().then(value))
return Then.from_base(super().then(value))


class Then(NwThen, Expr):
@classmethod
def from_then(cls, then: NwThen) -> Self:
def from_base(cls, then: NwThen) -> Self:
return cls(then._call)

def otherwise(self, value: Any) -> Expr:
return _stableify(super().otherwise(value))

def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen:
return _stableify(super().when(*predicates))


class ChainedWhen(NwChainedWhen):
@classmethod
def from_base(cls, chained_when: NwChainedWhen) -> Self:
return cls(_stableify(chained_when._above_then), *chained_when._predicates) # type: ignore[arg-type]

def then(self, value: Any) -> ChainedThen:
return _stableify(super().then(value)) # type: ignore[return-value]


class ChainedThen(NwChainedThen, Expr):
@classmethod
def from_base(cls, chained_then: NwChainedThen) -> Self:
return cls(chained_then._call)

def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen:
return _stableify(super().when(*predicates))

def otherwise(self, value: Any) -> Expr:
return _stableify(super().otherwise(value))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
"""
Expand Down Expand Up @@ -1753,7 +1809,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
β”‚ 3 ┆ 15 ┆ 6 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return When.from_when(nw_when(*predicates))
return _stableify(nw_when(*predicates))


def new_series(
Expand Down
Loading
Loading