Skip to content

Commit

Permalink
feat: Expr.rolling_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Oct 30, 2024
1 parent 136889e commit 92cfd45
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 1 deletion.
17 changes: 17 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,23 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
dtypes=self._dtypes,
)

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_mean",
window_size=window_size,
weights=weights,
min_periods=min_periods,
center=center,
)

def mode(self: Self) -> Self:
return reuse_series_implementation(self, "mode")

Expand Down
33 changes: 33 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,39 @@ def func(_input: Any, dtype: DType | type[DType]) -> Any:
returns_scalar=False,
)

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None,
*,
min_periods: int | None,
center: bool,
) -> Self:
if weights is not None:
msg = (
"`weights` argument is not supported in `rolling_mean` for Dask backend."
)
raise NotImplementedError(msg)

def func(
_input: dask_expr.Series,
_window: int,
_min_periods: int | None,
_center: bool, # noqa: FBT001
) -> dask_expr.Series:
return _input.rolling(
window=_window, min_periods=_min_periods, center=_center
).mean()

return self._from_call(
func,
"rolling_mean",
window_size,
min_periods,
center,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
17 changes: 17 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,23 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self:
def mode(self: Self) -> Self:
return reuse_series_implementation(self, "mode")

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_mean",
window_size=window_size,
weights=weights,
min_periods=min_periods,
center=center,
)

@property
def str(self: Self) -> PandasLikeExprStringNamespace:
return PandasLikeExprStringNamespace(self)
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def rolling_mean(
) -> Self:
if weights is not None:
msg = (
f"`weights` argument is not supported for {self._implementation} backend"
"`weights` argument is not supported in `rolling_meanr` for "
f"{self._implementation} backend."
)
raise NotImplementedError(msg)

Expand Down
81 changes: 81 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,87 @@ def mode(self: Self) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).mode())

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None = None,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""
Apply a rolling mean (moving mean) over the values.
A window of length `window_size` will traverse the values. The values that fill
this window will (optionally) be multiplied with the weights given by the
`weight` vector. The resulting values will be aggregated to their mean.
The window at a given row will include the row itself and the `window_size - 1`
elements before it.
Arguments:
window_size: The length of the window in number of elements.
weights: An optional slice with the same length as the window that will be
multiplied elementwise with the values in the window.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`.
center: Set the labels at the center of the window.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {"a": [100, 200, 300]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
>>> df_pa = pa.table(data)
We define a library agnostic function:
>>> @nw.narwhalify
... def func(df):
... return df.with_columns(b=nw.col("a").rolling_mean(window_size=2))
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
>>> func(df_pd)
a b
0 100 NaN
1 200 150.0
2 300 250.0
>>> func(df_pl)
shape: (3, 2)
┌─────┬───────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ f64 │
╞═════╪═══════╡
│ 100 ┆ null │
│ 200 ┆ 150.0 │
│ 300 ┆ 250.0 │
└─────┴───────┘
>>> func(df_pa) # doctest:+ELLIPSIS
pyarrow.Table
a: int64
b: double
----
a: [[100,200,300]]
b: [[null,150,250]]
"""
return self.__class__(
lambda plx: self._call(plx).rolling_mean(
window_size=window_size,
weights=weights,
min_periods=min_periods,
center=center,
)
)

@property
def str(self: Self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down

0 comments on commit 92cfd45

Please sign in to comment.