Skip to content

Commit

Permalink
Removed get_granularity. (#265)
Browse files Browse the repository at this point in the history
Replaced with `pandas.tseries.frequencies.to_offset`.

Relates to: #264
  • Loading branch information
Jasper Schulz authored Aug 28, 2019
1 parent 7ef55e4 commit e3fc99f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 65 deletions.
3 changes: 1 addition & 2 deletions src/gluonts/time_feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .holiday import SPECIAL_DATE_FEATURES, SpecialDateFeatureSet

from .lag import get_granularity, get_lags_for_frequency
from .lag import get_lags_for_frequency

__all__ = [
"DayOfMonth",
Expand All @@ -38,7 +38,6 @@
"WeekOfYear",
"SPECIAL_DATE_FEATURES",
"SpecialDateFeatureSet",
"get_granularity",
"get_lags_for_frequency",
]

Expand Down
112 changes: 49 additions & 63 deletions src/gluonts/time_feature/lag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# Third-party imports
import numpy as np
from pandas.tseries.frequencies import to_offset
from pandas.tseries import offsets

# First-party imports
from gluonts.time_feature import (
Expand All @@ -31,26 +33,6 @@
)


def get_granularity(freq_str: str) -> Tuple[int, str]:
"""
Splits a frequency string such as "7D" into the multiple 7 and the base
granularity "D".
Parameters
----------
freq_str
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
"""
freq_regex = r"\s*((\d+)?)\s*([^\d]\w*)"
m = re.match(freq_regex, freq_str)
assert m is not None, "Cannot parse frequency string: %s" % freq_str
groups = m.groups()
multiple = int(groups[1]) if groups[1] is not None else 1
granularity = groups[2]
return multiple, granularity


def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
"""
Returns a list of time features that will be appropriate for the given frequency string.
Expand All @@ -62,38 +44,40 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
"""
multiple, granularity = get_granularity(freq_str)
if granularity == "M":
feature_classes = [MonthOfYear]
elif granularity == "W":
feature_classes = [DayOfMonth, WeekOfYear]
elif granularity in ["D", "B"]:
feature_classes = [DayOfWeek, DayOfMonth, DayOfYear]
elif granularity == "H":
feature_classes = [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear]
elif granularity in ["min", "T"]:
feature_classes = [

features_by_offsets = {
offsets.YearOffset: [],
offsets.MonthOffset: [MonthOfYear],
offsets.Week: [DayOfMonth, WeekOfYear],
offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
offsets.Minute: [
MinuteOfHour,
HourOfDay,
DayOfWeek,
DayOfMonth,
DayOfYear,
]
else:
supported_freq_msg = f"""
Unsupported frequency {freq_str}
],
}

The following frequencies are supported:
offset = to_offset(freq_str)

M - monthly
W - week
D - daily
H - hourly
min - minutely
"""
raise RuntimeError(supported_freq_msg)
for offset_type, feature_classes in features_by_offsets.items():
if isinstance(offset, offset_type):
return [cls() for cls in feature_classes]

return [cls() for cls in feature_classes]
supported_freq_msg = f"""
Unsupported frequency {freq_str}
The following frequencies are supported:
M - monthly
W - week
D - daily
H - hourly
min - minutely
"""
raise RuntimeError(supported_freq_msg)


def _make_lags(middle: int, delta: int) -> np.ndarray:
Expand Down Expand Up @@ -126,8 +110,6 @@ def get_lags_for_frequency(
Maximum number of lags; by default all generated lags are returned
"""

multiple, granularity = get_granularity(freq_str)

# Lags are target values at the same `season` (+/- delta) but in the previous cycle.
def _make_lags_for_minute(multiple, num_cycles=3):
# We use previous ``num_cycles`` hours to generate lags
Expand Down Expand Up @@ -161,29 +143,33 @@ def _make_lags_for_month(multiple, num_cycles=3):
_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)
]

if granularity == "M":
lags = _make_lags_for_month(multiple)
elif granularity == "W":
lags = _make_lags_for_week(multiple)
elif granularity == "D":
lags = _make_lags_for_day(multiple) + _make_lags_for_week(
multiple / 7.0
# multiple, granularity = get_granularity(freq_str)
offset = to_offset(freq_str)

if offset.name == "M":
lags = _make_lags_for_month(offset.n)
elif offset.name == "W-SUN":
lags = _make_lags_for_week(offset.n)
elif offset.name == "D":
lags = _make_lags_for_day(offset.n) + _make_lags_for_week(
offset.n / 7.0
)
elif granularity == "B":
elif offset.name == "B":
# todo find good lags for business day
lags = []
elif granularity == "H":
elif offset.name == "H":
lags = (
_make_lags_for_hour(multiple)
+ _make_lags_for_day(multiple / 24.0)
+ _make_lags_for_week(multiple / (24.0 * 7))
_make_lags_for_hour(offset.n)
+ _make_lags_for_day(offset.n / 24.0)
+ _make_lags_for_week(offset.n / (24.0 * 7))
)
elif granularity == "min":
# minutes
elif offset.name == "T":
lags = (
_make_lags_for_minute(multiple)
+ _make_lags_for_hour(multiple / 60.0)
+ _make_lags_for_day(multiple / (60.0 * 24))
+ _make_lags_for_week(multiple / (60.0 * 24 * 7))
_make_lags_for_minute(offset.n)
+ _make_lags_for_hour(offset.n / 60.0)
+ _make_lags_for_day(offset.n / (60.0 * 24))
+ _make_lags_for_week(offset.n / (60.0 * 24 * 7))
)
else:
raise Exception("invalid frequency")
Expand Down

0 comments on commit e3fc99f

Please sign in to comment.