diff --git a/docs/api-reference/expr.md b/docs/api-reference/expr.md index 627e1cb4b..694ae504b 100644 --- a/docs/api-reference/expr.md +++ b/docs/api-reference/expr.md @@ -18,6 +18,7 @@ - cum_sum - diff - drop_nulls + - ewm_mean - fill_null - filter - gather_every diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index 12ca72208..d0cf7875f 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -23,6 +23,7 @@ - diff - drop_nulls - dtype + - ewm_mean - fill_null - filter - gather_every diff --git a/docs/css/code_select.css b/docs/css/code_select.css new file mode 100644 index 000000000..259cf11c5 --- /dev/null +++ b/docs/css/code_select.css @@ -0,0 +1,7 @@ +.highlight .gp, .highlight .go { /* Generic.Prompt, Generic.Output */ + user-select: none; + -webkit-user-select: none; /* Safari */ + -moz-user-select: none; /* Firefox */ + -ms-user-select: none; /* Internet Explorer/Edge */ + color: red; +} \ No newline at end of file diff --git a/docs/javascripts/katex.js b/docs/javascripts/katex.js new file mode 100644 index 000000000..3828300a7 --- /dev/null +++ b/docs/javascripts/katex.js @@ -0,0 +1,10 @@ +document$.subscribe(({ body }) => { + renderMathInElement(body, { + delimiters: [ + { left: "$$", right: "$$", display: true }, + { left: "$", right: "$", display: false }, + { left: "\\(", right: "\\)", display: false }, + { left: "\\[", right: "\\]", display: true } + ], + }) + }) \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index e7268e2b4..0358bd2c1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -118,3 +118,12 @@ markdown_extensions: - pymdownx.emoji: emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_generator: !!python/name:material.extensions.emoji.to_svg +- pymdownx.arithmatex: + generic: true +extra_javascript: + - javascripts/katex.js + - https://unpkg.com/katex@0/dist/katex.min.js + - https://unpkg.com/katex@0/dist/contrib/auto-render.min.js + +extra_css: + - https://unpkg.com/katex@0/dist/katex.min.css \ No newline at end of file diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 912d60295..b70635dc6 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -539,6 +539,20 @@ def round(self, decimals: int) -> Self: returns_scalar=False, ) + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> NoReturn: + msg = "`Expr.ewm_mean` is not supported for the Dask backend" + raise NotImplementedError(msg) + def unique(self) -> NoReturn: # We can't (yet?) allow methods which modify the index msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead." diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 802010977..fa769e790 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -285,6 +285,29 @@ def is_in(self, other: Any) -> Self: def arg_true(self) -> Self: return reuse_series_implementation(self, "arg_true") + def ewm_mean( + self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + return reuse_series_implementation( + self, + "ewm_mean", + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + def filter(self, *predicates: Any) -> Self: plx = self.__narwhals_namespace__() other = plx.all_horizontal(*predicates) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index a1b379260..9cbdfd6af 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -175,6 +175,25 @@ def dtype(self: Self) -> DType: self._native_series, self._dtypes, self._implementation ) + def ewm_mean( + self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> PandasLikeSeries: + ser = self._native_series + mask_na = ser.isna() + result = ser.ewm( + com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls + ).mean() + result[mask_na] = None + return self._from_native_series(result) + def scatter(self, indices: int | Sequence[int], values: Any) -> Self: if isinstance(values, self.__class__): # .copy() is necessary in some pre-2.2 versions of pandas to avoid diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 7f2bb3c7f..6a28bfbc4 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -9,6 +9,7 @@ from narwhals._polars.utils import extract_native from narwhals._polars.utils import narwhals_to_native_dtype from narwhals.utils import Implementation +from narwhals.utils import parse_version if TYPE_CHECKING: import polars as pl @@ -49,6 +50,35 @@ def cast(self, dtype: DType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_expr(expr.cast(dtype_pl)) + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + import polars as pl # ignore-banned-import() + + if parse_version(pl.__version__) <= (0, 20, 31): # pragma: no cover + msg = "`ewm_mean` not implemented for polars older than 0.20.31" + raise NotImplementedError(msg) + expr = self._native_expr + return self._from_native_expr( + expr.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + ) + def map_batches( self, function: Callable[[Any], Self], diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 9013d6ab4..a1251199e 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -225,6 +225,33 @@ def to_dummies( result, backend_version=self._backend_version, dtypes=self._dtypes ) + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + if self._backend_version < (0, 20, 31): # pragma: no cover + msg = "`ewm_mean` not implemented for polars older than 0.20.31" + raise NotImplementedError(msg) + expr = self._native_series + return self._from_native_series( + expr.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + ) + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: if self._backend_version < (0, 20, 6): result = self._native_series.sort(descending=descending) diff --git a/narwhals/expr.py b/narwhals/expr.py index 850ba7835..ddacd4fff 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -418,6 +418,103 @@ def all(self) -> Self: """ return self.__class__(lambda plx: self._call(plx).all()) + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + r"""Compute exponentially-weighted moving average. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$ + span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$ + half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$ + alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$. + adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings + - When `adjust=True` (the default) the EW function is calculated + using weights $w_i = (1 - \alpha)^i$ + - When `adjust=False` the EW function is calculated recursively by + $$ + y_0=x_0 + $$ + $$ + y_t = (1 - \alpha)y_{t - 1} + \alpha x_t + $$ + min_periods: Minimum number of observations in window required to have a value, (otherwise result is null). + ignore_nulls: Ignore missing values when calculating weights. + + - When `ignore_nulls=False` (default), weights are based on absolute + positions. + For example, the weights of $x_0$ and $x_2$ used in + calculating the final weighted average of $[x_0, None, x_2]$ are + $(1-\alpha)^2$ and $1$ if `adjust=True`, and + $(1-\alpha)^2$ and $\alpha$ if `adjust=False`. + - When `ignore_nulls=True`, weights are based + on relative positions. For example, the weights of + $x_0$ and $x_2$ used in calculating the final weighted + average of $[x_0, None, x_2]$ are + $1-\alpha$ and $1$ if `adjust=True`, + and $1-\alpha$ and $\alpha$ if `adjust=False`. + + Returns: + Expr + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data = {"a": [1, 2, 3]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=False)) + + We can then pass either pandas or Polars to `func`: + + >>> func(df_pd) + a + 0 1.000000 + 1 1.666667 + 2 2.428571 + + >>> func(df_pl) # doctest: +NORMALIZE_WHITESPACE + shape: (3, 1) + ┌──────────┐ + │ a │ + │ --- │ + │ f64 │ + ╞══════════╡ + │ 1.0 │ + │ 1.666667 │ + │ 2.428571 │ + └──────────┘ + """ + return self.__class__( + lambda plx: self._call(plx).ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + ) + def mean(self) -> Self: """Get mean value. diff --git a/narwhals/series.py b/narwhals/series.py index 78801f78c..01829cb04 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -376,6 +376,101 @@ def name(self) -> str: """ return self._compliant_series.name # type: ignore[no-any-return] + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + r"""Compute exponentially-weighted moving average. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$ + span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$ + half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$ + alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$. + adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings + + - When `adjust=True` (the default) the EW function is calculated + using weights $w_i = (1 - \alpha)^i$ + - When `adjust=False` the EW function is calculated recursively by + $$ + y_0=x_0 + $$ + $$ + y_t = (1 - \alpha)y_{t - 1} + \alpha x_t + $$ + min_periods: Minimum number of observations in window required to have a value (otherwise result is null). + ignore_nulls: Ignore missing values when calculating weights. + + - When `ignore_nulls=False` (default), weights are based on absolute + positions. + For example, the weights of $x_0$ and $x_2$ used in + calculating the final weighted average of $[x_0, None, x_2]$ are + $(1-\alpha)^2$ and $1$ if `adjust=True`, and + $(1-\alpha)^2$ and $\alpha$ if `adjust=False`. + - When `ignore_nulls=True`, weights are based + on relative positions. For example, the weights of + $x_0$ and $x_2$ used in calculating the final weighted + average of $[x_0, None, x_2]$ are + $1-\alpha$ and $1$ if `adjust=True`, + and $1-\alpha$ and $\alpha$ if `adjust=False`. + + Returns: + Series + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data = [1, 2, 3] + >>> s_pd = pd.Series(name="a", data=data) + >>> s_pl = pl.Series(name="a", values=data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(s): + ... return s.ewm_mean(com=1, ignore_nulls=False) + + We can then pass either pandas or Polars to `func`: + + >>> func(s_pd) + 0 1.000000 + 1 1.666667 + 2 2.428571 + Name: a, dtype: float64 + + >>> func(s_pl) # doctest: +NORMALIZE_WHITESPACE + shape: (3,) + Series: 'a' [f64] + [ + 1.0 + 1.666667 + 2.428571 + ] + """ + return self._from_compliant_series( + self._compliant_series.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + ) + def cast(self: Self, dtype: DType | type[DType]) -> Self: """Cast between data types. diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index d35ecd434..12016d80d 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -493,6 +493,107 @@ def value_counts( sort=sort, parallel=parallel, name=name, normalize=normalize ) + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + r"""Compute exponentially-weighted moving average. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$ + span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$ + half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$ + alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$. + adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings + + - When `adjust=True` (the default) the EW function is calculated + using weights $w_i = (1 - \alpha)^i$ + - When `adjust=False` the EW function is calculated recursively by + $$ + y_0=x_0 + $$ + $$ + y_t = (1 - \alpha)y_{t - 1} + \alpha x_t + $$ + min_periods: Minimum number of observations in window required to have a value (otherwise result is null). + ignore_nulls: Ignore missing values when calculating weights. + + - When `ignore_nulls=False` (default), weights are based on absolute + positions. + For example, the weights of $x_0$ and $x_2$ used in + calculating the final weighted average of $[x_0, None, x_2]$ are + $(1-\alpha)^2$ and $1$ if `adjust=True`, and + $(1-\alpha)^2$ and $\alpha$ if `adjust=False`. + - When `ignore_nulls=True`, weights are based + on relative positions. For example, the weights of + $x_0$ and $x_2$ used in calculating the final weighted + average of $[x_0, None, x_2]$ are + $1-\alpha$ and $1$ if `adjust=True`, + and $1-\alpha$ and $\alpha$ if `adjust=False`. + + Returns: + Series + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data = [1, 2, 3] + >>> s_pd = pd.Series(name="a", data=data) + >>> s_pl = pl.Series(name="a", values=data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(s): + ... return s.ewm_mean(com=1, ignore_nulls=False) + + We can then pass either pandas or Polars to `func`: + + >>> func(s_pd) + 0 1.000000 + 1 1.666667 + 2 2.428571 + Name: a, dtype: float64 + + >>> func(s_pl) # doctest: +NORMALIZE_WHITESPACE + shape: (3,) + Series: 'a' [f64] + [ + 1.0 + 1.666667 + 2.428571 + ] + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Series.ewm_mean` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + def rolling_sum( self: Self, window_size: int, @@ -589,6 +690,110 @@ class Expr(NwExpr): def _l1_norm(self) -> Self: return super()._taxicab_norm() + def ewm_mean( + self: Self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_periods: int = 1, + ignore_nulls: bool = False, + ) -> Self: + r"""Compute exponentially-weighted moving average. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$ + span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$ + half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$ + alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$. + adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings + + - When `adjust=True` (the default) the EW function is calculated + using weights $w_i = (1 - \alpha)^i$ + - When `adjust=False` the EW function is calculated recursively by + $$ + y_0=x_0 + $$ + $$ + y_t = (1 - \alpha)y_{t - 1} + \alpha x_t + $$ + min_periods: Minimum number of observations in window required to have a value, (otherwise result is null). + ignore_nulls: Ignore missing values when calculating weights. + + - When `ignore_nulls=False` (default), weights are based on absolute + positions. + For example, the weights of $x_0$ and $x_2$ used in + calculating the final weighted average of $[x_0, None, x_2]$ are + $(1-\alpha)^2$ and $1$ if `adjust=True`, and + $(1-\alpha)^2$ and $\alpha$ if `adjust=False`. + - When `ignore_nulls=True`, weights are based + on relative positions. For example, the weights of + $x_0$ and $x_2$ used in calculating the final weighted + average of $[x_0, None, x_2]$ are + $1-\alpha$ and $1$ if `adjust=True`, + and $1-\alpha$ and $\alpha$ if `adjust=False`. + + Returns: + Expr + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data = {"a": [1, 2, 3]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=False)) + + We can then pass either pandas or Polars to `func`: + + >>> func(df_pd) + a + 0 1.000000 + 1 1.666667 + 2 2.428571 + + >>> func(df_pl) # doctest: +NORMALIZE_WHITESPACE + shape: (3, 1) + ┌──────────┐ + │ a │ + │ --- │ + │ f64 │ + ╞══════════╡ + │ 1.0 │ + │ 1.666667 │ + │ 2.428571 │ + └──────────┘ + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Expr.ewm_mean` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) + def rolling_sum( self: Self, window_size: int, diff --git a/tests/expr_and_series/ewm_test.py b/tests/expr_and_series/ewm_test.py new file mode 100644 index 000000000..e541a5bfe --- /dev/null +++ b/tests/expr_and_series/ewm_test.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import POLARS_VERSION +from tests.utils import Constructor +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data + +data = {"a": [1, 1, 2], "b": [1, 2, 3]} + + +@pytest.mark.filterwarnings( + "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." +) +def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( + "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31) + ): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a", "b").ewm_mean(com=1)) + expected = { + "a": [1.0, 1.0, 1.5714285714285714], + "b": [1.0, 1.6666666666666667, 2.4285714285714284], + } + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Series.ewm_mean` is being called from the stable API although considered an unstable feature." +) +def test_ewm_mean_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or ( + "polars" in str(constructor_eager) and POLARS_VERSION <= (0, 20, 31) + ): + request.applymarker(pytest.mark.xfail) + + series = nw.from_native(constructor_eager(data), eager_only=True)["a"] + result = series.ewm_mean(com=1) + expected = {"a": [1.0, 1.0, 1.5714285714285714]} + assert_equal_data({"a": result}, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." +) +@pytest.mark.parametrize( + ("adjust", "expected"), + [ + ( + True, + { + "a": [1.0, 1.0, 1.5714285714285714], + "b": [1.0, 1.6666666666666667, 2.4285714285714284], + }, + ), + ( + False, + { + "a": [1.0, 1.0, 1.5], + "b": [1.0, 1.5, 2.25], + }, + ), + ], +) +def test_ewm_mean_expr_adjust( + request: pytest.FixtureRequest, + constructor: Constructor, + adjust: bool, # noqa: FBT001 + expected: dict[str, list[float]], +) -> None: + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( + "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31) + ): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a", "b").ewm_mean(com=1, adjust=adjust)) + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." +) +def test_ewm_mean_dask_raise() -> None: + pytest.importorskip("dask") + pytest.importorskip("dask_expr", exc_type=ImportError) + import dask.dataframe as dd + + df = nw.from_native(dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}))) + with pytest.raises( + NotImplementedError, + match="`Expr.ewm_mean` is not supported for the Dask backend", + ): + df.select(nw.col("a").ewm_mean(com=1)) + + +@pytest.mark.filterwarnings( + "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." +) +@pytest.mark.parametrize( + ("ignore_nulls", "expected"), + [ + ( + True, + { + "a": [ + 2.0, + 3.3333333333333335, + None, + 3.142857142857143, + ] + }, + ), + ( + False, + { + "a": [ + 2.0, + 3.3333333333333335, + None, + 3.090909090909091, + ] + }, + ), + ], +) +def test_ewm_mean_nulls( + request: pytest.FixtureRequest, + ignore_nulls: bool, # noqa: FBT001 + expected: dict[str, list[float]], + constructor: Constructor, +) -> None: + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( + "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31) + ): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor({"a": [2.0, 4.0, None, 3.0]})) + result = df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=ignore_nulls)) + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." +) +def test_ewm_mean_params( + request: pytest.FixtureRequest, + constructor: Constructor, +) -> None: + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( + "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31) + ): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor({"a": [2, 5, 3]})) + expected: dict[str, list[float | None]] = {"a": [2.0, 4.0, 3.4285714285714284]} + assert_equal_data( + df.select(nw.col("a").ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True)), + expected, + ) + + expected = {"a": [2.0, 4.500000000000001, 3.2903225806451615]} + assert_equal_data( + df.select(nw.col("a").ewm_mean(span=1.5, adjust=True, ignore_nulls=True)), + expected, + ) + + expected = {"a": [2.0, 3.1101184251576903, 3.0693702609187237]} + assert_equal_data( + df.select(nw.col("a").ewm_mean(half_life=1.5, adjust=False)), expected + ) + + expected = {"a": [None, 4.0, 3.4285714285714284]} + assert_equal_data( + df.select( + nw.col("a").ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=True) + ), + expected, + ) + + with pytest.raises(ValueError, match="mutually exclusive"): + df.select(nw.col("a").ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False))