From dd8449f59de8f06d5b780141b527b183c605286c Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 14 Nov 2023 10:49:14 +0100 Subject: [PATCH] Backports v0.13.8 (#3054) * Refactor tests for `ev.aggregations` (#3038) * Fix edge cases in metric computation (#3037) * Rotbaum: Add item-id to forecast. (#3049) * Fix mypy checks (#3052) * add init --------- Co-authored-by: Jasper --- src/gluonts/ev/aggregations.py | 16 ++- src/gluonts/ev/metrics.py | 6 +- src/gluonts/ext/rotbaum/_predictor.py | 4 +- src/gluonts/model/evaluation.py | 2 +- src/gluonts/mx/model/deepstate/issm.py | 4 +- test/ev/test_aggregations.py | 145 +++++++++++++-------- test/ev/test_metrics.py | 173 +++++++++++++++++++++++++ test/evaluation/__init__.py | 12 ++ 8 files changed, 295 insertions(+), 67 deletions(-) create mode 100644 test/ev/test_metrics.py create mode 100644 test/evaluation/__init__.py diff --git a/src/gluonts/ev/aggregations.py b/src/gluonts/ev/aggregations.py index c56753c854..098dfa02a4 100644 --- a/src/gluonts/ev/aggregations.py +++ b/src/gluonts/ev/aggregations.py @@ -48,7 +48,9 @@ class Sum(Aggregation): partial_result: Optional[Union[List[np.ndarray], np.ndarray]] = None def step(self, values: np.ndarray) -> None: - summed_values = np.ma.sum(values, axis=self.axis) + assert self.axis is None or isinstance(self.axis, tuple) + + summed_values = np.nansum(values, axis=self.axis) if self.axis is None or 0 in self.axis: if self.partial_result is None: @@ -61,9 +63,11 @@ def step(self, values: np.ndarray) -> None: def get(self) -> np.ndarray: if self.axis is None or 0 in self.axis: - return np.ma.copy(self.partial_result) + assert isinstance(self.partial_result, np.ndarray) + return np.copy(self.partial_result) - return np.ma.concatenate(self.partial_result) + assert isinstance(self.partial_result, list) + return np.concatenate(self.partial_result) @dataclass @@ -100,11 +104,13 @@ def step(self, values: np.ndarray) -> None: if self.partial_result is None: self.partial_result = [] - mean_values = np.ma.mean(values, axis=self.axis) + mean_values = np.nanmean(values, axis=self.axis) + assert isinstance(self.partial_result, list) self.partial_result.append(mean_values) def get(self) -> np.ndarray: if self.axis is None or 0 in self.axis: return self.partial_result / self.n - return np.ma.concatenate(self.partial_result) + assert isinstance(self.partial_result, list) + return np.concatenate(self.partial_result) diff --git a/src/gluonts/ev/metrics.py b/src/gluonts/ev/metrics.py index fd9483b159..34deab2cd8 100644 --- a/src/gluonts/ev/metrics.py +++ b/src/gluonts/ev/metrics.py @@ -284,7 +284,7 @@ def mean(**quantile_losses: np.ndarray) -> np.ndarray: [quantile_loss for quantile_loss in quantile_losses.values()], axis=0, ) - return np.ma.mean(stacked_quantile_losses, axis=0) + return np.mean(stacked_quantile_losses, axis=0) def __call__(self, axis: Optional[int] = None) -> DerivedEvaluator: return DerivedEvaluator( @@ -307,7 +307,7 @@ def mean(**quantile_losses: np.ndarray) -> np.ndarray: [quantile_loss for quantile_loss in quantile_losses.values()], axis=0, ) - return np.ma.mean(stacked_quantile_losses, axis=0) + return np.mean(stacked_quantile_losses, axis=0) def __call__(self, axis: Optional[int] = None) -> DerivedEvaluator: return DerivedEvaluator( @@ -332,7 +332,7 @@ def mean( [np.abs(coverages[f"coverage[{q}]"] - q) for q in quantile_levels], axis=0, ) - return np.ma.mean(intermediate_result, axis=0) + return np.mean(intermediate_result, axis=0) def __call__(self, axis: Optional[int] = None) -> DerivedEvaluator: return DerivedEvaluator( diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index 860a7e564a..482f4adbd1 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -50,12 +50,13 @@ def __init__( featurized_data: List, start_date: pd.Period, prediction_length: int, + item_id: Optional[str] = None, ): self.models = models self.featurized_data = featurized_data self.start_date = start_date self.prediction_length = prediction_length - self.item_id = None + self.item_id = item_id self.lead_time = None def quantile(self, q: float) -> np.ndarray: @@ -333,6 +334,7 @@ def predict( [featurized_data], start_date=forecast_start(ts), prediction_length=self.prediction_length, + item_id=ts.get("item_id"), ) def explain( diff --git a/src/gluonts/model/evaluation.py b/src/gluonts/model/evaluation.py index 5b6269114e..36f0f22ef6 100644 --- a/src/gluonts/model/evaluation.py +++ b/src/gluonts/model/evaluation.py @@ -203,7 +203,7 @@ def evaluate_forecasts( ) if index0 is not None: index0_repeated = np.take(index0, indices=index_arrays[0], axis=0) - index_arrays = (*zip(*index0_repeated), *index_arrays[1:]) + index_arrays = (*zip(*index0_repeated), *index_arrays[1:]) # type: ignore index = pd.MultiIndex.from_arrays(index_arrays) flattened_metrics = valmap(np.ravel, metrics_values) diff --git a/src/gluonts/mx/model/deepstate/issm.py b/src/gluonts/mx/model/deepstate/issm.py index 121c7bd6df..1078959978 100644 --- a/src/gluonts/mx/model/deepstate/issm.py +++ b/src/gluonts/mx/model/deepstate/issm.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import List, Tuple +from typing import List, Sequence, Tuple from pandas.tseries.frequencies import to_offset @@ -31,7 +31,7 @@ ) -def _make_block_diagonal(blocks: List[Tensor]) -> Tensor: +def _make_block_diagonal(blocks: Sequence[Tensor]) -> Tensor: assert ( len(blocks) > 0 ), "You need at least one tensor to make a block-diagonal tensor" diff --git a/test/ev/test_aggregations.py b/test/ev/test_aggregations.py index 4fe636562f..c3969efd53 100644 --- a/test/ev/test_aggregations.py +++ b/test/ev/test_aggregations.py @@ -18,61 +18,53 @@ from gluonts.ev import Mean, Sum from gluonts.itertools import power_set -VALUE_STREAM = [ - [ - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - ], - [ - np.array([[0, np.nan], [0, 0]]), - np.array([[0, 5], [-5, np.nan]]), - ], - [ - np.full(shape=(3, 3), fill_value=1), - np.full(shape=(1, 3), fill_value=4), - ], -] - -SUM_RES_AXIS_NONE = [ - 0, - 0, - 21, -] - -SUM_RES_AXIS_0 = [ - np.zeros(5), - np.array([-5, 5]), - np.array([7, 7, 7]), -] -SUM_RES_AXIS_1 = [ - np.zeros(9), - np.array([0, 0, 5, -5]), - np.array([3, 3, 3, 12]), -] - - -MEAN_RES_AXIS_NONE = [ - np.nan, - 0, - 1.75, -] - -MEAN_RES_AXIS_0 = [ - np.full(5, np.nan), - np.array([-1.25, 2.5]), - np.array([1.75, 1.75, 1.75]), -] -MEAN_RES_AXIS_1 = [ - np.full(9, np.nan), - np.array([0, 0, 2.5, -5]), - np.array([1, 1, 1, 4]), -] - @pytest.mark.parametrize( "value_stream, res_axis_none, res_axis_0, res_axis_1", - zip(VALUE_STREAM, SUM_RES_AXIS_NONE, SUM_RES_AXIS_0, SUM_RES_AXIS_1), + [ + ( + [ + np.full((3, 5), 0.0), + np.full((3, 5), 0.0), + np.full((3, 5), 0.0), + ], + 0.0, + np.zeros(5), + np.zeros(9), + ), + ( + np.ma.masked_invalid( + [ + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + ] + ), + 0, + np.zeros(5), + np.zeros(9), + ), + ( + np.ma.masked_invalid( + [ + np.array([[0, np.nan], [0, 0]]), + np.array([[0, 5], [-5, np.nan]]), + ] + ), + 0, + np.array([-5, 5]), + np.array([0, 0, 5, -5]), + ), + ( + [ + np.full(shape=(3, 3), fill_value=1), + np.full(shape=(1, 3), fill_value=4), + ], + 21, + np.array([7, 7, 7]), + np.array([3, 3, 3, 12]), + ), + ], ) def test_Sum(value_stream, res_axis_none, res_axis_0, res_axis_1): for axis, expected_result in zip( @@ -80,14 +72,57 @@ def test_Sum(value_stream, res_axis_none, res_axis_0, res_axis_1): ): sum = Sum(axis=axis) for values in value_stream: - sum.step(np.ma.masked_invalid(values)) + sum.step(values) np.testing.assert_almost_equal(sum.get(), expected_result) @pytest.mark.parametrize( "value_stream, res_axis_none, res_axis_0, res_axis_1", - zip(VALUE_STREAM, MEAN_RES_AXIS_NONE, MEAN_RES_AXIS_0, MEAN_RES_AXIS_1), + [ + ( + [ + np.full((3, 5), 0.0), + np.full((3, 5), 0.0), + np.full((3, 5), 0.0), + ], + 0.0, + np.zeros(5), + np.zeros(9), + ), + ( + np.ma.masked_invalid( + [ + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + ] + ), + np.nan, + np.full(5, np.nan), + np.full(9, np.nan), + ), + ( + np.ma.masked_invalid( + [ + np.array([[0, np.nan], [0, 0]]), + np.array([[0, 5], [-5, np.nan]]), + ] + ), + 0, + np.array([-1.25, 2.5]), + np.array([0, 0, 2.5, -5]), + ), + ( + [ + np.full(shape=(3, 3), fill_value=1), + np.full(shape=(1, 3), fill_value=4), + ], + 1.75, + np.array([1.75, 1.75, 1.75]), + np.array([1, 1, 1, 4]), + ), + ], ) def test_Mean(value_stream, res_axis_none, res_axis_0, res_axis_1): for axis, expected_result in zip( @@ -95,7 +130,7 @@ def test_Mean(value_stream, res_axis_none, res_axis_0, res_axis_1): ): mean = Mean(axis=axis) for values in value_stream: - mean.step(np.ma.masked_invalid(values)) + mean.step(values) np.testing.assert_almost_equal(mean.get(), expected_result) diff --git a/test/ev/test_metrics.py b/test/ev/test_metrics.py new file mode 100644 index 0000000000..aa38452560 --- /dev/null +++ b/test/ev/test_metrics.py @@ -0,0 +1,173 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Optional + +import numpy as np +import pytest + +from gluonts.ev.metrics import ( + Coverage, + MAECoverage, + MSE, + MAPE, + SMAPE, + MASE, + WeightedSumQuantileLoss, + MeanWeightedSumQuantileLoss, + ND, + RMSE, + NRMSE, +) +from gluonts.ev.ts_stats import seasonal_error + + +METRICS = [ + Coverage(0.5), + MAECoverage([0.1, 0.5, 0.9]), + MSE(), + MAPE(), + SMAPE(), + MASE(), + WeightedSumQuantileLoss(0.5), + MeanWeightedSumQuantileLoss([0.1, 0.5, 0.9]), + ND(), + RMSE(), + NRMSE(), +] + + +@pytest.mark.parametrize( + "metric", + METRICS, +) +@pytest.mark.parametrize("axis", [None, (0, 1), (0,), (1,), ()]) +def test_metric_shape(metric, axis: Optional[tuple]): + input_length = 20 + label_length = 5 + num_entries = 7 + + data = [ + { + "input": np.random.normal(size=(1, input_length)), + "label": np.random.normal(size=(1, label_length)), + "0.1": np.random.normal(size=(1, label_length)), + "0.5": np.random.normal(size=(1, label_length)), + "0.9": np.random.normal(size=(1, label_length)), + "mean": np.random.normal(size=(1, label_length)), + } + for _ in range(num_entries) + ] + + for entry in data: + entry["seasonal_error"] = seasonal_error( + entry["input"], seasonality=1, time_axis=1 + ) + + evaluator = metric(axis=axis) + for entry in data: + evaluator.update(entry) + metric_value = evaluator.get() + + if axis is None or axis == (0, 1): + assert isinstance(metric_value, float) + elif axis == (0,): + assert isinstance(metric_value, np.ndarray) + assert metric_value.shape == (label_length,) + elif axis == (1,): + assert isinstance(metric_value, np.ndarray) + assert metric_value.shape == (num_entries,) + elif axis == (): + assert isinstance(metric_value, np.ndarray) + assert metric_value.shape == (num_entries, label_length) + else: + raise ValueError("unsupported axis") + + return metric_value + + +@pytest.mark.parametrize( + "metric", + [ + ND(), + MASE(), + MAPE(), + NRMSE(), + WeightedSumQuantileLoss(0.5), + MeanWeightedSumQuantileLoss([0.1, 0.5, 0.9]), + ], +) +@pytest.mark.parametrize("axis", [None, (0, 1), (0,), (1,), ()]) +def test_metric_inf(metric, axis: Optional[tuple]): + time_series_length = 3 + number_of_entries = 2 + + data = { + "label": np.zeros((1, time_series_length)), + "0.5": np.ones((1, time_series_length)), + "0.1": np.ones((1, time_series_length)), + "0.9": np.ones((1, time_series_length)), + "mean": np.ones((1, time_series_length)), + "seasonal_error": 0.0, + } + + evaluator = metric(axis=axis) + for _ in range(number_of_entries): + evaluator.update(data) + + result = evaluator.get() + expected = np.full((number_of_entries, time_series_length), np.inf).sum( + axis=axis + ) + + assert result.shape == expected.shape + assert np.allclose(result, expected) + + +@pytest.mark.parametrize( + "metric", + [ + ND(), + MASE(), + MAPE(), + SMAPE(), + NRMSE(), + WeightedSumQuantileLoss(0.5), + MeanWeightedSumQuantileLoss([0.1, 0.5, 0.9]), + ], +) +@pytest.mark.parametrize("axis", [None, (0, 1), (0,), (1,), ()]) +def test_metric_nan(metric, axis: Optional[tuple]): + time_series_length = 3 + number_of_entries = 2 + + data = { + "label": np.zeros((1, time_series_length)), + "0.5": np.zeros((1, time_series_length)), + "0.1": np.zeros((1, time_series_length)), + "0.9": np.zeros((1, time_series_length)), + "mean": np.zeros((1, time_series_length)), + "seasonal_error": 0.0, + } + + evaluator = metric(axis=axis) + for _ in range(number_of_entries): + evaluator.update(data) + + result = evaluator.get() + expected = np.full((number_of_entries, time_series_length), np.nan).sum( + axis=axis + ) + + assert result.shape == expected.shape + assert np.allclose(result, expected, equal_nan=True) diff --git a/test/evaluation/__init__.py b/test/evaluation/__init__.py new file mode 100644 index 0000000000..f342912f9b --- /dev/null +++ b/test/evaluation/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License.