From 4f61a0563e0d7a1d9f32459cb28bdae19b65f6a8 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 20 Oct 2023 07:40:39 +0000 Subject: [PATCH] Fix WaveNet inputs --- src/gluonts/torch/model/wavenet/estimator.py | 1 + src/gluonts/torch/model/wavenet/lightning_module.py | 4 ++++ src/gluonts/torch/model/wavenet/module.py | 13 +++++++++++++ 3 files changed, 18 insertions(+) diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index e7fea4b0d7..6f7b72f11b 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -54,6 +54,7 @@ PREDICTION_INPUT_NAMES = [ "feat_static_cat", + "feat_static_real", "past_target", "past_observed_values", "past_time_feat", diff --git a/src/gluonts/torch/model/wavenet/lightning_module.py b/src/gluonts/torch/model/wavenet/lightning_module.py index 85dd0a6671..f4e7043ebb 100644 --- a/src/gluonts/torch/model/wavenet/lightning_module.py +++ b/src/gluonts/torch/model/wavenet/lightning_module.py @@ -53,6 +53,7 @@ def training_step(self, batch, batch_idx: int): # type: ignore Execute training step. """ feat_static_cat = batch["feat_static_cat"] + feat_static_real = batch["feat_static_real"] past_target = batch["past_target"] past_observed_values = batch["past_observed_values"] past_time_feat = batch["past_time_feat"] @@ -63,6 +64,7 @@ def training_step(self, batch, batch_idx: int): # type: ignore train_loss = self.model.loss( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_target=past_target, past_observed_values=past_observed_values, past_time_feat=past_time_feat, @@ -87,6 +89,7 @@ def validation_step(self, batch, batch_idx: int): # type: ignore Execute validation step. """ feat_static_cat = batch["feat_static_cat"] + feat_static_real = batch["feat_static_real"] past_target = batch["past_target"] past_observed_values = batch["past_observed_values"] past_time_feat = batch["past_time_feat"] @@ -97,6 +100,7 @@ def validation_step(self, batch, batch_idx: int): # type: ignore val_loss = self.model.loss( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_target=past_target, past_observed_values=past_observed_values, past_time_feat=past_time_feat, diff --git a/src/gluonts/torch/model/wavenet/module.py b/src/gluonts/torch/model/wavenet/module.py index c34c5203ab..4a6f3f95b1 100644 --- a/src/gluonts/torch/model/wavenet/module.py +++ b/src/gluonts/torch/model/wavenet/module.py @@ -143,6 +143,7 @@ def __init__( + num_feat_dynamic_real + num_feat_static_real + int(use_log_scale_feature) # the log(scale) + + 1 # for observed value indicator ) self.use_log_scale_feature = use_log_scale_feature @@ -217,6 +218,7 @@ def get_receptive_field(dilation_depth: int, num_stacks: int) -> int: def get_full_features( self, feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, past_observed_values: torch.Tensor, past_time_feat: torch.Tensor, future_time_feat: torch.Tensor, @@ -230,6 +232,8 @@ def get_full_features( ---------- feat_static_cat Static categorical features: (batch_size, num_cat_features) + feat_static_real + Static real-valued features: (batch_size, num_feat_static_real) past_observed_values Observed value indicator for the past target: (batch_size, receptive_field) @@ -256,6 +260,7 @@ def get_full_features( static_feat = torch.cat( [static_feat, torch.log(scale + 1.0)], dim=1 ) + static_feat = torch.cat([static_feat, feat_static_real], dim=1) repeated_static_feat = torch.repeat_interleave( static_feat[..., None], self.prediction_length + self.receptive_field, @@ -361,6 +366,7 @@ def base_net( def loss( self, feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, past_time_feat: torch.Tensor, @@ -375,6 +381,8 @@ def loss( ---------- feat_static_cat Static categorical features: (batch_size, num_cat_features) + feat_static_real + Static real-valued features: (batch_size, num_feat_static_real) past_target Past target: (batch_size, receptive_field) past_observed_values @@ -401,6 +409,7 @@ def loss( full_target = torch.cat([past_target, future_target], dim=-1).long() full_features = self.get_full_features( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_observed_values=past_observed_values, past_time_feat=past_time_feat, future_time_feat=future_time_feat, @@ -457,6 +466,7 @@ def _initialize_conv_queues( def forward( self, feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, past_time_feat: torch.Tensor, @@ -472,6 +482,8 @@ def forward( ---------- feat_static_cat Static categorical features: (batch_size, num_cat_features) + feat_static_real + Static real-valued features: (batch_size, num_feat_static_real) past_target Past target: (batch_size, receptive_field) past_observed_values @@ -508,6 +520,7 @@ def forward( past_target = past_target.long() full_features = self.get_full_features( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_observed_values=past_observed_values, past_time_feat=past_time_feat, future_time_feat=future_time_feat,