From 59b1ac810069509593acf34f0544f6e53f29b702 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Apr 2024 16:15:22 +0200 Subject: [PATCH] Fix loaders for M5 & ETT datasets (#3155) *Description of changes:* - Fix how `item_id` is obtained for M5 and ETT datasets - Fix `lxml` dependency range By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup --- pyproject.toml | 2 +- requirements/requirements-docs.txt | 1 + src/gluonts/dataset/repository/_ett_small.py | 8 ++++---- src/gluonts/dataset/repository/_m5.py | 6 +++++- src/gluonts/json.py | 3 ++- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 302a8657b9..262389ee98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ filterwarnings = "ignore" [tool.ruff] line-length = 79 -ignore = [ +lint.ignore = [ # line-length is handled by black "E501", diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt index 4e0fc592ae..645e7632c3 100644 --- a/requirements/requirements-docs.txt +++ b/requirements/requirements-docs.txt @@ -3,6 +3,7 @@ ipykernel~=6.5 nbconvert~=6.5.1 nbsphinx~=0.8.8 notedown +lxml~=5.1.0 pytest-runner~=2.11 recommonmark sphinx~=4.0 diff --git a/src/gluonts/dataset/repository/_ett_small.py b/src/gluonts/dataset/repository/_ett_small.py index 757f595d96..8370d4feaa 100644 --- a/src/gluonts/dataset/repository/_ett_small.py +++ b/src/gluonts/dataset/repository/_ett_small.py @@ -39,7 +39,7 @@ def generate_ett_small_dataset( dfs.append(df) test = [] - for df in dfs: + for region, df in enumerate(dfs): start = pd.Period(df["date"][0], freq=freq) for col in df.columns: if col in ["date"]: @@ -47,13 +47,13 @@ def generate_ett_small_dataset( test.append( { "start": start, - "item_id": col, + "item_id": f"{col}_{region}", "target": df[col].values, } ) train = [] - for df in dfs: + for region, df in enumerate(dfs): start = pd.Period(df["date"][0], freq=freq) for col in df.columns: if col in ["date"]: @@ -61,7 +61,7 @@ def generate_ett_small_dataset( train.append( { "start": start, - "item_id": col, + "item_id": f"{col}_{region}", "target": df[col].values[:-prediction_length], } ) diff --git a/src/gluonts/dataset/repository/_m5.py b/src/gluonts/dataset/repository/_m5.py index cea591dc5f..6d9e6b80a8 100644 --- a/src/gluonts/dataset/repository/_m5.py +++ b/src/gluonts/dataset/repository/_m5.py @@ -113,7 +113,11 @@ def generate_m5_dataset( ] # Build target series - train_ids = sales_train_validation["item_id"] + train_ids = ( + sales_train_validation["item_id"].str + + "_" + + sales_train_validation["store_id"].str + ) train_df = sales_train_validation.drop( ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"], axis=1, diff --git a/src/gluonts/json.py b/src/gluonts/json.py index 58e9e1fd8a..c92ce4f8cc 100644 --- a/src/gluonts/json.py +++ b/src/gluonts/json.py @@ -26,7 +26,8 @@ character if set to `True`. """ -__all__ = [ # noqa +# ruff: noqa: F822 +__all__ = [ "variant", "dump", "dumps",