Skip to content

Commit

Permalink
Backports for v0.12.8 (#2850)
Browse files Browse the repository at this point in the history
* Fix pd.Period serialization (#2827)

Co-authored-by: Abdul Fatir Ansari <[email protected]>

* Remove second call to create_lightning_module on torch estimator (#2834)

Co-authored-by: Pablo Vicente <[email protected]>

* Fix torch DeepAREstimator in case `context_length=1` (#2841)

* Ingore hidden files in FileDataset by default. (#2847)

* fix black

* fix typo

* Remove .to_timestamp() to fix interval plotting (#2800)

Co-authored-by: Abdul Fatir Ansari <[email protected]>
Co-authored-by: Jasper <[email protected]>

---------

Co-authored-by: Abdul Fatir <[email protected]>
Co-authored-by: Abdul Fatir Ansari <[email protected]>
Co-authored-by: Pablo Vicente <[email protected]>
Co-authored-by: Pablo Vicente <[email protected]>
Co-authored-by: Jasper <[email protected]>
  • Loading branch information
6 people authored May 11, 2023
1 parent 74e8fa9 commit 28ed1d0
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/core/serde/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def encode_pd_period(v: pd.Period) -> Any:
"""
return {
"__kind__": Kind.Instance,
"class": "pandas.Timestamp",
"class": "pandas.Period",
"args": encode([str(v)]),
"kwargs": {"freq": v.freqstr},
}
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def FileDataset(
pattern="*",
levels=2,
translate=None,
ignore_hidden=True,
) -> Dataset:
path = Path(path)

Expand All @@ -168,6 +169,9 @@ def FileDataset(
assert path.is_file()
paths = [path]

if ignore_hidden:
paths = [path for path in paths if not path.name.startswith(".")]

loaders = []
for subpath in paths:
if loader_class is None:
Expand Down
12 changes: 5 additions & 7 deletions src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,12 @@ def alpha_for_percentile(p):
i_p50 = len(percentiles_sorted) // 2

p50_data = ps_data[i_p50]
p50_series = pd.Series(data=p50_data, index=self.index.to_timestamp())
p50_series = pd.Series(data=p50_data, index=self.index)
p50_series.plot(color=color, ls="-", label=f"{label_prefix}median")

if show_mean:
mean_data = np.mean(self._sorted_samples, axis=0)
pd.Series(data=mean_data, index=self.index.to_timestamp()).plot(
pd.Series(data=mean_data, index=self.index).plot(
color=color,
ls=":",
label=f"{label_prefix}mean",
Expand All @@ -355,7 +355,7 @@ def alpha_for_percentile(p):
ptile = percentiles_sorted[i]
alpha = alpha_for_percentile(ptile)
plt.fill_between(
self.index.to_timestamp(),
self.index,
ps_data[i],
ps_data[-i - 1],
facecolor=color,
Expand All @@ -366,9 +366,7 @@ def alpha_for_percentile(p):
)
# Hack to create labels for the error intervals. Doesn't actually
# plot anything, because we only pass a single data point
pd.Series(
data=p50_data[:1], index=self.index.to_timestamp()[:1]
).plot(
pd.Series(data=p50_data[:1], index=self.index[:1]).plot(
color=color,
alpha=alpha,
linewidth=10,
Expand Down Expand Up @@ -718,7 +716,7 @@ def plot(self, label=None, output_file=None, keys=None, *args, **kwargs):
keys = self.forecast_keys

for k, v in zip(keys, self.forecast_array):
pd.Series(data=v, index=self.index.to_timestamp()).plot(
pd.Series(data=v, index=self.index).plot(
label=f"{label_prefix}q{k}",
*args,
**kwargs,
Expand Down
17 changes: 11 additions & 6 deletions src/gluonts/torch/model/deepar/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from gluonts.torch.util import (
lagged_sequence_values,
repeat_along_dim,
take_last,
unsqueeze_expand,
)
from gluonts.itertools import prod
Expand Down Expand Up @@ -269,13 +270,15 @@ def unroll_lagged_rnn(
time_feat = (
torch.cat(
(
past_time_feat[..., -self.context_length + 1 :, :],
take_last(
past_time_feat, dim=-2, num=self.context_length - 1
),
future_time_feat,
),
dim=-2,
)
if future_time_feat is not None
else past_time_feat[..., -self.context_length + 1 :, :]
else take_last(past_time_feat, dim=-2, num=self.context_length - 1)
)

features = torch.cat((expanded_static_feat, time_feat), dim=-1)
Expand Down Expand Up @@ -501,14 +504,16 @@ def loss(
)
else:
distr = self.output_distribution(params, scale)
context_target = past_target[:, -self.context_length + 1 :]
context_target = take_last(
past_target, dim=-1, num=self.context_length - 1
)
target = torch.cat(
(context_target, future_target_reshaped),
dim=1,
)
context_observed = past_observed_values[
:, -self.context_length + 1 :
]
context_observed = take_last(
past_observed_values, dim=-1, num=self.context_length - 1
)
observed_values = torch.cat(
(context_observed, future_observed_reshaped), dim=1
)
Expand Down
2 changes: 0 additions & 2 deletions src/gluonts/torch/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def train_model(
num_workers=num_workers,
)

training_network = self.create_lightning_module()

if from_predictor is not None:
training_network.load_state_dict(
from_predictor.network.state_dict()
Expand Down
23 changes: 23 additions & 0 deletions src/gluonts/torch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,29 @@ def slice_along_dim(a: torch.Tensor, dim: int, slice_: slice) -> torch.Tensor:
return a[idx]


def take_last(a: torch.Tensor, dim: int, num: int) -> torch.Tensor:
"""
Take last elements from a given tensor along a given dimension.
Parameters
----------
a
Original tensor to slice.
dim
Dimension to slice over.
num
Number of trailing elements to retain (non-negative).
Returns
-------
torch.Tensor
A tensor with the same size as the input one, except dimension
``dim`` which has length equal to ``num``.
"""
assert num >= 0
return slice_along_dim(a, dim, slice(a.shape[dim] - num, None))


def unsqueeze_expand(a: torch.Tensor, dim: int, size: int) -> torch.Tensor:
"""
Unsqueeze a dimension and expand over it in one go.
Expand Down
9 changes: 9 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@
loss=NegativeLogLikelihood(beta=0.1),
scaling=False,
),
lambda dataset: DeepAREstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
context_length=1,
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
scaling=False,
),
lambda dataset: MQF2MultiHorizonEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
Expand Down

0 comments on commit 28ed1d0

Please sign in to comment.