diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index b5e3eb628..e4906f64d 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -165,9 +165,9 @@ def main(**kwargs): else: print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}") - longest_seq_length, longest_seq_ix = get_longest_seq_length( - torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset]) - ) + longest_seq_length, _ = get_longest_seq_length(torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])) + else: + longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset) print( f"The longest sequence length in the train data is {longest_seq_length}, " f"passed context length is {train_config.context_length} and overall model's context length is " diff --git a/QEfficient/finetune/data/sampler.py b/QEfficient/finetune/data/sampler.py new file mode 100644 index 000000000..6050f2e92 --- /dev/null +++ b/QEfficient/finetune/data/sampler.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import random +from itertools import islice + +import numpy as np +import torch + + +class LengthBasedBatchSampler(torch.utils.data.BatchSampler): + def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool = True) -> None: + if isinstance(next(iter(data_source)), dict): + first_key = next(iter(next(iter(data_source)).keys())) + self.lengths = [len(d[first_key]) for d in data_source] + else: + self.lengths = [len(d) for d in data_source] + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + + def __iter__(self): + ids = np.argsort(self.lengths, kind="mergesort") + if self.drop_last: + ids = ids[: len(ids) // self.batch_size * self.batch_size] + + batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)] + + if self.shuffle: + random.shuffle(batches) + + for b in batches: + yield b + + def __len__(self): + if self.drop_last: + return len(self.lengths) // self.batch_size + else: + return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) + + +class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): + def __init__( + self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0 + ) -> None: + random.seed(seed) + self.batch_sampler = LengthBasedBatchSampler( + data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle + ) + self.num_replicas = num_replicas + self.rank = rank + + def __iter__(self): + max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas + return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) + + def __len__(self): + return len(self.batch_sampler) // self.num_replicas diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index d2f71c76b..0f555bccb 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -16,10 +16,12 @@ PrefixTuningConfig, ) from transformers import default_data_collator +from transformers.data import DataCollatorForSeq2Seq import QEfficient.finetune.configs.dataset_config as datasets from QEfficient.finetune.configs.peft_config import lora_config, prefix_config from QEfficient.finetune.configs.training import train_config +from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC @@ -63,41 +65,30 @@ def generate_peft_config(train_config, kwargs): def generate_dataset_config(train_config, kwargs): names = tuple(DATASET_PREPROC.keys()) - assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" - dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() - update_config(dataset_config, **kwargs) return dataset_config -# def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): -# kwargs = {} -# batch_size = ( -# train_config.batch_size_training -# if mode == "train" -# else train_config.val_batch_size -# ) -# if train_config.batching_strategy == "padding": -# kwargs["batch_sampler"] = LengthBasedBatchSampler( -# dataset, batch_size, drop_last=True, shuffle=mode == "train" -# ) -# kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) -# # kwargs["collate_fn"] = default_data_collator -# return kwargs - - def get_dataloader_kwargs(train_config: train_config, dataset, dataset_processer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size - kwargs["batch_size"] = batch_size - kwargs["drop_last"] = True - kwargs["collate_fn"] = default_data_collator # use a distributed sampler to split data between devices if train_config.enable_ddp: + kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( + dataset, + batch_size=batch_size, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) + else: kwargs["sampler"] = data_utils.DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True ) + kwargs["batch_size"] = batch_size + kwargs["drop_last"] = True + kwargs["collate_fn"] = default_data_collator return kwargs