Skip to content

Commit

Permalink
feat: Adding ewm_mean (#1298)
Browse files Browse the repository at this point in the history
  • Loading branch information
DeaMariaLeon authored Nov 19, 2024
1 parent d5aa778 commit d7c1d4f
Show file tree
Hide file tree
Showing 14 changed files with 727 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- cum_sum
- diff
- drop_nulls
- ewm_mean
- fill_null
- filter
- gather_every
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- diff
- drop_nulls
- dtype
- ewm_mean
- fill_null
- filter
- gather_every
Expand Down
7 changes: 7 additions & 0 deletions docs/css/code_select.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.highlight .gp, .highlight .go { /* Generic.Prompt, Generic.Output */
user-select: none;
-webkit-user-select: none; /* Safari */
-moz-user-select: none; /* Firefox */
-ms-user-select: none; /* Internet Explorer/Edge */
color: red;
}
10 changes: 10 additions & 0 deletions docs/javascripts/katex.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
document$.subscribe(({ body }) => {
renderMathInElement(body, {
delimiters: [
{ left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false },
{ left: "\\[", right: "\\]", display: true }
],
})
})
9 changes: 9 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@ markdown_extensions:
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- pymdownx.arithmatex:
generic: true
extra_javascript:
- javascripts/katex.js
- https://unpkg.com/katex@0/dist/katex.min.js
- https://unpkg.com/katex@0/dist/contrib/auto-render.min.js

extra_css:
- https://unpkg.com/katex@0/dist/katex.min.css
14 changes: 14 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,20 @@ def round(self, decimals: int) -> Self:
returns_scalar=False,
)

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> NoReturn:
msg = "`Expr.ewm_mean` is not supported for the Dask backend"
raise NotImplementedError(msg)

def unique(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead."
Expand Down
23 changes: 23 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,29 @@ def is_in(self, other: Any) -> Self:
def arg_true(self) -> Self:
return reuse_series_implementation(self, "arg_true")

def ewm_mean(
self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
return reuse_series_implementation(
self,
"ewm_mean",
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)

def filter(self, *predicates: Any) -> Self:
plx = self.__narwhals_namespace__()
other = plx.all_horizontal(*predicates)
Expand Down
19 changes: 19 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ def dtype(self: Self) -> DType:
self._native_series, self._dtypes, self._implementation
)

def ewm_mean(
self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> PandasLikeSeries:
ser = self._native_series
mask_na = ser.isna()
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
result[mask_na] = None
return self._from_native_series(result)

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
if isinstance(values, self.__class__):
# .copy() is necessary in some pre-2.2 versions of pandas to avoid
Expand Down
30 changes: 30 additions & 0 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.utils import Implementation
from narwhals.utils import parse_version

if TYPE_CHECKING:
import polars as pl
Expand Down Expand Up @@ -49,6 +50,35 @@ def cast(self, dtype: DType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._dtypes)
return self._from_native_expr(expr.cast(dtype_pl))

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
import polars as pl # ignore-banned-import()

if parse_version(pl.__version__) <= (0, 20, 31): # pragma: no cover
msg = "`ewm_mean` not implemented for polars older than 0.20.31"
raise NotImplementedError(msg)
expr = self._native_expr
return self._from_native_expr(
expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
)

def map_batches(
self,
function: Callable[[Any], Self],
Expand Down
27 changes: 27 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,33 @@ def to_dummies(
result, backend_version=self._backend_version, dtypes=self._dtypes
)

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
if self._backend_version < (0, 20, 31): # pragma: no cover
msg = "`ewm_mean` not implemented for polars older than 0.20.31"
raise NotImplementedError(msg)
expr = self._native_series
return self._from_native_series(
expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
if self._backend_version < (0, 20, 6):
result = self._native_series.sort(descending=descending)
Expand Down
97 changes: 97 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,103 @@ def all(self) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).all())

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> Self:
r"""Compute exponentially-weighted moving average.
!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.
Arguments:
com: Specify decay in terms of center of mass, $\gamma$, with <br> $\alpha = \frac{1}{1+\gamma}\forall\gamma\geq0$
span: Specify decay in terms of span, $\theta$, with <br> $\alpha = \frac{2}{\theta + 1} \forall \theta \geq 1$
half_life: Specify decay in terms of half-life, $\tau$, with <br> $\alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \forall \tau > 0$
alpha: Specify smoothing factor alpha directly, $0 < \alpha \leq 1$.
adjust: Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
- When `adjust=True` (the default) the EW function is calculated
using weights $w_i = (1 - \alpha)^i$
- When `adjust=False` the EW function is calculated recursively by
$$
y_0=x_0
$$
$$
y_t = (1 - \alpha)y_{t - 1} + \alpha x_t
$$
min_periods: Minimum number of observations in window required to have a value, (otherwise result is null).
ignore_nulls: Ignore missing values when calculating weights.
- When `ignore_nulls=False` (default), weights are based on absolute
positions.
For example, the weights of $x_0$ and $x_2$ used in
calculating the final weighted average of $[x_0, None, x_2]$ are
$(1-\alpha)^2$ and $1$ if `adjust=True`, and
$(1-\alpha)^2$ and $\alpha$ if `adjust=False`.
- When `ignore_nulls=True`, weights are based
on relative positions. For example, the weights of
$x_0$ and $x_2$ used in calculating the final weighted
average of $[x_0, None, x_2]$ are
$1-\alpha$ and $1$ if `adjust=True`,
and $1-\alpha$ and $\alpha$ if `adjust=False`.
Returns:
Expr
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> data = {"a": [1, 2, 3]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
We define a library agnostic function:
>>> @nw.narwhalify
... def func(df):
... return df.select(nw.col("a").ewm_mean(com=1, ignore_nulls=False))
We can then pass either pandas or Polars to `func`:
>>> func(df_pd)
a
0 1.000000
1 1.666667
2 2.428571
>>> func(df_pl) # doctest: +NORMALIZE_WHITESPACE
shape: (3, 1)
┌──────────┐
│ a │
│ --- │
│ f64 │
╞══════════╡
│ 1.0 │
│ 1.666667 │
│ 2.428571 │
└──────────┘
"""
return self.__class__(
lambda plx: self._call(plx).ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
)

def mean(self) -> Self:
"""Get mean value.
Expand Down
Loading

0 comments on commit d7c1d4f

Please sign in to comment.