Skip to content

Commit

Permalink
This PR implements the capability of the training library to save
Browse files Browse the repository at this point in the history
LoRA models when training with FSDP as the distributed backend.
This is accomplished by creating a copy of the LoRA model on the CPU,
loading in the state dict after gathering it from the distributed model,
and saving after merging the adapters back into the original model.
Afterwards, the CPU copy is discarded and training continues.

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Oct 23, 2024
1 parent 7b7894b commit 340326f
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 57 deletions.
51 changes: 11 additions & 40 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
add_noisy_embeddings,
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
create_lora_config,
ensure_loadable_granite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
Expand Down Expand Up @@ -114,13 +114,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
base_model_args["use_padding_free_transformer"] = True
model = GPTDolomiteForCausalLM.from_pretrained(
**base_model_args,
use_padding_free_transformer=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

# store the base model args so we can recall them later if saving a LoRA model
args.base_model_args = base_model_args

if len(tokenizer) > model.config.vocab_size:
print(
f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
Expand Down Expand Up @@ -175,46 +178,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
# - with the exception of granite, which handles it
# in the later stanza
if args.lora_r > 0:
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)
lora_config = create_lora_config(model, args)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
model,
lora_config,
args.distributed_training_framework,
gradient_checkpointing=not args.is_granite,
)

args.lora_config = lora_config
elif not args.is_granite:
model.gradient_checkpointing_enable()

Expand Down
29 changes: 18 additions & 11 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

# Third Party
from accelerate import Accelerator
from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP,
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import PreTrainedModel
import torch

# First Party
Expand Down Expand Up @@ -51,34 +49,43 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
return ds_plugin


def get_fsdp_config(args, model):
def get_fsdp_config(args, model: PreTrainedModel):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

block_name = model._no_split_modules[0]

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=partial(
wrap_policy = None
if args.lora_r > 0:
wrap_policy = fsdp_auto_wrap_policy(model)
else:
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
get_module_class_from_name(model, block_name),
},
),
)

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
mixed_precision_policy=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
backward_prefetch=BackwardPrefetch.BACKWARD_POST,
sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
)
if args.lora_r > 0:
fsdp_plugin.use_orig_params = False

return fsdp_plugin


def setup_accelerator(args, model, grad_accum):
def setup_accelerator(args, model: PreTrainedModel, grad_accum):
if args.distributed_training_framework == "deepspeed":
# Third Party
from deepspeed import DeepSpeedEngine
Expand Down
155 changes: 149 additions & 6 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from argparse import Namespace
from collections import OrderedDict
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple
import importlib
import inspect
import logging
Expand All @@ -21,25 +22,32 @@

# Third Party
# pylint: disable=no-name-in-module
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from instructlab.dolomite.hf_models import (
GPTDolomiteConfig,
export_to_huggingface,
import_from_huggingface,
)
from rich.logging import RichHandler
from torch import distributed as dist
from torch import nn
from torch.distributed import get_rank, is_initialized
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from transformers import PreTrainedModel
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
import numpy as np
import torch
import torch.nn.functional as F

# First Party
from instructlab.training.config import DistributedBackend


def retrieve_chat_template(chat_tmpl_path):
try:
Expand Down Expand Up @@ -304,16 +312,137 @@ def patch_target_module(
setattr(source, obj_name_to_patch, replace_with)


def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool:
"""Checks if a module or its children are an instance of one of the provided classes.
Args:
module (nn.Module): A PyTorch module.
wrapped_classes(Tuple): A tuple of potential classes the module could be.
Returns:
bool: True if the module or any of its children are instances of one of `wrapped_classes`, False otherwise.
"""
if isinstance(module, wrapped_classes):
return True

for m in module.children():
if wraps(m, wrapped_classes):
return True

return False


def create_lora_config(model: PreTrainedModel, args: Namespace) -> "peft.LoraConfig":
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

return LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)


def save_fsdp_lora_model(
args: Namespace,
model: FSDP,
tokenizer: PreTrainedTokenizer,
accelerator: Accelerator,
output_dir: Path,
):
"""Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original
model with the trained LoRA adapters merged into the copy.
This function creates a full copy of the model being trained and stores it in CPU memory.
If encountering OOM errors on CPU, this is likely a culprit.
Args:
args (Namespace): Args received by the ArgumentParser.
model (FSDP): FSDP model as prepared by `accelerate.Accelerator`
accelerator (Accelerator): The given accelerator object.
"""
# Third Party
from peft import LoraConfig, LoraModel

if accelerator.distributed_type != DistributedType.FSDP:
raise RuntimeError(
"`save_fsdp_lora_model` was called when FSDP was not being used."
)
if not wraps(model, FSDP):
raise RuntimeError(
"`save_fsdp_lora_model` was called but provided model is not an FSDP model."
)
if not wraps(model, LoraModel):
raise RuntimeError(
"`save_fsdp_lora_model` was called but provided model is not a LoRA model."
)

# okay now that validation is out of the way, we are free to implement saving
lora_conf: LoraConfig = args.lora_config
sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, sd_config):
state = model.state_dict()

if accelerator.is_main_process:
# remove device_map from args list so we can load the model on CPU
old_device_map = args.base_model_args.pop("device_map", None)
model_copy = AutoModelForCausalLM.from_pretrained(
**args.base_model_args, device_map="cpu"
)
model_copy = LoraModel(model_copy, lora_conf, "default")
model_copy.load_state_dict(state)
model_copy.merge_and_unload(progressbar=True)
model_copy.save_pretrained(output_dir, safe_serialization=True)
model.config.to_json_file(f"{output_dir}/config.json")
tokenizer.save_pretrained(output_dir)
del model_copy
if old_device_map:
# return the previous device_map so it can be used later on if needed
args.base_model_args["device_map"] = old_device_map

dist.barrier()


def prepare_peft_model(
model,
model: PreTrainedModel,
peft_config,
distributed_backend: str,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": True},
mixed_precision="bf16",
):
# will guard this
# Third Party
from peft import (
LoraModel,
PeftConfig,
PeftModel,
get_peft_model,
Expand Down Expand Up @@ -355,7 +484,11 @@ def make_inputs_require_grad(module, input, output):
make_inputs_require_grad
)

model = get_peft_model(model, peft_config)
if distributed_backend == DistributedBackend.FSDP.value:
# FSDP doesn't like `get_peft_model` as it leads to dtype mismatches
model = LoraModel(model, peft_config, "default")
else:
model = get_peft_model(model, peft_config)
if mixed_precision == "bf16" and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)

Expand Down Expand Up @@ -630,7 +763,7 @@ def _copy_no_lora_dict(state_dict):


def save_dict_accelerate(
accelerator,
accelerator: Accelerator,
state_to_save,
save_directory,
max_shard_size="5GB",
Expand Down Expand Up @@ -681,6 +814,16 @@ def save_hf_format_accelerate(
CONFIG_NAME = "config.json"
output_config_file = output_dir / CONFIG_NAME

if is_lora and accelerator.distributed_type == DistributedType.FSDP:
save_fsdp_lora_model(
args=args,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
output_dir=output_dir,
)
return

get_state_dict_unpatched = accelerator.get_state_dict

def _get_state_dict_patched(model, unwrap=False):
Expand Down

0 comments on commit 340326f

Please sign in to comment.