Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Oct 17, 2023
1 parent de9c852 commit 5ea1e2d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/gluonts/torch/model/i_transformer/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ITransformerModel(nn.Module):
Parameters
----------
imput_size
input_size
Number of multivariates to predict.
prediction_length
Number of time points to predict.
Expand Down Expand Up @@ -93,8 +93,10 @@ def __init__(
else:
self.scaler = NOPScaler(keepdim=True, dim=1)

# project each variate plus mean and std to d_model dimension
self.emebdding = nn.Linear(context_length + 2, d_model)

# transformer encoder
layer_norm_eps: float = 1e-5
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
Expand All @@ -111,10 +113,12 @@ def __init__(
encoder_layer, num_encoder_layers, encoder_norm
)

# project each variate to prediction length number of latent variables
self.projection = nn.Linear(
d_model, prediction_length * d_model // nhead
)

# project each prediction length latent to distribution parameters
self.args_proj = self.distr_output.get_args_proj(d_model // nhead)

def describe_inputs(self, batch_size=1) -> InputSpec:
Expand Down

0 comments on commit 5ea1e2d

Please sign in to comment.