Skip to content

Commit

Permalink
This PR implements the capability of the training library to save
Browse files Browse the repository at this point in the history
LoRA models when training with FSDP as the distributed backend.
This is accomplished by creating a copy of the LoRA model on the CPU,
loading in the state dict after gathering it from the distributed model,
and saving after merging the adapters back into the original model.
Afterwards, the CPU copy is discarded and training continues.

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Oct 23, 2024
1 parent 7b7894b commit 4f0820c
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 54 deletions.
66 changes: 25 additions & 41 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Standard
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
import argparse
import math
Expand Down Expand Up @@ -43,7 +44,7 @@
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
get_projection_layer_names,
create_lora_config,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
Expand Down Expand Up @@ -114,13 +115,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
base_model_args["use_padding_free_transformer"] = True
model = GPTDolomiteForCausalLM.from_pretrained(
**base_model_args,
use_padding_free_transformer=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

# store the base model args so we can recall them later if saving a LoRA model
args.base_model_args = base_model_args

if len(tokenizer) > model.config.vocab_size:
print(
f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
Expand Down Expand Up @@ -174,47 +178,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
# it is handled differently for lora and full
# - with the exception of granite, which handles it
# in the later stanza
distributed_backend = args.distributed_training_framework

Check warning on line 181 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / pylint

W0612: Unused variable 'distributed_backend' (unused-variable)
if args.lora_r > 0:
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)
lora_config = create_lora_config(model, args)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
model,
lora_config,
args.distributed_training_framework,
gradient_checkpointing=not args.is_granite,
)

args.lora_config = lora_config
elif not args.is_granite:
model.gradient_checkpointing_enable()

Expand Down Expand Up @@ -529,7 +502,11 @@ def main(args):
#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group("nccl")
nccl_timeout: timedelta | None = None
if args.debug:
# surely we won't need any more than this... right?
nccl_timeout = timedelta(days=1)
torch.distributed.init_process_group("nccl", timeout=nccl_timeout)
args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()
torch.distributed.all_reduce(tensor)
Expand Down Expand Up @@ -932,6 +909,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
)
# hidden argument for our own sake
parser.add_argument(
"--debug",
help="Enables settings for debugging. For example, the NCCL timeout increases so more time can be spent in breakpoints.",
action="store_true",
default=False,
)
parser.add_argument("--disable_flash_attn", action="store_true")
args = parser.parse_args()
set_random_seed(args.seed)
Expand Down
25 changes: 18 additions & 7 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

# Third Party
from accelerate import Accelerator
from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP,
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import (
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import torch
from transformers import PreTrainedModel

# First Party
from instructlab.training.config import DeepSpeedOptions
Expand Down Expand Up @@ -51,34 +53,43 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
return ds_plugin


def get_fsdp_config(args, model):
def get_fsdp_config(args, model: PreTrainedModel):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

block_name = model._no_split_modules[0]

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=partial(
wrap_policy = None
if args.lora_r > 0:
wrap_policy = fsdp_auto_wrap_policy(model)
else:
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
get_module_class_from_name(model, block_name),
},
),
)

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
mixed_precision_policy=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
backward_prefetch=BackwardPrefetch.BACKWARD_POST,
sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
)
if args.lora_r > 0:
fsdp_plugin.use_orig_params = False

return fsdp_plugin


def setup_accelerator(args, model, grad_accum):
def setup_accelerator(args, model: PreTrainedModel, grad_accum):
if args.distributed_training_framework == "deepspeed":
# Third Party
from deepspeed import DeepSpeedEngine
Expand Down
Loading

0 comments on commit 4f0820c

Please sign in to comment.