Skip to content

Commit

Permalink
fix: gsm8k padding only if ctx len is passed (#202)
Browse files Browse the repository at this point in the history
* fix: gsm8k padding only if ctx len is passed

Signed-off-by: Ilango Rajagopal <[email protected]>

* fix: remove length column from gsm8k

Signed-off-by: Ilango Rajagopal <[email protected]>

---------

Signed-off-by: Ilango Rajagopal <[email protected]>
  • Loading branch information
irajagop authored Dec 24, 2024
1 parent dc2c509 commit 93a2f41
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions QEfficient/finetune/dataset/gsm8k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#
# -----------------------------------------------------------------------------

import math
from typing import Dict

from datasets import Dataset, load_dataset
Expand Down Expand Up @@ -46,12 +45,12 @@ def tokenize_and_mask(row: Dict[str, str], *, tokenizer, instruction) -> Dict[st

labels = [-100] * len(ques_ids) + ans_ids

inputs = {"input_ids": input_ids, "labels": labels, "length": len(input_ids)}
inputs = {"input_ids": input_ids, "labels": labels}
return inputs


def pad_to_max_length(row: Dict[str, list], *, tokenizer, max_length: int) -> Dict[str, list]:
length = row["length"]
length = len(row["input_ids"])
return {
"input_ids": row["input_ids"] + [tokenizer.pad_token_id] * (max_length - length),
"attention_mask": [1] * length + [0] * (max_length - length),
Expand All @@ -73,17 +72,13 @@ def get_gsm8k_dataset(
remove_columns=["question", "answer"],
)

if context_length is None:
context_length = max(ds["length"])
context_length = 2 ** round(math.log2(context_length))
# context_length = 128
if context_length is not None:
ds = ds.filter(lambda x: x["length"] <= context_length)
ds = ds.map(
pad_to_max_length,
fn_kwargs={"tokenizer": tokenizer, "max_length": context_length},
)

ds = ds.filter(lambda x: x["length"] <= context_length)
ds = ds.map(
pad_to_max_length,
fn_kwargs={"tokenizer": tokenizer, "max_length": context_length},
remove_columns=["length"],
)
ds.set_format("torch")

return ds

0 comments on commit 93a2f41

Please sign in to comment.