Skip to content

Commit

Permalink
Fix item_id for M5 dataset (awslabs#3156)
Browse files Browse the repository at this point in the history
*Description of changes:*
- Fix the `item_id` calculation for the M5 dataset in case `id` column
is missing in the original dataset

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
shchur authored and kashif committed Jun 15, 2024
1 parent 59b1ac8 commit 4d33ea6
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/gluonts/dataset/repository/_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,18 @@ def generate_m5_dataset(
len(state_ids_un),
]

# Compute unique ID in case `id` column is missing
if "id" not in sales_train_validation.columns:
sales_train_validation["id"] = (
sales_train_validation["item_id"].astype("str")
+ "_"
+ sales_train_validation["store_id"].astype("str")
)
# Build target series
train_ids = (
sales_train_validation["item_id"].str
+ "_"
+ sales_train_validation["store_id"].str
)
train_ids = sales_train_validation["id"]
train_df = sales_train_validation.drop(
["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"],
axis=1,
errors="ignore",
)
test_target_values = train_df.values.copy()
train_target_values = [ts[:-prediction_length] for ts in train_df.values]
Expand Down

0 comments on commit 4d33ea6

Please sign in to comment.