diff --git a/docs/api-reference/expr.md b/docs/api-reference/expr.md
index 627e1cb4b..694ae504b 100644
--- a/docs/api-reference/expr.md
+++ b/docs/api-reference/expr.md
@@ -18,6 +18,7 @@
- cum_sum
- diff
- drop_nulls
+ - ewm_mean
- fill_null
- filter
- gather_every
diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md
index 12ca72208..d0cf7875f 100644
--- a/docs/api-reference/series.md
+++ b/docs/api-reference/series.md
@@ -23,6 +23,7 @@
- diff
- drop_nulls
- dtype
+ - ewm_mean
- fill_null
- filter
- gather_every
diff --git a/docs/css/code_select.css b/docs/css/code_select.css
new file mode 100644
index 000000000..259cf11c5
--- /dev/null
+++ b/docs/css/code_select.css
@@ -0,0 +1,7 @@
+.highlight .gp, .highlight .go { /* Generic.Prompt, Generic.Output */
+ user-select: none;
+ -webkit-user-select: none; /* Safari */
+ -moz-user-select: none; /* Firefox */
+ -ms-user-select: none; /* Internet Explorer/Edge */
+ color: red;
+}
\ No newline at end of file
diff --git a/docs/javascripts/katex.js b/docs/javascripts/katex.js
new file mode 100644
index 000000000..3828300a7
--- /dev/null
+++ b/docs/javascripts/katex.js
@@ -0,0 +1,10 @@
+document$.subscribe(({ body }) => {
+ renderMathInElement(body, {
+ delimiters: [
+ { left: "$$", right: "$$", display: true },
+ { left: "$", right: "$", display: false },
+ { left: "\\(", right: "\\)", display: false },
+ { left: "\\[", right: "\\]", display: true }
+ ],
+ })
+ })
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index e7268e2b4..0358bd2c1 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -118,3 +118,12 @@ markdown_extensions:
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
+- pymdownx.arithmatex:
+ generic: true
+extra_javascript:
+ - javascripts/katex.js
+ - https://unpkg.com/katex@0/dist/katex.min.js
+ - https://unpkg.com/katex@0/dist/contrib/auto-render.min.js
+
+extra_css:
+ - https://unpkg.com/katex@0/dist/katex.min.css
\ No newline at end of file
diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py
index 912d60295..b70635dc6 100644
--- a/narwhals/_dask/expr.py
+++ b/narwhals/_dask/expr.py
@@ -539,6 +539,20 @@ def round(self, decimals: int) -> Self:
returns_scalar=False,
)
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> NoReturn:
+ msg = "`Expr.ewm_mean` is not supported for the Dask backend"
+ raise NotImplementedError(msg)
+
def unique(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead."
diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py
index 802010977..fa769e790 100644
--- a/narwhals/_pandas_like/expr.py
+++ b/narwhals/_pandas_like/expr.py
@@ -285,6 +285,29 @@ def is_in(self, other: Any) -> Self:
def arg_true(self) -> Self:
return reuse_series_implementation(self, "arg_true")
+ def ewm_mean(
+ self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ return reuse_series_implementation(
+ self,
+ "ewm_mean",
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+
def filter(self, *predicates: Any) -> Self:
plx = self.__narwhals_namespace__()
other = plx.all_horizontal(*predicates)
diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py
index a1b379260..9cbdfd6af 100644
--- a/narwhals/_pandas_like/series.py
+++ b/narwhals/_pandas_like/series.py
@@ -175,6 +175,25 @@ def dtype(self: Self) -> DType:
self._native_series, self._dtypes, self._implementation
)
+ def ewm_mean(
+ self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> PandasLikeSeries:
+ ser = self._native_series
+ mask_na = ser.isna()
+ result = ser.ewm(
+ com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
+ ).mean()
+ result[mask_na] = None
+ return self._from_native_series(result)
+
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
if isinstance(values, self.__class__):
# .copy() is necessary in some pre-2.2 versions of pandas to avoid
diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py
index 7f2bb3c7f..6a28bfbc4 100644
--- a/narwhals/_polars/expr.py
+++ b/narwhals/_polars/expr.py
@@ -9,6 +9,7 @@
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.utils import Implementation
+from narwhals.utils import parse_version
if TYPE_CHECKING:
import polars as pl
@@ -49,6 +50,35 @@ def cast(self, dtype: DType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._dtypes)
return self._from_native_expr(expr.cast(dtype_pl))
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ import polars as pl # ignore-banned-import()
+
+ if parse_version(pl.__version__) <= (0, 20, 31): # pragma: no cover
+ msg = "`ewm_mean` not implemented for polars older than 0.20.31"
+ raise NotImplementedError(msg)
+ expr = self._native_expr
+ return self._from_native_expr(
+ expr.ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+ )
+
def map_batches(
self,
function: Callable[[Any], Self],
diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py
index 9013d6ab4..a1251199e 100644
--- a/narwhals/_polars/series.py
+++ b/narwhals/_polars/series.py
@@ -225,6 +225,33 @@ def to_dummies(
result, backend_version=self._backend_version, dtypes=self._dtypes
)
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ if self._backend_version < (0, 20, 31): # pragma: no cover
+ msg = "`ewm_mean` not implemented for polars older than 0.20.31"
+ raise NotImplementedError(msg)
+ expr = self._native_series
+ return self._from_native_series(
+ expr.ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+ )
+
def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
if self._backend_version < (0, 20, 6):
result = self._native_series.sort(descending=descending)
diff --git a/narwhals/expr.py b/narwhals/expr.py
index 850ba7835..ddacd4fff 100644
--- a/narwhals/expr.py
+++ b/narwhals/expr.py
@@ -418,6 +418,103 @@ def all(self) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).all())
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ r"""Compute exponentially-weighted moving average.
+
+ !!! warning
+ This functionality is considered **unstable**. It may be changed at any point
+ without it being considered a breaking change.
+
+ Arguments:
+ com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$
+ span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$
+ half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$
+ alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$.
+ adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
+ - When `adjust=True` (the default) the EW function is calculated
+ using weights $w_i = (1 - \alpha)^i$
+ - When `adjust=False` the EW function is calculated recursively by
+ $$
+ y_0=x_0
+ $$
+ $$
+ y_t = (1 - \alpha)y_{t - 1} + \alpha x_t
+ $$
+ min_periods: Minimum number of observations in window required to have a value, (otherwise result is null).
+ ignore_nulls: Ignore missing values when calculating weights.
+
+ - When `ignore_nulls=False` (default), weights are based on absolute
+ positions.
+ For example, the weights of $x_0$ and $x_2$ used in
+ calculating the final weighted average of $[x_0, None, x_2]$ are
+ $(1-\alpha)^2$ and $1$ if `adjust=True`, and
+ $(1-\alpha)^2$ and $\alpha$ if `adjust=False`.
+ - When `ignore_nulls=True`, weights are based
+ on relative positions. For example, the weights of
+ $x_0$ and $x_2$ used in calculating the final weighted
+ average of $[x_0, None, x_2]$ are
+ $1-\alpha$ and $1$ if `adjust=True`,
+ and $1-\alpha$ and $\alpha$ if `adjust=False`.
+
+ Returns:
+ Expr
+
+ Examples:
+ >>> import pandas as pd
+ >>> import polars as pl
+ >>> import narwhals as nw
+ >>> data = {"a": [1, 2, 3]}
+ >>> df_pd = pd.DataFrame(data)
+ >>> df_pl = pl.DataFrame(data)
+
+ We define a library agnostic function:
+
+ >>> @nw.narwhalify
+ ... def func(df):
+ ... return df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=False))
+
+ We can then pass either pandas or Polars to `func`:
+
+ >>> func(df_pd)
+ a
+ 0 1.000000
+ 1 1.666667
+ 2 2.428571
+
+ >>> func(df_pl) # doctest: +NORMALIZE_WHITESPACE
+ shape: (3, 1)
+ ┌──────────┐
+ │ a │
+ │ --- │
+ │ f64 │
+ ╞══════════╡
+ │ 1.0 │
+ │ 1.666667 │
+ │ 2.428571 │
+ └──────────┘
+ """
+ return self.__class__(
+ lambda plx: self._call(plx).ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+ )
+
def mean(self) -> Self:
"""Get mean value.
diff --git a/narwhals/series.py b/narwhals/series.py
index 78801f78c..01829cb04 100644
--- a/narwhals/series.py
+++ b/narwhals/series.py
@@ -376,6 +376,101 @@ def name(self) -> str:
"""
return self._compliant_series.name # type: ignore[no-any-return]
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ r"""Compute exponentially-weighted moving average.
+
+ !!! warning
+ This functionality is considered **unstable**. It may be changed at any point
+ without it being considered a breaking change.
+
+ Arguments:
+ com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$
+ span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$
+ half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$
+ alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$.
+ adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
+
+ - When `adjust=True` (the default) the EW function is calculated
+ using weights $w_i = (1 - \alpha)^i$
+ - When `adjust=False` the EW function is calculated recursively by
+ $$
+ y_0=x_0
+ $$
+ $$
+ y_t = (1 - \alpha)y_{t - 1} + \alpha x_t
+ $$
+ min_periods: Minimum number of observations in window required to have a value (otherwise result is null).
+ ignore_nulls: Ignore missing values when calculating weights.
+
+ - When `ignore_nulls=False` (default), weights are based on absolute
+ positions.
+ For example, the weights of $x_0$ and $x_2$ used in
+ calculating the final weighted average of $[x_0, None, x_2]$ are
+ $(1-\alpha)^2$ and $1$ if `adjust=True`, and
+ $(1-\alpha)^2$ and $\alpha$ if `adjust=False`.
+ - When `ignore_nulls=True`, weights are based
+ on relative positions. For example, the weights of
+ $x_0$ and $x_2$ used in calculating the final weighted
+ average of $[x_0, None, x_2]$ are
+ $1-\alpha$ and $1$ if `adjust=True`,
+ and $1-\alpha$ and $\alpha$ if `adjust=False`.
+
+ Returns:
+ Series
+
+ Examples:
+ >>> import pandas as pd
+ >>> import polars as pl
+ >>> import narwhals as nw
+ >>> data = [1, 2, 3]
+ >>> s_pd = pd.Series(name="a", data=data)
+ >>> s_pl = pl.Series(name="a", values=data)
+
+ We define a library agnostic function:
+
+ >>> @nw.narwhalify
+ ... def func(s):
+ ... return s.ewm_mean(com=1, ignore_nulls=False)
+
+ We can then pass either pandas or Polars to `func`:
+
+ >>> func(s_pd)
+ 0 1.000000
+ 1 1.666667
+ 2 2.428571
+ Name: a, dtype: float64
+
+ >>> func(s_pl) # doctest: +NORMALIZE_WHITESPACE
+ shape: (3,)
+ Series: 'a' [f64]
+ [
+ 1.0
+ 1.666667
+ 2.428571
+ ]
+ """
+ return self._from_compliant_series(
+ self._compliant_series.ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+ )
+
def cast(self: Self, dtype: DType | type[DType]) -> Self:
"""Cast between data types.
diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py
index d35ecd434..12016d80d 100644
--- a/narwhals/stable/v1/__init__.py
+++ b/narwhals/stable/v1/__init__.py
@@ -493,6 +493,107 @@ def value_counts(
sort=sort, parallel=parallel, name=name, normalize=normalize
)
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ r"""Compute exponentially-weighted moving average.
+
+ !!! warning
+ This functionality is considered **unstable**. It may be changed at any point
+ without it being considered a breaking change.
+
+ Arguments:
+ com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$
+ span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$
+ half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$
+ alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$.
+ adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
+
+ - When `adjust=True` (the default) the EW function is calculated
+ using weights $w_i = (1 - \alpha)^i$
+ - When `adjust=False` the EW function is calculated recursively by
+ $$
+ y_0=x_0
+ $$
+ $$
+ y_t = (1 - \alpha)y_{t - 1} + \alpha x_t
+ $$
+ min_periods: Minimum number of observations in window required to have a value (otherwise result is null).
+ ignore_nulls: Ignore missing values when calculating weights.
+
+ - When `ignore_nulls=False` (default), weights are based on absolute
+ positions.
+ For example, the weights of $x_0$ and $x_2$ used in
+ calculating the final weighted average of $[x_0, None, x_2]$ are
+ $(1-\alpha)^2$ and $1$ if `adjust=True`, and
+ $(1-\alpha)^2$ and $\alpha$ if `adjust=False`.
+ - When `ignore_nulls=True`, weights are based
+ on relative positions. For example, the weights of
+ $x_0$ and $x_2$ used in calculating the final weighted
+ average of $[x_0, None, x_2]$ are
+ $1-\alpha$ and $1$ if `adjust=True`,
+ and $1-\alpha$ and $\alpha$ if `adjust=False`.
+
+ Returns:
+ Series
+
+ Examples:
+ >>> import pandas as pd
+ >>> import polars as pl
+ >>> import narwhals as nw
+ >>> data = [1, 2, 3]
+ >>> s_pd = pd.Series(name="a", data=data)
+ >>> s_pl = pl.Series(name="a", values=data)
+
+ We define a library agnostic function:
+
+ >>> @nw.narwhalify
+ ... def func(s):
+ ... return s.ewm_mean(com=1, ignore_nulls=False)
+
+ We can then pass either pandas or Polars to `func`:
+
+ >>> func(s_pd)
+ 0 1.000000
+ 1 1.666667
+ 2 2.428571
+ Name: a, dtype: float64
+
+ >>> func(s_pl) # doctest: +NORMALIZE_WHITESPACE
+ shape: (3,)
+ Series: 'a' [f64]
+ [
+ 1.0
+ 1.666667
+ 2.428571
+ ]
+ """
+ from narwhals.exceptions import NarwhalsUnstableWarning
+ from narwhals.utils import find_stacklevel
+
+ msg = (
+ "`Series.ewm_mean` is being called from the stable API although considered "
+ "an unstable feature."
+ )
+ warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel())
+ return super().ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+
def rolling_sum(
self: Self,
window_size: int,
@@ -589,6 +690,110 @@ class Expr(NwExpr):
def _l1_norm(self) -> Self:
return super()._taxicab_norm()
+ def ewm_mean(
+ self: Self,
+ *,
+ com: float | None = None,
+ span: float | None = None,
+ half_life: float | None = None,
+ alpha: float | None = None,
+ adjust: bool = True,
+ min_periods: int = 1,
+ ignore_nulls: bool = False,
+ ) -> Self:
+ r"""Compute exponentially-weighted moving average.
+
+ !!! warning
+ This functionality is considered **unstable**. It may be changed at any point
+ without it being considered a breaking change.
+
+ Arguments:
+ com: Specify decay in terms of center of mass, $\gamma$, with
$\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$
+ span: Specify decay in terms of span, $\theta$, with
$\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$
+ half_life: Specify decay in terms of half-life, $\tau$, with
$\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$
+ alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$.
+ adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
+
+ - When `adjust=True` (the default) the EW function is calculated
+ using weights $w_i = (1 - \alpha)^i$
+ - When `adjust=False` the EW function is calculated recursively by
+ $$
+ y_0=x_0
+ $$
+ $$
+ y_t = (1 - \alpha)y_{t - 1} + \alpha x_t
+ $$
+ min_periods: Minimum number of observations in window required to have a value, (otherwise result is null).
+ ignore_nulls: Ignore missing values when calculating weights.
+
+ - When `ignore_nulls=False` (default), weights are based on absolute
+ positions.
+ For example, the weights of $x_0$ and $x_2$ used in
+ calculating the final weighted average of $[x_0, None, x_2]$ are
+ $(1-\alpha)^2$ and $1$ if `adjust=True`, and
+ $(1-\alpha)^2$ and $\alpha$ if `adjust=False`.
+ - When `ignore_nulls=True`, weights are based
+ on relative positions. For example, the weights of
+ $x_0$ and $x_2$ used in calculating the final weighted
+ average of $[x_0, None, x_2]$ are
+ $1-\alpha$ and $1$ if `adjust=True`,
+ and $1-\alpha$ and $\alpha$ if `adjust=False`.
+
+ Returns:
+ Expr
+
+ Examples:
+ >>> import pandas as pd
+ >>> import polars as pl
+ >>> import narwhals as nw
+ >>> data = {"a": [1, 2, 3]}
+ >>> df_pd = pd.DataFrame(data)
+ >>> df_pl = pl.DataFrame(data)
+
+ We define a library agnostic function:
+
+ >>> @nw.narwhalify
+ ... def func(df):
+ ... return df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=False))
+
+ We can then pass either pandas or Polars to `func`:
+
+ >>> func(df_pd)
+ a
+ 0 1.000000
+ 1 1.666667
+ 2 2.428571
+
+ >>> func(df_pl) # doctest: +NORMALIZE_WHITESPACE
+ shape: (3, 1)
+ ┌──────────┐
+ │ a │
+ │ --- │
+ │ f64 │
+ ╞══════════╡
+ │ 1.0 │
+ │ 1.666667 │
+ │ 2.428571 │
+ └──────────┘
+ """
+ from narwhals.exceptions import NarwhalsUnstableWarning
+ from narwhals.utils import find_stacklevel
+
+ msg = (
+ "`Expr.ewm_mean` is being called from the stable API although considered "
+ "an unstable feature."
+ )
+ warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel())
+ return super().ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ min_periods=min_periods,
+ ignore_nulls=ignore_nulls,
+ )
+
def rolling_sum(
self: Self,
window_size: int,
diff --git a/tests/expr_and_series/ewm_test.py b/tests/expr_and_series/ewm_test.py
new file mode 100644
index 000000000..e541a5bfe
--- /dev/null
+++ b/tests/expr_and_series/ewm_test.py
@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+import pandas as pd
+import pytest
+
+import narwhals.stable.v1 as nw
+from tests.utils import POLARS_VERSION
+from tests.utils import Constructor
+from tests.utils import ConstructorEager
+from tests.utils import assert_equal_data
+
+data = {"a": [1, 1, 2], "b": [1, 2, 3]}
+
+
+@pytest.mark.filterwarnings(
+ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
+)
+def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
+ if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
+ "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31)
+ ):
+ request.applymarker(pytest.mark.xfail)
+
+ df = nw.from_native(constructor(data))
+ result = df.select(nw.col("a", "b").ewm_mean(com=1))
+ expected = {
+ "a": [1.0, 1.0, 1.5714285714285714],
+ "b": [1.0, 1.6666666666666667, 2.4285714285714284],
+ }
+ assert_equal_data(result, expected)
+
+
+@pytest.mark.filterwarnings(
+ "ignore:`Series.ewm_mean` is being called from the stable API although considered an unstable feature."
+)
+def test_ewm_mean_series(
+ request: pytest.FixtureRequest, constructor_eager: ConstructorEager
+) -> None:
+ if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or (
+ "polars" in str(constructor_eager) and POLARS_VERSION <= (0, 20, 31)
+ ):
+ request.applymarker(pytest.mark.xfail)
+
+ series = nw.from_native(constructor_eager(data), eager_only=True)["a"]
+ result = series.ewm_mean(com=1)
+ expected = {"a": [1.0, 1.0, 1.5714285714285714]}
+ assert_equal_data({"a": result}, expected)
+
+
+@pytest.mark.filterwarnings(
+ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
+)
+@pytest.mark.parametrize(
+ ("adjust", "expected"),
+ [
+ (
+ True,
+ {
+ "a": [1.0, 1.0, 1.5714285714285714],
+ "b": [1.0, 1.6666666666666667, 2.4285714285714284],
+ },
+ ),
+ (
+ False,
+ {
+ "a": [1.0, 1.0, 1.5],
+ "b": [1.0, 1.5, 2.25],
+ },
+ ),
+ ],
+)
+def test_ewm_mean_expr_adjust(
+ request: pytest.FixtureRequest,
+ constructor: Constructor,
+ adjust: bool, # noqa: FBT001
+ expected: dict[str, list[float]],
+) -> None:
+ if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
+ "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31)
+ ):
+ request.applymarker(pytest.mark.xfail)
+
+ df = nw.from_native(constructor(data))
+ result = df.select(nw.col("a", "b").ewm_mean(com=1, adjust=adjust))
+ assert_equal_data(result, expected)
+
+
+@pytest.mark.filterwarnings(
+ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
+)
+def test_ewm_mean_dask_raise() -> None:
+ pytest.importorskip("dask")
+ pytest.importorskip("dask_expr", exc_type=ImportError)
+ import dask.dataframe as dd
+
+ df = nw.from_native(dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]})))
+ with pytest.raises(
+ NotImplementedError,
+ match="`Expr.ewm_mean` is not supported for the Dask backend",
+ ):
+ df.select(nw.col("a").ewm_mean(com=1))
+
+
+@pytest.mark.filterwarnings(
+ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
+)
+@pytest.mark.parametrize(
+ ("ignore_nulls", "expected"),
+ [
+ (
+ True,
+ {
+ "a": [
+ 2.0,
+ 3.3333333333333335,
+ None,
+ 3.142857142857143,
+ ]
+ },
+ ),
+ (
+ False,
+ {
+ "a": [
+ 2.0,
+ 3.3333333333333335,
+ None,
+ 3.090909090909091,
+ ]
+ },
+ ),
+ ],
+)
+def test_ewm_mean_nulls(
+ request: pytest.FixtureRequest,
+ ignore_nulls: bool, # noqa: FBT001
+ expected: dict[str, list[float]],
+ constructor: Constructor,
+) -> None:
+ if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
+ "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31)
+ ):
+ request.applymarker(pytest.mark.xfail)
+
+ df = nw.from_native(constructor({"a": [2.0, 4.0, None, 3.0]}))
+ result = df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=ignore_nulls))
+ assert_equal_data(result, expected)
+
+
+@pytest.mark.filterwarnings(
+ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
+)
+def test_ewm_mean_params(
+ request: pytest.FixtureRequest,
+ constructor: Constructor,
+) -> None:
+ if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
+ "polars" in str(constructor) and POLARS_VERSION <= (0, 20, 31)
+ ):
+ request.applymarker(pytest.mark.xfail)
+
+ df = nw.from_native(constructor({"a": [2, 5, 3]}))
+ expected: dict[str, list[float | None]] = {"a": [2.0, 4.0, 3.4285714285714284]}
+ assert_equal_data(
+ df.select(nw.col("a").ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True)),
+ expected,
+ )
+
+ expected = {"a": [2.0, 4.500000000000001, 3.2903225806451615]}
+ assert_equal_data(
+ df.select(nw.col("a").ewm_mean(span=1.5, adjust=True, ignore_nulls=True)),
+ expected,
+ )
+
+ expected = {"a": [2.0, 3.1101184251576903, 3.0693702609187237]}
+ assert_equal_data(
+ df.select(nw.col("a").ewm_mean(half_life=1.5, adjust=False)), expected
+ )
+
+ expected = {"a": [None, 4.0, 3.4285714285714284]}
+ assert_equal_data(
+ df.select(
+ nw.col("a").ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=True)
+ ),
+ expected,
+ )
+
+ with pytest.raises(ValueError, match="mutually exclusive"):
+ df.select(nw.col("a").ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False))