From 93a2f41e9a4183861869a85c2398d57deb161146 Mon Sep 17 00:00:00 2001 From: Ilango Rajagopal Date: Tue, 24 Dec 2024 11:36:12 +0530 Subject: [PATCH] fix: gsm8k padding only if ctx len is passed (#202) * fix: gsm8k padding only if ctx len is passed Signed-off-by: Ilango Rajagopal * fix: remove length column from gsm8k Signed-off-by: Ilango Rajagopal --------- Signed-off-by: Ilango Rajagopal --- QEfficient/finetune/dataset/gsm8k_dataset.py | 21 ++++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/QEfficient/finetune/dataset/gsm8k_dataset.py b/QEfficient/finetune/dataset/gsm8k_dataset.py index 293fd275..cfe73ef9 100644 --- a/QEfficient/finetune/dataset/gsm8k_dataset.py +++ b/QEfficient/finetune/dataset/gsm8k_dataset.py @@ -5,7 +5,6 @@ # # ----------------------------------------------------------------------------- -import math from typing import Dict from datasets import Dataset, load_dataset @@ -46,12 +45,12 @@ def tokenize_and_mask(row: Dict[str, str], *, tokenizer, instruction) -> Dict[st labels = [-100] * len(ques_ids) + ans_ids - inputs = {"input_ids": input_ids, "labels": labels, "length": len(input_ids)} + inputs = {"input_ids": input_ids, "labels": labels} return inputs def pad_to_max_length(row: Dict[str, list], *, tokenizer, max_length: int) -> Dict[str, list]: - length = row["length"] + length = len(row["input_ids"]) return { "input_ids": row["input_ids"] + [tokenizer.pad_token_id] * (max_length - length), "attention_mask": [1] * length + [0] * (max_length - length), @@ -73,17 +72,13 @@ def get_gsm8k_dataset( remove_columns=["question", "answer"], ) - if context_length is None: - context_length = max(ds["length"]) - context_length = 2 ** round(math.log2(context_length)) - # context_length = 128 + if context_length is not None: + ds = ds.filter(lambda x: x["length"] <= context_length) + ds = ds.map( + pad_to_max_length, + fn_kwargs={"tokenizer": tokenizer, "max_length": context_length}, + ) - ds = ds.filter(lambda x: x["length"] <= context_length) - ds = ds.map( - pad_to_max_length, - fn_kwargs={"tokenizer": tokenizer, "max_length": context_length}, - remove_columns=["length"], - ) ds.set_format("torch") return ds