Skip to content

Commit

Permalink
Add argument to enable and disable padding for ddp
Browse files Browse the repository at this point in the history
Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta committed Jan 9, 2025
1 parent 2e569e0 commit 00a8587
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
1 change: 1 addition & 0 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class train_config:
save_metrics: bool = True # saves training metrics to a json file for later plotting
intermediate_step_save: int = 1000
batching_strategy: str = "packing"
enable_sorting_for_ddp: bool = "True"

# TODO: vbaddi: Uncomment post adding qaic to Pytorch Profiler
# flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
Expand Down
25 changes: 15 additions & 10 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,27 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
if train_config.enable_ddp:
if train_config.context_length:
if train_config.enable_sorting_for_ddp:
if train_config.context_length:
raise ValueError(
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
)
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=False,
)
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
else:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=False,
)
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
else:
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
Expand Down

0 comments on commit 00a8587

Please sign in to comment.