Skip to content

Commit

Permalink
Backports v0.11.11 (#2675)
Browse files Browse the repository at this point in the history
* Faster index building in PandasDataset (#2663)

* Speed up `PandasDataset.from_long_dataframe` (#2665)

* Fix `DateSplitter` when split date is before start (#2670)

* Remove creation of ragged sequences in MultivariateGrouper (#2671)

Co-authored-by: Abdul Fatir Ansari <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

---------

Co-authored-by: Huibin Shen <[email protected]>
Co-authored-by: Gerald Woo <[email protected]>
Co-authored-by: Abdul Fatir <[email protected]>
Co-authored-by: Abdul Fatir Ansari <[email protected]>
  • Loading branch information
5 people authored Feb 20, 2023
1 parent 97a67ee commit 5231fa0
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
33 changes: 18 additions & 15 deletions src/gluonts/dataset/multivariate_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import pandas as pd

from gluonts.itertools import batcher
from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry, Dataset, ListDataset
from gluonts.dataset.field_names import FieldName
Expand Down Expand Up @@ -128,10 +129,14 @@ def _group_all(self, dataset: Dataset) -> Dataset:
def _prepare_train_data(self, dataset: Dataset) -> Dataset:
logging.info("group training time series to datasets")

# Creates a single multivariate time series from the
# univariate series in the dataset
grouped_data = self._transform_target(self._align_data_entry, dataset)
for data in dataset:
fields = data.keys()
break
grouped_data[FieldName.TARGET] = np.vstack(
grouped_data[FieldName.TARGET]
)

fields = next(iter(dataset), {}).keys()
if FieldName.FEAT_DYNAMIC_REAL in fields:
grouped_data[FieldName.FEAT_DYNAMIC_REAL] = np.vstack(
[data[FieldName.FEAT_DYNAMIC_REAL] for data in dataset],
Expand All @@ -150,21 +155,19 @@ def _prepare_test_data(self, dataset: Dataset) -> Dataset:
logging.info("group test time series to datasets")

grouped_data = self._transform_target(self._left_pad_data, dataset)
# splits test dataset with rolling date into N R^d time series where
# N is the number of rolling evaluation dates
split_dataset = np.split(
grouped_data[FieldName.TARGET], self.num_test_dates
)

# Splits test dataset with rolling date into N R^d time series,
# where N is the number of rolling evaluation dates
assert len(grouped_data[FieldName.TARGET]) % self.num_test_dates == 0
split_size = len(grouped_data[FieldName.TARGET]) // self.num_test_dates
split_dataset = batcher(grouped_data[FieldName.TARGET], split_size)

fields = next(iter(dataset), {}).keys()
all_entries = list()
for dataset_at_test_date in split_dataset:
grouped_data = dict()
grouped_data[FieldName.TARGET] = np.array(
list(dataset_at_test_date), dtype=np.float32
)
for data in dataset:
fields = data.keys()
break
grouped_data[FieldName.TARGET] = np.vstack(dataset_at_test_date)

if FieldName.FEAT_DYNAMIC_REAL in fields:
grouped_data[FieldName.FEAT_DYNAMIC_REAL] = np.vstack(
[data[FieldName.FEAT_DYNAMIC_REAL] for data in dataset],
Expand Down Expand Up @@ -202,7 +205,7 @@ def _left_pad_data(self, data: DataEntry) -> np.ndarray:

@staticmethod
def _transform_target(funcs, dataset: Dataset) -> DataEntry:
return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}
return {FieldName.TARGET: [funcs(data) for data in dataset]}

def _restrict_max_dimensionality(self, data: DataEntry) -> DataEntry:
"""
Expand Down
13 changes: 11 additions & 2 deletions src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def _pair_to_dataentry(
df = df.to_frame(name=self.target)

if self.timestamp:
df.index = pd.PeriodIndex(df[self.timestamp], freq=self.freq)
df.index = pd.DatetimeIndex(df[self.timestamp]).to_period(
freq=self.freq
)

if not self.assume_sorted:
df.sort_index(inplace=True)
Expand Down Expand Up @@ -187,7 +189,11 @@ def __str__(self) -> str:

@classmethod
def from_long_dataframe(
cls, dataframe: pd.DataFrame, item_id: str, **kwargs
cls,
dataframe: pd.DataFrame,
item_id: str,
timestamp: Optional[str] = None,
**kwargs,
) -> "PandasDataset":
"""
Construct ``PandasDataset`` out of a long dataframe. A long dataframe
Expand All @@ -211,6 +217,9 @@ def from_long_dataframe(
PandasDataset
Gluonts dataset based on ``pandas.DataFrame``s.
"""
if timestamp is not None:
dataframe.index = pd.to_datetime(dataframe[timestamp])

if not isinstance(dataframe.index, DatetimeIndexOpsMixin):
dataframe.index = pd.to_datetime(dataframe.index)
return cls(dataframes=dataframe.groupby(item_id), **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions src/gluonts/dataset/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def periods_between(
>>> periods_between(start, end)
9
"""
if start > end:
return 0
return ((end - start).n // start.freq.n) + 1


Expand Down
15 changes: 15 additions & 0 deletions test/dataset/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ def test_time_series_slice():
pd.Period("2021-01-01 11", "2H"),
6,
),
(
pd.Period("2021-03-04", freq="2D"),
pd.Period("2021-03-02", freq="2D"),
0,
),
(
pd.Period("2021-03-04", freq="2D"),
pd.Period("2021-03-04", freq="2D"),
1,
),
(
pd.Period("2021-03-03 23:00", freq="30T"),
pd.Period("2021-03-03 03:29", freq="30T"),
0,
),
],
)
def test_periods_between(start, end, count):
Expand Down

0 comments on commit 5231fa0

Please sign in to comment.