diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index b714259e..ac750ebe 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -38,6 +38,7 @@ class train_config: save_metrics: bool = True # saves training metrics to a json file for later plotting intermediate_step_save: int = 1000 batching_strategy: str = "packing" + enable_sorting_for_ddp: bool = "True" # TODO: vbaddi: Uncomment post adding qaic to Pytorch Profiler # flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index 7b13f8a6..58344b19 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -75,22 +75,27 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size if train_config.enable_ddp: - if train_config.context_length: + if train_config.enable_sorting_for_ddp: + if train_config.context_length: + raise ValueError( + "Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding" + ) + else: + kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( + dataset, + batch_size=batch_size, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=False, + ) + kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) + else: kwargs["sampler"] = data_utils.DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True ) kwargs["batch_size"] = batch_size kwargs["drop_last"] = True kwargs["collate_fn"] = default_data_collator - else: - kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( - dataset, - batch_size=batch_size, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - shuffle=False, - ) - kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) else: kwargs["batch_size"] = batch_size kwargs["drop_last"] = True