Skip to content

Commit

Permalink
Fix Rotbaum to handle short series (awslabs#3073)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Dec 6, 2023
1 parent 471cebf commit 0ea35e9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/gluonts/ext/rotbaum/_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,12 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
end_index = starting_index + self.context_window_size
if starting_index < 0:
prefix = [None] * abs(starting_index)
time_series_window = time_series["target"]
else:
prefix = []
time_series_window = time_series["target"][starting_index:end_index]
time_series_window = time_series["target"][
starting_index:end_index
]
only_lag_features, transform_dict = self._pre_transform(
time_series_window, self.subtract_mean, self.count_nans
)
Expand All @@ -480,10 +483,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
list(
chain(
*[
list(ent[0]) + list(ent[1].values())
prefix + list(ent[0]) + list(ent[1].values())
for ent in [
self._pre_transform(
ts[starting_index:end_index],
ts if prefix else ts[starting_index:end_index],
self.subtract_mean,
self.count_nans,
)
Expand Down
69 changes: 68 additions & 1 deletion test/ext/rotbaum/test_rotbaum_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# permissions and limitations under the License.

import pytest
import numpy as np

from gluonts.ext.rotbaum import TreeEstimator
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor

from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features
from gluonts.dataset.common import ListDataset

# TODO: Add support for categorical and dynamic features.

Expand Down Expand Up @@ -59,3 +61,68 @@ def test_rotbaum_smoke(datasets):
predictor = estimator.train(dataset_train)
forecasts = list(predictor.predict(dataset_test))
assert len(forecasts) == len(dataset_test)


def test_short_history_item_pred():
prediction_length = 7
freq = "D"

dataset = ListDataset(
data_iter=[
{
"start": "2017-10-11",
"item_id": "item_1",
"target": np.array(
[
1.0,
9.0,
2.0,
0.0,
0.0,
1.0,
5.0,
3.0,
4.0,
2.0,
0.0,
0.0,
1.0,
6.0,
]
),
"feat_static_cat": np.array([0.0, 0.0], dtype=float),
"past_feat_dynamic_real": np.array(
[
[1.0222e06 for i in range(14)],
[750.0 for i in range(14)],
]
),
},
{
"start": "2017-10-11",
"item_id": "item_2",
"target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]),
"feat_static_cat": np.array([0.0, 1.0], dtype=float),
"past_feat_dynamic_real": np.array(
[[0 for i in range(5)], [750.0 for i in range(5)]]
),
},
],
freq=freq,
)

predictor = TreePredictor(
freq=freq,
prediction_length=prediction_length,
quantiles=[0.1, 0.5, 0.9],
max_n_datapts=50000,
method="QuantileRegression",
use_past_feat_dynamic_real=True,
use_feat_dynamic_real=False,
use_feat_dynamic_cat=False,
use_feat_static_real=False,
cardinality="auto",
)
predictor = predictor.train(dataset)
forecasts = list(predictor.predict(dataset))
assert forecasts[1].quantile(0.5).shape[0] == prediction_length

0 comments on commit 0ea35e9

Please sign in to comment.