diff --git a/src/gluonts/ext/r_forecast/_univariate_predictor.py b/src/gluonts/ext/r_forecast/_univariate_predictor.py index eb2b2e60c5..5416a79f86 100644 --- a/src/gluonts/ext/r_forecast/_univariate_predictor.py +++ b/src/gluonts/ext/r_forecast/_univariate_predictor.py @@ -121,6 +121,7 @@ def __init__( params["intervals"] = sorted( set([level for level, _ in intervals_info]) ) + params.pop("quantiles") self.params.update(params) diff --git a/src/gluonts/model/forecast.py b/src/gluonts/model/forecast.py index 6cf9dea630..f41d001212 100644 --- a/src/gluonts/model/forecast.py +++ b/src/gluonts/model/forecast.py @@ -12,6 +12,7 @@ # permissions and limitations under the License. import re +import logging from dataclasses import field from typing import Callable, Dict, List, Optional, Union, Tuple @@ -22,6 +23,8 @@ from gluonts.core.component import validated from gluonts import maybe +logger = logging.getLogger(__name__) + def _linear_interpolation( xs: np.ndarray, ys: np.ndarray, x: float @@ -313,7 +316,7 @@ def plot( # If no color is provided, we use matplotlib's internal color cycle. # Note: This is an internal API and might change in the future. color = maybe.unwrap_or_else( - color, lambda: next(ax._get_lines.prop_cycler)["color"] + color, lambda: ax._get_lines.get_next_color() ) # Plot median forecast @@ -656,7 +659,11 @@ def mean(self) -> np.ndarray: """ if "mean" in self._forecast_dict: return self._forecast_dict["mean"] - + logger.warning( + "The mean prediction is not stored in the forecast data; " + "the median is being returned instead. " + "This behaviour may change in the future." + ) return self.quantile("p50") def dim(self) -> int: