From e4c679169c0904ced354e25d8feafdcd9b460bb5 Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Tue, 23 Apr 2024 14:49:13 +0200 Subject: [PATCH] Fix type checks --- src/gluonts/model/seasonal_agg/_predictor.py | 4 +- test/model/seasonal_agg/test_seasonal_agg.py | 345 +++++++++++++++---- 2 files changed, 288 insertions(+), 61 deletions(-) diff --git a/src/gluonts/model/seasonal_agg/_predictor.py b/src/gluonts/model/seasonal_agg/_predictor.py index 3cd3da38cf..be3fedf522 100644 --- a/src/gluonts/model/seasonal_agg/_predictor.py +++ b/src/gluonts/model/seasonal_agg/_predictor.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import Callable, Optional +from typing import Callable import numpy as np @@ -32,7 +32,7 @@ class SeasonalAggregatePredictor(RepresentablePredictor): Seasonal aggegate forecaster. For each time series :math:`y`, this predictor produces a forecast - :math:`\\tilde{y}(T+k) = f\big(y(T+k-h), y(T+k-2h), \ldots, + :math:`\\tilde{y}(T+k) = f\big(y(T+k-h), y(T+k-2h), ..., y(T+k-mh)\big)`, where :math:`T` is the forecast time, :math:`k = 0, ...,` `prediction_length - 1`, :math:`m =`num_seasons`, :math:`h =`season_length` and :math:`f =`agg_fun`. diff --git a/test/model/seasonal_agg/test_seasonal_agg.py b/test/model/seasonal_agg/test_seasonal_agg.py index a3fde1df97..321116ce07 100644 --- a/test/model/seasonal_agg/test_seasonal_agg.py +++ b/test/model/seasonal_agg/test_seasonal_agg.py @@ -52,23 +52,73 @@ def get_prediction( [ # same as seasonal naive ([1, 1, 1], [1], 1, 1, 1, np.nanmean, LastValueImputation()), - ([1, 10, 2, 20], [1.5, 15], 2, 2, 2, np.nanmean, LastValueImputation()), + ( + [1, 10, 2, 20], + [1.5, 15], + 2, + 2, + 2, + np.nanmean, + LastValueImputation(), + ), # check predictions repeat seasonally - ([1, 10, 2, 20], [1.5, 15, 1.5, 15], 4, 2, 2, np.nanmean, - LastValueImputation()), - ([1, 10, 2, 20], [1.5, 15, 1.5], 3, 2, 2, np.nanmean, - LastValueImputation()), + ( + [1, 10, 2, 20], + [1.5, 15, 1.5, 15], + 4, + 2, + 2, + np.nanmean, + LastValueImputation(), + ), + ( + [1, 10, 2, 20], + [1.5, 15, 1.5], + 3, + 2, + 2, + np.nanmean, + LastValueImputation(), + ), # check `nanmedian` - ([1, 10, 2, 20, 3, 30], [2, 20, 2, 20], 4, 2, 3, np.nanmedian, - LastValueImputation()), - ([1, 10, 2, 20, 3, 30], [2, 20, 2], 3, 2, 3, np.nanmedian, - LastValueImputation()), + ( + [1, 10, 2, 20, 3, 30], + [2, 20, 2, 20], + 4, + 2, + 3, + np.nanmedian, + LastValueImputation(), + ), + ( + [1, 10, 2, 20, 3, 30], + [2, 20, 2], + 3, + 2, + 3, + np.nanmedian, + LastValueImputation(), + ), # check `nanmax` - ([1, 10, 2, 20, 3, 30], [3, 30, 3, 30], 4, 2, 3, np.nanmax, - LastValueImputation()), + ( + [1, 10, 2, 20, 3, 30], + [3, 30, 3, 30], + 4, + 2, + 3, + np.nanmax, + LastValueImputation(), + ), # check `nanmin` - ([1, 10, 2, 20, 3, 30], [1, 10, 1, 10], 4, 2, 3, np.nanmin, - LastValueImputation()), + ( + [1, 10, 2, 20, 3, 30], + [1, 10, 1, 10], + 4, + 2, + 3, + np.nanmin, + LastValueImputation(), + ), # data is shorter than season length ([1, 2, 3], [2], 1, 4, 1, np.nanmean, LastValueImputation()), ([10, 1, 100], [10], 1, 4, 1, np.nanmedian, LastValueImputation()), @@ -76,39 +126,123 @@ def get_prediction( ([10, 1, 100], [1], 1, 4, 1, np.nanmin, LastValueImputation()), # data not available for all seasons ([1, 2, 3, 4, 5], [3] * 4, 4, 4, 2, np.nanmean, LastValueImputation()), - ([10, 20, 40, 50, 21], [21] * 4, 4, 4, 2, np.nanmedian, - LastValueImputation()), - ([10, 20, 40, 50, 21], [50] * 4, 4, 4, 2, np.nanmax, - LastValueImputation()), - ([10, 20, 40, 50, 21], [10] * 4, 4, 4, 2, np.nanmin, - LastValueImputation()), + ( + [10, 20, 40, 50, 21], + [21] * 4, + 4, + 4, + 2, + np.nanmedian, + LastValueImputation(), + ), + ( + [10, 20, 40, 50, 21], + [50] * 4, + 4, + 4, + 2, + np.nanmax, + LastValueImputation(), + ), + ( + [10, 20, 40, 50, 21], + [10] * 4, + 4, + 4, + 2, + np.nanmin, + LastValueImputation(), + ), # missing values with imputation ([np.nan], [0], 1, 1, 2, np.nanmean, LastValueImputation()), ([np.nan], [0], 1, 1, 2, np.nanmedian, LastValueImputation()), ([1, 4, np.nan], [3], 1, 3, 2, np.nanmean, LastValueImputation()), ([1, 4, np.nan], [4], 1, 3, 2, np.nanmedian, LastValueImputation()), - ([1, 10, np.nan, 1, 10, np.nan], [1, 10, 10], 3, 3, 2, np.nanmean, - LastValueImputation()), - ([1, 10, np.nan, 1, 10, np.nan], [1, 10, 10], 3, 3, 2, np.nanmedian, - LastValueImputation()), - ([1, 10, np.nan, 1, 10, np.nan], [1, 10, 10, 1, 10], 5, 3, 2, - np.nanmax, LastValueImputation()), - ([1, 10, np.nan, 1, 10, np.nan], [1, 10, 10, 1, 10], 5, 3, 2, - np.nanmin, LastValueImputation()), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10], + 3, + 3, + 2, + np.nanmean, + LastValueImputation(), + ), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10], + 3, + 3, + 2, + np.nanmedian, + LastValueImputation(), + ), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10, 1, 10], + 5, + 3, + 2, + np.nanmax, + LastValueImputation(), + ), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10, 1, 10], + 5, + 3, + 2, + np.nanmin, + LastValueImputation(), + ), # missing values without imputation ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmean, LeavesMissingValues()), - ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmedian, - LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan], + 1, + 1, + 1, + np.nanmedian, + LeavesMissingValues(), + ), ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmax, LeavesMissingValues()), ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmin, LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 1, np.nanmean, - LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 1, np.nanmedian, - LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 1, np.nanmax, - LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 1, np.nanmin, - LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmean, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmedian, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmax, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmin, + LeavesMissingValues(), + ), ([1, 3, np.nan], [3], 1, 1, 2, np.nanmean, LeavesMissingValues()), ([1, 3, np.nan], [3], 1, 1, 2, np.nanmedian, LeavesMissingValues()), ([1, 3, np.nan], [3], 1, 1, 2, np.nanmax, LeavesMissingValues()), @@ -117,24 +251,83 @@ def get_prediction( ([1, 3, np.nan], [3], 1, 2, 1, np.nanmedian, LeavesMissingValues()), ([1, 3, np.nan], [3], 1, 2, 1, np.nanmax, LeavesMissingValues()), ([1, 3, np.nan], [3], 1, 2, 1, np.nanmin, LeavesMissingValues()), - ([1, 3, np.nan], [3, np.nan], 2, 2, 1, np.nanmean, - LeavesMissingValues()), - ([1, 3, np.nan], [3, np.nan], 2, 2, 1, np.nanmedian, - LeavesMissingValues()), - ([1, 3, np.nan], [3, np.nan], 2, 2, 1, np.nanmax, - LeavesMissingValues()), - ([1, 3, np.nan], [3, np.nan], 2, 2, 1, np.nanmin, - LeavesMissingValues()), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmean, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmedian, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmax, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmin, + LeavesMissingValues(), + ), # check if `nanmean` works when some seasons have missing values ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmean, LeavesMissingValues()), - ([1, 3, np.nan], [3, 3, 3], 3, 1, 2, np.nanmean, LeavesMissingValues()), + ( + [1, 3, np.nan], + [3, 3, 3], + 3, + 1, + 2, + np.nanmean, + LeavesMissingValues(), + ), # check if `mean` works when some seasons have missing values - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 2, np.mean, LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 3, 3, 1, 2, np.mean, LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 2, + np.mean, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 3, + 3, + 1, + 2, + np.mean, + LeavesMissingValues(), + ), # check if `nanmedian` works when some seasons have missing values ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmedian, LeavesMissingValues()), - ([1, 3, np.nan], [3, 3, 3], 3, 1, 2, np.nanmedian, - LeavesMissingValues()), + ( + [1, 3, np.nan], + [3, 3, 3], + 3, + 1, + 2, + np.nanmedian, + LeavesMissingValues(), + ), # check if `nanmax` works when some seasons have missing values ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmax, LeavesMissingValues()), ([1, 3, np.nan], [3, 3, 3], 3, 1, 2, np.nanmax, LeavesMissingValues()), @@ -142,14 +335,43 @@ def get_prediction( ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmin, LeavesMissingValues()), ([1, 3, np.nan], [3, 3, 3], 3, 1, 2, np.nanmin, LeavesMissingValues()), # check if `mean` works when some seasons have missing values - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 2, np.median, - LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 3, 3, 1, 2, np.median, - LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 2, + np.median, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 3, + 3, + 1, + 2, + np.median, + LeavesMissingValues(), + ), # check if `median` works when some seasons have missing values - ([1, 3, np.nan], [np.nan] * 2, 2, 1, 2, np.median, LeavesMissingValues()), - ([1, 3, np.nan], [np.nan] * 3, 3, 1, 2, np.median, - LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 2, + np.median, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 3, + 3, + 1, + 2, + np.median, + LeavesMissingValues(), + ), # check if `max` works when some seasons have missing values ([1, 3, np.nan], [np.nan] * 2, 2, 1, 2, np.max, LeavesMissingValues()), ([1, 3, np.nan], [np.nan] * 3, 3, 1, 2, np.max, LeavesMissingValues()), @@ -159,8 +381,13 @@ def get_prediction( ], ) def test_predictor( - data, expected_output, prediction_length, season_length, - num_seasons, agg_fun, imputation_method + data, + expected_output, + prediction_length, + season_length, + num_seasons, + agg_fun, + imputation_method, ): prediction = get_prediction( data,