Skip to content

Commit

Permalink
add date-time features as patches
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 12, 2025
1 parent c8685f6 commit 27071bd
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 65 deletions.
84 changes: 46 additions & 38 deletions src/gluonts/torch/model/seg_diff/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.time_feature import time_features_from_frequency_str
from gluonts.transform import (
Transformation,
AddObservedValuesIndicator,
Expand All @@ -30,18 +31,25 @@
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
RenameFields,
AddTimeFeatures,
AddAgeFeature,
VstackFeatures,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor

from .lightning_module import SegDiffLightningModule

PREDICTION_INPUT_NAMES = ["past_target", "past_observed_values"]
PREDICTION_INPUT_NAMES = [
f"past_{FieldName.TARGET}",
f"past_{FieldName.OBSERVED_VALUES}",
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]

TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
f"future_{FieldName.TARGET}",
f"future_{FieldName.OBSERVED_VALUES}",
]


Expand Down Expand Up @@ -163,6 +171,8 @@ def __init__(
min_future=self.prediction_length
)

self.time_features = time_features_from_frequency_str("s")

def create_transformation(self) -> Transformation:
return (
SelectFields(
Expand All @@ -179,11 +189,32 @@ def create_transformation(self) -> Transformation:
),
allow_missing=True,
)
+ RenameFields({FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME})
+ AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
)
+ AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
)
+ AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
+ VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.num_feat_dynamic_real > 0
else []
),
)
)

def create_lightning_module(self) -> pl.LightningModule:
Expand All @@ -197,14 +228,15 @@ def create_lightning_module(self) -> pl.LightningModule:
"d_model": self.d_model,
"nhead": self.nhead,
"dim_feedforward": self.dim_feedforward,
"num_feat_dynamic_real": self.num_feat_dynamic_real,
"dropout": self.dropout,
"activation": self.activation,
"norm_first": self.norm_first,
"num_decoder_layers": self.num_decoder_layers,
# "distr_output": self.distr_output,
"scaling": self.scaling,
"n_steps": self.n_steps,
"num_feat_dynamic_real": len(self.time_features)
+ 1
+ self.num_feat_dynamic_real,
},
)

