From 9be6dfd92c1cbee364d884c937d179a85c92ce7d Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Thu, 12 Oct 2023 14:37:17 +0200 Subject: [PATCH] Add flag to toggle log scale feature for WaveNet model (#3012) --- src/gluonts/torch/model/wavenet/estimator.py | 7 +++++++ src/gluonts/torch/model/wavenet/module.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index aa3db218e5..ab9fc9db54 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -85,6 +85,7 @@ def __init__( cardinality: List[int] = [1], seasonality: Optional[int] = None, embedding_dimension: int = 5, + use_log_scale_feature: bool = True, time_features: Optional[List[TimeFeature]] = None, lr: float = 1e-3, weight_decay: float = 1e-8, @@ -140,6 +141,10 @@ def __init__( embedding_dimension, optional The dimension of the embeddings for categorical features, by default 5 + use_log_scale_feature, optional + If True, logarithm of the scale of the past data will be used as an + additional static feature, + by default True time_features, optional List of time features, from :py:mod:`gluonts.time_feature`, by default None @@ -187,6 +192,7 @@ def __init__( self.num_residual_channels = num_residual_channels self.num_skip_channels = num_skip_channels self.num_stacks = num_stacks + self.use_log_scale_feature = use_log_scale_feature self.time_features = unwrap_or( time_features, time_features_from_frequency_str(freq) ) @@ -382,6 +388,7 @@ def create_lightning_module(self) -> pl.LightningModule: pred_length=self.prediction_length, num_parallel_samples=self.num_parallel_samples, temperature=self.temperature, + use_log_scale_feature=self.use_log_scale_feature, ), ) diff --git a/src/gluonts/torch/model/wavenet/module.py b/src/gluonts/torch/model/wavenet/module.py index 10404ccf79..c34c5203ab 100644 --- a/src/gluonts/torch/model/wavenet/module.py +++ b/src/gluonts/torch/model/wavenet/module.py @@ -130,6 +130,7 @@ def __init__( embedding_dimension: int = 5, num_parallel_samples: int = 100, temperature: float = 1.0, + use_log_scale_feature: bool = True, ): super().__init__() @@ -141,8 +142,9 @@ def __init__( embedding_dimension * len(cardinality) + num_feat_dynamic_real + num_feat_static_real - + 1 # the log(scale) + + int(use_log_scale_feature) # the log(scale) ) + self.use_log_scale_feature = use_log_scale_feature # 1 extra bin to accounts for extreme values self.n_bins = len(bin_values) + 1 @@ -249,8 +251,11 @@ def get_full_features( network. Shape: (batch_size, num_features, receptive_field + pred_length) """ - embedded_cat = self.feature_embedder(feat_static_cat.long()) - static_feat = torch.cat([embedded_cat, torch.log(scale + 1.0)], dim=1) + static_feat = self.feature_embedder(feat_static_cat.long()) + if self.use_log_scale_feature: + static_feat = torch.cat( + [static_feat, torch.log(scale + 1.0)], dim=1 + ) repeated_static_feat = torch.repeat_interleave( static_feat[..., None], self.prediction_length + self.receptive_field,