Skip to content

Commit

Permalink
fix: don't add eot token if add_eot_token knob is False
Browse files Browse the repository at this point in the history
  • Loading branch information
EeyoreLee committed Dec 20, 2023
1 parent 8e4cdd8 commit 1a9875d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
37 changes: 20 additions & 17 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __getitem__(self, idx):


def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
end_of_conversation_token, max_seq_len):
end_of_conversation_token, max_seq_len, add_eot_token=True):
prompt_dataset = []
chosen_dataset = []
reject_dataset = []
Expand All @@ -172,7 +172,8 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
chosen_sentence = raw_dataset.get_prompt_and_chosen(
tmp_data) # the accept response
if chosen_sentence is not None:
chosen_sentence += end_of_conversation_token
if add_eot_token is True:
chosen_sentence += end_of_conversation_token
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
Expand All @@ -195,8 +196,9 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
reject_sentence = raw_dataset.get_prompt_and_rejected(
tmp_data) # the accept response
if chosen_sentence is not None and reject_sentence is not None:
chosen_sentence += end_of_conversation_token # the accept response
reject_sentence += end_of_conversation_token
if add_eot_token is True:
chosen_sentence += end_of_conversation_token # the accept response
reject_sentence += end_of_conversation_token
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
Expand All @@ -207,12 +209,7 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
padding="max_length",
truncation=True,
return_tensors="pt")
chosen_token["input_ids"] = chosen_token["input_ids"]
chosen_token["attention_mask"] = chosen_token["attention_mask"]
chosen_dataset.append(chosen_token)

reject_token["input_ids"] = reject_token["input_ids"]
reject_token["attention_mask"] = reject_token["attention_mask"]
reject_dataset.append(reject_token)
print(
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
Expand Down Expand Up @@ -241,7 +238,7 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,

def create_dataset(local_rank, dataset_name, data_split, output_path,
train_phase, seed, tokenizer, end_of_conversation_token,
max_seq_len, rebuild):
max_seq_len, rebuild, add_eot_token=True):
raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank)
train_dataset = raw_dataset.get_train_data()
train_index = get_raw_dataset_split_index(local_rank, output_path,
Expand All @@ -253,7 +250,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
train_dataset = create_dataset_split(train_dataset, raw_dataset,
train_phase, tokenizer,
end_of_conversation_token,
max_seq_len)
max_seq_len, add_eot_token=add_eot_token)

eval_dataset = raw_dataset.get_eval_data()
eval_index = get_raw_dataset_split_index(local_rank, output_path,
Expand All @@ -264,7 +261,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
eval_dataset = Subset(eval_dataset, eval_index)
eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase,
tokenizer, end_of_conversation_token,
max_seq_len)
max_seq_len, add_eot_token=add_eot_token)
return train_dataset, eval_dataset


Expand All @@ -277,11 +274,14 @@ def create_prompt_dataset(local_rank,
tokenizer,
max_seq_len,
end_of_conversation_token="<|endoftext|>",
sft_only_data_path=[],
reload=False):
sft_only_data_path=None,
reload=False,
add_eot_token=True):
"""
Creates the prompt dataset
"""
if sft_only_data_path is None:
sft_only_data_path = []
os.makedirs(output_path, exist_ok=True)
fname = "_".join(data_path)
sft_cache_key = "_".join(sft_only_data_path)
Expand Down Expand Up @@ -311,7 +311,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
rebuild=reload,
add_eot_token=add_eot_token)
else: # Blending datasets.
train_datasets = []
eval_datasets = []
Expand All @@ -328,7 +329,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
rebuild=reload,
add_eot_token=add_eot_token)
train_datasets.append(train_dataset)
eval_datasets.append(eval_dataset)
train_size += len(train_dataset)
Expand Down Expand Up @@ -357,7 +359,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
rebuild=reload,
add_eot_token=add_eot_token)
sft_train_datasets.append(sft_train_dataset)
sft_eval_datasets.append(sft_eval_dataset)
sft_train_size += len(sft_train_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def main():
train_dataset, eval_dataset = create_prompt_dataset(
args.local_rank, args.data_path, args.data_split,
args.data_output_path, train_phase, args.seed, tokenizer,
args.max_seq_len)
args.max_seq_len, add_eot_token=args.add_eot_token)

# DataLoaders creation:
data_collator = DataCollatorReward()
Expand Down

0 comments on commit 1a9875d

Please sign in to comment.