From 0ea35e9473fcfda6b60ac3b155fcf87fc1aa1c02 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 6 Dec 2023 21:24:35 +0100 Subject: [PATCH] Fix Rotbaum to handle short series (#3073) --- src/gluonts/ext/rotbaum/_preprocess.py | 9 ++-- test/ext/rotbaum/test_rotbaum_smoke.py | 69 +++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index b068730829..6baee5be68 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -452,9 +452,12 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: end_index = starting_index + self.context_window_size if starting_index < 0: prefix = [None] * abs(starting_index) + time_series_window = time_series["target"] else: prefix = [] - time_series_window = time_series["target"][starting_index:end_index] + time_series_window = time_series["target"][ + starting_index:end_index + ] only_lag_features, transform_dict = self._pre_transform( time_series_window, self.subtract_mean, self.count_nans ) @@ -480,10 +483,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: list( chain( *[ - list(ent[0]) + list(ent[1].values()) + prefix + list(ent[0]) + list(ent[1].values()) for ent in [ self._pre_transform( - ts[starting_index:end_index], + ts if prefix else ts[starting_index:end_index], self.subtract_mean, self.count_nans, ) diff --git a/test/ext/rotbaum/test_rotbaum_smoke.py b/test/ext/rotbaum/test_rotbaum_smoke.py index 2634644660..93d1e96dd5 100644 --- a/test/ext/rotbaum/test_rotbaum_smoke.py +++ b/test/ext/rotbaum/test_rotbaum_smoke.py @@ -12,10 +12,12 @@ # permissions and limitations under the License. import pytest +import numpy as np -from gluonts.ext.rotbaum import TreeEstimator +from gluonts.ext.rotbaum import TreeEstimator, TreePredictor from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features +from gluonts.dataset.common import ListDataset # TODO: Add support for categorical and dynamic features. @@ -59,3 +61,68 @@ def test_rotbaum_smoke(datasets): predictor = estimator.train(dataset_train) forecasts = list(predictor.predict(dataset_test)) assert len(forecasts) == len(dataset_test) + + +def test_short_history_item_pred(): + prediction_length = 7 + freq = "D" + + dataset = ListDataset( + data_iter=[ + { + "start": "2017-10-11", + "item_id": "item_1", + "target": np.array( + [ + 1.0, + 9.0, + 2.0, + 0.0, + 0.0, + 1.0, + 5.0, + 3.0, + 4.0, + 2.0, + 0.0, + 0.0, + 1.0, + 6.0, + ] + ), + "feat_static_cat": np.array([0.0, 0.0], dtype=float), + "past_feat_dynamic_real": np.array( + [ + [1.0222e06 for i in range(14)], + [750.0 for i in range(14)], + ] + ), + }, + { + "start": "2017-10-11", + "item_id": "item_2", + "target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]), + "feat_static_cat": np.array([0.0, 1.0], dtype=float), + "past_feat_dynamic_real": np.array( + [[0 for i in range(5)], [750.0 for i in range(5)]] + ), + }, + ], + freq=freq, + ) + + predictor = TreePredictor( + freq=freq, + prediction_length=prediction_length, + quantiles=[0.1, 0.5, 0.9], + max_n_datapts=50000, + method="QuantileRegression", + use_past_feat_dynamic_real=True, + use_feat_dynamic_real=False, + use_feat_dynamic_cat=False, + use_feat_static_real=False, + cardinality="auto", + ) + predictor = predictor.train(dataset) + forecasts = list(predictor.predict(dataset)) + assert forecasts[1].quantile(0.5).shape[0] == prediction_length