Skip to content

Commit

Permalink
add initial flow matching loss
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 10, 2024
1 parent 70fe524 commit b9b5a45
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 78 deletions.
13 changes: 6 additions & 7 deletions src/gluonts/torch/model/seg_diff/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class SegDiffEstimator(PyTorchLightningEstimator):
Weight decay regularization parameter (default: ``1e-8``).
scaling
Scaling parameter can be "mean", "std" or None.
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput()).
# distr_output
# Distribution to use to evaluate observations and sample predictions
# (default: StudentTOutput()).
batch_size
The size of the batches to be used for training (default: 32).
num_batches_per_epoch
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
lr: float = 1e-3,
weight_decay: float = 1e-8,
scaling: Optional[str] = "mean",
distr_output: Output = StudentTOutput(),
# distr_output: Output = StudentTOutput(),
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -141,7 +141,7 @@ def __init__(

self.lr = lr
self.weight_decay = weight_decay
self.distr_output = distr_output
# self.distr_output = distr_output
self.scaling = scaling
self.patch_len = patch_len
self.d_model = d_model
Expand Down Expand Up @@ -201,7 +201,7 @@ def create_lightning_module(self) -> pl.LightningModule:
"activation": self.activation,
"norm_first": self.norm_first,
"num_decoder_layers": self.num_decoder_layers,
"distr_output": self.distr_output,
# "distr_output": self.distr_output,
"scaling": self.scaling,
},
)
Expand Down Expand Up @@ -229,7 +229,6 @@ def _create_instance_splitter(
+ (
[FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else []
),
dummy_value=self.distr_output.value_in_support,
)

def create_training_data_loader(
Expand Down
Loading

0 comments on commit b9b5a45

Please sign in to comment.