From 5fb2d335be04dd61d3717b66e4614b2986895a13 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Apr 2024 16:36:46 +0000 Subject: [PATCH 1/2] Fix item_id for M5 dataset --- src/gluonts/dataset/repository/_m5.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gluonts/dataset/repository/_m5.py b/src/gluonts/dataset/repository/_m5.py index 6d9e6b80a8..71bef110d5 100644 --- a/src/gluonts/dataset/repository/_m5.py +++ b/src/gluonts/dataset/repository/_m5.py @@ -112,16 +112,17 @@ def generate_m5_dataset( len(state_ids_un), ] - # Build target series - train_ids = ( - sales_train_validation["item_id"].str + # Compute unique ID in case `id` column is missing + sales_train_validation["id"] = ( + sales_train_validation["item_id"].astype("str") + "_" - + sales_train_validation["store_id"].str + + sales_train_validation["store_id"].astype("str") ) + # Build target series + train_ids = sales_train_validation["id"] train_df = sales_train_validation.drop( ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"], axis=1, - errors="ignore", ) test_target_values = train_df.values.copy() train_target_values = [ts[:-prediction_length] for ts in train_df.values] From 6967ef94654a92bc471aa4e841391d1e124c9c86 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Apr 2024 16:52:24 +0000 Subject: [PATCH 2/2] Address PR comment --- src/gluonts/dataset/repository/_m5.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gluonts/dataset/repository/_m5.py b/src/gluonts/dataset/repository/_m5.py index 71bef110d5..58af999f0a 100644 --- a/src/gluonts/dataset/repository/_m5.py +++ b/src/gluonts/dataset/repository/_m5.py @@ -113,11 +113,12 @@ def generate_m5_dataset( ] # Compute unique ID in case `id` column is missing - sales_train_validation["id"] = ( - sales_train_validation["item_id"].astype("str") - + "_" - + sales_train_validation["store_id"].astype("str") - ) + if "id" not in sales_train_validation.columns: + sales_train_validation["id"] = ( + sales_train_validation["item_id"].astype("str") + + "_" + + sales_train_validation["store_id"].astype("str") + ) # Build target series train_ids = sales_train_validation["id"] train_df = sales_train_validation.drop(