From 4a481ca7f657f0d1aad035a70f9b302b9a047732 Mon Sep 17 00:00:00 2001 From: Xingchen Ma Date: Mon, 15 Jan 2024 16:11:42 +0100 Subject: [PATCH] Add TiDE model (#3096) --- src/gluonts/torch/__init__.py | 2 + src/gluonts/torch/model/tide/__init__.py | 18 + src/gluonts/torch/model/tide/estimator.py | 412 ++++++++++++++++ .../torch/model/tide/lightning_module.py | 123 +++++ src/gluonts/torch/model/tide/module.py | 439 ++++++++++++++++++ test/torch/model/test_estimators.py | 19 + 6 files changed, 1013 insertions(+) create mode 100644 src/gluonts/torch/model/tide/__init__.py create mode 100644 src/gluonts/torch/model/tide/estimator.py create mode 100644 src/gluonts/torch/model/tide/lightning_module.py create mode 100644 src/gluonts/torch/model/tide/module.py diff --git a/src/gluonts/torch/__init__.py b/src/gluonts/torch/__init__.py index 3c3d8f9dda..ac572920e7 100644 --- a/src/gluonts/torch/__init__.py +++ b/src/gluonts/torch/__init__.py @@ -16,6 +16,7 @@ "PyTorchPredictor", "DeepNPTSEstimator", "DeepAREstimator", + "TiDEEstimator", "SimpleFeedForwardEstimator", "TemporalFusionTransformerEstimator", "WaveNetEstimator", @@ -28,6 +29,7 @@ from .model.predictor import PyTorchPredictor from .model.deep_npts import DeepNPTSEstimator from .model.deepar import DeepAREstimator +from .model.tide import TiDEEstimator from .model.simple_feedforward import SimpleFeedForwardEstimator from .model.tft import TemporalFusionTransformerEstimator from .model.wavenet import WaveNetEstimator diff --git a/src/gluonts/torch/model/tide/__init__.py b/src/gluonts/torch/model/tide/__init__.py new file mode 100644 index 0000000000..80b41e58c5 --- /dev/null +++ b/src/gluonts/torch/model/tide/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from .module import TiDEModel +from .lightning_module import TiDELightningModule +from .estimator import TiDEEstimator + +__all__ = ["TiDEModel", "TiDELightningModule", "TiDEEstimator"] diff --git a/src/gluonts/torch/model/tide/estimator.py b/src/gluonts/torch/model/tide/estimator.py new file mode 100644 index 0000000000..ab6c004c02 --- /dev/null +++ b/src/gluonts/torch/model/tide/estimator.py @@ -0,0 +1,412 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Optional, Iterable, Dict, Any, List + +import torch +import lightning.pytorch as pl + +from gluonts.core.component import validated +from gluonts.dataset.common import Dataset +from gluonts.dataset.field_names import FieldName +from gluonts.dataset.loader import as_stacked_batches +from gluonts.itertools import Cyclic +from gluonts.model.forecast_generator import DistributionForecastGenerator +from gluonts.time_feature import ( + minute_of_hour, + hour_of_day, + day_of_month, + day_of_week, + day_of_year, + month_of_year, + week_of_year, +) +from gluonts.transform import ( + Transformation, + Chain, + RemoveFields, + SetField, + AsNumpyArray, + AddObservedValuesIndicator, + AddTimeFeatures, + VstackFeatures, + InstanceSplitter, + ValidationSplitSampler, + TestSplitSampler, + ExpectedNumInstanceSampler, + InstanceSampler, +) + +from gluonts.torch.model.estimator import PyTorchLightningEstimator +from gluonts.torch.model.predictor import PyTorchPredictor +from gluonts.torch.distributions import ( + DistributionOutput, + StudentTOutput, +) + +from .lightning_module import TiDELightningModule + +PREDICTION_INPUT_NAMES = [ + "feat_static_real", + "feat_static_cat", + "past_time_feat", + "past_target", + "past_observed_values", + "future_time_feat", +] + +TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ + "future_target", + "future_observed_values", +] + + +class TiDEEstimator(PyTorchLightningEstimator): + """ + An estimator training the TiDE model form the paper + https://arxiv.org/abs/2304.08424 extended for probabilistic forecasting. + + This class is uses the model defined in ``TiDEModel``, + and wraps it into a ``TiDELightningModule`` for training + purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` + class. + + Parameters + ---------- + freq + Frequency of the data to train on and predict. + prediction_length + Length of the prediction horizon. + context_length + Number of time steps prior to prediction time that the model + takes as inputs (default: ``prediction_length``). + feat_proj_hidden_dim + Size of the feature projection layer (default: 4). + encoder_hidden_dim + Size of the dense encoder layer (default: 4). + decoder_hidden_dim + Size of the dense decoder layer (default: 4). + temporal_hidden_dim + Size of the temporal decoder layer (default: 4). + distr_hidden_dim + Size of the distribution projection layer (default: 4). + num_layers_encoder + Number of layers in dense encoder (default: 1). + num_layers_decoder + Number of layers in dense decoder (default: 1). + decoder_output_dim + Output size of dense decoder (default: 4). + dropout_rate + Dropout regularization parameter (default: 0.3). + num_feat_dynamic_proj + Output size of feature projection layer (default: 2). + num_feat_dynamic_real + Number of dynamic real features in the data (default: 0). + num_feat_static_real + Number of static real features in the data (default: 0). + num_feat_static_cat + Number of static categorical features in the data (default: 0). + cardinality + Number of values of each categorical feature. + This must be set if ``num_feat_static_cat > 0`` (default: None). + embedding_dimension + Dimension of the embeddings for categorical features + (default: ``[16 for cat in cardinality]``). + layer_norm + Enable layer normalization or not (default: False). + lr + Learning rate (default: ``1e-3``). + weight_decay + Weight decay regularization parameter (default: ``1e-8``). + patience + Patience parameter for learning rate scheduler (default: 10). + distr_output + Distribution to use to evaluate observations and sample predictions + (default: StudentTOutput()). + scaling + Which scaling method to use to scale the target values (default: mean). + batch_size + The size of the batches to be used for training (default: 32). + num_batches_per_epoch + Number of batches to be processed in each training epoch + (default: 50). + trainer_kwargs + Additional arguments to provide to ``pl.Trainer`` for construction. + train_sampler + Controls the sampling of windows during training. + validation_sampler + Controls the sampling of windows during validation. + """ + + @validated() + def __init__( + self, + freq: str, + prediction_length: int, + context_length: Optional[int] = None, + feat_proj_hidden_dim: Optional[int] = None, + encoder_hidden_dim: Optional[int] = None, + decoder_hidden_dim: Optional[int] = None, + temporal_hidden_dim: Optional[int] = None, + distr_hidden_dim: Optional[int] = None, + num_layers_encoder: Optional[int] = None, + num_layers_decoder: Optional[int] = None, + decoder_output_dim: Optional[int] = None, + dropout_rate: Optional[float] = None, + num_feat_dynamic_proj: Optional[int] = None, + num_feat_dynamic_real: int = 0, + num_feat_static_real: int = 0, + num_feat_static_cat: int = 0, + cardinality: Optional[List[int]] = None, + embedding_dimension: Optional[List[int]] = None, + layer_norm: bool = False, + lr: float = 1e-3, + weight_decay: float = 1e-8, + patience: int = 10, + scaling: Optional[str] = "mean", + distr_output: DistributionOutput = StudentTOutput(), + batch_size: int = 32, + num_batches_per_epoch: int = 50, + trainer_kwargs: Optional[Dict[str, Any]] = None, + train_sampler: Optional[InstanceSampler] = None, + validation_sampler: Optional[InstanceSampler] = None, + ) -> None: + default_trainer_kwargs = { + "max_epochs": 100, + "gradient_clip_val": 10.0, + } + if trainer_kwargs is not None: + default_trainer_kwargs.update(trainer_kwargs) + super().__init__(trainer_kwargs=default_trainer_kwargs) + + self.freq = freq + self.prediction_length = prediction_length + self.context_length = context_length or prediction_length + self.feat_proj_hidden_dim = feat_proj_hidden_dim or 4 + self.encoder_hidden_dim = encoder_hidden_dim or 4 + self.decoder_hidden_dim = decoder_hidden_dim or 4 + self.temporal_hidden_dim = temporal_hidden_dim or 4 + self.distr_hidden_dim = distr_hidden_dim or 4 + self.num_layers_encoder = num_layers_encoder or 1 + self.num_layers_decoder = num_layers_decoder or 1 + self.decoder_output_dim = decoder_output_dim or 4 + self.dropout_rate = dropout_rate or 0.3 + + self.num_feat_dynamic_proj = num_feat_dynamic_proj or 2 + self.num_feat_dynamic_real = num_feat_dynamic_real + self.num_feat_static_real = num_feat_static_real + self.num_feat_static_cat = num_feat_static_cat + self.cardinality = ( + cardinality if cardinality and num_feat_static_cat > 0 else [1] + ) + self.embedding_dimension = ( + embedding_dimension + if embedding_dimension is not None or self.cardinality is None + else [16 for cat in self.cardinality] + ) + + self.layer_norm = layer_norm + self.lr = lr + self.weight_decay = weight_decay + self.patience = patience + self.distr_output = distr_output + self.scaling = scaling + self.batch_size = batch_size + self.num_batches_per_epoch = num_batches_per_epoch + + self.train_sampler = train_sampler or ExpectedNumInstanceSampler( + num_instances=1.0, min_future=prediction_length + ) + self.validation_sampler = validation_sampler or ValidationSplitSampler( + min_future=prediction_length + ) + + def create_transformation(self) -> Transformation: + remove_field_names = [] + if self.num_feat_static_real == 0: + remove_field_names.append(FieldName.FEAT_STATIC_REAL) + if self.num_feat_dynamic_real == 0: + remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) + + return Chain( + [RemoveFields(field_names=remove_field_names)] + + ( + [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])] + if not self.num_feat_static_cat > 0 + else [] + ) + + ( + [ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] + ) + ] + if not self.num_feat_static_real > 0 + else [] + ) + + [ + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=int, + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, + expected_ndim=1, + ), + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=1 + len(self.distr_output.event_shape), + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=[ + minute_of_hour, + hour_of_day, + day_of_month, + day_of_week, + day_of_year, + month_of_year, + week_of_year, + ], + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.num_feat_dynamic_real > 0 + else [] + ), + drop_inputs=False, + ), + AsNumpyArray(FieldName.FEAT_TIME, expected_ndim=2), + ] + ) + + def create_lightning_module(self) -> pl.LightningModule: + return TiDELightningModule( + lr=self.lr, + weight_decay=self.weight_decay, + patience=self.patience, + model_kwargs={ + "context_length": self.context_length, + "prediction_length": self.prediction_length, + "num_feat_dynamic_real": 7 + self.num_feat_dynamic_real, + "num_feat_dynamic_proj": self.num_feat_dynamic_proj, + "num_feat_static_real": max(1, self.num_feat_static_real), + "num_feat_static_cat": max(1, self.num_feat_static_cat), + "cardinality": self.cardinality, + "embedding_dimension": self.embedding_dimension, + "feat_proj_hidden_dim": self.feat_proj_hidden_dim, + "encoder_hidden_dim": self.encoder_hidden_dim, + "decoder_hidden_dim": self.decoder_hidden_dim, + "temporal_hidden_dim": self.temporal_hidden_dim, + "distr_hidden_dim": self.distr_hidden_dim, + "decoder_output_dim": self.decoder_output_dim, + "dropout_rate": self.dropout_rate, + "num_layers_encoder": self.num_layers_encoder, + "num_layers_decoder": self.num_layers_decoder, + "layer_norm": self.layer_norm, + "distr_output": self.distr_output, + "scaling": self.scaling, + }, + ) + + def _create_instance_splitter( + self, module: TiDELightningModule, mode: str + ): + assert mode in ["training", "validation", "test"] + + instance_sampler = { + "training": self.train_sampler, + "validation": self.validation_sampler, + "test": TestSplitSampler(), + }[mode] + + return InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=instance_sampler, + past_length=self.context_length, + future_length=self.prediction_length, + time_series_fields=[ + FieldName.FEAT_TIME, + FieldName.OBSERVED_VALUES, + ], + dummy_value=self.distr_output.value_in_support, + ) + + def create_training_data_loader( + self, + data: Dataset, + module: TiDELightningModule, + shuffle_buffer_length: Optional[int] = None, + **kwargs, + ) -> Iterable: + data = Cyclic(data).stream() + instances = self._create_instance_splitter(module, "training").apply( + data, is_train=True + ) + return as_stacked_batches( + instances, + batch_size=self.batch_size, + shuffle_buffer_length=shuffle_buffer_length, + field_names=TRAINING_INPUT_NAMES, + output_type=torch.tensor, + num_batches_per_epoch=self.num_batches_per_epoch, + ) + + def create_validation_data_loader( + self, + data: Dataset, + module: TiDELightningModule, + **kwargs, + ) -> Iterable: + instances = self._create_instance_splitter(module, "validation").apply( + data, is_train=True + ) + return as_stacked_batches( + instances, + batch_size=self.batch_size, + field_names=TRAINING_INPUT_NAMES, + output_type=torch.tensor, + ) + + def create_predictor( + self, + transformation: Transformation, + module, + ) -> PyTorchPredictor: + prediction_splitter = self._create_instance_splitter(module, "test") + + return PyTorchPredictor( + input_transform=transformation + prediction_splitter, + input_names=PREDICTION_INPUT_NAMES, + prediction_net=module, + forecast_generator=DistributionForecastGenerator( + self.distr_output + ), + batch_size=self.batch_size, + prediction_length=self.prediction_length, + device="auto", + ) diff --git a/src/gluonts/torch/model/tide/lightning_module.py b/src/gluonts/torch/model/tide/lightning_module.py new file mode 100644 index 0000000000..e23b2872d6 --- /dev/null +++ b/src/gluonts/torch/model/tide/lightning_module.py @@ -0,0 +1,123 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import lightning.pytorch as pl +import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from gluonts.core.component import validated + +from gluonts.itertools import select +from gluonts.torch.model.lightning_util import has_validation_loop + + +from .module import TiDEModel + + +class TiDELightningModule(pl.LightningModule): + """ + A ``pl.LightningModule`` class that can be used to train a + ``TiDEModel`` with PyTorch Lightning. + + This is a thin layer around a (wrapped) ``TiDEModel`` object, + that exposes the methods to evaluate training and validation loss. + + Parameters + ---------- + model_kwargs + Keyword arguments to construct the ``TiDEModel`` to be trained. + lr + Learning rate. + weight_decay + Weight decay regularization parameter. + patience + Patience parameter for learning rate scheduler. + """ + + @validated() + def __init__( + self, + model_kwargs: dict, + lr: float = 1e-3, + weight_decay: float = 1e-8, + patience: int = 10, + ): + super().__init__() + self.save_hyperparameters() + self.model = TiDEModel(**model_kwargs) + self.lr = lr + self.weight_decay = weight_decay + self.patience = patience + self.inputs = self.model.describe_inputs() + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def training_step(self, batch, batch_idx: int): # type: ignore + """ + Execute training step. + """ + train_loss = self.model.loss( + **select(self.inputs, batch), + future_target=batch["future_target"], + future_observed_values=batch["future_observed_values"], + ).mean() + self.log( + "train_loss", + train_loss, + on_epoch=True, + on_step=False, + prog_bar=True, + ) + return train_loss + + def validation_step(self, batch, batch_idx: int): # type: ignore + """ + Execute validation step. + """ + val_loss = self.model.loss( + **select(self.inputs, batch), + future_target=batch["future_target"], + future_observed_values=batch["future_observed_values"], + ).mean() + + self.log( + "val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True + ) + return val_loss + + def configure_optimizers(self): + """ + Returns the optimizer to use. + """ + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + monitor = ( + "val_loss" if has_validation_loop(self.trainer) else "train_loss" + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau( + optimizer=optimizer, + mode="min", + factor=0.5, + patience=self.patience, + ), + "monitor": monitor, + }, + } diff --git a/src/gluonts/torch/model/tide/module.py b/src/gluonts/torch/model/tide/module.py new file mode 100644 index 0000000000..875a05f8b2 --- /dev/null +++ b/src/gluonts/torch/model/tide/module.py @@ -0,0 +1,439 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List, Tuple + +import torch +from torch import nn + +from gluonts.core.component import validated +from gluonts.torch.modules.feature import FeatureEmbedder +from gluonts.model import Input, InputSpec +from gluonts.torch.distributions import DistributionOutput +from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler +from gluonts.torch.model.simple_feedforward import make_linear_layer +from gluonts.torch.util import weighted_average + + +class ResBlock(nn.Module): + def __init__( + self, + dim_in: int, + dim_hidden: int, + dim_out: int, + dropout_rate: float, + layer_norm: bool, + ): + super().__init__() + + self.fc = nn.Sequential( + make_linear_layer(dim_in, dim_hidden), + nn.ReLU(), + make_linear_layer(dim_hidden, dim_out), + nn.Dropout(p=dropout_rate), + ) + self.skip = make_linear_layer(dim_in, dim_out) + if layer_norm: + self.ln = nn.LayerNorm(dim_out) + self.layer_norm = layer_norm + + def forward(self, x): + if self.layer_norm: + return self.ln(self.fc(x) + self.skip(x)) + return self.fc(x) + self.skip(x) + + +class FeatureProjection(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + dropout_rate: float, + layer_norm: bool, + ): + super().__init__() + + self.proj = ResBlock( + input_dim, + hidden_dim, + output_dim, + dropout_rate, + layer_norm, + ) + + def forward(self, x): + return self.proj(x) + + +class DenseEncoder(nn.Module): + def __init__( + self, + num_layers: int, + input_dim: int, + hidden_dim: int, + dropout_rate: float, + layer_norm: bool, + ): + super().__init__() + + layers = [] + layers.append( + ResBlock( + input_dim, hidden_dim, hidden_dim, dropout_rate, layer_norm + ) + ) + for i in range(num_layers - 1): + layers.append( + ResBlock( + hidden_dim, + hidden_dim, + hidden_dim, + dropout_rate, + layer_norm, + ) + ) + self.encoder = nn.Sequential(*layers) + + def forward(self, x): + return self.encoder(x) + + +class DenseDecoder(nn.Module): + def __init__( + self, + num_layers: int, + hidden_dim: int, + output_dim: int, + dropout_rate: float, + layer_norm: bool, + ): + super().__init__() + + layers = [] + for i in range(num_layers - 1): + layers.append( + ResBlock( + hidden_dim, + hidden_dim, + hidden_dim, + dropout_rate, + layer_norm, + ) + ) + layers.append( + ResBlock( + hidden_dim, + hidden_dim, + output_dim, + dropout_rate, + layer_norm, + ) + ) + + self.decoder = nn.Sequential(*layers) + + def forward(self, x): + return self.decoder(x) + + +class TemporalDecoder(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + dropout_rate: float, + layer_norm: bool, + ): + super().__init__() + + self.temporal_decoder = ResBlock( + input_dim, + hidden_dim, + output_dim, + dropout_rate, + layer_norm, + ) + + def forward(self, x): + return self.temporal_decoder(x) + + +class TiDEModel(nn.Module): + """ + + Parameters + ---------- + context_length + Number of time steps prior to prediction time that the model + takes as inputs. + prediction_length + Length of the prediction horizon. + num_feat_dynamic_proj + Output size of feature projection layer. + num_feat_dynamic_real + Number of dynamic real features in the data. + num_feat_static_real + Number of static real features in the data. + num_feat_static_cat + Number of static categorical features in the data. + cardinality + Number of values of each categorical feature. + This must be set if ``num_feat_static_cat > 0``. + embedding_dimension + Dimension of the embeddings for categorical features. + feat_proj_hidden_dim + Size of the feature projection layer. + encoder_hidden_dim + Size of the dense encoder layer. + decoder_hidden_dim + Size of the dense decoder layer. + temporal_hidden_dim + Size of the temporal decoder layer. + distr_hidden_dim + Size of the distribution projection layer. + decoder_output_dim + Output size of dense decoder. + dropout_rate + Dropout regularization parameter. + num_layers_encoder + Number of layers in dense encoder. + num_layers_decoder + Number of layers in dense decoder. + layer_norm + Enable layer normalization or not. + distr_output + Distribution to use to evaluate observations and sample predictions. + scaling + Which scaling method to use to scale the target values. + + + """ + + @validated() + def __init__( + self, + context_length: int, + prediction_length: int, + num_feat_dynamic_real: int, + num_feat_dynamic_proj: int, + num_feat_static_real: int, + num_feat_static_cat: int, + cardinality: List[int], + embedding_dimension: List[int], + feat_proj_hidden_dim: int, + encoder_hidden_dim: int, + decoder_hidden_dim: int, + temporal_hidden_dim: int, + distr_hidden_dim: int, + decoder_output_dim: int, + dropout_rate: float, + num_layers_encoder: int, + num_layers_decoder: int, + layer_norm: bool, + distr_output: DistributionOutput, + scaling: str, + ) -> None: + super().__init__() + + assert context_length > 0 + assert prediction_length > 0 + assert num_feat_dynamic_real > 0 + assert num_feat_static_real > 0 + assert num_feat_static_cat > 0 + assert len(cardinality) == num_feat_static_cat + assert len(embedding_dimension) == num_feat_static_cat + + self.context_length = context_length + self.prediction_length = prediction_length + self.num_feat_dynamic_real = num_feat_dynamic_real + self.num_feat_dynamic_proj = num_feat_dynamic_proj + self.num_feat_static_real = num_feat_static_real + self.num_feat_static_cat = num_feat_static_cat + self.embedding_dimension = embedding_dimension + self.feat_proj_hidden_dim = feat_proj_hidden_dim + self.num_layers_encoder = num_layers_encoder + self.num_layers_decoder = num_layers_decoder + self.encoder_hidden_dim = encoder_hidden_dim + self.decoder_hidden_dim = decoder_hidden_dim + self.temporal_hidden_dim = temporal_hidden_dim + self.distr_hidden_dim = distr_hidden_dim + self.decoder_output_dim = decoder_output_dim + self.dropout_rate = dropout_rate + + self.proj_flatten_dim = ( + context_length + prediction_length + ) * num_feat_dynamic_proj + encoder_input_dim = ( + context_length + + num_feat_static_real + + sum(self.embedding_dimension) + + self.proj_flatten_dim + ) + self.temporal_decoder_input_dim = ( + decoder_output_dim + num_feat_dynamic_proj + ) + + self.embedder = FeatureEmbedder( + cardinalities=cardinality, + embedding_dims=self.embedding_dimension, + ) + + self.feat_proj = FeatureProjection( + num_feat_dynamic_real, + feat_proj_hidden_dim, + num_feat_dynamic_proj, + dropout_rate, + layer_norm, + ) + self.dense_encoder = DenseEncoder( + num_layers_encoder, + encoder_input_dim, + encoder_hidden_dim, + dropout_rate, + layer_norm, + ) + self.dense_decoder = DenseDecoder( + num_layers_encoder, + decoder_hidden_dim, + prediction_length * decoder_output_dim, + dropout_rate, + layer_norm, + ) + self.temporal_decoder = TemporalDecoder( + self.temporal_decoder_input_dim, + temporal_hidden_dim, + distr_hidden_dim, + dropout_rate, + layer_norm, + ) + self.loopback_skip = make_linear_layer( + self.context_length, self.prediction_length * distr_hidden_dim + ) + + self.distr_output = distr_output + if scaling == "mean": + self.scaler = MeanScaler(keepdim=True) + elif scaling == "std": + self.scaler = StdScaler(keepdim=True) + else: + self.scaler = NOPScaler(keepdim=True) + + self.args_proj = self.distr_output.get_args_proj(self.distr_hidden_dim) + + def describe_inputs(self, batch_size=1) -> InputSpec: + return InputSpec( + { + "past_target": Input( + shape=(batch_size, self.context_length), dtype=torch.float + ), + "past_observed_values": Input( + shape=(batch_size, self.context_length), dtype=torch.float + ), + "past_time_feat": Input( + shape=( + batch_size, + self.context_length, + self.num_feat_dynamic_real, + ), + dtype=torch.float, + ), + "feat_static_real": Input( + shape=(batch_size, self.num_feat_static_real), + dtype=torch.float, + ), + "feat_static_cat": Input( + shape=(batch_size, self.num_feat_static_cat), + dtype=torch.long, + ), + "future_time_feat": Input( + shape=( + batch_size, + self.prediction_length, + self.num_feat_dynamic_real, + ), + dtype=torch.float, + ), + }, + torch.zeros, + ) + + def forward( + self, + feat_static_real: torch.Tensor, + feat_static_cat: torch.Tensor, + past_time_feat: torch.Tensor, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + future_time_feat: torch.Tensor, + ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: + past_target_scaled, loc, scale = self.scaler( + past_target, past_observed_values + ) + + embedded_cat = self.embedder(feat_static_cat) + time_feat = torch.cat((past_time_feat, future_time_feat), dim=-2) + proj = self.feat_proj(time_feat) + proj_future = proj[..., -self.prediction_length :, :] + proj_flatten = proj.view(-1, self.proj_flatten_dim) + encoder_input = torch.cat( + (past_target_scaled, feat_static_real, embedded_cat, proj_flatten), + dim=-1, + ) + encoder_output = self.dense_encoder(encoder_input) + decoder_output = self.dense_decoder(encoder_output) + + temporal_decoder_input = torch.cat( + ( + decoder_output.view( + -1, self.prediction_length, self.decoder_output_dim + ), + proj_future, + ), + dim=-1, + ) + out = self.temporal_decoder(temporal_decoder_input) + out = ( + self.loopback_skip(past_target_scaled).view( + -1, self.prediction_length, self.distr_hidden_dim + ) + + out + ) + + distr_args = self.args_proj(out) + return distr_args, loc, scale + + def loss( + self, + feat_static_real: torch.Tensor, + feat_static_cat: torch.Tensor, + past_time_feat: torch.Tensor, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + future_time_feat: torch.Tensor, + future_target: torch.Tensor, + future_observed_values: torch.Tensor, + ): + distr_args, loc, scale = self( + feat_static_real=feat_static_real, + feat_static_cat=feat_static_cat, + past_time_feat=past_time_feat, + past_target=past_target, + past_observed_values=past_observed_values, + future_time_feat=future_time_feat, + ) + loss = self.distr_output.loss( + target=future_target, distr_args=distr_args, loc=loc, scale=scale + ) + return weighted_average(loss, weights=future_observed_values, dim=-1) diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index 79bf9b8a7a..b2317549c4 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -34,6 +34,7 @@ from gluonts.torch.model.simple_feedforward import SimpleFeedForwardEstimator from gluonts.torch.model.d_linear import DLinearEstimator from gluonts.torch.model.patch_tst import PatchTSTEstimator +from gluonts.torch.model.tide import TiDEEstimator from gluonts.torch.model.lag_tst import LagTSTEstimator from gluonts.torch.model.tft import TemporalFusionTransformerEstimator from gluonts.torch.model.wavenet import WaveNetEstimator @@ -139,6 +140,13 @@ num_batches_per_epoch=3, trainer_kwargs=dict(max_epochs=2), ), + lambda dataset: TiDEEstimator( + freq=dataset.metadata.freq, + prediction_length=dataset.metadata.prediction_length, + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), lambda dataset: WaveNetEstimator( freq=dataset.metadata.freq, prediction_length=dataset.metadata.prediction_length, @@ -207,6 +215,17 @@ def test_estimator_constant_dataset( trainer_kwargs=dict(max_epochs=2), distr_output=ImplicitQuantileNetworkOutput(), ), + lambda freq, prediction_length: TiDEEstimator( + freq=freq, + prediction_length=prediction_length, + batch_size=4, + num_batches_per_epoch=3, + num_feat_dynamic_real=3, + num_feat_static_real=1, + num_feat_static_cat=2, + cardinality=[2, 2], + trainer_kwargs=dict(max_epochs=2), + ), lambda freq, prediction_length: MQF2MultiHorizonEstimator( freq=freq, prediction_length=prediction_length,