diff --git a/src/gluonts/shell/sagemaker/dyn.py b/src/gluonts/shell/sagemaker/dyn.py index 0717c85f93..d0e4bbe5ea 100644 --- a/src/gluonts/shell/sagemaker/dyn.py +++ b/src/gluonts/shell/sagemaker/dyn.py @@ -30,6 +30,8 @@ from pathlib import Path from typing import Optional +from gluonts.util import safe_extractall + class Installer: def __init__(self, packages): @@ -63,10 +65,12 @@ def pip_install(self, path: Path): def install(self, path): if path.is_file(): if tarfile.is_tarfile(path): - self.handle_archive(tarfile.open, path) + self.handle_archive(tarfile.open, safe_extractall, path) elif zipfile.is_zipfile(path): - self.handle_archive(zipfile.ZipFile, path) + self.handle_archive( + zipfile.ZipFile, zipfile.ZipFile.extractall, path + ) elif path.suffix == ".py": self.copy_install(path) @@ -80,14 +84,14 @@ def install(self, path): for subpath in path.iterdir(): self.install(subpath) - def handle_archive(self, open_fn, path): + def handle_archive(self, open_fn, extractall_fn, path): with open_fn(path) as archive: tempdir = tempfile.mkdtemp() self.cleanups.append( partial(shutil.rmtree, tempdir, ignore_errors=True) ) - archive.extractall(tempdir) + extractall_fn(archive, tempdir) self.install(Path(tempdir)) diff --git a/src/gluonts/time_feature/_base.py b/src/gluonts/time_feature/_base.py index ab38bb10ec..418a5299b1 100644 --- a/src/gluonts/time_feature/_base.py +++ b/src/gluonts/time_feature/_base.py @@ -176,7 +176,17 @@ def __call__(self, index: pd.PeriodIndex) -> np.ndarray: def norm_freq_str(freq_str: str) -> str: - return freq_str.split("-")[0] + base_freq = freq_str.split("-")[0] + + # Pandas has start and end frequencies, e.g `AS` and `A` for yearly start + # and yearly end frequencies. We don't make that difference and instead + # rely only on the end frequencies which don't have the `S` prefix. + # Note: Secondly ("S") frequency exists, where we don't want to remove the + # "S"! + if len(base_freq) >= 2 and base_freq.endswith("S"): + return base_freq[:-1] + + return base_freq def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index b336a2e8ce..80f7122a3c 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -174,8 +174,8 @@ def train_model( validation_data_loader = None - with env._let(max_idle_transforms=max(len(training_data), 100)): - if validation_data is not None: + if validation_data is not None: + with env._let(max_idle_transforms=max(len(validation_data), 100)): transformed_validation_data = transformation.apply( validation_data, is_train=True ) diff --git a/test/time_feature/test_base.py b/test/time_feature/test_base.py new file mode 100644 index 0000000000..b97f43cb93 --- /dev/null +++ b/test/time_feature/test_base.py @@ -0,0 +1,31 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from pandas.tseries.frequencies import to_offset + +from gluonts.time_feature import norm_freq_str + + +def test_norm_freq_str(): + assert norm_freq_str(to_offset("Y").name) == "A" + assert norm_freq_str(to_offset("YS").name) == "A" + assert norm_freq_str(to_offset("A").name) == "A" + assert norm_freq_str(to_offset("AS").name) == "A" + + assert norm_freq_str(to_offset("Q").name) == "Q" + assert norm_freq_str(to_offset("QS").name) == "Q" + + assert norm_freq_str(to_offset("M").name) == "M" + assert norm_freq_str(to_offset("MS").name) == "M" + + assert norm_freq_str(to_offset("S").name) == "S"