Expand All @@ -227,10 +259,10 @@ def _create_instance_splitter(
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[FieldName.OBSERVED_VALUES]
+ (
[FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else []
),
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
)

def create_training_data_loader(
Expand All @@ -248,15 +280,7 @@ def create_training_data_loader(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
Expand All @@ -270,15 +294,7 @@ def create_validation_data_loader(
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
)

Expand All @@ -289,15 +305,7 @@ def create_predictor(

return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
Expand Down
95 changes: 68 additions & 27 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,33 @@ def __init__(self, patch_size: int, patch_stride: int) -> None:
self.patch_stride = patch_stride

def forward(self, x: torch.Tensor) -> torch.Tensor:
length = x.shape[-1]
# Ensure input is at least 3D
if x.ndim == 1:
x = x.unsqueeze(0).unsqueeze(-1) # [L] -> [1, L, 1]
elif x.ndim == 2:
x = x.unsqueeze(-1) # [B, L] -> [B, L, 1]

if length % self.patch_size != 0:
padding_size = (
*x.shape[:-1],
self.patch_size - (length % self.patch_size),
)
batch_size, seq_len, feat_dim = x.shape

# Handle padding if needed
if seq_len % self.patch_size != 0:
padding_size = self.patch_size - (seq_len % self.patch_size)
padding = torch.full(
size=padding_size,
size=(batch_size, padding_size, feat_dim),
fill_value=torch.nan,
dtype=x.dtype,
device=x.device,
)
x = torch.concat((padding, x), dim=-1)
x = torch.cat((padding, x), dim=1)
seq_len = x.shape[1]

# Unfold along sequence dimension
x = x.unfold(
dimension=-1, size=self.patch_size, step=self.patch_stride
)
return x
dimension=1, size=self.patch_size, step=self.patch_stride
) # [B, num_patches, patch_size, Feature]

# Reshape to [B, num_patches, patch_size * Feature]
return x.reshape(batch_size, -1, self.patch_size * feat_dim)


class ResidualBlock(nn.Module):
Expand Down Expand Up @@ -531,15 +539,15 @@ def params_from_decoder_output(
)

# scale the input
past_target_scaled, loc, scale = self.scaler(
target_scaled, loc, scale = self.scaler(
past_target, past_observed_values
)
patched_past_target = self.patch(past_target_scaled)
patched_target = self.patch(target_scaled)

# do patching for time features as well
if self.num_feat_dynamic_real > 0:
time_feat = torch.cat((past_time_feat, future_time_feat), dim=1)
patched_time_feat = self.patch(time_feat)
# if self.num_feat_dynamic_real > 0:
# time_feat = torch.cat((past_time_feat, future_time_feat), dim=1)
# patched_time_feat = self.patch(time_feat)

# add loc and scale to past_target_patches as additional features
log_abs_loc = loc.sign() * loc.abs().log1p()
Expand All @@ -548,14 +556,32 @@ def params_from_decoder_output(
expanded_static_feat = unsqueeze_expand(
torch.cat([log_abs_loc, log_scale], dim=-1),
dim=1,
size=patched_past_target.shape[1],
size=patched_target.shape[1],
)
inputs = torch.cat((patched_past_target, expanded_static_feat), dim=-1)
inputs = torch.cat((patched_target, expanded_static_feat), dim=-1)

if self.num_feat_dynamic_real > 0:
inputs = torch.cat((inputs, patched_time_feat), dim=-1)
if future_time_feat is not None:
past_time_feat = torch.cat(
(past_time_feat, future_time_feat), dim=1
)
patched_time_feat = self.patch(past_time_feat)[:, 1:, :]

if future_target is not None:
# shift the time featur patches by one and pad the very last patch with zeros:
patched_time_feat = torch.cat(
(
patched_time_feat,
torch.zeros_like(patched_time_feat[:, -1, :]).unsqueeze(1),
),
dim=1,
)

# if self.num_feat_dynamic_real > 0:
# inputs = torch.cat((inputs, patched_time_feat), dim=-1)
# project the input embeddings to the model dimension
input_embeddings = self.input_patch_embedding(inputs)
input_embeddings = self.input_patch_embedding(
torch.cat((inputs, patched_time_feat), dim=-1)
)

# causal mask for the transformer decoder
mask = nn.Transformer.generate_square_subsequent_mask(
Expand All @@ -580,11 +606,12 @@ def loss(
flow_cond, loc, scale = self.params_from_decoder_output(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
future_target=future_target,
future_observed_values=future_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
)

# Get patches for target
target = self.patch(
(torch.cat((past_target, future_target), dim=1) - loc) / scale
Expand All @@ -593,7 +620,7 @@ def loss(
# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
x_0 = torch.randn_like(x_1) # Random noise source distribution
# x_0 = target[:, :-1, :]
# x_0 = target[:, :-1, :] + torch.randn_like(target[:, :-1, :]) * 0.7

return self.flow.compute_loss(
x_1=x_1, x_0=x_0, cond=flow_cond[:, :-1, :]
Expand Down Expand Up @@ -667,6 +694,9 @@ def forward(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat[:, : self.patch_len]
if future_time_feat is not None
else None,
)

# Initialize samples for each batch
Expand All @@ -678,6 +708,10 @@ def forward(
self.patch_len,
device=past_target.device,
)
# add it to the very last patch of past_target of size self.patch_len
# x = x + (
# (past_target[:, -self.patch_len :] - loc) / scale
# ).repeat_interleave(num_parallel_samples, dim=0)

# # the very last patch from past_target
# x = (
Expand Down Expand Up @@ -735,6 +769,15 @@ def forward(
future_samples_flat = future_samples_flat.view(
batch_size * num_parallel_samples, -1
)
# Calculate the current offset for future_time_feat
time_feat_offset = min(
total_samples + self.patch_len, self.prediction_length
)
current_future_time_feat = (
repeat_future_time_feat[:, total_samples:time_feat_offset]
if future_time_feat is not None
else None
)

flow_cond, loc, scale = self.params_from_decoder_output(
past_target=repeat_past_target,
Expand All @@ -744,9 +787,7 @@ def forward(
else None,
future_target=future_samples_flat,
future_observed_values=torch.ones_like(future_samples_flat),
future_time_feat=repeat_future_time_feat
if future_time_feat is not None
else None,
future_time_feat=current_future_time_feat,
)

# Sample new noise for next patch
Expand Down

0 comments on commit 27071bd

Please sign in to comment.