From fe1599ee78811fee01587fb925d36a907e818f3c Mon Sep 17 00:00:00 2001 From: 2U1 Date: Wed, 28 Aug 2024 10:08:04 +0900 Subject: [PATCH] feat: add functionality to set `vision_lr` and `resampler_lr` fix: resolve issue where saving was not functioning correctly chore: update args with hyperparameters for improved fine-tuning performance --- finetune/finetune.py | 30 +++++++++++---- finetune/finetune_ds.sh | 4 +- finetune/finetune_lora.sh | 8 +++- finetune/trainer.py | 81 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 11 deletions(-) diff --git a/finetune/finetune.py b/finetune/finetune.py index 0c596d7b..287078e1 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -53,6 +53,8 @@ class TrainingArguments(transformers.TrainingArguments): llm_type: str = field(default="minicpm") use_lora: Optional[bool] = field(default=False) max_slice_nums: Optional[int] = field(default=9) + vision_lr: Optional[float] = None + resampler_lr: Optional[float] = None @dataclass @@ -74,12 +76,25 @@ def rank0_print(*args): if local_rank == 0: print(*args) - -def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"): +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): """Collects the state dict and dump to disk.""" - if trainer.args.should_save and trainer.args.local_rank == 0: - trainer.save_model(output_dir,) + if trainer.deepspeed: + trainer.accelerator.wait_for_everyone() + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + trainer.model.config.save_pretrained(output_dir) def make_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, @@ -202,6 +217,7 @@ def train(): trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map, + attn_implementation="flash_attention_2" ) tokenizer = AutoTokenizer.from_pretrained( @@ -250,7 +266,6 @@ def get_input_embeddings(self): rank0_print(f'llm_type={llm_type}') - # Load data if hasattr(model.config, "slice_config"): model.config.slice_config.max_slice_nums = training_args.max_slice_nums @@ -291,9 +306,8 @@ def get_input_embeddings(self): safe_save_model_for_hf_trainer( trainer=trainer, - output_dir=training_args.output_dir, - bias=lora_args.lora_bias) + output_dir=training_args.output_dir) if __name__ == "__main__": - train() + train() \ No newline at end of file diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh index c0494715..9582629b 100644 --- a/finetune/finetune_ds.sh +++ b/finetune/finetune_ds.sh @@ -53,7 +53,9 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --save_strategy "steps" \ --save_steps 1000 \ --save_total_limit 10 \ - --learning_rate 1e-6 \ + --learning_rate 1e-5 \ + --vision_lr 2e-6 \ + --resampler_lr 2e-6 \ --weight_decay 0.1 \ --adam_beta2 0.95 \ --warmup_ratio 0.01 \ diff --git a/finetune/finetune_lora.sh b/finetune/finetune_lora.sh index 19437b1a..7a7b837d 100644 --- a/finetune/finetune_lora.sh +++ b/finetune/finetune_lora.sh @@ -41,7 +41,9 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --tune_vision true \ --tune_llm false \ --use_lora true \ - --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \ + --lora_r 64 \ + --lora_alpha 128 \ + --lora_target_modules "llm\..*layers\.\d+\.(self_attn|mlp)\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)" \ --model_max_length $MODEL_MAX_Length \ --max_slice_nums 9 \ --max_steps 10000 \ @@ -56,7 +58,9 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --save_strategy "steps" \ --save_steps 1000 \ --save_total_limit 10 \ - --learning_rate 1e-6 \ + --learning_rate 1e-4 \ + --vision_lr 2e-6 \ + --resampler_lr 2e-6 \ --weight_decay 0.1 \ --adam_beta2 0.95 \ --warmup_ratio 0.01 \ diff --git a/finetune/trainer.py b/finetune/trainer.py index 7da95ed8..378e728e 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -10,6 +10,87 @@ class CPMTrainer(Trainer): + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + lr_mapper = {} + if self.args.resampler_lr is not None: + lr_mapper["resampler"] = self.args.resampler_lr + if self.args.vision_lr is not None: + lr_mapper["vpm"] = self.args.vision_lr + if len(lr_mapper) > 0: + special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + for module_keyword, lr in lr_mapper.items(): + module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name] + optimizer_grouped_parameters.extend( + [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + "lr": lr, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)], + "weight_decay": 0.0, + "lr": lr, + }, + ] + ) + else: + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + def compute_loss(self, model, inputs, return_outputs=False): if "labels" in inputs: labels = inputs.pop("labels")