From e3fc99f92c1b03943d18660c0d9056c4093f3a33 Mon Sep 17 00:00:00 2001 From: Jasper Schulz Date: Tue, 27 Aug 2019 21:47:50 -0400 Subject: [PATCH] Removed `get_granularity`. (#265) Replaced with `pandas.tseries.frequencies.to_offset`. Relates to: #264 --- src/gluonts/time_feature/__init__.py | 3 +- src/gluonts/time_feature/lag.py | 112 ++++++++++++--------------- 2 files changed, 50 insertions(+), 65 deletions(-) diff --git a/src/gluonts/time_feature/__init__.py b/src/gluonts/time_feature/__init__.py index 63d2d89493..3b95f71bc5 100644 --- a/src/gluonts/time_feature/__init__.py +++ b/src/gluonts/time_feature/__init__.py @@ -25,7 +25,7 @@ from .holiday import SPECIAL_DATE_FEATURES, SpecialDateFeatureSet -from .lag import get_granularity, get_lags_for_frequency +from .lag import get_lags_for_frequency __all__ = [ "DayOfMonth", @@ -38,7 +38,6 @@ "WeekOfYear", "SPECIAL_DATE_FEATURES", "SpecialDateFeatureSet", - "get_granularity", "get_lags_for_frequency", ] diff --git a/src/gluonts/time_feature/lag.py b/src/gluonts/time_feature/lag.py index a28994244c..1acef476f8 100644 --- a/src/gluonts/time_feature/lag.py +++ b/src/gluonts/time_feature/lag.py @@ -17,6 +17,8 @@ # Third-party imports import numpy as np +from pandas.tseries.frequencies import to_offset +from pandas.tseries import offsets # First-party imports from gluonts.time_feature import ( @@ -31,26 +33,6 @@ ) -def get_granularity(freq_str: str) -> Tuple[int, str]: - """ - Splits a frequency string such as "7D" into the multiple 7 and the base - granularity "D". - - Parameters - ---------- - - freq_str - Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. - """ - freq_regex = r"\s*((\d+)?)\s*([^\d]\w*)" - m = re.match(freq_regex, freq_str) - assert m is not None, "Cannot parse frequency string: %s" % freq_str - groups = m.groups() - multiple = int(groups[1]) if groups[1] is not None else 1 - granularity = groups[2] - return multiple, granularity - - def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: """ Returns a list of time features that will be appropriate for the given frequency string. @@ -62,38 +44,40 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. """ - multiple, granularity = get_granularity(freq_str) - if granularity == "M": - feature_classes = [MonthOfYear] - elif granularity == "W": - feature_classes = [DayOfMonth, WeekOfYear] - elif granularity in ["D", "B"]: - feature_classes = [DayOfWeek, DayOfMonth, DayOfYear] - elif granularity == "H": - feature_classes = [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear] - elif granularity in ["min", "T"]: - feature_classes = [ + + features_by_offsets = { + offsets.YearOffset: [], + offsets.MonthOffset: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ MinuteOfHour, HourOfDay, DayOfWeek, DayOfMonth, DayOfYear, - ] - else: - supported_freq_msg = f""" - Unsupported frequency {freq_str} + ], + } - The following frequencies are supported: + offset = to_offset(freq_str) - M - monthly - W - week - D - daily - H - hourly - min - minutely - """ - raise RuntimeError(supported_freq_msg) + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] - return [cls() for cls in feature_classes] + supported_freq_msg = f""" + Unsupported frequency {freq_str} + + The following frequencies are supported: + + M - monthly + W - week + D - daily + H - hourly + min - minutely + """ + raise RuntimeError(supported_freq_msg) def _make_lags(middle: int, delta: int) -> np.ndarray: @@ -126,8 +110,6 @@ def get_lags_for_frequency( Maximum number of lags; by default all generated lags are returned """ - multiple, granularity = get_granularity(freq_str) - # Lags are target values at the same `season` (+/- delta) but in the previous cycle. def _make_lags_for_minute(multiple, num_cycles=3): # We use previous ``num_cycles`` hours to generate lags @@ -161,29 +143,33 @@ def _make_lags_for_month(multiple, num_cycles=3): _make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1) ] - if granularity == "M": - lags = _make_lags_for_month(multiple) - elif granularity == "W": - lags = _make_lags_for_week(multiple) - elif granularity == "D": - lags = _make_lags_for_day(multiple) + _make_lags_for_week( - multiple / 7.0 + # multiple, granularity = get_granularity(freq_str) + offset = to_offset(freq_str) + + if offset.name == "M": + lags = _make_lags_for_month(offset.n) + elif offset.name == "W-SUN": + lags = _make_lags_for_week(offset.n) + elif offset.name == "D": + lags = _make_lags_for_day(offset.n) + _make_lags_for_week( + offset.n / 7.0 ) - elif granularity == "B": + elif offset.name == "B": # todo find good lags for business day lags = [] - elif granularity == "H": + elif offset.name == "H": lags = ( - _make_lags_for_hour(multiple) - + _make_lags_for_day(multiple / 24.0) - + _make_lags_for_week(multiple / (24.0 * 7)) + _make_lags_for_hour(offset.n) + + _make_lags_for_day(offset.n / 24.0) + + _make_lags_for_week(offset.n / (24.0 * 7)) ) - elif granularity == "min": + # minutes + elif offset.name == "T": lags = ( - _make_lags_for_minute(multiple) - + _make_lags_for_hour(multiple / 60.0) - + _make_lags_for_day(multiple / (60.0 * 24)) - + _make_lags_for_week(multiple / (60.0 * 24 * 7)) + _make_lags_for_minute(offset.n) + + _make_lags_for_hour(offset.n / 60.0) + + _make_lags_for_day(offset.n / (60.0 * 24)) + + _make_lags_for_week(offset.n / (60.0 * 24 * 7)) ) else: raise Exception("invalid frequency")