Skip to content

Commit

Permalink
Sort dataset based on length for DDP
Browse files Browse the repository at this point in the history
Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta committed Jan 3, 2025
1 parent 25f1ce1 commit 21041e8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
62 changes: 62 additions & 0 deletions QEfficient/finetune/data/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import random
from itertools import islice

import numpy as np
import torch


class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool = True) -> None:
if isinstance(next(iter(data_source)), dict):
first_key = next(iter(next(iter(data_source)).keys()))
self.lengths = [len(d[first_key]) for d in data_source]
else:
self.lengths = [len(d) for d in data_source]
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle

def __iter__(self):
ids = np.argsort(self.lengths, kind="mergesort")
if self.drop_last:
ids = ids[: len(ids) // self.batch_size * self.batch_size]

batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]

if self.shuffle:
random.shuffle(batches)

for b in batches:
yield b

def __len__(self):
if self.drop_last:
return len(self.lengths) // self.batch_size
else:
return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)


class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(
self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0
) -> None:
random.seed(seed)
self.batch_sampler = LengthBasedBatchSampler(
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
)
self.num_replicas = num_replicas
self.rank = rank

def __iter__(self):
max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)

def __len__(self):
return len(self.batch_sampler) // self.num_replicas
13 changes: 10 additions & 3 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from dataclasses import asdict

import torch.distributed as dist
import torch.utils.data as data_utils
from peft import (
AdaptionPromptConfig,
LoraConfig,
Expand All @@ -20,6 +19,7 @@
import QEfficient.finetune.configs.dataset_config as datasets
from QEfficient.finetune.configs.peft_config import lora_config, prefix_config
from QEfficient.finetune.configs.training import train_config
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC


Expand Down Expand Up @@ -97,7 +97,14 @@ def get_dataloader_kwargs(train_config: train_config, dataset, dataset_processer
kwargs["collate_fn"] = default_data_collator
# use a distributed sampler to split data between devices
if train_config.enable_ddp:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode == "train",
)
# kwargs["sampler"] = data_utils.DistributedSampler(
# dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
# )
return kwargs

0 comments on commit 21041e8

Please sign in to comment.