From 21041e848f4b4ceee11b4b14b2a60172f6af7efa Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 3 Jan 2025 08:08:31 +0000 Subject: [PATCH] Sort dataset based on length for DDP Signed-off-by: Mamta Singh --- QEfficient/finetune/data/sampler.py | 62 +++++++++++++++++++++++ QEfficient/finetune/utils/config_utils.py | 13 +++-- 2 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 QEfficient/finetune/data/sampler.py diff --git a/QEfficient/finetune/data/sampler.py b/QEfficient/finetune/data/sampler.py new file mode 100644 index 000000000..6050f2e92 --- /dev/null +++ b/QEfficient/finetune/data/sampler.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import random +from itertools import islice + +import numpy as np +import torch + + +class LengthBasedBatchSampler(torch.utils.data.BatchSampler): + def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool = True) -> None: + if isinstance(next(iter(data_source)), dict): + first_key = next(iter(next(iter(data_source)).keys())) + self.lengths = [len(d[first_key]) for d in data_source] + else: + self.lengths = [len(d) for d in data_source] + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + + def __iter__(self): + ids = np.argsort(self.lengths, kind="mergesort") + if self.drop_last: + ids = ids[: len(ids) // self.batch_size * self.batch_size] + + batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)] + + if self.shuffle: + random.shuffle(batches) + + for b in batches: + yield b + + def __len__(self): + if self.drop_last: + return len(self.lengths) // self.batch_size + else: + return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) + + +class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): + def __init__( + self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0 + ) -> None: + random.seed(seed) + self.batch_sampler = LengthBasedBatchSampler( + data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle + ) + self.num_replicas = num_replicas + self.rank = rank + + def __iter__(self): + max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas + return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) + + def __len__(self): + return len(self.batch_sampler) // self.num_replicas diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index d2f71c76b..b4f1b33bb 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -9,7 +9,6 @@ from dataclasses import asdict import torch.distributed as dist -import torch.utils.data as data_utils from peft import ( AdaptionPromptConfig, LoraConfig, @@ -20,6 +19,7 @@ import QEfficient.finetune.configs.dataset_config as datasets from QEfficient.finetune.configs.peft_config import lora_config, prefix_config from QEfficient.finetune.configs.training import train_config +from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC @@ -97,7 +97,14 @@ def get_dataloader_kwargs(train_config: train_config, dataset, dataset_processer kwargs["collate_fn"] = default_data_collator # use a distributed sampler to split data between devices if train_config.enable_ddp: - kwargs["sampler"] = data_utils.DistributedSampler( - dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True + kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( + dataset, + batch_size=batch_size, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=mode == "train", ) + # kwargs["sampler"] = data_utils.DistributedSampler( + # dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True + # ) return kwargs