From d7446f4de63cf98dc055e3e6caa85736cc4bfbd0 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 12:38:19 +0300 Subject: [PATCH 01/64] add simple when --- narwhals/_pandas_like/expr.py | 62 ++++++++++++++++++++++++++++++++ narwhals/expression.py | 29 +++++++++++++++ narwhals/expressions/whenthen.py | 0 tests/test_common.py | 10 ++++++ 4 files changed, 101 insertions(+) create mode 100644 narwhals/expressions/whenthen.py diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index afe57e780..bc65d2d73 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -4,8 +4,10 @@ from typing import Any from typing import Callable from typing import Literal +from typing import Iterable from narwhals._pandas_like.series import PandasSeries +from narwhals._pandas_like.typing import IntoPandasExpr from narwhals._pandas_like.utils import reuse_series_implementation from narwhals._pandas_like.utils import reuse_series_namespace_implementation @@ -296,6 +298,14 @@ def str(self) -> PandasExprStringNamespace: def dt(self) -> PandasExprDateTimeNamespace: return PandasExprDateTimeNamespace(self) + def when(self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any) -> PandasWhen: + # TODO: Support conditions + from narwhals._pandas_like.namespace import PandasNamespace + + plx = PandasNamespace(self._implementation) + condition = plx.all_horizontal(*predicates) + return PandasWhen(self, condition) + class PandasExprStringNamespace: def __init__(self, expr: PandasExpr) -> None: @@ -380,3 +390,55 @@ def total_nanoseconds(self) -> PandasExpr: return reuse_series_namespace_implementation( self._expr, "dt", "total_nanoseconds" ) + +class PandasWhen: + def __init__(self, condition: PandasExpr) -> None: + self._condition = condition + + def then(self, value: Any) -> PandasThen: + return PandasThen(self, value=value, implementation=self._condition._implementation) + +class PandasThen(PandasExpr): + def __init__(self, when: PandasWhen, *, value: Any, implementation: str) -> None: + self._when = when + self._then_value = value + self._implementation = implementation + + def func(df: PandasDataFrame) -> list[PandasSeries]: + from narwhals._pandas_like.namespace import PandasNamespace + + plx = PandasNamespace(implementation=self._implementation) + + condition = self._when._condition._call(df)[0] + + value_series = plx._create_series_from_scalar(self._then_value, condition) + none_series = plx._create_series_from_scalar(None, condition) + return [ + value_series.zip_with(condition, none_series) + ] + + self._call = func + self._depth = 0 + self._function_name = "whenthen" + self._root_names = None + self._output_names = None + + def otherwise(self, value: Any) -> PandasExpr: + def func(df: PandasDataFrame) -> list[PandasSeries]: + from narwhals._pandas_like.namespace import PandasNamespace + plx = PandasNamespace(implementation=self._implementation) + condition = self._when._condition._call(df)[0] + value_series = plx._create_series_from_scalar(self._then_value, condition) + otherwise_series = plx._create_series_from_scalar(value, condition) + return [ + value_series.zip_with(condition, otherwise_series) + ] + + return PandasExpr( + func, + depth=0, + function_name="whenthenotherwise", + root_names=None, + output_names=None, + implementation=self._implementation, + ) diff --git a/narwhals/expression.py b/narwhals/expression.py index b4c1ed242..93ddac3b0 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -11,6 +11,9 @@ from narwhals.utils import flatten from narwhals.utils import parse_version + +from functools import reduce + if TYPE_CHECKING: from narwhals.typing import IntoExpr @@ -2633,6 +2636,32 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: lambda plx: plx.sum_horizontal([extract_native(plx, v) for v in flatten(exprs)]) ) +class When: + def __init__(self, condition: Expr) -> None: + self._condition = condition + + def then(self, value: Any) -> Then: + return Then(self, value=value) + +class Then(Expr): + def __init__(self, when: When, *, value: Any) -> None: + self._when = when + self._then_value = value + + def func(plx): + return plx.when(self._when._condition._call(plx)).then(self._then_value) + + self._call = func + + def otherwise(self, value: Any) -> Expr: + def func(plx): + return plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value) + + return Expr(func) + +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: + return When(reduce(lambda a, b: a & b, flatten([predicates]))) + __all__ = [ "Expr", diff --git a/narwhals/expressions/whenthen.py b/narwhals/expressions/whenthen.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_common.py b/tests/test_common.py index 90b5c21af..4162d7b24 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -706,3 +706,13 @@ def test_quantile( df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) ) compare_dicts(result, expected) + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_when(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.with_columns( + a=nw.when(nw.col("a") > 2, 1).otherwise(0), + b=nw.when(nw.col("a") > 2, 1).when(nw.col("a") < 1, -1).otherwise(0), + ) + expected = {"a": [0, 1, 0], "b": [0, 1, 0], "z": [7.0, 8.0, 9.0]} + compare_dicts(result, expected) From 6ebc78bcbe1b3edcf3ff3d814ea8719968f34e46 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 12:43:10 +0300 Subject: [PATCH 02/64] delete unnecessary file --- narwhals/expressions/whenthen.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 narwhals/expressions/whenthen.py diff --git a/narwhals/expressions/whenthen.py b/narwhals/expressions/whenthen.py deleted file mode 100644 index e69de29bb..000000000 From a3fdcc5e23f9f45873df4a4f6c5564ee858b59c7 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 12:44:10 +0300 Subject: [PATCH 03/64] lint with ruff --- narwhals/_pandas_like/expr.py | 3 +-- narwhals/expression.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index bc65d2d73..b2c6ae1e3 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -3,11 +3,10 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Literal from typing import Iterable +from typing import Literal from narwhals._pandas_like.series import PandasSeries -from narwhals._pandas_like.typing import IntoPandasExpr from narwhals._pandas_like.utils import reuse_series_implementation from narwhals._pandas_like.utils import reuse_series_namespace_implementation diff --git a/narwhals/expression.py b/narwhals/expression.py index 93ddac3b0..eeeeae961 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -11,9 +12,6 @@ from narwhals.utils import flatten from narwhals.utils import parse_version - -from functools import reduce - if TYPE_CHECKING: from narwhals.typing import IntoExpr From 1ad1c94d91852cc0a069bdaadf74c5fa3dfae6cd Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 17:06:18 +0300 Subject: [PATCH 04/64] use lambda expression --- narwhals/expression.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/narwhals/expression.py b/narwhals/expression.py index eeeeae961..090aa0cec 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -2646,18 +2646,12 @@ def __init__(self, when: When, *, value: Any) -> None: self._when = when self._then_value = value - def func(plx): - return plx.when(self._when._condition._call(plx)).then(self._then_value) - - self._call = func + self._call = lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value) def otherwise(self, value: Any) -> Expr: - def func(plx): - return plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value) - - return Expr(func) + return Expr(lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 return When(reduce(lambda a, b: a & b, flatten([predicates]))) From 93e712193deb3018b0157cc5af1a8dcebca54d9d Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 18 Jul 2024 12:46:47 +0300 Subject: [PATCH 05/64] remove deleted file --- tests/test_common.py | 718 ------------------------------------------- 1 file changed, 718 deletions(-) delete mode 100644 tests/test_common.py diff --git a/tests/test_common.py b/tests/test_common.py deleted file mode 100644 index 4162d7b24..000000000 --- a/tests/test_common.py +++ /dev/null @@ -1,718 +0,0 @@ -from __future__ import annotations - -import os -import warnings -from typing import Any -from typing import Literal - -import numpy as np -import pandas as pd -import polars as pl -import pytest -from pandas.testing import assert_series_equal as pd_assert_series_equal -from polars.testing import assert_series_equal as pl_assert_series_equal - -import narwhals as nw -from narwhals.utils import parse_version -from tests.utils import compare_dicts - -df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) -if parse_version(pd.__version__) >= parse_version("1.5.0"): - df_pandas_pyarrow = pd.DataFrame( - {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - ).astype( - { - "a": "Int64[pyarrow]", - "b": "Int64[pyarrow]", - "z": "Float64[pyarrow]", - } - ) - df_pandas_nullable = pd.DataFrame( - {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - ).astype( - { - "a": "Int64", - "b": "Int64", - "z": "Float64", - } - ) -else: # pragma: no cover - df_pandas_pyarrow = df_pandas - df_pandas_nullable = df_pandas -df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) -df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) -df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) -df_lazy_na = pl.LazyFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) -df_right_pandas = pd.DataFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) -df_right_lazy = pl.LazyFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) - -if os.environ.get("CI", None): - try: - import modin.pandas as mpd - except ImportError: # pragma: no cover - df_mpd = df_pandas.copy() - else: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - df_mpd = mpd.DataFrame( - pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) - ) -else: # pragma: no cover - df_mpd = df_pandas.copy() - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_polars, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_sort(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.sort("a", "b") - result_native = nw.to_native(result) - expected = { - "a": [1, 2, 3], - "b": [4, 6, 4], - "z": [7.0, 9.0, 8.0], - } - compare_dicts(result_native, expected) - result = df.sort("a", "b", descending=[True, False]) - result_native = nw.to_native(result) - expected = { - "a": [3, 2, 1], - "b": [4, 6, 4], - "z": [8.0, 9.0, 7.0], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_filter(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.filter(nw.col("a") > 1) - result_native = nw.to_native(result) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_polars], -) -def test_filter_series(df_raw: Any) -> None: - df = nw.DataFrame(df_raw).with_columns(mask=nw.col("a") > 1) - result = df.filter(df["mask"]).drop("mask") - result_native = nw.to_native(result) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_add(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns( - c=nw.col("a") + nw.col("b"), - d=nw.col("a") - nw.col("a").mean(), - e=nw.col("a") - nw.col("a").std(), - ) - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "c": [5, 7, 8], - "d": [-1.0, 1.0, 0.0], - "e": [0.0, 2.0, 1.0], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_std(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select( - nw.col("a").std().alias("a_ddof_default"), - nw.col("a").std(ddof=1).alias("a_ddof_1"), - nw.col("a").std(ddof=0).alias("a_ddof_0"), - nw.col("b").std(ddof=2).alias("b_ddof_2"), - nw.col("z").std(ddof=0).alias("z_ddof_0"), - ) - result_native = nw.to_native(result) - expected = { - "a_ddof_default": [1.0], - "a_ddof_1": [1.0], - "a_ddof_0": [0.816497], - "b_ddof_2": [1.632993], - "z_ddof_0": [0.816497], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_double(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns(nw.all() * 2) - result_native = nw.to_native(result) - expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} - compare_dicts(result_native, expected) - result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) - result_native = nw.to_native(result) - expected = {"o": [1, 3, 2], "a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_select(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select("a") - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy, df_pandas_nullable]) -def test_sumh(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns(horizonal_sum=nw.sum_horizontal(nw.col("a"), nw.col("b"))) - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "horizonal_sum": [5, 7, 8], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_sumh_literal(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns(horizonal_sum=nw.sum_horizontal("a", nw.col("b"))) - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "horizonal_sum": [5, 7, 8], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_sum_all(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select(nw.all().sum()) - result_native = nw.to_native(result) - expected = {"a": [6], "b": [14], "z": [24.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_double_selected(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select(nw.col("a", "b") * 2) - result_native = nw.to_native(result) - expected = {"a": [2, 6, 4], "b": [8, 8, 12]} - compare_dicts(result_native, expected) - result = df.select("z", nw.col("a", "b") * 2) - result_native = nw.to_native(result) - expected = {"z": [7, 8, 9], "a": [2, 6, 4], "b": [8, 8, 12]} - compare_dicts(result_native, expected) - result = df.select("a").select(nw.col("a") + nw.all()) - result_native = nw.to_native(result) - expected = {"a": [2, 6, 4]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_rename(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.rename({"a": "x", "b": "y"}) - result_native = nw.to_native(result) - expected = {"x": [1, 3, 2], "y": [4, 4, 6], "z": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_join(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - df_right = df - result = df.join(df_right, left_on=["a", "b"], right_on=["a", "b"], how="inner") - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "z_right": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - with pytest.raises(NotImplementedError): - result = df.join(df_right, left_on="a", right_on="a", how="left") # type: ignore[arg-type] - - result = df.collect().join(df_right.collect(), left_on="a", right_on="a", how="inner") # type: ignore[assignment] - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "b_right": [4, 4, 6], - "z": [7.0, 8, 9], - "z_right": [7.0, 8, 9], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_schema(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw).schema - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - assert result == expected - result = nw.LazyFrame(df_raw).collect().schema - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - assert result == expected - result = nw.LazyFrame(df_raw).columns # type: ignore[assignment] - expected = ["a", "b", "z"] # type: ignore[assignment] - assert result == expected - result = nw.LazyFrame(df_raw).collect().columns # type: ignore[assignment] - expected = ["a", "b", "z"] # type: ignore[assignment] - assert result == expected - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_columns(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.columns - expected = ["a", "b", "z"] - assert result == expected - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_lazy_instantiation(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw) - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_lazy]) -def test_lazy_instantiation_error(df_raw: Any) -> None: - with pytest.raises( - TypeError, match="Can't instantiate DataFrame from Polars LazyFrame." - ): - _ = nw.DataFrame(df_raw).shape - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) -def test_eager_instantiation(df_raw: Any) -> None: - result = nw.DataFrame(df_raw) - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - -def test_accepted_dataframes() -> None: - array = np.array([[0, 4.0], [2, 5]]) - with pytest.raises( - TypeError, - match="Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: ", - ): - nw.DataFrame(array) - with pytest.raises( - TypeError, - match="Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: ", - ): - nw.LazyFrame(array) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) -@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") -def test_convert_pandas(df_raw: Any) -> None: - result = nw.from_native(df_raw).to_pandas() # type: ignore[union-attr] - expected = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) - pd.testing.assert_frame_equal(result, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_polars, df_pandas, df_mpd, df_pandas_nullable, df_pandas_pyarrow] -) -@pytest.mark.filterwarnings( - r"ignore:np\.find_common_type is deprecated\.:DeprecationWarning" -) -def test_convert_numpy(df_raw: Any) -> None: - result = nw.DataFrame(df_raw).to_numpy() - expected = np.array([[1, 3, 2], [4, 4, 6], [7.0, 8, 9]]).T - np.testing.assert_array_equal(result, expected) - assert result.dtype == "float64" - result = nw.DataFrame(df_raw).__array__() - np.testing.assert_array_equal(result, expected) - assert result.dtype == "float64" - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) -def test_shape(df_raw: Any) -> None: - result = nw.DataFrame(df_raw).shape - expected = (3, 3) - assert result == expected - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_expr_binary(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw).with_columns( - a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")), - b=nw.col("z") / (2 - nw.col("b")), - c=nw.col("a") + nw.col("b") / 2, - d=nw.col("a") - nw.col("b"), - e=((nw.col("a") > nw.col("b")) & (nw.col("a") >= nw.col("z"))).cast(nw.Int64), - f=( - (nw.col("a") < nw.col("b")) - | (nw.col("a") <= nw.col("z")) - | (nw.col("a") == 1) - ).cast(nw.Int64), - g=nw.col("a") != 1, - h=(False & (nw.col("a") != 1)), - i=(False | (nw.col("a") != 1)), - j=2 ** nw.col("a"), - k=2 // nw.col("a"), - l=nw.col("a") // 2, - m=nw.col("a") ** 2, - n=nw.col("a") % 2, - o=2 % nw.col("a"), - ) - result_native = nw.to_native(result) - expected = { - "a": [4, 3.333333, 3.5], - "b": [-3.5, -4.0, -2.25], - "z": [7.0, 8.0, 9.0], - "c": [3, 5, 5], - "d": [-3, -1, -4], - "e": [0, 0, 0], - "f": [1, 1, 1], - "g": [False, True, True], - "h": [False, False, False], - "i": [False, True, True], - "j": [2, 8, 4], - "k": [2, 0, 1], - "l": [0, 1, 1], - "m": [1, 9, 4], - "n": [1, 1, 0], - "o": [0, 2, 0], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_lazy]) -def test_expr_unary(df_raw: Any) -> None: - result = ( - nw.from_native(df_raw) - .with_columns( - a_mean=nw.col("a").mean(), - a_sum=nw.col("a").sum(), - b_nunique=nw.col("b").n_unique(), - z_min=nw.col("z").min(), - z_max=nw.col("z").max(), - ) - .select(nw.col("a_mean", "a_sum", "b_nunique", "z_min", "z_max").unique()) - ) - result_native = nw.to_native(result) - expected = {"a_mean": [2], "a_sum": [6], "b_nunique": [2], "z_min": [7], "z_max": [9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_expr_transform(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw).with_columns( - a=nw.col("a").is_between(-1, 1), b=nw.col("b").is_in([4, 5]) - ) - result_native = nw.to_native(result) - expected = {"a": [True, False, False], "b": [True, True, False], "z": [7, 8, 9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_lazy]) -def test_expr_min_max(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result_min = nw.to_native(df.select(nw.min("a", "b", "z"))) - result_max = nw.to_native(df.select(nw.max("a", "b", "z"))) - expected_min = {"a": [1], "b": [4], "z": [7]} - expected_max = {"a": [3], "b": [6], "z": [9]} - compare_dicts(result_min, expected_min) - compare_dicts(result_max, expected_max) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_expr_sample(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result_shape = nw.to_native(df.select(nw.col("a").sample(n=2)).collect()).shape - expected = (2, 1) - assert result_shape == expected - result_shape = nw.to_native(df.collect()["a"].sample(n=2)).shape - expected = (2,) # type: ignore[assignment] - assert result_shape == expected - - -@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) -def test_expr_na(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result_nna = nw.to_native( - df.filter((~nw.col("a").is_null()) & (~df.collect()["z"].is_null())) - ) - expected = {"a": [2], "b": [6], "z": [9]} - compare_dicts(result_nna, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_head(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.head(2)) - expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} - compare_dicts(result, expected) - result = nw.to_native(df.collect().head(2)) - expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_unique(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.unique("b").sort("b")) - expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]} - compare_dicts(result, expected) - result = nw.to_native(df.collect().unique("b").sort("b")) - expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) -def test_drop_nulls(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.select(nw.col("a").drop_nulls())) - expected = {"a": [3, 2]} - compare_dicts(result, expected) - result = nw.to_native(df.select(df.collect()["a"].drop_nulls())) - expected = {"a": [3, 2]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize( - ("df_raw", "df_raw_right"), [(df_pandas, df_right_pandas), (df_lazy, df_right_lazy)] -) -def test_concat_horizontal(df_raw: Any, df_raw_right: Any) -> None: - df_left = nw.LazyFrame(df_raw) - df_right = nw.LazyFrame(df_raw_right) - result = nw.concat([df_left, df_right], how="horizontal") - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8, 9], - "c": [6, 12, -1], - "d": [0, -4, 2], - } - compare_dicts(result_native, expected) - - with pytest.raises(ValueError, match="No items"): - nw.concat([]) - - -@pytest.mark.parametrize( - ("df_raw", "df_raw_right"), [(df_pandas, df_right_pandas), (df_lazy, df_right_lazy)] -) -def test_concat_vertical(df_raw: Any, df_raw_right: Any) -> None: - df_left = nw.LazyFrame(df_raw).collect().rename({"a": "c", "b": "d"}).lazy().drop("z") - df_right = nw.LazyFrame(df_raw_right) - result = nw.concat([df_left, df_right], how="vertical") - result_native = nw.to_native(result) - expected = {"c": [1, 3, 2, 6, 12, -1], "d": [4, 4, 6, 0, -4, 2]} - compare_dicts(result_native, expected) - with pytest.raises(ValueError, match="No items"): - nw.concat([], how="vertical") - with pytest.raises(Exception, match="unable to vstack"): - nw.concat([df_left, df_right.rename({"d": "i"})], how="vertical").collect() # type: ignore[union-attr] - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_lazy(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.lazy() - assert isinstance(result, nw.LazyFrame) - - -def test_to_dict() -> None: - df = nw.DataFrame(df_pandas) - result = df.to_dict(as_series=True) - expected = { - "a": pd.Series([1, 3, 2], name="a"), - "b": pd.Series([4, 4, 6], name="b"), - "z": pd.Series([7.0, 8, 9], name="z"), - } - for key in expected: - pd_assert_series_equal(nw.to_native(result[key]), expected[key]) - - df = nw.DataFrame(df_polars) - result = df.to_dict(as_series=True) - expected = { - "a": pl.Series("a", [1, 3, 2]), - "b": pl.Series("b", [4, 4, 6]), - "z": pl.Series("z", [7.0, 8, 9]), - } - for key in expected: - pl_assert_series_equal(nw.to_native(result[key]), expected[key]) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_any_all(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.select((nw.all() > 1).all())) - expected = {"a": [False], "b": [True], "z": [True]} - compare_dicts(result, expected) - result = nw.to_native(df.select((nw.all() > 1).any())) - expected = {"a": [True], "b": [True], "z": [True]} - compare_dicts(result, expected) - - -def test_invalid() -> None: - df = nw.LazyFrame(df_pandas) - with pytest.raises(ValueError, match="Multi-output"): - df.select(nw.all() + nw.all()) - with pytest.raises(TypeError, match="Perhaps you:"): - df.select([pl.col("a")]) # type: ignore[list-item] - with pytest.raises(TypeError, match="Perhaps you:"): - df.select([nw.col("a").cast(pl.Int64)]) - - -@pytest.mark.parametrize("df_raw", [df_pandas]) -def test_reindex(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.select("b", df["a"].sort(descending=True)) - expected = {"b": [4, 4, 6], "a": [3, 2, 1]} - compare_dicts(result, expected) - result = df.select("b", nw.col("a").sort(descending=True)) - compare_dicts(result, expected) - - s = df["a"] - result_s = s > s.sort() - assert not result_s[0] - assert result_s[1] - assert not result_s[2] - result = df.with_columns(s.sort()) - expected = {"a": [1, 2, 3], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} # type: ignore[list-item] - compare_dicts(result, expected) - with pytest.raises(ValueError, match="Multi-output expressions are not supported"): - nw.to_native(df.with_columns(nw.all() + nw.all())) - - -@pytest.mark.parametrize( - ("df_raw", "df_raw_right"), - [(df_pandas, df_polars), (df_polars, df_pandas)], -) -def test_library(df_raw: Any, df_raw_right: Any) -> None: - df_left = nw.LazyFrame(df_raw) - df_right = nw.LazyFrame(df_raw_right) - with pytest.raises( - NotImplementedError, match="Cross-library comparisons aren't supported" - ): - nw.concat([df_left, df_right], how="horizontal") - with pytest.raises( - NotImplementedError, match="Cross-library comparisons aren't supported" - ): - nw.concat([df_left, df_right], how="vertical") - with pytest.raises( - NotImplementedError, match="Cross-library comparisons aren't supported" - ): - df_left.join(df_right, left_on=["a"], right_on=["a"], how="inner") - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_is_duplicated(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = nw.concat([df, df.head(1)]).is_duplicated() # type: ignore [union-attr] - expected = np.array([True, False, False, True]) - assert (result.to_numpy() == expected).all() - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -@pytest.mark.parametrize(("threshold", "expected"), [(0, False), (10, True)]) -def test_is_empty(df_raw: Any, threshold: Any, expected: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.filter(nw.col("a") > threshold).is_empty() - assert result == expected - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_is_unique(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = nw.concat([df, df.head(1)]).is_unique() # type: ignore [union-attr] - expected = np.array([False, True, True, False]) - assert (result.to_numpy() == expected).all() - - -@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na.collect()]) -def test_null_count(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = nw.to_native(df.null_count()) - expected = {"a": [1], "b": [0], "z": [1]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -@pytest.mark.parametrize( - ("interpolation", "expected"), - [ - ("lower", {"a": [1.0], "b": [4.0], "z": [7.0]}), - ("higher", {"a": [2.0], "b": [4.0], "z": [8.0]}), - ("midpoint", {"a": [1.5], "b": [4.0], "z": [7.5]}), - ("linear", {"a": [1.6], "b": [4.0], "z": [7.6]}), - ("nearest", {"a": [2.0], "b": [4.0], "z": [8.0]}), - ], -) -def test_quantile( - df_raw: Any, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], - expected: dict[str, list[float]], -) -> None: - q = 0.3 - - df = nw.from_native(df_raw) - result = nw.to_native( - df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) - ) - compare_dicts(result, expected) - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_when(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.with_columns( - a=nw.when(nw.col("a") > 2, 1).otherwise(0), - b=nw.when(nw.col("a") > 2, 1).when(nw.col("a") < 1, -1).otherwise(0), - ) - expected = {"a": [0, 1, 0], "b": [0, 1, 0], "z": [7.0, 8.0, 9.0]} - compare_dicts(result, expected) From f3770b75d31836532b6708ceddb5724d569d766e Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 16:38:38 +0300 Subject: [PATCH 06/64] Fix errors from the migration --- narwhals/_pandas_like/expr.py | 63 ----------------------- narwhals/_pandas_like/namespace.py | 80 ++++++++++++++++++++++++++++++ narwhals/expression.py | 15 +++--- tests/test_where.py | 47 ++++++++++++++++++ 4 files changed, 134 insertions(+), 71 deletions(-) create mode 100644 tests/test_where.py diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 50e67a70b..aaaa550a1 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Iterable from typing import Literal from narwhals._expression_parsing import reuse_series_implementation @@ -337,14 +336,6 @@ def dt(self) -> PandasLikeExprDateTimeNamespace: def cat(self) -> PandasLikeExprCatNamespace: return PandasLikeExprCatNamespace(self) - def when(self, *predicates: PandasLikeExpr | Iterable[PandasExpr], **conditions: Any) -> PandasWhen: - # TODO: Support conditions - from narwhals._pandas_like.namespace import PandasLikeNamespace - - plx = PandasLikeNamespace(self._implementation) - condition = plx.all_horizontal(*predicates) - return PandasWhen(self, condition) - class PandasLikeExprCatNamespace: def __init__(self, expr: PandasLikeExpr) -> None: @@ -474,57 +465,3 @@ def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 return reuse_series_namespace_implementation( self._expr, "dt", "to_string", format ) - -class PandasWhen: - def __init__(self, condition: PandasLikeExpr) -> None: - self._condition = condition - - def then(self, value: Any) -> PandasThen: - return PandasThen(self, value=value, implementation=self._condition._implementation) - -class PandasThen(PandasLikeExpr): - def __init__(self, when: PandasWhen, *, value: Any, implementation: Implementation, backend_version: tuple[int, ...]) -> None: - self._when = when - self._then_value = value - self._implementation = implementation - self.backend_version = backend_version - - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - from narwhals._pandas_like.namespace import PandasLikeNamespace - - plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self.backend_version) - - condition = self._when._condition._call(df)[0] - - value_series = plx._create_series_from_scalar(self._then_value, condition) - none_series = plx._create_series_from_scalar(None, condition) - return [ - value_series.zip_with(condition, none_series) - ] - - self._call = func - self._depth = 0 - self._function_name = "whenthen" - self._root_names = None - self._output_names = None - - def otherwise(self, value: Any) -> PandasLikeExpr: - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - from narwhals._pandas_like.namespace import PandasLikeNamespace - plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self.backend_version) - condition = self._when._condition._call(df)[0] - value_series = plx._create_series_from_scalar(self._then_value, condition) - otherwise_series = plx._create_series_from_scalar(value, condition) - return [ - value_series.zip_with(condition, otherwise_series) - ] - - return PandasLikeExpr( - func, - depth=0, - function_name="whenthenotherwise", - root_names=None, - output_names=None, - implementation=self._implementation, - backend_version=self.backend_version, - ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index dbae3bbb7..b62327e86 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -86,6 +86,17 @@ def _create_series_from_scalar( backend_version=self._backend_version, ) + def _create_broadcast_series_from_scalar( + self, value: Any, series: PandasLikeSeries + ) -> PandasLikeSeries: + return PandasLikeSeries._from_iterable( + [value] * len(series._native_series), + name=series._native_series.name, + index=series._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + def _create_expr_from_series(self, series: PandasLikeSeries) -> PandasLikeExpr: return PandasLikeExpr( lambda _df: [series], @@ -246,3 +257,72 @@ def concat( backend_version=self._backend_version, ) raise NotImplementedError + + def when(self, *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], **conditions: Any) -> PandasWhen: # noqa: ARG002 + plx = self.__class__(self._implementation, self._backend_version) + condition = plx.all_horizontal(*predicates) + return PandasWhen(condition) + +class InnerPandasWhen: + def __init__(self, implementation: Implementation, backend_version: tuple[int, ...], condition: PandasLikeExpr, value: Any, otherise_value: Any = None) -> None: + self._implementation = implementation + self._backend_version = backend_version + self._condition = condition + self._value = value + self._otherwise_value = otherise_value + + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self._backend_version) + + condition = self._condition._call(df)[0] + + value_series = plx._create_broadcast_series_from_scalar(self._value, condition) + none_series = plx._create_broadcast_series_from_scalar(self._otherwise_value, condition) + return [ + value_series.zip_with(condition, none_series) + ] + +class PandasWhen: + def __init__(self, condition: PandasLikeExpr) -> None: + self._condition = condition + + def then(self, value: Any) -> PandasThen: + + return PandasThen( + InnerPandasWhen(self._condition._implementation, self._condition._backend_version, self._condition, value), + depth=0, + function_name="whenthen", + root_names=None, + output_names=None, + implementation=self._condition._implementation, + backend_version=self._condition._backend_version, + ) + +class PandasThen(PandasLikeExpr): + + def __init__( + self, + call: InnerPandasWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + implementation: Implementation, + backend_version: tuple[int, ...], + ) -> None: + self._implementation = implementation + self._backend_version = backend_version + + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + + def otherwise(self, value: Any) -> PandasLikeExpr: + self._call._otherwise_value = value + self._function_name = "whenotherwise" + return self diff --git a/narwhals/expression.py b/narwhals/expression.py index 904d164a9..4d1786f67 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -2968,7 +2968,7 @@ def to_string(self, format: str) -> Expr: # noqa: A002 of trailing zeros. Nonetheless, this is probably consistent enough for most applications. - If you have an application where this is not enough, please open an issue + If you have an application here this is not enough, please open an issue and let us know. Examples: @@ -3352,19 +3352,18 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: def __init__(self, condition: Expr) -> None: self._condition = condition + self._then_value = None + self._otehrwise_value = None def then(self, value: Any) -> Then: - return Then(self, value=value) + return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) class Then(Expr): - def __init__(self, when: When, *, value: Any) -> None: - self._when = when - self._then_value = value - - self._call = lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value) + def __init__(self, call) -> None: # noqa: ANN001 + self._call = call def otherwise(self, value: Any) -> Expr: - return Expr(lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value)) + return Expr(lambda plx: self._call(plx).otherwise(value)) def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 return When(reduce(lambda a, b: a & b, flatten([predicates]))) diff --git a/tests/test_where.py b/tests/test_where.py new file mode 100644 index 000000000..3661db0c5 --- /dev/null +++ b/tests/test_where.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import narwhals.stable.v1 as nw +from narwhals.expression import when +from tests.utils import compare_dicts + +data = { + "a": [1, 1, 2], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], +} + + +def test_when(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) + expected = { + "a": [1, 1, 2], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, 3, None], + } + compare_dicts(result, expected) + +def test_when_otherwise(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) + expected = { + "a": [1, 1, 2], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, 3, 6], + } + compare_dicts(result, expected) From a7f442ae497feef623f5e79409d754711a08b231 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 16:49:41 +0300 Subject: [PATCH 07/64] remove unnecessary changes --- narwhals/_pandas_like/expr.py | 5 ++--- narwhals/expr.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 31120a88f..d5ec89c32 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -470,8 +470,8 @@ def total_nanoseconds(self) -> PandasLikeExpr: def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 return reuse_series_namespace_implementation( - self._expr, "dt", "to_string", format - ) + self._expr, "dt", "to_string", format + ) class PandasLikeExprNameNamespace: def __init__(self: Self, expr: PandasLikeExpr) -> None: @@ -625,4 +625,3 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._expr._implementation, backend_version=self._expr._backend_version, ) ->>>>>>> main diff --git a/narwhals/expr.py b/narwhals/expr.py index a489fc0af..3dff0c676 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3016,7 +3016,7 @@ def to_string(self, format: str) -> Expr: # noqa: A002 of trailing zeros. Nonetheless, this is probably consistent enough for most applications. - If you have an application here this is not enough, please open an issue + If you have an application where this is not enough, please open an issue and let us know. Examples: From 7f23f051d9aecdb7dc6d3e84d9a85d36d8e86a5e Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 16:51:44 +0300 Subject: [PATCH 08/64] add back the change in version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 884d4a680..d00b10e89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "narwhals" -version = "1.1.1" +version = "1.1.3" authors = [ { name="Marco Gorelli", email="33491632+MarcoGorelli@users.noreply.github.com" }, ] From 7cc3aad3926d0e448770eeebdf9bf485b42f6a5c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 17:02:44 +0300 Subject: [PATCH 09/64] fix rename change --- tests/test_where.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_where.py b/tests/test_where.py index 3661db0c5..cc95cc347 100644 --- a/tests/test_where.py +++ b/tests/test_where.py @@ -5,7 +5,7 @@ import pytest import narwhals.stable.v1 as nw -from narwhals.expression import when +from narwhals.expr import when from tests.utils import compare_dicts data = { From ab85e406c69f7ece65da31190900f7fa969f36dd Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 17:03:52 +0300 Subject: [PATCH 10/64] rename test file --- tests/{test_where.py => test_when.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_where.py => test_when.py} (100%) diff --git a/tests/test_where.py b/tests/test_when.py similarity index 100% rename from tests/test_where.py rename to tests/test_when.py From 4a8ac56f9cb340cc8c34fec5db6199f3cb527f2e Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 19:02:14 +0300 Subject: [PATCH 11/64] fix forgotten memeber change --- narwhals/_pandas_like/namespace.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index b62327e86..10a31b19e 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -258,17 +258,17 @@ def concat( ) raise NotImplementedError - def when(self, *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], **conditions: Any) -> PandasWhen: # noqa: ARG002 + def when(self, *predicates: IntoPandasLikeExpr, **conditions: Any) -> PandasWhen: # noqa: ARG002 plx = self.__class__(self._implementation, self._backend_version) condition = plx.all_horizontal(*predicates) - return PandasWhen(condition) + return PandasWhen(condition, self._implementation, self._backend_version) -class InnerPandasWhen: - def __init__(self, implementation: Implementation, backend_version: tuple[int, ...], condition: PandasLikeExpr, value: Any, otherise_value: Any = None) -> None: +class PandasWhen: + def __init__(self, condition: PandasLikeExpr, implementation: Implementation, backend_version: tuple[int, ...], then_value: Any = None, otherise_value: Any = None) -> None: self._implementation = implementation self._backend_version = backend_version self._condition = condition - self._value = value + self._then_value = then_value self._otherwise_value = otherise_value def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @@ -278,20 +278,18 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: condition = self._condition._call(df)[0] - value_series = plx._create_broadcast_series_from_scalar(self._value, condition) - none_series = plx._create_broadcast_series_from_scalar(self._otherwise_value, condition) + value_series = plx._create_broadcast_series_from_scalar(self._then_value, condition) + otherwise_series = plx._create_broadcast_series_from_scalar(self._otherwise_value, condition) return [ - value_series.zip_with(condition, none_series) + value_series.zip_with(condition, otherwise_series) ] -class PandasWhen: - def __init__(self, condition: PandasLikeExpr) -> None: - self._condition = condition - def then(self, value: Any) -> PandasThen: + self._then_value = value + return PandasThen( - InnerPandasWhen(self._condition._implementation, self._condition._backend_version, self._condition, value), + self, depth=0, function_name="whenthen", root_names=None, @@ -304,7 +302,7 @@ class PandasThen(PandasLikeExpr): def __init__( self, - call: InnerPandasWhen, + call: PandasWhen, *, depth: int, function_name: str, From 8283f2481294ee5216c9e67e52069b1ab797bd34 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 19:08:48 +0300 Subject: [PATCH 12/64] make api identical --- narwhals/_pandas_like/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 10a31b19e..06af399d3 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -258,7 +258,7 @@ def concat( ) raise NotImplementedError - def when(self, *predicates: IntoPandasLikeExpr, **conditions: Any) -> PandasWhen: # noqa: ARG002 + def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002 plx = self.__class__(self._implementation, self._backend_version) condition = plx.all_horizontal(*predicates) return PandasWhen(condition, self._implementation, self._backend_version) From f1c667eb57f5b7ea791a811dd74c24903f12f78d Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 19:12:09 +0300 Subject: [PATCH 13/64] remove unnecessary diff --- .gitignore | 2 +- narwhals/_pandas_like/expr.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e3bba127f..3911158a8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ todo.md site/ .coverage.* .nox -docs/api-completeness.md +docs/api-completeness.md \ No newline at end of file diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index d5ec89c32..f846da610 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -473,6 +473,7 @@ def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 self._expr, "dt", "to_string", format ) + class PandasLikeExprNameNamespace: def __init__(self: Self, expr: PandasLikeExpr) -> None: self._expr = expr From 74937ea9f37d263c3efac218ebd0afbeddc4d9b5 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 12:15:36 +0300 Subject: [PATCH 14/64] add when documentation --- narwhals/expr.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/narwhals/expr.py b/narwhals/expr.py index 3dff0c676..5aba8cc61 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3622,6 +3622,51 @@ def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 + """ + Start a `when-then-otherwise` expression. + Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. + Chained when-then operations should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to `True` will be picked. + If none of the conditions are `True`, an optional `.otherwise()` can be appended at the end. If not appended, and none of the conditions are `True`, `None` will be returned. + + Parameters: + predicates + Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name. + constraints + Apply conditions as `col_name = value` keyword arguments that are treated as equality matches, such as `x = 123`. As with the predicates parameter, multiple conditions are implicitly combined using `&`. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]}) + >>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]}) + + We define a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df_any): + ... from narwhals.expr import when + ... return df_any.with_columns(when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")) + + We can then pass either pandas or polars to `func`: + + >>> func(df_pd) + a b a_when + 0 1 5 5 + 1 2 10 5 + 2 3 15 6 + >>> func(df_pl) + shape: (3, 3) + ┌─────┬─────┬────────┐ + │ a ┆ b ┆ a_when │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i32 │ + ╞═════╪═════╪════════╡ + │ 1 ┆ 5 ┆ 5 │ + │ 2 ┆ 10 ┆ 5 │ + │ 3 ┆ 15 ┆ 6 │ + └─────┴─────┴────────┘ + """ return When(reduce(lambda a, b: a & b, flatten([predicates]))) From 5b030d660ecbbc38e2e2fcb3b5bafd1a8fb8c9a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 09:22:15 +0000 Subject: [PATCH 15/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/namespace.py | 31 ++++++++++++++++++++---------- narwhals/expr.py | 12 +++++++++--- tests/test_when.py | 1 + 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 06af399d3..51c82648f 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -258,13 +258,21 @@ def concat( ) raise NotImplementedError - def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002 + def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002 plx = self.__class__(self._implementation, self._backend_version) condition = plx.all_horizontal(*predicates) return PandasWhen(condition, self._implementation, self._backend_version) + class PandasWhen: - def __init__(self, condition: PandasLikeExpr, implementation: Implementation, backend_version: tuple[int, ...], then_value: Any = None, otherise_value: Any = None) -> None: + def __init__( + self, + condition: PandasLikeExpr, + implementation: Implementation, + backend_version: tuple[int, ...], + then_value: Any = None, + otherise_value: Any = None, + ) -> None: self._implementation = implementation self._backend_version = backend_version self._condition = condition @@ -274,18 +282,21 @@ def __init__(self, condition: PandasLikeExpr, implementation: Implementation, ba def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._pandas_like.namespace import PandasLikeNamespace - plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self._backend_version) + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) condition = self._condition._call(df)[0] - value_series = plx._create_broadcast_series_from_scalar(self._then_value, condition) - otherwise_series = plx._create_broadcast_series_from_scalar(self._otherwise_value, condition) - return [ - value_series.zip_with(condition, otherwise_series) - ] + value_series = plx._create_broadcast_series_from_scalar( + self._then_value, condition + ) + otherwise_series = plx._create_broadcast_series_from_scalar( + self._otherwise_value, condition + ) + return [value_series.zip_with(condition, otherwise_series)] def then(self, value: Any) -> PandasThen: - self._then_value = value return PandasThen( @@ -298,8 +309,8 @@ def then(self, value: Any) -> PandasThen: backend_version=self._condition._backend_version, ) -class PandasThen(PandasLikeExpr): +class PandasThen(PandasLikeExpr): def __init__( self, call: PandasWhen, diff --git a/narwhals/expr.py b/narwhals/expr.py index 5aba8cc61..6eb0dd135 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3605,6 +3605,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) ) + class When: def __init__(self, condition: Expr) -> None: self._condition = condition @@ -3614,14 +3615,16 @@ def __init__(self, condition: Expr) -> None: def then(self, value: Any) -> Then: return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) + class Then(Expr): - def __init__(self, call) -> None: # noqa: ANN001 + def __init__(self, call) -> None: # noqa: ANN001 self._call = call def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 + +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 """ Start a `when-then-otherwise` expression. Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. @@ -3646,7 +3649,10 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When >>> @nw.narwhalify ... def func(df_any): ... from narwhals.expr import when - ... return df_any.with_columns(when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")) + ... + ... return df_any.with_columns( + ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... ) We can then pass either pandas or polars to `func`: diff --git a/tests/test_when.py b/tests/test_when.py index cc95cc347..90df13180 100644 --- a/tests/test_when.py +++ b/tests/test_when.py @@ -31,6 +31,7 @@ def test_when(request: Any, constructor: Any) -> None: } compare_dicts(result, expected) + def test_when_otherwise(request: Any, constructor: Any) -> None: if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) From 7390e1aeacd5034ecf41cbd9214b199003604a2c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 13:05:29 +0300 Subject: [PATCH 16/64] address mypy issues --- narwhals/_pandas_like/namespace.py | 5 ++++- narwhals/expr.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 51c82648f..cd1deee18 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -332,6 +332,9 @@ def __init__( self._output_names = output_names def otherwise(self, value: Any) -> PandasLikeExpr: - self._call._otherwise_value = value + # type ignore because we are setting the `_call` attribute to a + # callable object of type `PandasWhen`, base class has the attribute as + # only a `Callable` + self._call._otherwise_value = value # type: ignore self._function_name = "whenotherwise" return self diff --git a/narwhals/expr.py b/narwhals/expr.py index 6eb0dd135..2ef1e5308 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3617,7 +3617,7 @@ def then(self, value: Any) -> Then: class Then(Expr): - def __init__(self, call) -> None: # noqa: ANN001 + def __init__(self, call: Callable[[Any], Any]) -> None: self._call = call def otherwise(self, value: Any) -> Expr: From 279e3ad5796ebf50e20967f57114eb82e7df9929 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:08:20 +0000 Subject: [PATCH 17/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index cd1deee18..9411ba500 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -335,6 +335,6 @@ def otherwise(self, value: Any) -> PandasLikeExpr: # type ignore because we are setting the `_call` attribute to a # callable object of type `PandasWhen`, base class has the attribute as # only a `Callable` - self._call._otherwise_value = value # type: ignore + self._call._otherwise_value = value # type: ignore self._function_name = "whenotherwise" return self From 63048ee3752b7c13a7773e6fdbda21bdea478249 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 13:11:57 +0300 Subject: [PATCH 18/64] address ruff type-ignore blanket issue --- narwhals/_pandas_like/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 9411ba500..e05df30e8 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -335,6 +335,6 @@ def otherwise(self, value: Any) -> PandasLikeExpr: # type ignore because we are setting the `_call` attribute to a # callable object of type `PandasWhen`, base class has the attribute as # only a `Callable` - self._call._otherwise_value = value # type: ignore + self._call._otherwise_value = value # type: ignore[attr-defined] self._function_name = "whenotherwise" return self From e96af891e41c163c581db3cce0c1039448b953e1 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 15:00:42 +0300 Subject: [PATCH 19/64] support `Iterable[Expr]` in the pandas api --- narwhals/_pandas_like/namespace.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index e05df30e8..840414525 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -258,9 +258,13 @@ def concat( ) raise NotImplementedError - def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002 + def when( + self, + *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], + **constraints: Any, # noqa: ARG002 + ) -> PandasWhen: plx = self.__class__(self._implementation, self._backend_version) - condition = plx.all_horizontal(*predicates) + condition = plx.all_horizontal(*flatten(predicates)) return PandasWhen(condition, self._implementation, self._backend_version) From d4f0e9cc78df6a3cf577ce7409ebe4cad83e3848 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 15:24:34 +0300 Subject: [PATCH 20/64] move when test file to a better location --- tests/{ => expr_and_series}/test_when.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => expr_and_series}/test_when.py (100%) diff --git a/tests/test_when.py b/tests/expr_and_series/test_when.py similarity index 100% rename from tests/test_when.py rename to tests/expr_and_series/test_when.py From 99d9899e42c63d671e26f6a86ddffa2364477657 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 15:30:06 +0300 Subject: [PATCH 21/64] make when test filename similar to other tests --- tests/expr_and_series/{test_when.py => when_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/expr_and_series/{test_when.py => when_test.py} (100%) diff --git a/tests/expr_and_series/test_when.py b/tests/expr_and_series/when_test.py similarity index 100% rename from tests/expr_and_series/test_when.py rename to tests/expr_and_series/when_test.py From 71e542d4f611d4e2ffeb8a031671b02bc300b887 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 12:38:19 +0300 Subject: [PATCH 22/64] add simple when --- narwhals/_pandas_like/expr.py | 10 +- narwhals/expr.py | 29 ++ narwhals/expressions/whenthen.py | 0 tests/test_common.py | 718 +++++++++++++++++++++++++++++++ 4 files changed, 756 insertions(+), 1 deletion(-) create mode 100644 narwhals/expressions/whenthen.py create mode 100644 tests/test_common.py diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index f846da610..40dc15232 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -4,11 +4,11 @@ from typing import Any from typing import Callable from typing import Literal +from typing import Iterable from narwhals._expression_parsing import reuse_series_implementation from narwhals._expression_parsing import reuse_series_namespace_implementation from narwhals._pandas_like.series import PandasLikeSeries - if TYPE_CHECKING: from typing_extensions import Self @@ -343,6 +343,14 @@ def cat(self: Self) -> PandasLikeExprCatNamespace: def name(self: Self) -> PandasLikeExprNameNamespace: return PandasLikeExprNameNamespace(self) + def when(self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any) -> PandasWhen: + # TODO: Support conditions + from narwhals._pandas_like.namespace import PandasNamespace + + plx = PandasNamespace(self._implementation) + condition = plx.all_horizontal(*predicates) + return PandasWhen(self, condition) + class PandasLikeExprCatNamespace: def __init__(self, expr: PandasLikeExpr) -> None: diff --git a/narwhals/expr.py b/narwhals/expr.py index 2ef1e5308..c14256936 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -12,6 +12,9 @@ from narwhals.dtypes import translate_dtype from narwhals.utils import flatten + +from functools import reduce + if TYPE_CHECKING: from typing_extensions import Self @@ -3605,6 +3608,32 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) ) +class When: + def __init__(self, condition: Expr) -> None: + self._condition = condition + + def then(self, value: Any) -> Then: + return Then(self, value=value) + +class Then(Expr): + def __init__(self, when: When, *, value: Any) -> None: + self._when = when + self._then_value = value + + def func(plx): + return plx.when(self._when._condition._call(plx)).then(self._then_value) + + self._call = func + + def otherwise(self, value: Any) -> Expr: + def func(plx): + return plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value) + + return Expr(func) + +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: + return When(reduce(lambda a, b: a & b, flatten([predicates]))) + class When: def __init__(self, condition: Expr) -> None: diff --git a/narwhals/expressions/whenthen.py b/narwhals/expressions/whenthen.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 000000000..4162d7b24 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,718 @@ +from __future__ import annotations + +import os +import warnings +from typing import Any +from typing import Literal + +import numpy as np +import pandas as pd +import polars as pl +import pytest +from pandas.testing import assert_series_equal as pd_assert_series_equal +from polars.testing import assert_series_equal as pl_assert_series_equal + +import narwhals as nw +from narwhals.utils import parse_version +from tests.utils import compare_dicts + +df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) +if parse_version(pd.__version__) >= parse_version("1.5.0"): + df_pandas_pyarrow = pd.DataFrame( + {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + ).astype( + { + "a": "Int64[pyarrow]", + "b": "Int64[pyarrow]", + "z": "Float64[pyarrow]", + } + ) + df_pandas_nullable = pd.DataFrame( + {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + ).astype( + { + "a": "Int64", + "b": "Int64", + "z": "Float64", + } + ) +else: # pragma: no cover + df_pandas_pyarrow = df_pandas + df_pandas_nullable = df_pandas +df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) +df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) +df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) +df_lazy_na = pl.LazyFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) +df_right_pandas = pd.DataFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) +df_right_lazy = pl.LazyFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) + +if os.environ.get("CI", None): + try: + import modin.pandas as mpd + except ImportError: # pragma: no cover + df_mpd = df_pandas.copy() + else: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + df_mpd = mpd.DataFrame( + pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) + ) +else: # pragma: no cover + df_mpd = df_pandas.copy() + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_polars, df_lazy, df_pandas_nullable, df_pandas_pyarrow], +) +def test_sort(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.sort("a", "b") + result_native = nw.to_native(result) + expected = { + "a": [1, 2, 3], + "b": [4, 6, 4], + "z": [7.0, 9.0, 8.0], + } + compare_dicts(result_native, expected) + result = df.sort("a", "b", descending=[True, False]) + result_native = nw.to_native(result) + expected = { + "a": [3, 2, 1], + "b": [4, 6, 4], + "z": [8.0, 9.0, 7.0], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], +) +def test_filter(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.filter(nw.col("a") > 1) + result_native = nw.to_native(result) + expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_polars], +) +def test_filter_series(df_raw: Any) -> None: + df = nw.DataFrame(df_raw).with_columns(mask=nw.col("a") > 1) + result = df.filter(df["mask"]).drop("mask") + result_native = nw.to_native(result) + expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], +) +def test_add(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.with_columns( + c=nw.col("a") + nw.col("b"), + d=nw.col("a") - nw.col("a").mean(), + e=nw.col("a") - nw.col("a").std(), + ) + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "c": [5, 7, 8], + "d": [-1.0, 1.0, 0.0], + "e": [0.0, 2.0, 1.0], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], +) +def test_std(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.select( + nw.col("a").std().alias("a_ddof_default"), + nw.col("a").std(ddof=1).alias("a_ddof_1"), + nw.col("a").std(ddof=0).alias("a_ddof_0"), + nw.col("b").std(ddof=2).alias("b_ddof_2"), + nw.col("z").std(ddof=0).alias("z_ddof_0"), + ) + result_native = nw.to_native(result) + expected = { + "a_ddof_default": [1.0], + "a_ddof_1": [1.0], + "a_ddof_0": [0.816497], + "b_ddof_2": [1.632993], + "z_ddof_0": [0.816497], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], +) +def test_double(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.with_columns(nw.all() * 2) + result_native = nw.to_native(result) + expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} + compare_dicts(result_native, expected) + result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) + result_native = nw.to_native(result) + expected = {"o": [1, 3, 2], "a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], +) +def test_select(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.select("a") + result_native = nw.to_native(result) + expected = {"a": [1, 3, 2]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy, df_pandas_nullable]) +def test_sumh(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.with_columns(horizonal_sum=nw.sum_horizontal(nw.col("a"), nw.col("b"))) + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "horizonal_sum": [5, 7, 8], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_sumh_literal(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.with_columns(horizonal_sum=nw.sum_horizontal("a", nw.col("b"))) + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "horizonal_sum": [5, 7, 8], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_sum_all(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.select(nw.all().sum()) + result_native = nw.to_native(result) + expected = {"a": [6], "b": [14], "z": [24.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_double_selected(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.select(nw.col("a", "b") * 2) + result_native = nw.to_native(result) + expected = {"a": [2, 6, 4], "b": [8, 8, 12]} + compare_dicts(result_native, expected) + result = df.select("z", nw.col("a", "b") * 2) + result_native = nw.to_native(result) + expected = {"z": [7, 8, 9], "a": [2, 6, 4], "b": [8, 8, 12]} + compare_dicts(result_native, expected) + result = df.select("a").select(nw.col("a") + nw.all()) + result_native = nw.to_native(result) + expected = {"a": [2, 6, 4]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_rename(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.rename({"a": "x", "b": "y"}) + result_native = nw.to_native(result) + expected = {"x": [1, 3, 2], "y": [4, 4, 6], "z": [7.0, 8, 9]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_join(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + df_right = df + result = df.join(df_right, left_on=["a", "b"], right_on=["a", "b"], how="inner") + result_native = nw.to_native(result) + expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "z_right": [7.0, 8, 9]} + compare_dicts(result_native, expected) + + with pytest.raises(NotImplementedError): + result = df.join(df_right, left_on="a", right_on="a", how="left") # type: ignore[arg-type] + + result = df.collect().join(df_right.collect(), left_on="a", right_on="a", how="inner") # type: ignore[assignment] + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "b_right": [4, 4, 6], + "z": [7.0, 8, 9], + "z_right": [7.0, 8, 9], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_schema(df_raw: Any) -> None: + result = nw.LazyFrame(df_raw).schema + expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} + assert result == expected + result = nw.LazyFrame(df_raw).collect().schema + expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} + assert result == expected + result = nw.LazyFrame(df_raw).columns # type: ignore[assignment] + expected = ["a", "b", "z"] # type: ignore[assignment] + assert result == expected + result = nw.LazyFrame(df_raw).collect().columns # type: ignore[assignment] + expected = ["a", "b", "z"] # type: ignore[assignment] + assert result == expected + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_columns(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = df.columns + expected = ["a", "b", "z"] + assert result == expected + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +def test_lazy_instantiation(df_raw: Any) -> None: + result = nw.LazyFrame(df_raw) + result_native = nw.to_native(result) + expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_lazy]) +def test_lazy_instantiation_error(df_raw: Any) -> None: + with pytest.raises( + TypeError, match="Can't instantiate DataFrame from Polars LazyFrame." + ): + _ = nw.DataFrame(df_raw).shape + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) +def test_eager_instantiation(df_raw: Any) -> None: + result = nw.DataFrame(df_raw) + result_native = nw.to_native(result) + expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + compare_dicts(result_native, expected) + + +def test_accepted_dataframes() -> None: + array = np.array([[0, 4.0], [2, 5]]) + with pytest.raises( + TypeError, + match="Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: ", + ): + nw.DataFrame(array) + with pytest.raises( + TypeError, + match="Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: ", + ): + nw.LazyFrame(array) + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) +@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") +def test_convert_pandas(df_raw: Any) -> None: + result = nw.from_native(df_raw).to_pandas() # type: ignore[union-attr] + expected = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) + pd.testing.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_polars, df_pandas, df_mpd, df_pandas_nullable, df_pandas_pyarrow] +) +@pytest.mark.filterwarnings( + r"ignore:np\.find_common_type is deprecated\.:DeprecationWarning" +) +def test_convert_numpy(df_raw: Any) -> None: + result = nw.DataFrame(df_raw).to_numpy() + expected = np.array([[1, 3, 2], [4, 4, 6], [7.0, 8, 9]]).T + np.testing.assert_array_equal(result, expected) + assert result.dtype == "float64" + result = nw.DataFrame(df_raw).__array__() + np.testing.assert_array_equal(result, expected) + assert result.dtype == "float64" + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) +def test_shape(df_raw: Any) -> None: + result = nw.DataFrame(df_raw).shape + expected = (3, 3) + assert result == expected + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +def test_expr_binary(df_raw: Any) -> None: + result = nw.LazyFrame(df_raw).with_columns( + a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")), + b=nw.col("z") / (2 - nw.col("b")), + c=nw.col("a") + nw.col("b") / 2, + d=nw.col("a") - nw.col("b"), + e=((nw.col("a") > nw.col("b")) & (nw.col("a") >= nw.col("z"))).cast(nw.Int64), + f=( + (nw.col("a") < nw.col("b")) + | (nw.col("a") <= nw.col("z")) + | (nw.col("a") == 1) + ).cast(nw.Int64), + g=nw.col("a") != 1, + h=(False & (nw.col("a") != 1)), + i=(False | (nw.col("a") != 1)), + j=2 ** nw.col("a"), + k=2 // nw.col("a"), + l=nw.col("a") // 2, + m=nw.col("a") ** 2, + n=nw.col("a") % 2, + o=2 % nw.col("a"), + ) + result_native = nw.to_native(result) + expected = { + "a": [4, 3.333333, 3.5], + "b": [-3.5, -4.0, -2.25], + "z": [7.0, 8.0, 9.0], + "c": [3, 5, 5], + "d": [-3, -1, -4], + "e": [0, 0, 0], + "f": [1, 1, 1], + "g": [False, True, True], + "h": [False, False, False], + "i": [False, True, True], + "j": [2, 8, 4], + "k": [2, 0, 1], + "l": [0, 1, 1], + "m": [1, 9, 4], + "n": [1, 1, 0], + "o": [0, 2, 0], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_lazy]) +def test_expr_unary(df_raw: Any) -> None: + result = ( + nw.from_native(df_raw) + .with_columns( + a_mean=nw.col("a").mean(), + a_sum=nw.col("a").sum(), + b_nunique=nw.col("b").n_unique(), + z_min=nw.col("z").min(), + z_max=nw.col("z").max(), + ) + .select(nw.col("a_mean", "a_sum", "b_nunique", "z_min", "z_max").unique()) + ) + result_native = nw.to_native(result) + expected = {"a_mean": [2], "a_sum": [6], "b_nunique": [2], "z_min": [7], "z_max": [9]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +def test_expr_transform(df_raw: Any) -> None: + result = nw.LazyFrame(df_raw).with_columns( + a=nw.col("a").is_between(-1, 1), b=nw.col("b").is_in([4, 5]) + ) + result_native = nw.to_native(result) + expected = {"a": [True, False, False], "b": [True, True, False], "z": [7, 8, 9]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_lazy]) +def test_expr_min_max(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result_min = nw.to_native(df.select(nw.min("a", "b", "z"))) + result_max = nw.to_native(df.select(nw.max("a", "b", "z"))) + expected_min = {"a": [1], "b": [4], "z": [7]} + expected_max = {"a": [3], "b": [6], "z": [9]} + compare_dicts(result_min, expected_min) + compare_dicts(result_max, expected_max) + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +def test_expr_sample(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result_shape = nw.to_native(df.select(nw.col("a").sample(n=2)).collect()).shape + expected = (2, 1) + assert result_shape == expected + result_shape = nw.to_native(df.collect()["a"].sample(n=2)).shape + expected = (2,) # type: ignore[assignment] + assert result_shape == expected + + +@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) +def test_expr_na(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result_nna = nw.to_native( + df.filter((~nw.col("a").is_null()) & (~df.collect()["z"].is_null())) + ) + expected = {"a": [2], "b": [6], "z": [9]} + compare_dicts(result_nna, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_head(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = nw.to_native(df.head(2)) + expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} + compare_dicts(result, expected) + result = nw.to_native(df.collect().head(2)) + expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} + compare_dicts(result, expected) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_unique(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = nw.to_native(df.unique("b").sort("b")) + expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]} + compare_dicts(result, expected) + result = nw.to_native(df.collect().unique("b").sort("b")) + expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]} + compare_dicts(result, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) +def test_drop_nulls(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = nw.to_native(df.select(nw.col("a").drop_nulls())) + expected = {"a": [3, 2]} + compare_dicts(result, expected) + result = nw.to_native(df.select(df.collect()["a"].drop_nulls())) + expected = {"a": [3, 2]} + compare_dicts(result, expected) + + +@pytest.mark.parametrize( + ("df_raw", "df_raw_right"), [(df_pandas, df_right_pandas), (df_lazy, df_right_lazy)] +) +def test_concat_horizontal(df_raw: Any, df_raw_right: Any) -> None: + df_left = nw.LazyFrame(df_raw) + df_right = nw.LazyFrame(df_raw_right) + result = nw.concat([df_left, df_right], how="horizontal") + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8, 9], + "c": [6, 12, -1], + "d": [0, -4, 2], + } + compare_dicts(result_native, expected) + + with pytest.raises(ValueError, match="No items"): + nw.concat([]) + + +@pytest.mark.parametrize( + ("df_raw", "df_raw_right"), [(df_pandas, df_right_pandas), (df_lazy, df_right_lazy)] +) +def test_concat_vertical(df_raw: Any, df_raw_right: Any) -> None: + df_left = nw.LazyFrame(df_raw).collect().rename({"a": "c", "b": "d"}).lazy().drop("z") + df_right = nw.LazyFrame(df_raw_right) + result = nw.concat([df_left, df_right], how="vertical") + result_native = nw.to_native(result) + expected = {"c": [1, 3, 2, 6, 12, -1], "d": [4, 4, 6, 0, -4, 2]} + compare_dicts(result_native, expected) + with pytest.raises(ValueError, match="No items"): + nw.concat([], how="vertical") + with pytest.raises(Exception, match="unable to vstack"): + nw.concat([df_left, df_right.rename({"d": "i"})], how="vertical").collect() # type: ignore[union-attr] + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_lazy(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.lazy() + assert isinstance(result, nw.LazyFrame) + + +def test_to_dict() -> None: + df = nw.DataFrame(df_pandas) + result = df.to_dict(as_series=True) + expected = { + "a": pd.Series([1, 3, 2], name="a"), + "b": pd.Series([4, 4, 6], name="b"), + "z": pd.Series([7.0, 8, 9], name="z"), + } + for key in expected: + pd_assert_series_equal(nw.to_native(result[key]), expected[key]) + + df = nw.DataFrame(df_polars) + result = df.to_dict(as_series=True) + expected = { + "a": pl.Series("a", [1, 3, 2]), + "b": pl.Series("b", [4, 4, 6]), + "z": pl.Series("z", [7.0, 8, 9]), + } + for key in expected: + pl_assert_series_equal(nw.to_native(result[key]), expected[key]) + + +@pytest.mark.parametrize( + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] +) +def test_any_all(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = nw.to_native(df.select((nw.all() > 1).all())) + expected = {"a": [False], "b": [True], "z": [True]} + compare_dicts(result, expected) + result = nw.to_native(df.select((nw.all() > 1).any())) + expected = {"a": [True], "b": [True], "z": [True]} + compare_dicts(result, expected) + + +def test_invalid() -> None: + df = nw.LazyFrame(df_pandas) + with pytest.raises(ValueError, match="Multi-output"): + df.select(nw.all() + nw.all()) + with pytest.raises(TypeError, match="Perhaps you:"): + df.select([pl.col("a")]) # type: ignore[list-item] + with pytest.raises(TypeError, match="Perhaps you:"): + df.select([nw.col("a").cast(pl.Int64)]) + + +@pytest.mark.parametrize("df_raw", [df_pandas]) +def test_reindex(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.select("b", df["a"].sort(descending=True)) + expected = {"b": [4, 4, 6], "a": [3, 2, 1]} + compare_dicts(result, expected) + result = df.select("b", nw.col("a").sort(descending=True)) + compare_dicts(result, expected) + + s = df["a"] + result_s = s > s.sort() + assert not result_s[0] + assert result_s[1] + assert not result_s[2] + result = df.with_columns(s.sort()) + expected = {"a": [1, 2, 3], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} # type: ignore[list-item] + compare_dicts(result, expected) + with pytest.raises(ValueError, match="Multi-output expressions are not supported"): + nw.to_native(df.with_columns(nw.all() + nw.all())) + + +@pytest.mark.parametrize( + ("df_raw", "df_raw_right"), + [(df_pandas, df_polars), (df_polars, df_pandas)], +) +def test_library(df_raw: Any, df_raw_right: Any) -> None: + df_left = nw.LazyFrame(df_raw) + df_right = nw.LazyFrame(df_raw_right) + with pytest.raises( + NotImplementedError, match="Cross-library comparisons aren't supported" + ): + nw.concat([df_left, df_right], how="horizontal") + with pytest.raises( + NotImplementedError, match="Cross-library comparisons aren't supported" + ): + nw.concat([df_left, df_right], how="vertical") + with pytest.raises( + NotImplementedError, match="Cross-library comparisons aren't supported" + ): + df_left.join(df_right, left_on=["a"], right_on=["a"], how="inner") + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_is_duplicated(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = nw.concat([df, df.head(1)]).is_duplicated() # type: ignore [union-attr] + expected = np.array([True, False, False, True]) + assert (result.to_numpy() == expected).all() + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize(("threshold", "expected"), [(0, False), (10, True)]) +def test_is_empty(df_raw: Any, threshold: Any, expected: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.filter(nw.col("a") > threshold).is_empty() + assert result == expected + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_is_unique(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = nw.concat([df, df.head(1)]).is_unique() # type: ignore [union-attr] + expected = np.array([False, True, True, False]) + assert (result.to_numpy() == expected).all() + + +@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na.collect()]) +def test_null_count(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = nw.to_native(df.null_count()) + expected = {"a": [1], "b": [0], "z": [1]} + compare_dicts(result, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize( + ("interpolation", "expected"), + [ + ("lower", {"a": [1.0], "b": [4.0], "z": [7.0]}), + ("higher", {"a": [2.0], "b": [4.0], "z": [8.0]}), + ("midpoint", {"a": [1.5], "b": [4.0], "z": [7.5]}), + ("linear", {"a": [1.6], "b": [4.0], "z": [7.6]}), + ("nearest", {"a": [2.0], "b": [4.0], "z": [8.0]}), + ], +) +def test_quantile( + df_raw: Any, + interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + expected: dict[str, list[float]], +) -> None: + q = 0.3 + + df = nw.from_native(df_raw) + result = nw.to_native( + df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) + ) + compare_dicts(result, expected) + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_when(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.with_columns( + a=nw.when(nw.col("a") > 2, 1).otherwise(0), + b=nw.when(nw.col("a") > 2, 1).when(nw.col("a") < 1, -1).otherwise(0), + ) + expected = {"a": [0, 1, 0], "b": [0, 1, 0], "z": [7.0, 8.0, 9.0]} + compare_dicts(result, expected) From 8b1355a24e6ffbe3d5682f1b4143f3fc16bf08c9 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 12:44:10 +0300 Subject: [PATCH 23/64] lint with ruff --- narwhals/_pandas_like/expr.py | 9 +++++---- narwhals/expr.py | 3 --- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 40dc15232..58bb46ed7 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -3,12 +3,13 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Literal from typing import Iterable +from typing import Literal + +from narwhals._pandas_like.series import PandasSeries +from narwhals._pandas_like.utils import reuse_series_implementation +from narwhals._pandas_like.utils import reuse_series_namespace_implementation -from narwhals._expression_parsing import reuse_series_implementation -from narwhals._expression_parsing import reuse_series_namespace_implementation -from narwhals._pandas_like.series import PandasLikeSeries if TYPE_CHECKING: from typing_extensions import Self diff --git a/narwhals/expr.py b/narwhals/expr.py index c14256936..3a876cc27 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -12,9 +12,6 @@ from narwhals.dtypes import translate_dtype from narwhals.utils import flatten - -from functools import reduce - if TYPE_CHECKING: from typing_extensions import Self From eb361649024c0698cab6d52c3a9f6a6e9382c2ed Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 17 Jul 2024 17:06:18 +0300 Subject: [PATCH 24/64] use lambda expression --- narwhals/expr.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 3a876cc27..6af61ed96 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3617,18 +3617,12 @@ def __init__(self, when: When, *, value: Any) -> None: self._when = when self._then_value = value - def func(plx): - return plx.when(self._when._condition._call(plx)).then(self._then_value) - - self._call = func + self._call = lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value) def otherwise(self, value: Any) -> Expr: - def func(plx): - return plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value) - - return Expr(func) + return Expr(lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 return When(reduce(lambda a, b: a & b, flatten([predicates]))) From c9b09bfb370ee0e6f5d95f0021ace8bd45fad01d Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 16:38:38 +0300 Subject: [PATCH 25/64] Fix errors from the migration --- narwhals/_pandas_like/expr.py | 64 +++++++++++++++++++++++++++++++++-- narwhals/expr.py | 15 ++++---- tests/test_where.py | 47 +++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 10 deletions(-) create mode 100644 tests/test_where.py diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 58bb46ed7..34ed69fdb 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Iterable from typing import Literal from narwhals._pandas_like.series import PandasSeries @@ -352,7 +351,6 @@ def when(self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any condition = plx.all_horizontal(*predicates) return PandasWhen(self, condition) - class PandasLikeExprCatNamespace: def __init__(self, expr: PandasLikeExpr) -> None: self._expr = expr @@ -479,6 +477,7 @@ def total_nanoseconds(self) -> PandasLikeExpr: def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 return reuse_series_namespace_implementation( +<<<<<<< HEAD self._expr, "dt", "to_string", format ) @@ -635,3 +634,64 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._expr._implementation, backend_version=self._expr._backend_version, ) +||||||| parent of f3770b7 (Fix errors from the migration) + self._expr, "dt", "to_string", format + ) + +class PandasWhen: + def __init__(self, condition: PandasLikeExpr) -> None: + self._condition = condition + + def then(self, value: Any) -> PandasThen: + return PandasThen(self, value=value, implementation=self._condition._implementation) + +class PandasThen(PandasLikeExpr): + def __init__(self, when: PandasWhen, *, value: Any, implementation: Implementation, backend_version: tuple[int, ...]) -> None: + self._when = when + self._then_value = value + self._implementation = implementation + self.backend_version = backend_version + + def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self.backend_version) + + condition = self._when._condition._call(df)[0] + + value_series = plx._create_series_from_scalar(self._then_value, condition) + none_series = plx._create_series_from_scalar(None, condition) + return [ + value_series.zip_with(condition, none_series) + ] + + self._call = func + self._depth = 0 + self._function_name = "whenthen" + self._root_names = None + self._output_names = None + + def otherwise(self, value: Any) -> PandasLikeExpr: + def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._pandas_like.namespace import PandasLikeNamespace + plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self.backend_version) + condition = self._when._condition._call(df)[0] + value_series = plx._create_series_from_scalar(self._then_value, condition) + otherwise_series = plx._create_series_from_scalar(value, condition) + return [ + value_series.zip_with(condition, otherwise_series) + ] + + return PandasLikeExpr( + func, + depth=0, + function_name="whenthenotherwise", + root_names=None, + output_names=None, + implementation=self._implementation, + backend_version=self.backend_version, + ) +======= + self._expr, "dt", "to_string", format + ) +>>>>>>> f3770b7 (Fix errors from the migration) diff --git a/narwhals/expr.py b/narwhals/expr.py index 6af61ed96..cc394487a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3016,7 +3016,7 @@ def to_string(self, format: str) -> Expr: # noqa: A002 of trailing zeros. Nonetheless, this is probably consistent enough for most applications. - If you have an application where this is not enough, please open an issue + If you have an application here this is not enough, please open an issue and let us know. Examples: @@ -3608,19 +3608,18 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: def __init__(self, condition: Expr) -> None: self._condition = condition + self._then_value = None + self._otehrwise_value = None def then(self, value: Any) -> Then: - return Then(self, value=value) + return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) class Then(Expr): - def __init__(self, when: When, *, value: Any) -> None: - self._when = when - self._then_value = value - - self._call = lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value) + def __init__(self, call) -> None: # noqa: ANN001 + self._call = call def otherwise(self, value: Any) -> Expr: - return Expr(lambda plx: plx.when(self._when._condition._call(plx)).then(self._then_value).otherwise(value)) + return Expr(lambda plx: self._call(plx).otherwise(value)) def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 return When(reduce(lambda a, b: a & b, flatten([predicates]))) diff --git a/tests/test_where.py b/tests/test_where.py new file mode 100644 index 000000000..3661db0c5 --- /dev/null +++ b/tests/test_where.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import narwhals.stable.v1 as nw +from narwhals.expression import when +from tests.utils import compare_dicts + +data = { + "a": [1, 1, 2], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], +} + + +def test_when(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) + expected = { + "a": [1, 1, 2], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, 3, None], + } + compare_dicts(result, expected) + +def test_when_otherwise(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) + expected = { + "a": [1, 1, 2], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, 3, 6], + } + compare_dicts(result, expected) From 2b1eabcb0eddccabbe3acc3ccf5810b11ac47fa5 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 16:49:41 +0300 Subject: [PATCH 26/64] remove unnecessary changes --- narwhals/_pandas_like/expr.py | 9 ++++----- narwhals/expr.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 34ed69fdb..a16bb768a 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -477,11 +477,9 @@ def total_nanoseconds(self) -> PandasLikeExpr: def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 return reuse_series_namespace_implementation( -<<<<<<< HEAD self._expr, "dt", "to_string", format ) - class PandasLikeExprNameNamespace: def __init__(self: Self, expr: PandasLikeExpr) -> None: self._expr = expr @@ -634,9 +632,6 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._expr._implementation, backend_version=self._expr._backend_version, ) -||||||| parent of f3770b7 (Fix errors from the migration) - self._expr, "dt", "to_string", format - ) class PandasWhen: def __init__(self, condition: PandasLikeExpr) -> None: @@ -695,3 +690,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: self._expr, "dt", "to_string", format ) >>>>>>> f3770b7 (Fix errors from the migration) +||||||| parent of a7f442a (remove unnecessary changes) +>>>>>>> main +======= +>>>>>>> a7f442a (remove unnecessary changes) diff --git a/narwhals/expr.py b/narwhals/expr.py index cc394487a..6997c42f6 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3016,7 +3016,7 @@ def to_string(self, format: str) -> Expr: # noqa: A002 of trailing zeros. Nonetheless, this is probably consistent enough for most applications. - If you have an application here this is not enough, please open an issue + If you have an application where this is not enough, please open an issue and let us know. Examples: From 2ef564e93a8dd14f93b932f0592def68003ffb89 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 22 Jul 2024 19:12:09 +0300 Subject: [PATCH 27/64] remove unnecessary diff --- narwhals/_pandas_like/expr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index a16bb768a..5a8e63d47 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -480,6 +480,7 @@ def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 self._expr, "dt", "to_string", format ) + class PandasLikeExprNameNamespace: def __init__(self: Self, expr: PandasLikeExpr) -> None: self._expr = expr From 0e4773d51cee4ac4e62fed33f233874b5a6323c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 12:56:07 +0000 Subject: [PATCH 28/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/expr.py | 7 +++++-- tests/test_common.py | 1 + tests/test_where.py | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 6997c42f6..61199b5fa 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3605,6 +3605,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) ) + class When: def __init__(self, condition: Expr) -> None: self._condition = condition @@ -3614,14 +3615,16 @@ def __init__(self, condition: Expr) -> None: def then(self, value: Any) -> Then: return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) + class Then(Expr): - def __init__(self, call) -> None: # noqa: ANN001 + def __init__(self, call) -> None: # noqa: ANN001 self._call = call def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 + +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 return When(reduce(lambda a, b: a & b, flatten([predicates]))) diff --git a/tests/test_common.py b/tests/test_common.py index 4162d7b24..8abc0b765 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -707,6 +707,7 @@ def test_quantile( ) compare_dicts(result, expected) + @pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) def test_when(df_raw: Any) -> None: df = nw.DataFrame(df_raw) diff --git a/tests/test_where.py b/tests/test_where.py index 3661db0c5..776480ece 100644 --- a/tests/test_where.py +++ b/tests/test_where.py @@ -31,6 +31,7 @@ def test_when(request: Any, constructor: Any) -> None: } compare_dicts(result, expected) + def test_when_otherwise(request: Any, constructor: Any) -> None: if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) From 151fe14b3e8c2f96c67365712e01567f6a86a377 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 16:01:50 +0300 Subject: [PATCH 29/64] fix rebase error --- narwhals/_pandas_like/expr.py | 14 +++----------- narwhals/expressions/whenthen.py | 0 2 files changed, 3 insertions(+), 11 deletions(-) delete mode 100644 narwhals/expressions/whenthen.py diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 5a8e63d47..52fdc0068 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -5,9 +5,9 @@ from typing import Callable from typing import Literal -from narwhals._pandas_like.series import PandasSeries -from narwhals._pandas_like.utils import reuse_series_implementation -from narwhals._pandas_like.utils import reuse_series_namespace_implementation +from narwhals._expression_parsing import reuse_series_implementation +from narwhals._expression_parsing import reuse_series_namespace_implementation +from narwhals._pandas_like.series import PandasLikeSeries if TYPE_CHECKING: from typing_extensions import Self @@ -687,11 +687,3 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self.backend_version, ) -======= - self._expr, "dt", "to_string", format - ) ->>>>>>> f3770b7 (Fix errors from the migration) -||||||| parent of a7f442a (remove unnecessary changes) ->>>>>>> main -======= ->>>>>>> a7f442a (remove unnecessary changes) diff --git a/narwhals/expressions/whenthen.py b/narwhals/expressions/whenthen.py deleted file mode 100644 index e69de29bb..000000000 From add7b896f8996ab97d582f4756658000085eab63 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 16:03:23 +0300 Subject: [PATCH 30/64] remove files left from wrong rebase --- tests/test_common.py | 719 ------------------------------------------- tests/test_where.py | 48 --- 2 files changed, 767 deletions(-) delete mode 100644 tests/test_common.py delete mode 100644 tests/test_where.py diff --git a/tests/test_common.py b/tests/test_common.py deleted file mode 100644 index 8abc0b765..000000000 --- a/tests/test_common.py +++ /dev/null @@ -1,719 +0,0 @@ -from __future__ import annotations - -import os -import warnings -from typing import Any -from typing import Literal - -import numpy as np -import pandas as pd -import polars as pl -import pytest -from pandas.testing import assert_series_equal as pd_assert_series_equal -from polars.testing import assert_series_equal as pl_assert_series_equal - -import narwhals as nw -from narwhals.utils import parse_version -from tests.utils import compare_dicts - -df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) -if parse_version(pd.__version__) >= parse_version("1.5.0"): - df_pandas_pyarrow = pd.DataFrame( - {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - ).astype( - { - "a": "Int64[pyarrow]", - "b": "Int64[pyarrow]", - "z": "Float64[pyarrow]", - } - ) - df_pandas_nullable = pd.DataFrame( - {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - ).astype( - { - "a": "Int64", - "b": "Int64", - "z": "Float64", - } - ) -else: # pragma: no cover - df_pandas_pyarrow = df_pandas - df_pandas_nullable = df_pandas -df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) -df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) -df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) -df_lazy_na = pl.LazyFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) -df_right_pandas = pd.DataFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) -df_right_lazy = pl.LazyFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) - -if os.environ.get("CI", None): - try: - import modin.pandas as mpd - except ImportError: # pragma: no cover - df_mpd = df_pandas.copy() - else: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - df_mpd = mpd.DataFrame( - pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) - ) -else: # pragma: no cover - df_mpd = df_pandas.copy() - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_polars, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_sort(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.sort("a", "b") - result_native = nw.to_native(result) - expected = { - "a": [1, 2, 3], - "b": [4, 6, 4], - "z": [7.0, 9.0, 8.0], - } - compare_dicts(result_native, expected) - result = df.sort("a", "b", descending=[True, False]) - result_native = nw.to_native(result) - expected = { - "a": [3, 2, 1], - "b": [4, 6, 4], - "z": [8.0, 9.0, 7.0], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_filter(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.filter(nw.col("a") > 1) - result_native = nw.to_native(result) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_polars], -) -def test_filter_series(df_raw: Any) -> None: - df = nw.DataFrame(df_raw).with_columns(mask=nw.col("a") > 1) - result = df.filter(df["mask"]).drop("mask") - result_native = nw.to_native(result) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_add(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns( - c=nw.col("a") + nw.col("b"), - d=nw.col("a") - nw.col("a").mean(), - e=nw.col("a") - nw.col("a").std(), - ) - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "c": [5, 7, 8], - "d": [-1.0, 1.0, 0.0], - "e": [0.0, 2.0, 1.0], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_std(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select( - nw.col("a").std().alias("a_ddof_default"), - nw.col("a").std(ddof=1).alias("a_ddof_1"), - nw.col("a").std(ddof=0).alias("a_ddof_0"), - nw.col("b").std(ddof=2).alias("b_ddof_2"), - nw.col("z").std(ddof=0).alias("z_ddof_0"), - ) - result_native = nw.to_native(result) - expected = { - "a_ddof_default": [1.0], - "a_ddof_1": [1.0], - "a_ddof_0": [0.816497], - "b_ddof_2": [1.632993], - "z_ddof_0": [0.816497], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_double(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns(nw.all() * 2) - result_native = nw.to_native(result) - expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} - compare_dicts(result_native, expected) - result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) - result_native = nw.to_native(result) - expected = {"o": [1, 3, 2], "a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", - [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow], -) -def test_select(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select("a") - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy, df_pandas_nullable]) -def test_sumh(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns(horizonal_sum=nw.sum_horizontal(nw.col("a"), nw.col("b"))) - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "horizonal_sum": [5, 7, 8], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_sumh_literal(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.with_columns(horizonal_sum=nw.sum_horizontal("a", nw.col("b"))) - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "horizonal_sum": [5, 7, 8], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_sum_all(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select(nw.all().sum()) - result_native = nw.to_native(result) - expected = {"a": [6], "b": [14], "z": [24.0]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_double_selected(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.select(nw.col("a", "b") * 2) - result_native = nw.to_native(result) - expected = {"a": [2, 6, 4], "b": [8, 8, 12]} - compare_dicts(result_native, expected) - result = df.select("z", nw.col("a", "b") * 2) - result_native = nw.to_native(result) - expected = {"z": [7, 8, 9], "a": [2, 6, 4], "b": [8, 8, 12]} - compare_dicts(result_native, expected) - result = df.select("a").select(nw.col("a") + nw.all()) - result_native = nw.to_native(result) - expected = {"a": [2, 6, 4]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_rename(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.rename({"a": "x", "b": "y"}) - result_native = nw.to_native(result) - expected = {"x": [1, 3, 2], "y": [4, 4, 6], "z": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_join(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - df_right = df - result = df.join(df_right, left_on=["a", "b"], right_on=["a", "b"], how="inner") - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "z_right": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - with pytest.raises(NotImplementedError): - result = df.join(df_right, left_on="a", right_on="a", how="left") # type: ignore[arg-type] - - result = df.collect().join(df_right.collect(), left_on="a", right_on="a", how="inner") # type: ignore[assignment] - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "b_right": [4, 4, 6], - "z": [7.0, 8, 9], - "z_right": [7.0, 8, 9], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_schema(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw).schema - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - assert result == expected - result = nw.LazyFrame(df_raw).collect().schema - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - assert result == expected - result = nw.LazyFrame(df_raw).columns # type: ignore[assignment] - expected = ["a", "b", "z"] # type: ignore[assignment] - assert result == expected - result = nw.LazyFrame(df_raw).collect().columns # type: ignore[assignment] - expected = ["a", "b", "z"] # type: ignore[assignment] - assert result == expected - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_columns(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = df.columns - expected = ["a", "b", "z"] - assert result == expected - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_lazy_instantiation(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw) - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_lazy]) -def test_lazy_instantiation_error(df_raw: Any) -> None: - with pytest.raises( - TypeError, match="Can't instantiate DataFrame from Polars LazyFrame." - ): - _ = nw.DataFrame(df_raw).shape - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) -def test_eager_instantiation(df_raw: Any) -> None: - result = nw.DataFrame(df_raw) - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - -def test_accepted_dataframes() -> None: - array = np.array([[0, 4.0], [2, 5]]) - with pytest.raises( - TypeError, - match="Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: ", - ): - nw.DataFrame(array) - with pytest.raises( - TypeError, - match="Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: ", - ): - nw.LazyFrame(array) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) -@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") -def test_convert_pandas(df_raw: Any) -> None: - result = nw.from_native(df_raw).to_pandas() # type: ignore[union-attr] - expected = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) - pd.testing.assert_frame_equal(result, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_polars, df_pandas, df_mpd, df_pandas_nullable, df_pandas_pyarrow] -) -@pytest.mark.filterwarnings( - r"ignore:np\.find_common_type is deprecated\.:DeprecationWarning" -) -def test_convert_numpy(df_raw: Any) -> None: - result = nw.DataFrame(df_raw).to_numpy() - expected = np.array([[1, 3, 2], [4, 4, 6], [7.0, 8, 9]]).T - np.testing.assert_array_equal(result, expected) - assert result.dtype == "float64" - result = nw.DataFrame(df_raw).__array__() - np.testing.assert_array_equal(result, expected) - assert result.dtype == "float64" - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) -def test_shape(df_raw: Any) -> None: - result = nw.DataFrame(df_raw).shape - expected = (3, 3) - assert result == expected - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_expr_binary(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw).with_columns( - a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")), - b=nw.col("z") / (2 - nw.col("b")), - c=nw.col("a") + nw.col("b") / 2, - d=nw.col("a") - nw.col("b"), - e=((nw.col("a") > nw.col("b")) & (nw.col("a") >= nw.col("z"))).cast(nw.Int64), - f=( - (nw.col("a") < nw.col("b")) - | (nw.col("a") <= nw.col("z")) - | (nw.col("a") == 1) - ).cast(nw.Int64), - g=nw.col("a") != 1, - h=(False & (nw.col("a") != 1)), - i=(False | (nw.col("a") != 1)), - j=2 ** nw.col("a"), - k=2 // nw.col("a"), - l=nw.col("a") // 2, - m=nw.col("a") ** 2, - n=nw.col("a") % 2, - o=2 % nw.col("a"), - ) - result_native = nw.to_native(result) - expected = { - "a": [4, 3.333333, 3.5], - "b": [-3.5, -4.0, -2.25], - "z": [7.0, 8.0, 9.0], - "c": [3, 5, 5], - "d": [-3, -1, -4], - "e": [0, 0, 0], - "f": [1, 1, 1], - "g": [False, True, True], - "h": [False, False, False], - "i": [False, True, True], - "j": [2, 8, 4], - "k": [2, 0, 1], - "l": [0, 1, 1], - "m": [1, 9, 4], - "n": [1, 1, 0], - "o": [0, 2, 0], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_lazy]) -def test_expr_unary(df_raw: Any) -> None: - result = ( - nw.from_native(df_raw) - .with_columns( - a_mean=nw.col("a").mean(), - a_sum=nw.col("a").sum(), - b_nunique=nw.col("b").n_unique(), - z_min=nw.col("z").min(), - z_max=nw.col("z").max(), - ) - .select(nw.col("a_mean", "a_sum", "b_nunique", "z_min", "z_max").unique()) - ) - result_native = nw.to_native(result) - expected = {"a_mean": [2], "a_sum": [6], "b_nunique": [2], "z_min": [7], "z_max": [9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_expr_transform(df_raw: Any) -> None: - result = nw.LazyFrame(df_raw).with_columns( - a=nw.col("a").is_between(-1, 1), b=nw.col("b").is_in([4, 5]) - ) - result_native = nw.to_native(result) - expected = {"a": [True, False, False], "b": [True, True, False], "z": [7, 8, 9]} - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_lazy]) -def test_expr_min_max(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result_min = nw.to_native(df.select(nw.min("a", "b", "z"))) - result_max = nw.to_native(df.select(nw.max("a", "b", "z"))) - expected_min = {"a": [1], "b": [4], "z": [7]} - expected_max = {"a": [3], "b": [6], "z": [9]} - compare_dicts(result_min, expected_min) - compare_dicts(result_max, expected_max) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) -def test_expr_sample(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result_shape = nw.to_native(df.select(nw.col("a").sample(n=2)).collect()).shape - expected = (2, 1) - assert result_shape == expected - result_shape = nw.to_native(df.collect()["a"].sample(n=2)).shape - expected = (2,) # type: ignore[assignment] - assert result_shape == expected - - -@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) -def test_expr_na(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result_nna = nw.to_native( - df.filter((~nw.col("a").is_null()) & (~df.collect()["z"].is_null())) - ) - expected = {"a": [2], "b": [6], "z": [9]} - compare_dicts(result_nna, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_head(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.head(2)) - expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} - compare_dicts(result, expected) - result = nw.to_native(df.collect().head(2)) - expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_unique(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.unique("b").sort("b")) - expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]} - compare_dicts(result, expected) - result = nw.to_native(df.collect().unique("b").sort("b")) - expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) -def test_drop_nulls(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.select(nw.col("a").drop_nulls())) - expected = {"a": [3, 2]} - compare_dicts(result, expected) - result = nw.to_native(df.select(df.collect()["a"].drop_nulls())) - expected = {"a": [3, 2]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize( - ("df_raw", "df_raw_right"), [(df_pandas, df_right_pandas), (df_lazy, df_right_lazy)] -) -def test_concat_horizontal(df_raw: Any, df_raw_right: Any) -> None: - df_left = nw.LazyFrame(df_raw) - df_right = nw.LazyFrame(df_raw_right) - result = nw.concat([df_left, df_right], how="horizontal") - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8, 9], - "c": [6, 12, -1], - "d": [0, -4, 2], - } - compare_dicts(result_native, expected) - - with pytest.raises(ValueError, match="No items"): - nw.concat([]) - - -@pytest.mark.parametrize( - ("df_raw", "df_raw_right"), [(df_pandas, df_right_pandas), (df_lazy, df_right_lazy)] -) -def test_concat_vertical(df_raw: Any, df_raw_right: Any) -> None: - df_left = nw.LazyFrame(df_raw).collect().rename({"a": "c", "b": "d"}).lazy().drop("z") - df_right = nw.LazyFrame(df_raw_right) - result = nw.concat([df_left, df_right], how="vertical") - result_native = nw.to_native(result) - expected = {"c": [1, 3, 2, 6, 12, -1], "d": [4, 4, 6, 0, -4, 2]} - compare_dicts(result_native, expected) - with pytest.raises(ValueError, match="No items"): - nw.concat([], how="vertical") - with pytest.raises(Exception, match="unable to vstack"): - nw.concat([df_left, df_right.rename({"d": "i"})], how="vertical").collect() # type: ignore[union-attr] - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_lazy(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.lazy() - assert isinstance(result, nw.LazyFrame) - - -def test_to_dict() -> None: - df = nw.DataFrame(df_pandas) - result = df.to_dict(as_series=True) - expected = { - "a": pd.Series([1, 3, 2], name="a"), - "b": pd.Series([4, 4, 6], name="b"), - "z": pd.Series([7.0, 8, 9], name="z"), - } - for key in expected: - pd_assert_series_equal(nw.to_native(result[key]), expected[key]) - - df = nw.DataFrame(df_polars) - result = df.to_dict(as_series=True) - expected = { - "a": pl.Series("a", [1, 3, 2]), - "b": pl.Series("b", [4, 4, 6]), - "z": pl.Series("z", [7.0, 8, 9]), - } - for key in expected: - pl_assert_series_equal(nw.to_native(result[key]), expected[key]) - - -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_any_all(df_raw: Any) -> None: - df = nw.LazyFrame(df_raw) - result = nw.to_native(df.select((nw.all() > 1).all())) - expected = {"a": [False], "b": [True], "z": [True]} - compare_dicts(result, expected) - result = nw.to_native(df.select((nw.all() > 1).any())) - expected = {"a": [True], "b": [True], "z": [True]} - compare_dicts(result, expected) - - -def test_invalid() -> None: - df = nw.LazyFrame(df_pandas) - with pytest.raises(ValueError, match="Multi-output"): - df.select(nw.all() + nw.all()) - with pytest.raises(TypeError, match="Perhaps you:"): - df.select([pl.col("a")]) # type: ignore[list-item] - with pytest.raises(TypeError, match="Perhaps you:"): - df.select([nw.col("a").cast(pl.Int64)]) - - -@pytest.mark.parametrize("df_raw", [df_pandas]) -def test_reindex(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.select("b", df["a"].sort(descending=True)) - expected = {"b": [4, 4, 6], "a": [3, 2, 1]} - compare_dicts(result, expected) - result = df.select("b", nw.col("a").sort(descending=True)) - compare_dicts(result, expected) - - s = df["a"] - result_s = s > s.sort() - assert not result_s[0] - assert result_s[1] - assert not result_s[2] - result = df.with_columns(s.sort()) - expected = {"a": [1, 2, 3], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} # type: ignore[list-item] - compare_dicts(result, expected) - with pytest.raises(ValueError, match="Multi-output expressions are not supported"): - nw.to_native(df.with_columns(nw.all() + nw.all())) - - -@pytest.mark.parametrize( - ("df_raw", "df_raw_right"), - [(df_pandas, df_polars), (df_polars, df_pandas)], -) -def test_library(df_raw: Any, df_raw_right: Any) -> None: - df_left = nw.LazyFrame(df_raw) - df_right = nw.LazyFrame(df_raw_right) - with pytest.raises( - NotImplementedError, match="Cross-library comparisons aren't supported" - ): - nw.concat([df_left, df_right], how="horizontal") - with pytest.raises( - NotImplementedError, match="Cross-library comparisons aren't supported" - ): - nw.concat([df_left, df_right], how="vertical") - with pytest.raises( - NotImplementedError, match="Cross-library comparisons aren't supported" - ): - df_left.join(df_right, left_on=["a"], right_on=["a"], how="inner") - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_is_duplicated(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = nw.concat([df, df.head(1)]).is_duplicated() # type: ignore [union-attr] - expected = np.array([True, False, False, True]) - assert (result.to_numpy() == expected).all() - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -@pytest.mark.parametrize(("threshold", "expected"), [(0, False), (10, True)]) -def test_is_empty(df_raw: Any, threshold: Any, expected: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.filter(nw.col("a") > threshold).is_empty() - assert result == expected - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_is_unique(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = nw.concat([df, df.head(1)]).is_unique() # type: ignore [union-attr] - expected = np.array([False, True, True, False]) - assert (result.to_numpy() == expected).all() - - -@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na.collect()]) -def test_null_count(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = nw.to_native(df.null_count()) - expected = {"a": [1], "b": [0], "z": [1]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -@pytest.mark.parametrize( - ("interpolation", "expected"), - [ - ("lower", {"a": [1.0], "b": [4.0], "z": [7.0]}), - ("higher", {"a": [2.0], "b": [4.0], "z": [8.0]}), - ("midpoint", {"a": [1.5], "b": [4.0], "z": [7.5]}), - ("linear", {"a": [1.6], "b": [4.0], "z": [7.6]}), - ("nearest", {"a": [2.0], "b": [4.0], "z": [8.0]}), - ], -) -def test_quantile( - df_raw: Any, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], - expected: dict[str, list[float]], -) -> None: - q = 0.3 - - df = nw.from_native(df_raw) - result = nw.to_native( - df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) - ) - compare_dicts(result, expected) - - -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) -def test_when(df_raw: Any) -> None: - df = nw.DataFrame(df_raw) - result = df.with_columns( - a=nw.when(nw.col("a") > 2, 1).otherwise(0), - b=nw.when(nw.col("a") > 2, 1).when(nw.col("a") < 1, -1).otherwise(0), - ) - expected = {"a": [0, 1, 0], "b": [0, 1, 0], "z": [7.0, 8.0, 9.0]} - compare_dicts(result, expected) diff --git a/tests/test_where.py b/tests/test_where.py deleted file mode 100644 index 776480ece..000000000 --- a/tests/test_where.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -import narwhals.stable.v1 as nw -from narwhals.expression import when -from tests.utils import compare_dicts - -data = { - "a": [1, 1, 2], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], -} - - -def test_when(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) - - df = nw.from_native(constructor(data)) - result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) - expected = { - "a": [1, 1, 2], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], - "a_when": [3, 3, None], - } - compare_dicts(result, expected) - - -def test_when_otherwise(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) - - df = nw.from_native(constructor(data)) - result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) - expected = { - "a": [1, 1, 2], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], - "a_when": [3, 3, 6], - } - compare_dicts(result, expected) From fd21c78d8fa8c31fac04c7a920bdfae93b6b8daa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:08:47 +0000 Subject: [PATCH 31/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/expr.py | 37 ++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 52fdc0068..fa6d4988a 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -343,7 +343,9 @@ def cat(self: Self) -> PandasLikeExprCatNamespace: def name(self: Self) -> PandasLikeExprNameNamespace: return PandasLikeExprNameNamespace(self) - def when(self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any) -> PandasWhen: + def when( + self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any + ) -> PandasWhen: # TODO: Support conditions from narwhals._pandas_like.namespace import PandasNamespace @@ -351,6 +353,7 @@ def when(self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any condition = plx.all_horizontal(*predicates) return PandasWhen(self, condition) + class PandasLikeExprCatNamespace: def __init__(self, expr: PandasLikeExpr) -> None: self._expr = expr @@ -634,15 +637,26 @@ def to_uppercase(self: Self) -> PandasLikeExpr: backend_version=self._expr._backend_version, ) + class PandasWhen: def __init__(self, condition: PandasLikeExpr) -> None: self._condition = condition def then(self, value: Any) -> PandasThen: - return PandasThen(self, value=value, implementation=self._condition._implementation) + return PandasThen( + self, value=value, implementation=self._condition._implementation + ) + class PandasThen(PandasLikeExpr): - def __init__(self, when: PandasWhen, *, value: Any, implementation: Implementation, backend_version: tuple[int, ...]) -> None: + def __init__( + self, + when: PandasWhen, + *, + value: Any, + implementation: Implementation, + backend_version: tuple[int, ...], + ) -> None: self._when = when self._then_value = value self._implementation = implementation @@ -651,15 +665,15 @@ def __init__(self, when: PandasWhen, *, value: Any, implementation: Implementati def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._pandas_like.namespace import PandasLikeNamespace - plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self.backend_version) + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self.backend_version + ) condition = self._when._condition._call(df)[0] value_series = plx._create_series_from_scalar(self._then_value, condition) none_series = plx._create_series_from_scalar(None, condition) - return [ - value_series.zip_with(condition, none_series) - ] + return [value_series.zip_with(condition, none_series)] self._call = func self._depth = 0 @@ -670,13 +684,14 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def otherwise(self, value: Any) -> PandasLikeExpr: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._pandas_like.namespace import PandasLikeNamespace - plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self.backend_version) + + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self.backend_version + ) condition = self._when._condition._call(df)[0] value_series = plx._create_series_from_scalar(self._then_value, condition) otherwise_series = plx._create_series_from_scalar(value, condition) - return [ - value_series.zip_with(condition, otherwise_series) - ] + return [value_series.zip_with(condition, otherwise_series)] return PandasLikeExpr( func, From 2f03bd0dbb4979bbec1e24ef8c4fcc504752e778 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 16:41:54 +0300 Subject: [PATCH 32/64] chore: remove all wrong rebase leftover code --- narwhals/_pandas_like/expr.py | 76 ----------------------------------- narwhals/expr.py | 22 ---------- 2 files changed, 98 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index fa6d4988a..f846da610 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -343,16 +343,6 @@ def cat(self: Self) -> PandasLikeExprCatNamespace: def name(self: Self) -> PandasLikeExprNameNamespace: return PandasLikeExprNameNamespace(self) - def when( - self, *predicates: PandasExpr | Iterable[PandasExpr], **conditions: Any - ) -> PandasWhen: - # TODO: Support conditions - from narwhals._pandas_like.namespace import PandasNamespace - - plx = PandasNamespace(self._implementation) - condition = plx.all_horizontal(*predicates) - return PandasWhen(self, condition) - class PandasLikeExprCatNamespace: def __init__(self, expr: PandasLikeExpr) -> None: @@ -636,69 +626,3 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._expr._implementation, backend_version=self._expr._backend_version, ) - - -class PandasWhen: - def __init__(self, condition: PandasLikeExpr) -> None: - self._condition = condition - - def then(self, value: Any) -> PandasThen: - return PandasThen( - self, value=value, implementation=self._condition._implementation - ) - - -class PandasThen(PandasLikeExpr): - def __init__( - self, - when: PandasWhen, - *, - value: Any, - implementation: Implementation, - backend_version: tuple[int, ...], - ) -> None: - self._when = when - self._then_value = value - self._implementation = implementation - self.backend_version = backend_version - - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - from narwhals._pandas_like.namespace import PandasLikeNamespace - - plx = PandasLikeNamespace( - implementation=self._implementation, backend_version=self.backend_version - ) - - condition = self._when._condition._call(df)[0] - - value_series = plx._create_series_from_scalar(self._then_value, condition) - none_series = plx._create_series_from_scalar(None, condition) - return [value_series.zip_with(condition, none_series)] - - self._call = func - self._depth = 0 - self._function_name = "whenthen" - self._root_names = None - self._output_names = None - - def otherwise(self, value: Any) -> PandasLikeExpr: - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - from narwhals._pandas_like.namespace import PandasLikeNamespace - - plx = PandasLikeNamespace( - implementation=self._implementation, backend_version=self.backend_version - ) - condition = self._when._condition._call(df)[0] - value_series = plx._create_series_from_scalar(self._then_value, condition) - otherwise_series = plx._create_series_from_scalar(value, condition) - return [value_series.zip_with(condition, otherwise_series)] - - return PandasLikeExpr( - func, - depth=0, - function_name="whenthenotherwise", - root_names=None, - output_names=None, - implementation=self._implementation, - backend_version=self.backend_version, - ) diff --git a/narwhals/expr.py b/narwhals/expr.py index 61199b5fa..2ef1e5308 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3616,28 +3616,6 @@ def then(self, value: Any) -> Then: return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) -class Then(Expr): - def __init__(self, call) -> None: # noqa: ANN001 - self._call = call - - def otherwise(self, value: Any) -> Expr: - return Expr(lambda plx: self._call(plx).otherwise(value)) - - -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 - return When(reduce(lambda a, b: a & b, flatten([predicates]))) - - -class When: - def __init__(self, condition: Expr) -> None: - self._condition = condition - self._then_value = None - self._otehrwise_value = None - - def then(self, value: Any) -> Then: - return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) - - class Then(Expr): def __init__(self, call: Callable[[Any], Any]) -> None: self._call = call From 43e267072ff11e32c14deb74fdbe5ac6fad482d1 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 16:50:40 +0300 Subject: [PATCH 33/64] feat: add chaining for polars --- narwhals/expr.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/narwhals/expr.py b/narwhals/expr.py index 2ef1e5308..78ba99ed5 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3623,6 +3623,30 @@ def __init__(self, call: Callable[[Any], Any]) -> None: def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) + def when( + *predicates: IntoExpr | Iterable[IntoExpr], + **constraints: Any, # noqa: ARG002 + ) -> ChainedWhen: + return ChainedWhen(reduce(lambda a, b: a & b, flatten([predicates]))) + + +class ChainedWhen: + def __init__(self, condition: Expr) -> None: + self._condition = condition + self._then_value = None + self._otehrwise_value = None + + def then(self, value: Any) -> ChainedThen: + return ChainedThen(lambda plx: plx.when(self._condition._call(plx)).then(value)) + + +class ChainedThen(Expr): + def __init__(self, call: Callable[[Any], Any]) -> None: + self._call = call + + def otherwise(self, value: Any) -> Expr: + return Expr(lambda plx: self._call(plx).otherwise(value)) + def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 """ From 071ec9f3f800df0813876b418e93d8fcd0844992 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 23 Jul 2024 16:52:39 +0300 Subject: [PATCH 34/64] chore: remove unused fields --- narwhals/expr.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 78ba99ed5..87bc4d57b 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3609,8 +3609,6 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: def __init__(self, condition: Expr) -> None: self._condition = condition - self._then_value = None - self._otehrwise_value = None def then(self, value: Any) -> Then: return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) @@ -3633,8 +3631,6 @@ def when( class ChainedWhen: def __init__(self, condition: Expr) -> None: self._condition = condition - self._then_value = None - self._otehrwise_value = None def then(self, value: Any) -> ChainedThen: return ChainedThen(lambda plx: plx.when(self._condition._call(plx)).then(value)) From 9a49db6c08e971f1c5c311773e8a2b5bf7ff5223 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 16:19:38 +0300 Subject: [PATCH 35/64] bug: fix bug in chaining feat: add pandas when chaining --- narwhals/_pandas_like/namespace.py | 108 +++++++++++++++++++++++++++++ narwhals/expr.py | 12 +++- tests/expr_and_series/when_test.py | 33 +++++++-- 3 files changed, 145 insertions(+), 8 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 840414525..e268f2b60 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -282,6 +282,7 @@ def __init__( self._condition = condition self._then_value = then_value self._otherwise_value = otherise_value + self._already_set = self._condition def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -335,6 +336,15 @@ def __init__( self._root_names = root_names self._output_names = output_names + def when(self, condition: PandasLikeExpr) -> PandasChainedWhen: + return PandasChainedWhen( + self._call, # type: ignore[arg-type] + condition, + depth=self._depth + 1, + implementation=self._implementation, + backend_version=self._backend_version, + ) + def otherwise(self, value: Any) -> PandasLikeExpr: # type ignore because we are setting the `_call` attribute to a # callable object of type `PandasWhen`, base class has the attribute as @@ -342,3 +352,101 @@ def otherwise(self, value: Any) -> PandasLikeExpr: self._call._otherwise_value = value # type: ignore[attr-defined] self._function_name = "whenotherwise" return self + + +class PandasChainedWhen: + def __init__( + self, + above_when: PandasWhen | PandasChainedWhen, + condition: PandasLikeExpr, + depth: int, + implementation: Implementation, + backend_version: tuple[int, ...], + then_value: Any = None, + otherise_value: Any = None, + ) -> None: + self._implementation = implementation + self._depth = depth + self._backend_version = backend_version + self._condition = condition + self._above_when = above_when + self._then_value = then_value + self._otherwise_value = otherise_value + + # TODO @aivanoved: this is way slow as during computation time this takes + # quadratic time need to improve this to linear time + self._condition = self._condition & (~self._above_when._already_set) # type: ignore[has-type] + self._already_set = self._above_when._already_set | self._condition # type: ignore[has-type] + + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) + + set_then = self._condition._call(df)[0] + already_set = self._already_set._call(df)[0] + + value_series = plx._create_broadcast_series_from_scalar( + self._then_value, set_then + ) + otherwise_series = plx._create_broadcast_series_from_scalar( + self._otherwise_value, set_then + ) + + above_result = self._above_when(df)[0] + + result = value_series.zip_with(set_then, above_result).zip_with( + already_set, otherwise_series + ) + + return [result] + + def then(self, value: Any) -> PandasChainedThen: + self._then_value = value + return PandasChainedThen( + self, + depth=self._depth, + implementation=self._implementation, + function_name="chainedwhen", + root_names=None, + output_names=None, + backend_version=self._backend_version, + ) + + +class PandasChainedThen(PandasLikeExpr): + def __init__( + self, + call: PandasChainedWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + implementation: Implementation, + backend_version: tuple[int, ...], + ) -> None: + self._implementation = implementation + self._backend_version = backend_version + + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + + def when(self, condition: PandasLikeExpr) -> PandasChainedWhen: + return PandasChainedWhen( + self._call, # type: ignore[arg-type] + condition, + depth=self._depth + 1, + implementation=self._implementation, + backend_version=self._backend_version, + ) + + def otherwise(self, value: Any) -> PandasLikeExpr: + self._call._otherwise_value = value # type: ignore[attr-defined] + self._function_name = "chainedwhenotherwise" + return self diff --git a/narwhals/expr.py b/narwhals/expr.py index 87bc4d57b..c7d27b7a6 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3622,18 +3622,24 @@ def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) def when( + self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any, # noqa: ARG002 ) -> ChainedWhen: - return ChainedWhen(reduce(lambda a, b: a & b, flatten([predicates]))) + return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) class ChainedWhen: - def __init__(self, condition: Expr) -> None: + def __init__(self, above_then: Then, condition: Expr) -> None: + self._above_then = above_then self._condition = condition def then(self, value: Any) -> ChainedThen: - return ChainedThen(lambda plx: plx.when(self._condition._call(plx)).then(value)) + return ChainedThen( + lambda plx: self._above_then._call(plx) + .when(self._condition._call(plx)) + .then(value) + ) class ChainedThen(Expr): diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 90df13180..01abefda3 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -9,7 +9,7 @@ from tests.utils import compare_dicts data = { - "a": [1, 1, 2], + "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [4.1, 5.0, 6.0], "d": [True, False, True], @@ -23,11 +23,11 @@ def test_when(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { - "a": [1, 1, 2], + "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [4.1, 5.0, 6.0], "d": [True, False, True], - "a_when": [3, 3, None], + "a_when": [3, None, None], } compare_dicts(result, expected) @@ -39,10 +39,33 @@ def test_when_otherwise(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { - "a": [1, 1, 2], + "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [4.1, 5.0, 6.0], "d": [True, False, True], - "a_when": [3, 3, 6], + "a_when": [3, 6, 6], + } + compare_dicts(result, expected) + + +def test_chained_when(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns( + when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .otherwise(7) + .alias("a_when"), + ) + expected = { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, 5, 7], } compare_dicts(result, expected) From 47a8cdd20dd8c06b52fc6a8f1c119a9d503163d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:20:20 +0000 Subject: [PATCH 36/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/namespace.py | 2 +- narwhals/expr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index e268f2b60..8a7733639 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -261,7 +261,7 @@ def concat( def when( self, *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], - **constraints: Any, # noqa: ARG002 + **constraints: Any, ) -> PandasWhen: plx = self.__class__(self._implementation, self._backend_version) condition = plx.all_horizontal(*flatten(predicates)) diff --git a/narwhals/expr.py b/narwhals/expr.py index c7d27b7a6..18988ac3a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3624,7 +3624,7 @@ def otherwise(self, value: Any) -> Expr: def when( self, *predicates: IntoExpr | Iterable[IntoExpr], - **constraints: Any, # noqa: ARG002 + **constraints: Any, ) -> ChainedWhen: return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) From b71847b63dae136e3b09702149dc81bf8efc5991 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 16:32:28 +0300 Subject: [PATCH 37/64] bug: add chaing from chained then test: add test for multiple chained conditions --- narwhals/expr.py | 11 +++++- tests/expr_and_series/when_test.py | 62 +++++++++++++++++++++--------- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 18988ac3a..f23d5917c 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3624,13 +3624,13 @@ def otherwise(self, value: Any) -> Expr: def when( self, *predicates: IntoExpr | Iterable[IntoExpr], - **constraints: Any, + **constraints: Any, # noqa: ARG002 ) -> ChainedWhen: return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) class ChainedWhen: - def __init__(self, above_then: Then, condition: Expr) -> None: + def __init__(self, above_then: Then | ChainedThen, condition: Expr) -> None: self._above_then = above_then self._condition = condition @@ -3646,6 +3646,13 @@ class ChainedThen(Expr): def __init__(self, call: Callable[[Any], Any]) -> None: self._call = call + def when( + self, + *predicates: IntoExpr | Iterable[IntoExpr], + **constraints: Any, # noqa: ARG002 + ) -> ChainedWhen: + return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) + def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 01abefda3..703dd5c4c 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -9,10 +9,10 @@ from tests.utils import compare_dicts data = { - "a": [1, 2, 3], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], + "a": [1, 2, 3, 4, 5], + "b": ["a", "b", "c", "d", "e"], + "c": [4.1, 5.0, 6.0, 7.0, 8.0], + "d": [True, False, True, False, True], } @@ -23,11 +23,11 @@ def test_when(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { - "a": [1, 2, 3], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], - "a_when": [3, None, None], + "a": [1, 2, 3, 4, 5], + "b": ["a", "b", "c", "d", "e"], + "c": [4.1, 5.0, 6.0, 7.0, 8.0], + "d": [True, False, True, False, True], + "a_when": [3, None, None, None, None], } compare_dicts(result, expected) @@ -39,11 +39,11 @@ def test_when_otherwise(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { - "a": [1, 2, 3], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], - "a_when": [3, 6, 6], + "a": [1, 2, 3, 4, 5], + "b": ["a", "b", "c", "d", "e"], + "c": [4.1, 5.0, 6.0, 7.0, 8.0], + "d": [True, False, True, False, True], + "a_when": [3, 6, 6, 6, 6], } compare_dicts(result, expected) @@ -62,10 +62,34 @@ def test_chained_when(request: Any, constructor: Any) -> None: .alias("a_when"), ) expected = { - "a": [1, 2, 3], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], - "a_when": [3, 5, 7], + "a": [1, 2, 3, 4, 5], + "b": ["a", "b", "c", "d", "e"], + "c": [4.1, 5.0, 6.0, 7.0, 8.0], + "d": [True, False, True, False, True], + "a_when": [3, 5, 7, 7, 7], + } + compare_dicts(result, expected) + + +def test_when_with_multiple_conditions(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + result = df.with_columns( + when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .otherwise(9) + .alias("a_when"), + ) + expected = { + "a": [1, 2, 3, 4, 5], + "b": ["a", "b", "c", "d", "e"], + "c": [4.1, 5.0, 6.0, 7.0, 8.0], + "d": [True, False, True, False, True], + "a_when": [3, 5, 7, 9, 9], } compare_dicts(result, expected) From b5d7df851bebc2d62f542ccdd478085467198258 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:33:50 +0000 Subject: [PATCH 38/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index f23d5917c..ebf3f12b2 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3624,7 +3624,7 @@ def otherwise(self, value: Any) -> Expr: def when( self, *predicates: IntoExpr | Iterable[IntoExpr], - **constraints: Any, # noqa: ARG002 + **constraints: Any, ) -> ChainedWhen: return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) @@ -3649,7 +3649,7 @@ def __init__(self, call: Callable[[Any], Any]) -> None: def when( self, *predicates: IntoExpr | Iterable[IntoExpr], - **constraints: Any, # noqa: ARG002 + **constraints: Any, ) -> ChainedWhen: return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) From 2ddd993a4749f88ebc3f879ef58cc233c5781761 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 16:39:14 +0300 Subject: [PATCH 39/64] docs: add when to api reference --- docs/api-reference/narwhals.md | 1 + narwhals/__init__.py | 2 ++ narwhals/expr.py | 4 +--- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index f6976b4c6..fe188e7db 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -25,6 +25,7 @@ Here are the top-level functions available in Narwhals. - narwhalify - sum - sum_horizontal + - when - show_versions - to_native show_source: false diff --git a/narwhals/__init__.py b/narwhals/__init__.py index b9175d192..8b1e4f80f 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -32,6 +32,7 @@ from narwhals.expr import min from narwhals.expr import sum from narwhals.expr import sum_horizontal +from narwhals.expr import when from narwhals.functions import concat from narwhals.functions import get_level from narwhals.functions import show_versions @@ -69,6 +70,7 @@ "mean", "sum", "sum_horizontal", + "when", "DataFrame", "LazyFrame", "Series", diff --git a/narwhals/expr.py b/narwhals/expr.py index ebf3f12b2..02a77a867 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3681,10 +3681,8 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When >>> @nw.narwhalify ... def func(df_any): - ... from narwhals.expr import when - ... ... return df_any.with_columns( - ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") ... ) We can then pass either pandas or polars to `func`: From 489c46350e1111f949e7ff80eb0dfdc13e54b730 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 18:13:37 +0300 Subject: [PATCH 40/64] bug: allow constraints to be passed to pandas implementation --- narwhals/_pandas_like/namespace.py | 34 ++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 8a7733639..257d84830 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -264,7 +264,17 @@ def when( **constraints: Any, ) -> PandasWhen: plx = self.__class__(self._implementation, self._backend_version) - condition = plx.all_horizontal(*flatten(predicates)) + import narwhals as nw + + if predicates: + condition = plx.all_horizontal(*flatten(predicates)) + elif constraints: + condition = plx.all_horizontal( + *(nw.col(name) == value for name, value in constraints.items()) + ) + else: + msg = "Must provide at least one predicate or constraint" + raise ValueError(msg) return PandasWhen(condition, self._implementation, self._backend_version) @@ -437,7 +447,27 @@ def __init__( self._root_names = root_names self._output_names = output_names - def when(self, condition: PandasLikeExpr) -> PandasChainedWhen: + def when( + self, + *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], + **constraints: Any, + ) -> PandasChainedWhen: + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) + if predicates: + condition = plx.all_horizontal(*flatten(predicates)) + elif constraints: + import narwhals as nw + + condition = plx.all_horizontal( + *(nw.col(name) == value for name, value in constraints.items()) + ) + else: + msg = "Must provide at least one predicate or constraint" + raise ValueError(msg) return PandasChainedWhen( self._call, # type: ignore[arg-type] condition, From bb3847e2842d8e3a4371041ea300f33ec0e6ad11 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 18:24:19 +0300 Subject: [PATCH 41/64] misc: fix typo --- narwhals/_pandas_like/namespace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 257d84830..b32e8640b 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -273,7 +273,7 @@ def when( *(nw.col(name) == value for name, value in constraints.items()) ) else: - msg = "Must provide at least one predicate or constraint" + msg = "Must provide at least one predicates or constraints" raise ValueError(msg) return PandasWhen(condition, self._implementation, self._backend_version) @@ -466,7 +466,7 @@ def when( *(nw.col(name) == value for name, value in constraints.items()) ) else: - msg = "Must provide at least one predicate or constraint" + msg = "Must provide at least one predicates or constraints" raise ValueError(msg) return PandasChainedWhen( self._call, # type: ignore[arg-type] From 37cc634e91c0df3aeb63de8adf9e9733f4624029 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 17:29:54 +0000 Subject: [PATCH 42/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 840414525..dee544b31 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -261,7 +261,7 @@ def concat( def when( self, *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], - **constraints: Any, # noqa: ARG002 + **constraints: Any, ) -> PandasWhen: plx = self.__class__(self._implementation, self._backend_version) condition = plx.all_horizontal(*flatten(predicates)) From 0454ac4a0b41d785f8d4e75f30c9d5013e1464ec Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 20:59:18 +0300 Subject: [PATCH 43/64] misc: keep api the same --- narwhals/_pandas_like/namespace.py | 13 ++++++++++++- narwhals/expr.py | 23 +++++++++++++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index dee544b31..48767e80f 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -264,7 +264,18 @@ def when( **constraints: Any, ) -> PandasWhen: plx = self.__class__(self._implementation, self._backend_version) - condition = plx.all_horizontal(*flatten(predicates)) + if predicates: + condition = plx.all_horizontal(*flatten(predicates)) + elif constraints: + import narwhals as nw + + condition = plx.all_horizontal( + *flatten((nw.col(key) == value) for key, value in constraints.items()) + ) + else: + msg = "Must provide either predicates or constraints" + raise ValueError(msg) + return PandasWhen(condition, self._implementation, self._backend_version) diff --git a/narwhals/expr.py b/narwhals/expr.py index 2ef1e5308..3a91faaa8 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -3607,13 +3606,21 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: - def __init__(self, condition: Expr) -> None: - self._condition = condition - self._then_value = None - self._otehrwise_value = None + def __init__( + self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any + ) -> None: + self._predicates = flatten([predicates]) + self._constraints = constraints + + def _extract_predicates(self, plx: Any) -> Any: + return [extract_compliant(plx, v) for v in self._predicates] def then(self, value: Any) -> Then: - return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) + return Then( + lambda plx: plx.when( + *self._extract_predicates(plx), **self._constraints + ).then(value) + ) class Then(Expr): @@ -3624,7 +3631,7 @@ def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: """ Start a `when-then-otherwise` expression. Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. @@ -3673,7 +3680,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When │ 3 ┆ 15 ┆ 6 │ └─────┴─────┴────────┘ """ - return When(reduce(lambda a, b: a & b, flatten([predicates]))) + return When(*predicates, **constraints) def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: From 4ad28b7b4966a160b8292ce8f96a2d2ad27cf372 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 25 Jul 2024 21:06:16 +0300 Subject: [PATCH 44/64] test: add test for multiple predicates --- tests/expr_and_series/when_test.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 90df13180..47b19e70b 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -9,7 +9,7 @@ from tests.utils import compare_dicts data = { - "a": [1, 1, 2], + "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [4.1, 5.0, 6.0], "d": [True, False, True], @@ -23,11 +23,11 @@ def test_when(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { - "a": [1, 1, 2], + "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [4.1, 5.0, 6.0], "d": [True, False, True], - "a_when": [3, 3, None], + "a_when": [3, None, None], } compare_dicts(result, expected) @@ -39,10 +39,28 @@ def test_when_otherwise(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { - "a": [1, 1, 2], + "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [4.1, 5.0, 6.0], "d": [True, False, True], - "a_when": [3, 3, 6], + "a_when": [3, 6, 6], + } + compare_dicts(result, expected) + + +def test_multiple_conditions(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns( + when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") + ) + expected = { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, None, None], } compare_dicts(result, expected) From 0ded39307f9294a8faf8c70cc260fdd90a898f23 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 29 Jul 2024 03:53:12 +0300 Subject: [PATCH 45/64] misc: make when stable --- narwhals/stable/v1.py | 74 ++++++++++++++++++++++++++++++ tests/expr_and_series/when_test.py | 2 +- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 18f520bdf..0760f0c8f 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -34,6 +34,9 @@ from narwhals.dtypes import UInt64 from narwhals.dtypes import Unknown from narwhals.expr import Expr as NwExpr +from narwhals.expr import Then as NwThen +from narwhals.expr import When as NwWhen +from narwhals.expr import when as nw_when from narwhals.functions import concat from narwhals.functions import show_versions from narwhals.schema import Schema as NwSchema @@ -1391,6 +1394,76 @@ def get_level( return nw.get_level(obj) +class When(NwWhen): + @classmethod + def from_when(cls, when: NwWhen) -> Self: + return cls(*when._predicates, **when._constraints) + + def then(self, value: Any) -> Then: + return Then( + lambda plx: plx.when( + *self._extract_predicates(plx), **self._constraints + ).then(value) + ) + + +class Then(NwThen): + def otherwise(self, value: Any) -> Expr: + return _stableify(super().otherwise(value)) + + +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: + """ + Start a `when-then-otherwise` expression. + Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. + Chained when-then operations should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to `True` will be picked. + If none of the conditions are `True`, an optional `.otherwise()` can be appended at the end. If not appended, and none of the conditions are `True`, `None` will be returned. + + Parameters: + predicates + Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name. + constraints + Apply conditions as `col_name = value` keyword arguments that are treated as equality matches, such as `x = 123`. As with the predicates parameter, multiple conditions are implicitly combined using `&`. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]}) + >>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]}) + + We define a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df_any): + ... from narwhals.expr import when + ... + ... return df_any.with_columns( + ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... ) + + We can then pass either pandas or polars to `func`: + + >>> func(df_pd) + a b a_when + 0 1 5 5 + 1 2 10 5 + 2 3 15 6 + >>> func(df_pl) + shape: (3, 3) + ┌─────┬─────┬────────┐ + │ a ┆ b ┆ a_when │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i32 │ + ╞═════╪═════╪════════╡ + │ 1 ┆ 5 ┆ 5 │ + │ 2 ┆ 10 ┆ 5 │ + │ 3 ┆ 15 ┆ 6 │ + └─────┴─────┴────────┘ + """ + return When.from_when(nw_when(*predicates, **constraints)) + + __all__ = [ "selectors", "concat", @@ -1412,6 +1485,7 @@ def get_level( "mean", "sum", "sum_horizontal", + "when", "DataFrame", "LazyFrame", "Series", diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 47b19e70b..b17d92e61 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -5,7 +5,7 @@ import pytest import narwhals.stable.v1 as nw -from narwhals.expr import when +from narwhals.stable.v1 import when from tests.utils import compare_dicts data = { From 3280a3ca3fad62ed03e16158eadd96b6ee596cc1 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 29 Jul 2024 12:56:46 +0300 Subject: [PATCH 46/64] bug: make stable v1 `Then` a stable expr `Expr` docs: update stable v1 `when` docs to use stable api --- narwhals/stable/v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 0760f0c8f..ee025dc1c 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -1407,7 +1407,7 @@ def then(self, value: Any) -> Then: ) -class Then(NwThen): +class Then(NwThen, Expr): def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) @@ -1428,7 +1428,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When Examples: >>> import pandas as pd >>> import polars as pl - >>> import narwhals as nw + >>> import narwhals.stable.v1 as nw >>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]}) >>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]}) From 5c6deed0332c56f8aca102b8c82ffa7154a52024 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 29 Jul 2024 13:09:12 +0300 Subject: [PATCH 47/64] bug: fix when constraints pandas implementation --- narwhals/_pandas_like/namespace.py | 4 +--- tests/expr_and_series/when_test.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 48767e80f..c46785ae3 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -267,10 +267,8 @@ def when( if predicates: condition = plx.all_horizontal(*flatten(predicates)) elif constraints: - import narwhals as nw - condition = plx.all_horizontal( - *flatten((nw.col(key) == value) for key, value in constraints.items()) + *flatten([plx.col(key) == value for key, value in constraints.items()]) ) else: msg = "Must provide either predicates or constraints" diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index b17d92e61..adbdbe061 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -64,3 +64,19 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None: "a_when": [3, None, None], } compare_dicts(result, expected) + + +def test_when_constraint(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.with_columns(when(a=1).then(value=3).alias("a_when")) + expected = { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": [4.1, 5.0, 6.0], + "d": [True, False, True], + "a_when": [3, None, None], + } + compare_dicts(result, expected) From 8688491cababc9ee29fc116c8d42b78dcde739de Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 29 Jul 2024 13:27:36 +0300 Subject: [PATCH 48/64] test: stabalise all paths and test error on no arg --- narwhals/_pandas_like/namespace.py | 4 ++-- narwhals/stable/v1.py | 10 +++++----- tests/expr_and_series/when_test.py | 9 +++++++++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index f14211ac8..fc1bf8626 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -276,8 +276,8 @@ def when( *flatten([plx.col(key) == value for key, value in constraints.items()]) ) else: - msg = "Must provide either predicates or constraints" - raise ValueError(msg) + msg = "at least one predicate or constraint must be provided" + raise TypeError(msg) return PandasWhen(condition, self._implementation, self._backend_version) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 62a64a427..778152c74 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -1478,14 +1478,14 @@ def from_when(cls, when: NwWhen) -> Self: return cls(*when._predicates, **when._constraints) def then(self, value: Any) -> Then: - return Then( - lambda plx: plx.when( - *self._extract_predicates(plx), **self._constraints - ).then(value) - ) + return Then.from_then(super().then(value)) class Then(NwThen, Expr): + @classmethod + def from_then(cls, then: NwThen) -> Self: + return cls(then._call) + def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index adbdbe061..e8aeb2d3c 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -80,3 +80,12 @@ def test_when_constraint(request: Any, constructor: Any) -> None: "a_when": [3, None, None], } compare_dicts(result, expected) + + +def test_no_arg_when_fail(request: Any, constructor: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + with pytest.raises(TypeError): + df.with_columns(when().then(value=3).alias("a_when")) From 81039bf59cc016d9c30af17f6a1be75cda6408e1 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 29 Jul 2024 13:32:46 +0300 Subject: [PATCH 49/64] misc: add when to main api --- docs/api-reference/narwhals.md | 1 + narwhals/__init__.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index 16bc6621c..42c1a2e44 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -27,6 +27,7 @@ Here are the top-level functions available in Narwhals. - narwhalify - sum - sum_horizontal + - when - show_versions - to_native show_source: false diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 3ad1468af..c31459447 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -33,6 +33,7 @@ from narwhals.expr import min from narwhals.expr import sum from narwhals.expr import sum_horizontal +from narwhals.expr import when from narwhals.functions import concat from narwhals.functions import from_dict from narwhals.functions import get_level @@ -73,6 +74,7 @@ "mean", "sum", "sum_horizontal", + "when", "DataFrame", "LazyFrame", "Series", From af83c505399d690083817da9223a9a8b9a0d7d53 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 29 Jul 2024 17:11:23 +0300 Subject: [PATCH 50/64] fix: fix when chain --- narwhals/_pandas_like/namespace.py | 67 +++++++++++++++++------------- narwhals/expr.py | 33 +++++++++------ tests/expr_and_series/when_test.py | 4 +- 3 files changed, 59 insertions(+), 45 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 219df3a3d..47342fcc6 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -268,18 +268,29 @@ def when( *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], **constraints: Any, ) -> PandasWhen: - plx = self.__class__(self._implementation, self._backend_version) - if predicates: - condition = plx.all_horizontal(*flatten(predicates)) - elif constraints: - condition = plx.all_horizontal( - *flatten([plx.col(key) == value for key, value in constraints.items()]) - ) - else: - msg = "at least one predicate or constraint must be provided" - raise TypeError(msg) + return PandasWhen( + when_processing(self, *predicates, **constraints), + self._implementation, + self._backend_version, + ) + + +def when_processing( + plx: PandasLikeNamespace, + *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], + **constraints: Any, +) -> PandasLikeExpr: + if predicates: + condition = plx.all_horizontal(*flatten(predicates)) + elif constraints: + condition = plx.all_horizontal( + *flatten([plx.col(key) == value for key, value in constraints.items()]) + ) + else: + msg = "at least one predicate or constraint must be provided" + raise TypeError(msg) - return PandasWhen(condition, self._implementation, self._backend_version) + return condition class PandasWhen: @@ -350,10 +361,18 @@ def __init__( self._root_names = root_names self._output_names = output_names - def when(self, condition: PandasLikeExpr) -> PandasChainedWhen: + def when( + self, + *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], + **constraints: Any, + ) -> PandasChainedWhen: return PandasChainedWhen( self._call, # type: ignore[arg-type] - condition, + when_processing( + PandasLikeNamespace(self._implementation, self._backend_version), + *predicates, + **constraints, + ), depth=self._depth + 1, implementation=self._implementation, backend_version=self._backend_version, @@ -456,25 +475,13 @@ def when( *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], **constraints: Any, ) -> PandasChainedWhen: - from narwhals._pandas_like.namespace import PandasLikeNamespace - - plx = PandasLikeNamespace( - implementation=self._implementation, backend_version=self._backend_version - ) - if predicates: - condition = plx.all_horizontal(*flatten(predicates)) - elif constraints: - import narwhals as nw - - condition = plx.all_horizontal( - *(nw.col(name) == value for name, value in constraints.items()) - ) - else: - msg = "Must provide at least one predicates or constraints" - raise ValueError(msg) return PandasChainedWhen( self._call, # type: ignore[arg-type] - condition, + when_processing( + PandasLikeNamespace(self._implementation, self._backend_version), + *predicates, + **constraints, + ), depth=self._depth + 1, implementation=self._implementation, backend_version=self._backend_version, diff --git a/narwhals/expr.py b/narwhals/expr.py index fd3ba41c8..919df74fd 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -3644,22 +3643,22 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) +def _extract_predicates(plx: Any, predicates: Iterable[IntoExpr]) -> Any: + return [extract_compliant(plx, v) for v in predicates] + + class When: def __init__( self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any ) -> None: - self._predicates = flatten([predicates]) + self._predicates = predicates self._constraints = constraints - self._then_value = None - self._otehrwise_value = None - - def _extract_predicates(self, plx: Any) -> Any: - return [extract_compliant(plx, v) for v in self._predicates] def then(self, value: Any) -> Then: return Then( lambda plx: plx.when( - *self._extract_predicates(plx), **self._constraints + *_extract_predicates(plx, flatten([self._predicates])), + **self._constraints, ).then(value) ) @@ -3676,18 +3675,26 @@ def when( *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any, ) -> ChainedWhen: - return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) + return ChainedWhen(self, *predicates, **constraints) class ChainedWhen: - def __init__(self, above_then: Then | ChainedThen, condition: Expr) -> None: + def __init__( + self, + above_then: Then | ChainedThen, + *predicates: IntoExpr | Iterable[IntoExpr], + **conditions: Any, + ) -> None: self._above_then = above_then - self._condition = condition + self._predicates = predicates + self._conditions = conditions def then(self, value: Any) -> ChainedThen: return ChainedThen( lambda plx: self._above_then._call(plx) - .when(self._condition._call(plx)) + .when( + *_extract_predicates(plx, flatten([self._predicates])), **self._conditions + ) .then(value) ) @@ -3701,7 +3708,7 @@ def when( *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any, ) -> ChainedWhen: - return ChainedWhen(self, reduce(lambda a, b: a & b, flatten([predicates]))) + return ChainedWhen(self, *predicates, **constraints) def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 6652aac2d..5dd3abf3c 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -4,8 +4,8 @@ import pytest -import narwhals.stable.v1 as nw -from narwhals.stable.v1 import when +import narwhals as nw +from narwhals import when from tests.utils import compare_dicts data = { From 1196fab51299a2aadedb45f5c629163ec7e23afb Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 30 Jul 2024 10:33:56 +0300 Subject: [PATCH 51/64] misc: remove constraints --- narwhals/_pandas_like/namespace.py | 7 +------ narwhals/expr.py | 17 ++++------------- narwhals/stable/v1.py | 8 +++----- tests/expr_and_series/when_test.py | 16 ---------------- 4 files changed, 8 insertions(+), 40 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index fc1bf8626..4f9c9fa3e 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -266,17 +266,12 @@ def concat( def when( self, *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], - **constraints: Any, ) -> PandasWhen: plx = self.__class__(self._implementation, self._backend_version) if predicates: condition = plx.all_horizontal(*flatten(predicates)) - elif constraints: - condition = plx.all_horizontal( - *flatten([plx.col(key) == value for key, value in constraints.items()]) - ) else: - msg = "at least one predicate or constraint must be provided" + msg = "at least one predicate needs to be provided" raise TypeError(msg) return PandasWhen(condition, self._implementation, self._backend_version) diff --git a/narwhals/expr.py b/narwhals/expr.py index d99c91068..fa8fa2101 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3644,21 +3644,14 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: - def __init__( - self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any - ) -> None: + def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicates = flatten([predicates]) - self._constraints = constraints def _extract_predicates(self, plx: Any) -> Any: return [extract_compliant(plx, v) for v in self._predicates] def then(self, value: Any) -> Then: - return Then( - lambda plx: plx.when( - *self._extract_predicates(plx), **self._constraints - ).then(value) - ) + return Then(lambda plx: plx.when(*self._extract_predicates(plx)).then(value)) class Then(Expr): @@ -3669,7 +3662,7 @@ def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: +def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: """ Start a `when-then-otherwise` expression. Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. @@ -3679,8 +3672,6 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When Parameters: predicates Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name. - constraints - Apply conditions as `col_name = value` keyword arguments that are treated as equality matches, such as `x = 123`. As with the predicates parameter, multiple conditions are implicitly combined using `&`. Examples: >>> import pandas as pd @@ -3718,7 +3709,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When │ 3 ┆ 15 ┆ 6 │ └─────┴─────┴────────┘ """ - return When(*predicates, **constraints) + return When(*predicates) def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 778152c74..a0ae5e01f 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -1475,7 +1475,7 @@ def get_level( class When(NwWhen): @classmethod def from_when(cls, when: NwWhen) -> Self: - return cls(*when._predicates, **when._constraints) + return cls(*when._predicates) def then(self, value: Any) -> Then: return Then.from_then(super().then(value)) @@ -1490,7 +1490,7 @@ def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: +def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: """ Start a `when-then-otherwise` expression. Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. @@ -1500,8 +1500,6 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When Parameters: predicates Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name. - constraints - Apply conditions as `col_name = value` keyword arguments that are treated as equality matches, such as `x = 123`. As with the predicates parameter, multiple conditions are implicitly combined using `&`. Examples: >>> import pandas as pd @@ -1539,7 +1537,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When │ 3 ┆ 15 ┆ 6 │ └─────┴─────┴────────┘ """ - return When.from_when(nw_when(*predicates, **constraints)) + return When.from_when(nw_when(*predicates)) def from_dict( diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index e8aeb2d3c..8d8d554b4 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -66,22 +66,6 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -def test_when_constraint(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) - - df = nw.from_native(constructor(data)) - result = df.with_columns(when(a=1).then(value=3).alias("a_when")) - expected = { - "a": [1, 2, 3], - "b": ["a", "b", "c"], - "c": [4.1, 5.0, 6.0], - "d": [True, False, True], - "a_when": [3, None, None], - } - compare_dicts(result, expected) - - def test_no_arg_when_fail(request: Any, constructor: Any) -> None: if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) From 899110335b7fdfa10e062e117b2a74eb7700fb12 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 30 Jul 2024 10:46:14 +0300 Subject: [PATCH 52/64] docs: remove wrong import --- narwhals/stable/v1.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index a0ae5e01f..572760478 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -1512,10 +1512,8 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: >>> @nw.narwhalify ... def func(df_any): - ... from narwhals.expr import when - ... ... return df_any.with_columns( - ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") ... ) We can then pass either pandas or polars to `func`: From beba175fd3a371f9bcfbcc73a060b9ffaaa8a1d2 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 30 Jul 2024 10:47:07 +0300 Subject: [PATCH 53/64] docs: remove wrong import in stable --- narwhals/stable/v1.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index a0ae5e01f..572760478 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -1512,10 +1512,8 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: >>> @nw.narwhalify ... def func(df_any): - ... from narwhals.expr import when - ... ... return df_any.with_columns( - ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") ... ) We can then pass either pandas or polars to `func`: From 45684d4d76b322816f3cc345dd1306284a390918 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 30 Jul 2024 10:52:37 +0300 Subject: [PATCH 54/64] docs: remove wrong import in main docstring --- narwhals/expr.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index fa8fa2101..22473a86b 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3684,10 +3684,8 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: >>> @nw.narwhalify ... def func(df_any): - ... from narwhals.expr import when - ... ... return df_any.with_columns( - ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") ... ) We can then pass either pandas or polars to `func`: From 10245c4552a07e235d74a9dcb95080365e18d319 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 30 Jul 2024 11:47:01 +0300 Subject: [PATCH 55/64] misc: make when the chaining stable --- narwhals/stable/v1.py | 72 +++++++++++++++++++++++++++--- tests/expr_and_series/when_test.py | 17 +++---- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 572760478..aac1ba6b3 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -33,6 +33,8 @@ from narwhals.dtypes import UInt32 from narwhals.dtypes import UInt64 from narwhals.dtypes import Unknown +from narwhals.expr import ChainedThen as NwChainedThen +from narwhals.expr import ChainedWhen as NwChainedWhen from narwhals.expr import Expr as NwExpr from narwhals.expr import Then as NwThen from narwhals.expr import When as NwWhen @@ -479,12 +481,38 @@ def _stableify(obj: NwSeries) -> Series: ... @overload def _stableify(obj: NwExpr) -> Expr: ... @overload +def _stableify(when_then: NwWhen) -> When: ... +@overload +def _stableify(when_then: NwChainedWhen) -> ChainedWhen: ... +@overload +def _stableify(when_then: NwThen) -> Then: ... +@overload +def _stableify(when_then: NwChainedThen) -> ChainedThen: ... +@overload def _stableify(obj: Any) -> Any: ... def _stableify( - obj: NwDataFrame[IntoFrameT] | NwLazyFrame[IntoFrameT] | NwSeries | NwExpr | Any, -) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series | Expr | Any: + obj: NwDataFrame[IntoFrameT] + | NwLazyFrame[IntoFrameT] + | NwSeries + | NwExpr + | NwWhen + | NwChainedWhen + | NwThen + | NwChainedThen + | Any, +) -> ( + DataFrame[IntoFrameT] + | LazyFrame[IntoFrameT] + | Series + | Expr + | When + | ChainedWhen + | Then + | ChainedThen + | Any +): if isinstance(obj, NwDataFrame): return DataFrame( obj._compliant_frame, @@ -500,6 +528,14 @@ def _stableify( obj._compliant_series, level=obj._level, ) + elif isinstance(obj, NwChainedWhen): + return ChainedWhen.from_base(obj) + if isinstance(obj, NwWhen): + return When.from_base(obj) + elif isinstance(obj, NwChainedThen): + return ChainedThen.from_base(obj) + elif isinstance(obj, NwThen): + return Then.from_base(obj) if isinstance(obj, NwExpr): return Expr(obj._call) return obj @@ -1474,18 +1510,42 @@ def get_level( class When(NwWhen): @classmethod - def from_when(cls, when: NwWhen) -> Self: + def from_base(cls, when: NwWhen) -> Self: return cls(*when._predicates) def then(self, value: Any) -> Then: - return Then.from_then(super().then(value)) + return _stableify(super().then(value)) class Then(NwThen, Expr): @classmethod - def from_then(cls, then: NwThen) -> Self: + def from_base(cls, then: NwThen) -> Self: return cls(then._call) + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return _stableify(super().when(*predicates)) + + def otherwise(self, value: Any) -> Expr: + return _stableify(super().otherwise(value)) + + +class ChainedWhen(NwChainedWhen): + @classmethod + def from_base(cls, chained_when: NwChainedWhen) -> Self: + return cls(_stableify(chained_when._above_then), *chained_when._predicates) + + def then(self, value: Any) -> ChainedThen: + return _stableify(super().then(value)) + + +class ChainedThen(NwChainedThen, Expr): + @classmethod + def from_base(cls, chained_then: NwChainedThen) -> Self: + return cls(chained_then._call) + + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return _stableify(super().when(*predicates)) + def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) @@ -1535,7 +1595,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: │ 3 ┆ 15 ┆ 6 │ └─────┴─────┴────────┘ """ - return When.from_when(nw_when(*predicates)) + return _stableify(nw_when(*predicates)) def from_dict( diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 1ecbda210..a18095743 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -4,8 +4,7 @@ import pytest -import narwhals as nw -from narwhals import when +import narwhals.stable.v1 as nw from tests.utils import compare_dicts data = { @@ -21,7 +20,7 @@ def test_when(request: Any, constructor: Any) -> None: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when")) + result = df.with_columns(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { "a": [1, 2, 3, 4, 5], "b": ["a", "b", "c", "d", "e"], @@ -37,7 +36,9 @@ def test_when_otherwise(request: Any, constructor: Any) -> None: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) + result = df.with_columns( + nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when") + ) expected = { "a": [1, 2, 3, 4, 5], "b": ["a", "b", "c", "d", "e"], @@ -54,7 +55,7 @@ def test_chained_when(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns( - when(nw.col("a") == 1) + nw.when(nw.col("a") == 1) .then(3) .when(nw.col("a") == 2) .then(5) @@ -76,7 +77,7 @@ def test_when_with_multiple_conditions(request: Any, constructor: Any) -> None: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.with_columns( - when(nw.col("a") == 1) + nw.when(nw.col("a") == 1) .then(3) .when(nw.col("a") == 2) .then(5) @@ -101,7 +102,7 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns( - when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") + nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") ) expected = { "a": [1, 2, 3, 4, 5], @@ -119,4 +120,4 @@ def test_no_arg_when_fail(request: Any, constructor: Any) -> None: df = nw.from_native(constructor(data)) with pytest.raises(TypeError): - df.with_columns(when().then(value=3).alias("a_when")) + df.with_columns(nw.when().then(value=3).alias("a_when")) From 09bea00cbc5f70ca6a80b5a20c990677baf2e85c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:02:41 +0300 Subject: [PATCH 56/64] feat: add when then chaining back --- narwhals/_pandas_like/namespace.py | 260 +++++++++++++++++------------ narwhals/expr.py | 60 +++---- tests/expr_and_series/when_test.py | 127 ++++++++------ 3 files changed, 269 insertions(+), 178 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 753b49f69..dae72573c 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -291,6 +291,8 @@ def __init__( self._then_value = then_value self._otherwise_value = otherwise_value + self._already_set = self._condition + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._expression_parsing import parse_into_expr from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -372,7 +374,7 @@ def __init__( self._root_names = root_names self._output_names = output_names - def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr: + def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: # type ignore because we are setting the `_call` attribute to a # callable object of type `PandasWhen`, base class has the attribute as # only a `Callable` @@ -380,106 +382,158 @@ def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLik self._function_name = "whenotherwise" return self + def when(self, *predicates: IntoPandasLikeExpr) -> PandasChainedWhen: + plx = PandasLikeNamespace(self._implementation, self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + return PandasChainedWhen( + self, condition, self._depth + 1, self._implementation, self._backend_version + ) + + +class PandasChainedWhen: + def __init__( + self, + above_then: PandasThen | PandasChainedThen, + condition: PandasLikeExpr, + depth: int, + implementation: Implementation, + backend_version: tuple[int, ...], + then_value: Any = None, + otherise_value: Any = None, + ) -> None: + self._implementation = implementation + self._depth = depth + self._backend_version = backend_version + self._condition = condition + self._above_then = above_then + self._then_value = then_value + self._otherwise_value = otherise_value + + # TODO @aivanoved: this is way slow as during computation time this takes + # quadratic time need to improve this to linear time + self._above_already_set = self._above_then._call._already_set # type: ignore[attr-defined] + self._already_set = self._above_already_set | self._condition -# class PandasChainedWhen: -# def __init__( -# self, -# above_when: PandasWhen | PandasChainedWhen, -# condition: PandasLikeExpr, -# depth: int, -# implementation: Implementation, -# backend_version: tuple[int, ...], -# then_value: Any = None, -# otherise_value: Any = None, -# ) -> None: -# self._implementation = implementation -# self._depth = depth -# self._backend_version = backend_version -# self._condition = condition -# self._above_when = above_when -# self._then_value = then_value -# self._otherwise_value = otherise_value -# -# # TODO @aivanoved: this is way slow as during computation time this takes -# # quadratic time need to improve this to linear time -# self._condition = self._condition & (~self._above_when._already_set) # type: ignore[has-type] -# self._already_set = self._above_when._already_set | self._condition # type: ignore[has-type] -# -# def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: -# from narwhals._pandas_like.namespace import PandasLikeNamespace -# -# plx = PandasLikeNamespace( -# implementation=self._implementation, backend_version=self._backend_version -# ) -# -# set_then = self._condition._call(df)[0] -# already_set = self._already_set._call(df)[0] -# -# value_series = plx._create_broadcast_series_from_scalar( -# self._then_value, set_then -# ) -# otherwise_series = plx._create_broadcast_series_from_scalar( -# self._otherwise_value, set_then -# ) -# -# above_result = self._above_when(df)[0] -# -# result = value_series.zip_with(set_then, above_result).zip_with( -# already_set, otherwise_series -# ) -# -# return [result] -# -# def then(self, value: Any) -> PandasChainedThen: -# self._then_value = value -# return PandasChainedThen( -# self, -# depth=self._depth, -# implementation=self._implementation, -# function_name="chainedwhen", -# root_names=None, -# output_names=None, -# backend_version=self._backend_version, -# ) -# -# -# class PandasChainedThen(PandasLikeExpr): -# def __init__( -# self, -# call: PandasChainedWhen, -# *, -# depth: int, -# function_name: str, -# root_names: list[str] | None, -# output_names: list[str] | None, -# implementation: Implementation, -# backend_version: tuple[int, ...], -# ) -> None: -# self._implementation = implementation -# self._backend_version = backend_version -# -# self._call = call -# self._depth = depth -# self._function_name = function_name -# self._root_names = root_names -# self._output_names = output_names -# -# def when( -# self, -# *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], -# ) -> PandasChainedWhen: -# return PandasChainedWhen( -# self._call, # type: ignore[arg-type] -# when_processing( -# PandasLikeNamespace(self._implementation, self._backend_version), -# *predicates, -# ), -# depth=self._depth + 1, -# implementation=self._implementation, -# backend_version=self._backend_version, -# ) -# -# def otherwise(self, value: Any) -> PandasLikeExpr: -# self._call._otherwise_value = value # type: ignore[attr-defined] -# self._function_name = "chainedwhenotherwise" -# return self + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._expression_parsing import parse_into_expr + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) + + condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] + try: + value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + value_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._then_value] * len(condition), + name="literal", + index=condition._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + value_series = cast(PandasLikeSeries, value_series) + + set_then = condition + set_then_native = set_then._native_series + above_already_set = parse_into_expr(self._above_already_set, namespace=plx)._call( + df # type: ignore[arg-type] + )[0] + + value_series_native = value_series._native_series + + above_result = self._above_then._call(df)[0] + above_result_native = above_result._native_series + set_then_native = set_then._native_series + above_already_set_native = above_already_set._native_series + if self._otherwise_value is None: + return [ + above_result._from_native_series( + value_series_native.where( + ~above_already_set_native & set_then_native, above_result_native + ) + ) + ] + + try: + otherwise_series = parse_into_expr( + self._otherwise_value, namespace=plx + )._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + otherwise_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._otherwise_value] * len(condition), + name="literal", + index=condition._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + otherwise_series = cast(PandasLikeSeries, otherwise_series) + return [ + above_result.zip_with( + above_already_set, value_series.zip_with(set_then, otherwise_series) + ) + ] + + def then(self, value: Any) -> PandasChainedThen: + self._then_value = value + return PandasChainedThen( + self, + depth=self._depth, + implementation=self._implementation, + function_name="chainedwhen", + root_names=None, + output_names=None, + backend_version=self._backend_version, + ) + + +class PandasChainedThen(PandasLikeExpr): + def __init__( + self, + call: PandasChainedWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + implementation: Implementation, + backend_version: tuple[int, ...], + ) -> None: + self._implementation = implementation + self._backend_version = backend_version + + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + + def when( + self, + *predicates: IntoPandasLikeExpr, + ) -> PandasChainedWhen: + plx = PandasLikeNamespace(self._implementation, self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + return PandasChainedWhen( + self, + condition, + depth=self._depth + 1, + implementation=self._implementation, + backend_version=self._backend_version, + ) + + def otherwise(self, value: Any) -> PandasChainedThen: + self._call._otherwise_value = value # type: ignore[attr-defined] + self._function_name = "chainedwhenotherwise" + return self diff --git a/narwhals/expr.py b/narwhals/expr.py index 5c5ff7d2e..a8407915a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3995,16 +3995,17 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) +def _extract_predicates(plx: Any, predicates: IntoExpr | Iterable[IntoExpr]) -> Any: + return [extract_compliant(plx, v) for v in flatten([predicates])] + + class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicates = flatten([predicates]) - def _extract_predicates(self, plx: Any) -> Any: - return [extract_compliant(plx, v) for v in self._predicates] - def then(self, value: Any) -> Then: return Then( - lambda plx: plx.when(*self._extract_predicates(plx)).then( + lambda plx: plx.when(*_extract_predicates(plx, self._predicates)).then( extract_compliant(plx, value) ) ) @@ -4017,36 +4018,39 @@ def __init__(self, call: Callable[[Any], Any]) -> None: def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value))) + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return ChainedWhen(self, *predicates) -# class ChainedWhen: -# def __init__( -# self, -# above_then: Then | ChainedThen, -# *predicates: IntoExpr | Iterable[IntoExpr], -# ) -> None: -# self._above_then = above_then -# self._predicates = predicates -# def then(self, value: Any) -> ChainedThen: -# return ChainedThen( -# lambda plx: self._above_then._call(plx) -# .when(*_extract_predicates(plx, flatten([self._predicates]))) -# .then(value) -# ) +class ChainedWhen: + def __init__( + self, + above_then: Then | ChainedThen, + *predicates: IntoExpr | Iterable[IntoExpr], + ) -> None: + self._above_then = above_then + self._predicates = flatten([predicates]) + def then(self, value: Any) -> ChainedThen: + return ChainedThen( + lambda plx: self._above_then._call(plx) + .when(*_extract_predicates(plx, self._predicates)) + .then(value) + ) -# class ChainedThen(Expr): -# def __init__(self, call: Callable[[Any], Any]) -> None: -# self._call = call -# def when( -# self, -# *predicates: IntoExpr | Iterable[IntoExpr], -# ) -> ChainedWhen: -# return ChainedWhen(self, *predicates) +class ChainedThen(Expr): + def __init__(self, call: Callable[[Any], Any]) -> None: + self._call = call -# def otherwise(self, value: Any) -> Expr: -# return Expr(lambda plx: self._call(plx).otherwise(value)) + def when( + self, + *predicates: IntoExpr | Iterable[IntoExpr], + ) -> ChainedWhen: + return ChainedWhen(self, *predicates) + + def otherwise(self, value: Any) -> Expr: + return Expr(lambda plx: self._call(plx).otherwise(value)) def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 50d69f5f5..0e3049440 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -16,6 +16,12 @@ "e": [7.0, 2.0, 1.1], } +large_data = { + "a": [1, 2, 3, 4, 5, 6], + "b": ["a", "b", "c", "d", "e", "f"], + "c": [True, False, True, False, True, False], +} + def test_when(constructor: Any) -> None: df = nw.from_native(constructor(data)) @@ -136,53 +142,80 @@ def test_when_then_otherwise_into_expr(request: Any, constructor: Any) -> None: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) - expected = {"c": [7, 5, 6]} + result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e").alias("a_when")) + expected = {"a_when": [7, 5, 6]} + compare_dicts(result, expected) + + +def test_chained_when(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + nw.when(nw.col("a") == 1).then(3).when(nw.col("a") == 2).then(5).alias("a_when"), + ) + expected = { + "a_when": [3, 5, np.nan], + } + compare_dicts(result, expected) + + +def test_chained_when_otherewise(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .otherwise(7) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7], + } compare_dicts(result, expected) -# def test_chained_when(request: Any, constructor: Any) -> None: -# if "pyarrow_table" in str(constructor): -# request.applymarker(pytest.mark.xfail) - -# df = nw.from_native(constructor(data)) -# result = df.with_columns( -# nw.when(nw.col("a") == 1) -# .then(3) -# .when(nw.col("a") == 2) -# .then(5) -# .otherwise(7) -# .alias("a_when"), -# ) -# expected = { -# "a": [1, 2, 3, 4, 5], -# "b": ["a", "b", "c", "d", "e"], -# "c": [4.1, 5.0, 6.0, 7.0, 8.0], -# "d": [True, False, True, False, True], -# "a_when": [3, 5, 7, 7, 7], -# } -# compare_dicts(result, expected) - - -# def test_when_with_multiple_conditions(request: Any, constructor: Any) -> None: -# if "pyarrow_table" in str(constructor): -# request.applymarker(pytest.mark.xfail) -# df = nw.from_native(constructor(data)) -# result = df.with_columns( -# nw.when(nw.col("a") == 1) -# .then(3) -# .when(nw.col("a") == 2) -# .then(5) -# .when(nw.col("a") == 3) -# .then(7) -# .otherwise(9) -# .alias("a_when"), -# ) -# expected = { -# "a": [1, 2, 3, 4, 5], -# "b": ["a", "b", "c", "d", "e"], -# "c": [4.1, 5.0, 6.0, 7.0, 8.0], -# "d": [True, False, True, False, True], -# "a_when": [3, 5, 7, 9, 9], -# } -# compare_dicts(result, expected) +def test_multi_chained_when(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(large_data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7, np.nan, np.nan, np.nan], + } + compare_dicts(result, expected) + + +def test_multi_chained_when_otherewise(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(large_data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .otherwise(9) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7, 9, 9, 9], + } + compare_dicts(result, expected) From 2c37729220750b149c17f4fce8083831f4a41d4a Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:05:02 +0300 Subject: [PATCH 57/64] misc: fix typo --- tests/expr_and_series/when_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 0e3049440..5b60edfa9 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -200,7 +200,7 @@ def test_multi_chained_when(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -def test_multi_chained_when_otherewise(request: Any, constructor: Any) -> None: +def test_multi_chained_when_otherwise(request: Any, constructor: Any) -> None: if "dask" in str(constructor) or "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) From 64cde0b65d8a55be03e7ef1f23530f4e87e52c0c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:06:55 +0300 Subject: [PATCH 58/64] misc: remove unnecessary file --- docs/api-completeness.md | 221 --------------------------------------- 1 file changed, 221 deletions(-) delete mode 100644 docs/api-completeness.md diff --git a/docs/api-completeness.md b/docs/api-completeness.md deleted file mode 100644 index f24ccc4b9..000000000 --- a/docs/api-completeness.md +++ /dev/null @@ -1,221 +0,0 @@ -# API Completeness - -Narwhals has two different level of support for libraries: "full" and "interchange". - -Libraries for which we have full support we intend to support the whole Narwhals API, however this is a work in progress. - -In the following table it is possible to check which method is implemented for which backend. - -!!! info - - - "pandas-like" means pandas, cuDF and Modin - - Polars supports all the methods (by design) - -| Class | Method | pandas-like | arrow | -|-------------------------|--------------------|--------------------|--------------------| -| DataFrame | clone | :white_check_mark: | :white_check_mark: | -| DataFrame | collect_schema | :white_check_mark: | :white_check_mark: | -| DataFrame | columns | :white_check_mark: | :white_check_mark: | -| DataFrame | drop | :white_check_mark: | :white_check_mark: | -| DataFrame | drop_nulls | :white_check_mark: | :white_check_mark: | -| DataFrame | filter | :white_check_mark: | :white_check_mark: | -| DataFrame | gather_every | :white_check_mark: | :white_check_mark: | -| DataFrame | get_column | :white_check_mark: | :white_check_mark: | -| DataFrame | group_by | :white_check_mark: | :white_check_mark: | -| DataFrame | head | :white_check_mark: | :white_check_mark: | -| DataFrame | is_duplicated | :white_check_mark: | :white_check_mark: | -| DataFrame | is_empty | :white_check_mark: | :white_check_mark: | -| DataFrame | is_unique | :white_check_mark: | :white_check_mark: | -| DataFrame | item | :white_check_mark: | :white_check_mark: | -| DataFrame | iter_rows | :white_check_mark: | :white_check_mark: | -| DataFrame | join | :white_check_mark: | :white_check_mark: | -| DataFrame | lazy | :white_check_mark: | :white_check_mark: | -| DataFrame | null_count | :white_check_mark: | :white_check_mark: | -| DataFrame | pipe | :x: | :x: | -| DataFrame | rename | :white_check_mark: | :white_check_mark: | -| DataFrame | rows | :white_check_mark: | :white_check_mark: | -| DataFrame | schema | :white_check_mark: | :white_check_mark: | -| DataFrame | select | :white_check_mark: | :white_check_mark: | -| DataFrame | shape | :white_check_mark: | :white_check_mark: | -| DataFrame | sort | :white_check_mark: | :white_check_mark: | -| DataFrame | tail | :white_check_mark: | :white_check_mark: | -| DataFrame | to_dict | :white_check_mark: | :white_check_mark: | -| DataFrame | to_numpy | :white_check_mark: | :white_check_mark: | -| DataFrame | to_pandas | :white_check_mark: | :white_check_mark: | -| DataFrame | unique | :white_check_mark: | :white_check_mark: | -| DataFrame | with_columns | :white_check_mark: | :white_check_mark: | -| DataFrame | with_row_index | :white_check_mark: | :white_check_mark: | -| DataFrame | write_parquet | :white_check_mark: | :white_check_mark: | -| Expr | abs | :white_check_mark: | :white_check_mark: | -| Expr | alias | :white_check_mark: | :white_check_mark: | -| Expr | all | :white_check_mark: | :white_check_mark: | -| Expr | any | :white_check_mark: | :white_check_mark: | -| Expr | arg_true | :white_check_mark: | :white_check_mark: | -| Expr | cast | :white_check_mark: | :white_check_mark: | -| Expr | cat | :white_check_mark: | :white_check_mark: | -| Expr | clip | :white_check_mark: | :x: | -| Expr | count | :white_check_mark: | :white_check_mark: | -| Expr | cum_sum | :white_check_mark: | :white_check_mark: | -| Expr | diff | :white_check_mark: | :white_check_mark: | -| Expr | drop_nulls | :white_check_mark: | :white_check_mark: | -| Expr | dt | :white_check_mark: | :white_check_mark: | -| Expr | fill_null | :white_check_mark: | :white_check_mark: | -| Expr | filter | :white_check_mark: | :white_check_mark: | -| Expr | gather_every | :white_check_mark: | :white_check_mark: | -| Expr | head | :white_check_mark: | :white_check_mark: | -| Expr | is_between | :white_check_mark: | :white_check_mark: | -| Expr | is_duplicated | :white_check_mark: | :white_check_mark: | -| Expr | is_first_distinct | :white_check_mark: | :white_check_mark: | -| Expr | is_in | :white_check_mark: | :white_check_mark: | -| Expr | is_last_distinct | :white_check_mark: | :white_check_mark: | -| Expr | is_null | :white_check_mark: | :white_check_mark: | -| Expr | is_unique | :white_check_mark: | :white_check_mark: | -| Expr | len | :white_check_mark: | :white_check_mark: | -| Expr | max | :white_check_mark: | :white_check_mark: | -| Expr | mean | :white_check_mark: | :white_check_mark: | -| Expr | min | :white_check_mark: | :white_check_mark: | -| Expr | n_unique | :white_check_mark: | :white_check_mark: | -| Expr | name | :white_check_mark: | :white_check_mark: | -| Expr | null_count | :white_check_mark: | :white_check_mark: | -| Expr | over | :white_check_mark: | :x: | -| Expr | quantile | :white_check_mark: | :white_check_mark: | -| Expr | round | :white_check_mark: | :x: | -| Expr | sample | :white_check_mark: | :white_check_mark: | -| Expr | shift | :white_check_mark: | :x: | -| Expr | sort | :white_check_mark: | :white_check_mark: | -| Expr | std | :white_check_mark: | :white_check_mark: | -| Expr | str | :white_check_mark: | :white_check_mark: | -| Expr | sum | :white_check_mark: | :white_check_mark: | -| Expr | tail | :white_check_mark: | :white_check_mark: | -| Expr | unique | :white_check_mark: | :white_check_mark: | -| ExprCatNamespace | get_categories | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | day | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | hour | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | microsecond | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | millisecond | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | minute | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | month | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | nanosecond | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | ordinal_day | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | second | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | to_string | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | total_microseconds | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | total_milliseconds | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | total_minutes | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | total_nanoseconds | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | total_seconds | :white_check_mark: | :white_check_mark: | -| ExprDateTimeNamespace | year | :white_check_mark: | :white_check_mark: | -| ExprNameNamespace | keep | :white_check_mark: | :white_check_mark: | -| ExprNameNamespace | map | :white_check_mark: | :white_check_mark: | -| ExprNameNamespace | prefix | :white_check_mark: | :white_check_mark: | -| ExprNameNamespace | suffix | :white_check_mark: | :white_check_mark: | -| ExprNameNamespace | to_lowercase | :white_check_mark: | :white_check_mark: | -| ExprNameNamespace | to_uppercase | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | contains | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | ends_with | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | head | :x: | :x: | -| ExprStringNamespace | slice | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | starts_with | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | tail | :x: | :x: | -| ExprStringNamespace | to_datetime | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | to_lowercase | :white_check_mark: | :white_check_mark: | -| ExprStringNamespace | to_uppercase | :white_check_mark: | :white_check_mark: | -| LazyFrame | clone | :white_check_mark: | :white_check_mark: | -| LazyFrame | collect | :white_check_mark: | :white_check_mark: | -| LazyFrame | collect_schema | :white_check_mark: | :white_check_mark: | -| LazyFrame | columns | :white_check_mark: | :white_check_mark: | -| LazyFrame | drop | :white_check_mark: | :white_check_mark: | -| LazyFrame | drop_nulls | :white_check_mark: | :white_check_mark: | -| LazyFrame | filter | :white_check_mark: | :white_check_mark: | -| LazyFrame | gather_every | :white_check_mark: | :white_check_mark: | -| LazyFrame | group_by | :white_check_mark: | :white_check_mark: | -| LazyFrame | head | :white_check_mark: | :white_check_mark: | -| LazyFrame | join | :white_check_mark: | :white_check_mark: | -| LazyFrame | lazy | :white_check_mark: | :white_check_mark: | -| LazyFrame | pipe | :x: | :x: | -| LazyFrame | rename | :white_check_mark: | :white_check_mark: | -| LazyFrame | schema | :white_check_mark: | :white_check_mark: | -| LazyFrame | select | :white_check_mark: | :white_check_mark: | -| LazyFrame | sort | :white_check_mark: | :white_check_mark: | -| LazyFrame | tail | :white_check_mark: | :white_check_mark: | -| LazyFrame | unique | :white_check_mark: | :white_check_mark: | -| LazyFrame | with_columns | :white_check_mark: | :white_check_mark: | -| LazyFrame | with_row_index | :white_check_mark: | :white_check_mark: | -| Series | abs | :white_check_mark: | :white_check_mark: | -| Series | alias | :white_check_mark: | :white_check_mark: | -| Series | all | :white_check_mark: | :white_check_mark: | -| Series | any | :white_check_mark: | :white_check_mark: | -| Series | arg_true | :white_check_mark: | :white_check_mark: | -| Series | cast | :white_check_mark: | :white_check_mark: | -| Series | cat | :white_check_mark: | :white_check_mark: | -| Series | clip | :white_check_mark: | :x: | -| Series | count | :white_check_mark: | :white_check_mark: | -| Series | cum_sum | :white_check_mark: | :white_check_mark: | -| Series | diff | :white_check_mark: | :white_check_mark: | -| Series | drop_nulls | :white_check_mark: | :white_check_mark: | -| Series | dt | :white_check_mark: | :white_check_mark: | -| Series | dtype | :white_check_mark: | :white_check_mark: | -| Series | fill_null | :white_check_mark: | :white_check_mark: | -| Series | filter | :white_check_mark: | :white_check_mark: | -| Series | gather_every | :white_check_mark: | :white_check_mark: | -| Series | head | :white_check_mark: | :white_check_mark: | -| Series | is_between | :white_check_mark: | :white_check_mark: | -| Series | is_duplicated | :white_check_mark: | :white_check_mark: | -| Series | is_empty | :white_check_mark: | :white_check_mark: | -| Series | is_first_distinct | :white_check_mark: | :white_check_mark: | -| Series | is_in | :white_check_mark: | :white_check_mark: | -| Series | is_last_distinct | :white_check_mark: | :white_check_mark: | -| Series | is_null | :white_check_mark: | :white_check_mark: | -| Series | is_sorted | :white_check_mark: | :white_check_mark: | -| Series | is_unique | :white_check_mark: | :white_check_mark: | -| Series | item | :white_check_mark: | :white_check_mark: | -| Series | len | :white_check_mark: | :white_check_mark: | -| Series | max | :white_check_mark: | :white_check_mark: | -| Series | mean | :white_check_mark: | :white_check_mark: | -| Series | min | :white_check_mark: | :white_check_mark: | -| Series | n_unique | :white_check_mark: | :white_check_mark: | -| Series | name | :white_check_mark: | :white_check_mark: | -| Series | null_count | :white_check_mark: | :white_check_mark: | -| Series | quantile | :white_check_mark: | :white_check_mark: | -| Series | round | :white_check_mark: | :x: | -| Series | sample | :white_check_mark: | :white_check_mark: | -| Series | shape | :white_check_mark: | :white_check_mark: | -| Series | shift | :white_check_mark: | :x: | -| Series | sort | :white_check_mark: | :white_check_mark: | -| Series | std | :white_check_mark: | :white_check_mark: | -| Series | str | :white_check_mark: | :white_check_mark: | -| Series | sum | :white_check_mark: | :white_check_mark: | -| Series | tail | :white_check_mark: | :white_check_mark: | -| Series | to_dummies | :white_check_mark: | :white_check_mark: | -| Series | to_frame | :white_check_mark: | :white_check_mark: | -| Series | to_list | :white_check_mark: | :white_check_mark: | -| Series | to_numpy | :white_check_mark: | :white_check_mark: | -| Series | to_pandas | :white_check_mark: | :white_check_mark: | -| Series | unique | :white_check_mark: | :white_check_mark: | -| Series | value_counts | :white_check_mark: | :white_check_mark: | -| Series | zip_with | :white_check_mark: | :white_check_mark: | -| SeriesCatNamespace | get_categories | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | day | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | hour | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | microsecond | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | millisecond | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | minute | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | month | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | nanosecond | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | ordinal_day | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | second | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | to_string | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | total_microseconds | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | total_milliseconds | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | total_minutes | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | total_nanoseconds | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | total_seconds | :white_check_mark: | :white_check_mark: | -| SeriesDateTimeNamespace | year | :white_check_mark: | :white_check_mark: | -| SeriesStringNamespace | contains | :white_check_mark: | :white_check_mark: | -| SeriesStringNamespace | ends_with | :white_check_mark: | :white_check_mark: | -| SeriesStringNamespace | head | :x: | :x: | -| SeriesStringNamespace | slice | :white_check_mark: | :white_check_mark: | -| SeriesStringNamespace | starts_with | :white_check_mark: | :white_check_mark: | -| SeriesStringNamespace | tail | :x: | :x: | -| SeriesStringNamespace | to_lowercase | :white_check_mark: | :white_check_mark: | -| SeriesStringNamespace | to_uppercase | :white_check_mark: | :white_check_mark: | \ No newline at end of file From 53de101bda6ef3ee7de136a28072f05ed2065c08 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:08:34 +0300 Subject: [PATCH 59/64] misc: remove unused function --- narwhals/_pandas_like/namespace.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index dae72573c..ab2ee8037 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -86,17 +86,6 @@ def _create_series_from_scalar( backend_version=self._backend_version, ) - def _create_broadcast_series_from_scalar( - self, value: Any, series: PandasLikeSeries - ) -> PandasLikeSeries: - return PandasLikeSeries._from_iterable( - [value] * len(series._native_series), - name=series._native_series.name, - index=series._native_series.index, - implementation=self._implementation, - backend_version=self._backend_version, - ) - def _create_expr_from_series(self, series: PandasLikeSeries) -> PandasLikeExpr: return PandasLikeExpr( lambda _df: [series], From 05d2c5f54490fd93be75cd7485953505ceecd7f7 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:30:04 +0300 Subject: [PATCH 60/64] misc: add some stability --- narwhals/stable/v1.py | 71 ++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 0720980c1..e7dbd1043 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -33,9 +33,8 @@ from narwhals.dtypes import UInt32 from narwhals.dtypes import UInt64 from narwhals.dtypes import Unknown - -# from narwhals.expr import ChainedThen as NwChainedThen -# from narwhals.expr import ChainedWhen as NwChainedWhen +from narwhals.expr import ChainedThen as NwChainedThen +from narwhals.expr import ChainedWhen as NwChainedWhen from narwhals.expr import Expr as NwExpr from narwhals.expr import Then as NwThen from narwhals.expr import When as NwWhen @@ -490,13 +489,9 @@ def _stableify(obj: NwSeries) -> Series: ... @overload def _stableify(obj: NwExpr) -> Expr: ... @overload -def _stableify(when_then: NwWhen) -> When: ... -# @overload -# def _stableify(when_then: NwChainedWhen) -> ChainedWhen: ... +def _stableify(obj: NwWhen) -> When: ... @overload -def _stableify(when_then: NwThen) -> Then: ... -# @overload -# def _stableify(when_then: NwChainedThen) -> ChainedThen: ... +def _stableify(obj: NwChainedWhen) -> ChainedWhen: ... @overload def _stableify(obj: Any) -> Any: ... @@ -507,9 +502,9 @@ def _stableify( | NwSeries | NwExpr | NwWhen - # | NwChainedWhen + | NwChainedWhen | NwThen - # | NwChainedThen + | NwChainedThen | Any, ) -> ( DataFrame[IntoFrameT] @@ -517,9 +512,9 @@ def _stableify( | Series | Expr | When - # | ChainedWhen + | ChainedWhen | Then - # | ChainedThen + | ChainedThen | Any ): if isinstance(obj, NwDataFrame): @@ -537,12 +532,12 @@ def _stableify( obj._compliant_series, level=obj._level, ) - # elif isinstance(obj, NwChainedWhen): - # return ChainedWhen.from_base(obj) + elif isinstance(obj, NwChainedWhen): + return ChainedWhen.from_base(obj) if isinstance(obj, NwWhen): return When.from_when(obj) - # elif isinstance(obj, NwChainedThen): - # return ChainedThen.from_base(obj) + elif isinstance(obj, NwChainedThen): + return ChainedThen.from_base(obj) elif isinstance(obj, NwThen): return Then.from_then(obj) if isinstance(obj, NwExpr): @@ -1741,6 +1736,27 @@ def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) +class ChainedWhen(NwChainedWhen): + @classmethod + def from_base(cls, chained_when: NwChainedWhen) -> Self: + return cls(_stableify(chained_when._above_then), *chained_when._predicates) # type: ignore[arg-type] + + def then(self, value: Any) -> ChainedThen: + return _stableify(super().then(value)) # type: ignore[return-value] + + +class ChainedThen(NwChainedThen, Expr): + @classmethod + def from_base(cls, chained_then: NwChainedThen) -> Self: + return cls(chained_then._call) + + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return _stableify(super().when(*predicates)) + + def otherwise(self, value: Any) -> Expr: + return _stableify(super().otherwise(value)) + + def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: """ Start a `when-then-otherwise` expression. @@ -1841,27 +1857,6 @@ def new_series( ) -# class ChainedWhen(NwChainedWhen): -# @classmethod -# def from_base(cls, chained_when: NwChainedWhen) -> Self: -# return cls(_stableify(chained_when._above_then), *chained_when._predicates) - -# def then(self, value: Any) -> ChainedThen: -# return _stableify(super().then(value)) - - -# class ChainedThen(NwChainedThen, Expr): -# @classmethod -# def from_base(cls, chained_then: NwChainedThen) -> Self: -# return cls(chained_then._call) - -# def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: -# return _stableify(super().when(*predicates)) - -# def otherwise(self, value: Any) -> Expr: -# return _stableify(super().otherwise(value)) - - def from_dict( data: dict[str, Any], schema: dict[str, DType] | Schema | None = None, From fbb6c26d4d6bf71af78bddb04eb826f7c56e9a0e Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:31:17 +0300 Subject: [PATCH 61/64] misc: stabilify when-then --- narwhals/stable/v1.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index e7dbd1043..dde7ca5fd 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -1735,6 +1735,9 @@ def from_then(cls, then: NwThen) -> Self: def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return _stableify(super().when(*predicates)) + class ChainedWhen(NwChainedWhen): @classmethod From e0819ce6996008b4245dd11730fa51ad8bb967fe Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 16 Sep 2024 12:30:43 +0300 Subject: [PATCH 62/64] tests: increase the test coverage for when-the-otherwise --- narwhals/stable/v1.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 4758e938f..e5282c38b 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -539,11 +539,11 @@ def _stableify( elif isinstance(obj, NwChainedWhen): return ChainedWhen.from_base(obj) if isinstance(obj, NwWhen): - return When.from_when(obj) + return When.from_base(obj) elif isinstance(obj, NwChainedThen): return ChainedThen.from_base(obj) elif isinstance(obj, NwThen): - return Then.from_then(obj) + return Then.from_base(obj) if isinstance(obj, NwExpr): return Expr(obj._call) return obj @@ -1724,16 +1724,16 @@ def get_level( class When(NwWhen): @classmethod - def from_when(cls, when: NwWhen) -> Self: + def from_base(cls, when: NwWhen) -> Self: return cls(*when._predicates) def then(self, value: Any) -> Then: - return Then.from_then(super().then(value)) + return Then.from_base(super().then(value)) class Then(NwThen, Expr): @classmethod - def from_then(cls, then: NwThen) -> Self: + def from_base(cls, then: NwThen) -> Self: return cls(then._call) def otherwise(self, value: Any) -> Expr: @@ -1809,7 +1809,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: │ 3 ┆ 15 ┆ 6 │ └─────┴─────┴────────┘ """ - return When.from_when(nw_when(*predicates)) + return _stableify(nw_when(*predicates)) def new_series( From 8b32e5d4d436ae7272ac268b6daa3dbb6256055b Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 16 Sep 2024 12:33:00 +0300 Subject: [PATCH 63/64] tests: actually increase the thest coverage of when-the-otherwise --- tests/expr_and_series/when_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index ed4a2ccd6..fbfc5f932 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -228,3 +228,30 @@ def test_multi_chained_when_otherwise(request: Any, constructor: Any) -> None: "a_when": [3, 5, 7, 9, 9, 9], } compare_dicts(result, expected) + + +def test_then_when_no_condition(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + + with pytest.raises((TypeError, ValueError)): + df.select(nw.when(nw.col("a") == 1).then(value=3).when().then(value=7)) + + +def test_then_chained_when_no_condition(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + + with pytest.raises((TypeError, ValueError)): + df.select( + nw.when(nw.col("a") == 1) + .then(value=3) + .when(nw.col("a") == 3) + .then(value=7) + .when() + .then(value=9) + ) From 9427b33a6097b28d2a160c522597d03a5ad18855 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov <54221777+aivanoved@users.noreply.github.com> Date: Mon, 16 Sep 2024 12:37:11 +0300 Subject: [PATCH 64/64] Update narwhals/_pandas_like/namespace.py --- narwhals/_pandas_like/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index ab2ee8037..0be930556 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -419,7 +419,7 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: try: value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type] except TypeError: - # `self._otherwise_value` is a scalar and can't be converted to an expression + # `self._then_value` is a scalar and can't be converted to an expression value_series = condition.__class__._from_iterable( # type: ignore[call-arg] [self._then_value] * len(condition), name="literal",