Skip to content

Commit

Permalink
Add flag to toggle log scale feature for WaveNet model (#3012)
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur authored Oct 12, 2023
1 parent 0bc0a45 commit 9be6dfd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/gluonts/torch/model/wavenet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
cardinality: List[int] = [1],
seasonality: Optional[int] = None,
embedding_dimension: int = 5,
use_log_scale_feature: bool = True,
time_features: Optional[List[TimeFeature]] = None,
lr: float = 1e-3,
weight_decay: float = 1e-8,
Expand Down Expand Up @@ -140,6 +141,10 @@ def __init__(
embedding_dimension, optional
The dimension of the embeddings for categorical features,
by default 5
use_log_scale_feature, optional
If True, logarithm of the scale of the past data will be used as an
additional static feature,
by default True
time_features, optional
List of time features, from :py:mod:`gluonts.time_feature`,
by default None
Expand Down Expand Up @@ -187,6 +192,7 @@ def __init__(
self.num_residual_channels = num_residual_channels
self.num_skip_channels = num_skip_channels
self.num_stacks = num_stacks
self.use_log_scale_feature = use_log_scale_feature
self.time_features = unwrap_or(
time_features, time_features_from_frequency_str(freq)
)
Expand Down Expand Up @@ -382,6 +388,7 @@ def create_lightning_module(self) -> pl.LightningModule:
pred_length=self.prediction_length,
num_parallel_samples=self.num_parallel_samples,
temperature=self.temperature,
use_log_scale_feature=self.use_log_scale_feature,
),
)

Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/torch/model/wavenet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
embedding_dimension: int = 5,
num_parallel_samples: int = 100,
temperature: float = 1.0,
use_log_scale_feature: bool = True,
):
super().__init__()

Expand All @@ -141,8 +142,9 @@ def __init__(
embedding_dimension * len(cardinality)
+ num_feat_dynamic_real
+ num_feat_static_real
+ 1 # the log(scale)
+ int(use_log_scale_feature) # the log(scale)
)
self.use_log_scale_feature = use_log_scale_feature

# 1 extra bin to accounts for extreme values
self.n_bins = len(bin_values) + 1
Expand Down Expand Up @@ -249,8 +251,11 @@ def get_full_features(
network.
Shape: (batch_size, num_features, receptive_field + pred_length)
"""
embedded_cat = self.feature_embedder(feat_static_cat.long())
static_feat = torch.cat([embedded_cat, torch.log(scale + 1.0)], dim=1)
static_feat = self.feature_embedder(feat_static_cat.long())
if self.use_log_scale_feature:
static_feat = torch.cat(
[static_feat, torch.log(scale + 1.0)], dim=1
)
repeated_static_feat = torch.repeat_interleave(
static_feat[..., None],
self.prediction_length + self.receptive_field,
Expand Down

0 comments on commit 9be6dfd

Please sign in to comment.