Skip to content

Commit

Permalink
Sorting in dataset, sample based on length for DDP
Browse files Browse the repository at this point in the history
Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta committed Jan 5, 2025
1 parent 7350c96 commit 1734001
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
6 changes: 3 additions & 3 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def main(**kwargs):
else:
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")

longest_seq_length, longest_seq_ix = get_longest_seq_length(
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
)
longest_seq_length, _ = get_longest_seq_length(torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset]))
else:
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)
print(
f"The longest sequence length in the train data is {longest_seq_length}, "
f"passed context length is {train_config.context_length} and overall model's context length is "
Expand Down
62 changes: 62 additions & 0 deletions QEfficient/finetune/data/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import random
from itertools import islice

import numpy as np
import torch


class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool = True) -> None:
if isinstance(next(iter(data_source)), dict):
first_key = next(iter(next(iter(data_source)).keys()))
self.lengths = [len(d[first_key]) for d in data_source]
else:
self.lengths = [len(d) for d in data_source]
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle

def __iter__(self):
ids = np.argsort(self.lengths, kind="mergesort")
if self.drop_last:
ids = ids[: len(ids) // self.batch_size * self.batch_size]

batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]

if self.shuffle:
random.shuffle(batches)

for b in batches:
yield b

def __len__(self):
if self.drop_last:
return len(self.lengths) // self.batch_size
else:
return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)


class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(
self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0
) -> None:
random.seed(seed)
self.batch_sampler = LengthBasedBatchSampler(
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
)
self.num_replicas = num_replicas
self.rank = rank

def __iter__(self):
max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)

def __len__(self):
return len(self.batch_sampler) // self.num_replicas
35 changes: 13 additions & 22 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq

import QEfficient.finetune.configs.dataset_config as datasets
from QEfficient.finetune.configs.peft_config import lora_config, prefix_config
from QEfficient.finetune.configs.training import train_config
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC


Expand Down Expand Up @@ -63,41 +65,30 @@ def generate_peft_config(train_config, kwargs):

def generate_dataset_config(train_config, kwargs):
names = tuple(DATASET_PREPROC.keys())

assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"

dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()

update_config(dataset_config, **kwargs)

return dataset_config


# def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
# kwargs = {}
# batch_size = (
# train_config.batch_size_training
# if mode == "train"
# else train_config.val_batch_size
# )
# if train_config.batching_strategy == "padding":
# kwargs["batch_sampler"] = LengthBasedBatchSampler(
# dataset, batch_size, drop_last=True, shuffle=mode == "train"
# )
# kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
# # kwargs["collate_fn"] = default_data_collator
# return kwargs


def get_dataloader_kwargs(train_config: train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
# use a distributed sampler to split data between devices
if train_config.enable_ddp:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
else:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
return kwargs

0 comments on commit 1734001

Please sign in to comment.