diff --git a/src/gluonts/model/seasonal_agg/_predictor.py b/src/gluonts/model/seasonal_agg/_predictor.py index c0ee4d0551..3cd3da38cf 100644 --- a/src/gluonts/model/seasonal_agg/_predictor.py +++ b/src/gluonts/model/seasonal_agg/_predictor.py @@ -29,16 +29,18 @@ class SeasonalAggregatePredictor(RepresentablePredictor): """ - Seasonal average forecaster. + Seasonal aggegate forecaster. For each time series :math:`y`, this predictor produces a forecast - :math:`\\tilde{y}(T+k) = y(T+k-h)`, where :math:`T` is the forecast time, - :math:`k = 0, ...,` `prediction_length - 1`, and :math:`h =` - `season_length`. + :math:`\\tilde{y}(T+k) = f\big(y(T+k-h), y(T+k-2h), \ldots, + y(T+k-mh)\big)`, where :math:`T` is the forecast time, + :math:`k = 0, ...,` `prediction_length - 1`, :math:`m =`num_seasons`, + :math:`h =`season_length` and :math:`f =`agg_fun`. - If `prediction_length > season_length`, then the season is repeated - multiple times. If a time series is shorter than season_length, then the - mean observed value is used as prediction. + If `prediction_length > season_length` :math:\times `num_seasons`, then the + seasonal aggregate is repeated multiple times. If a time series is shorter + than season_length` :math:\times `num_seasons`, then the `agg_fun` is + applied to the full time series. Parameters ---------- @@ -46,6 +48,10 @@ class SeasonalAggregatePredictor(RepresentablePredictor): Number of time points to predict. season_length Seasonality used to make predictions. + num_seasons + Number of seasons to aggregate. + agg_fun + Aggregate function. imputation_method The imputation method to use in case of missing values. Defaults to :py:class:`LastValueImputation` which replaces each missing @@ -107,7 +113,6 @@ def predict_item(self, item: DataEntry) -> Forecast: ] for j in range(self.num_seasons) ] - print(indices) samples = self.agg_fun(target[indices], axis=0).reshape( (1, self.prediction_length) )