Skip to content

Commit

Permalink
Merge branch 'instructlab:main' into feature/add-try-catch-import-to-…
Browse files Browse the repository at this point in the history
…deepspeed
  • Loading branch information
Harthi7 authored Oct 25, 2024
2 parents c44a78a + 03d1b62 commit 9a37873
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 107 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ py-cpuinfo
# we set this to be above 0a0 so that it doesn't
# replace custom pytorch images with the 2.3.0
torch>=2.3.0a0
transformers>=4.41.2
transformers>=4.45.2
accelerate>=0.34.2
datasets>=2.15.0
numba
Expand Down
3 changes: 2 additions & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ class TrainingArgs(BaseModel):
save_samples: int
learning_rate: float
warmup_steps: int
is_padding_free: bool
random_seed: int = 42
use_dolomite: bool = False
is_padding_free: bool = False # TODO: deprecate
checkpoint_at_epoch: bool = True
accelerate_full_state_at_epoch: bool = True

Expand Down
25 changes: 21 additions & 4 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def print_masked_samples(data, tokenizer, is_pretrain, num_proc):
def get_masked_and_orig_text(sample):
labels = sample["labels"]
input_ids = sample["input_ids"]
mask_id = get_sp_token(tokenizer, "<MASK>")[0]
mask_id = get_sp_token(tokenizer, "<|MASK|>")[0]
label = [mask_id if tk == -100 else tk for tk in labels]
text = tokenizer.decode(label)
orig_text = tokenizer.decode(input_ids)
Expand Down Expand Up @@ -239,7 +239,7 @@ def main(args: DataProcessArgs):

# Adding after tokenizer setup as these are temp tokens, not to be saved
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<MASK>"]}
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<|MASK|>"]}
)

try:
Expand Down Expand Up @@ -347,9 +347,26 @@ def main(args: DataProcessArgs):
)

# extract only labels and messages formatted into a new dataset
data_with_labels = data_with_labels.select_columns(["labels", "input_ids"])
data_with_labels = data_with_labels.map(
lambda x: {
"len": len(x["input_ids"]),
},
num_proc=NUM_PROC,
)
data_with_labels = data_with_labels.select_columns(["labels", "input_ids", "len"])
# MASK and both pretrain tokens should not be in the final tokens, those are special tokens added only for data processing purposes.
max_id = len(tokenizer) - 3
final_valid_data = data_with_labels.filter(
lambda x: all(tk < max_id for tk in x["labels"]), num_proc=NUM_PROC
)
# Dropping samples that could break training due to oob ids
if len(final_valid_data) < len(data_with_labels):
dropped_samples = len(data_with_labels) - len(final_valid_data)
print(
f"\033[93mWarning: {dropped_samples} samples were dropped because they contained token IDs greater than or equal to {max_id}.\033[0m"
)
# use path to get the stem of the file
data_with_labels.to_json(Path(args.data_output_path) / f"data.jsonl")
final_valid_data.to_json(Path(args.data_output_path) / "data.jsonl")


if __name__ == "__main__":
Expand Down
64 changes: 31 additions & 33 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
StreamablePopen,
add_noisy_embeddings,
apply_gradient_checkpointing,
check_flash_attn_enabled,
check_valid_train_args,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
Expand Down Expand Up @@ -101,7 +103,7 @@ def setup_optimizer(args, model):
return optimizer


def setup_model(args, tokenizer, train_loader, grad_accum):
def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
bnb_config = None
if args.lora_r > 0 and args.lora_quant_bits == 4:
# Third Party
Expand All @@ -119,15 +121,11 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
}
if not args.disable_flash_attn:
if flash_enabled:
base_model_args["attn_implementation"] = "flash_attention_2"
elif args.is_granite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(
if args.use_dolomite:
with ensure_loadable_dolomite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
Expand Down Expand Up @@ -182,9 +180,10 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"Starcoder2ForCausalLM",
"GemmaForCausalLM",
"MixtralForCausalLM",
"GraniteForCausalLM",
], f"Model class name: {model.__class__.__name__} is not supported."

model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite)
model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)

# handling of gradient checkpointing
Expand Down Expand Up @@ -229,15 +228,15 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
target_modules=args.lora_target_modules,
)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
model, peft_config, gradient_checkpointing=not args.use_dolomite
)

elif not args.is_granite:
elif not args.use_dolomite:
model.gradient_checkpointing_enable()

# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
if args.use_dolomite:
block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
model,
Expand Down Expand Up @@ -269,6 +268,9 @@ def make_inputs_require_grad(module, input, output):
deepcopy(train_loader),
lr_scheduler,
)
# Necessary so that Accelerate does not step once per GPU
# see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69
lr_scheduler.split_batches = True
return model, lr_scheduler, optimizer, accelerator


Expand Down Expand Up @@ -398,8 +400,8 @@ def train(
num_loss_counted_tokens = float(
torch.tensor([batch.pop("num_loss_counted_tokens")])
)
micro_batch_size = float(len(batch["input_ids"]))
if not args.is_granite:
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
if not args.use_dolomite:
for k in batch:
batch[k] = batch[k].to(local_rank)
output = model(
Expand Down Expand Up @@ -470,7 +472,7 @@ def train(
"batch_size": int(micro_batch_size),
"total_loss": float(log_loss / num_loss_counted_tokens),
"samples_seen": samples_seen,
# "gradnorm": global_grad_norm,
"gradnorm": global_grad_norm,
# "weight_norm": weight_norm,
}
)
Expand Down Expand Up @@ -558,6 +560,8 @@ def main(args):
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()

flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite)

dataset = setup_dataset(
args.data_path,
mock=args.mock_data,
Expand All @@ -570,7 +574,7 @@ def main(args):
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not args.is_granite,
is_padding=not (args.use_dolomite or flash_enabled),
dataset=dataset,
seed=args.seed,
)
Expand All @@ -593,7 +597,8 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -612,7 +617,8 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -636,7 +642,7 @@ def main(args):
)

model, lr_scheduler, optimizer, accelerator = setup_model(
args, tokenizer, train_loader, grad_accum
args, tokenizer, train_loader, grad_accum, flash_enabled
)

load_latest_full_state(args=args, accelerator=accelerator)
Expand All @@ -662,11 +668,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
# early validation logic here
if train_args.max_batch_len < train_args.max_seq_len:
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)
check_valid_train_args(train_args)

if train_args.process_data:
dp.main(
Expand Down Expand Up @@ -720,14 +722,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.mock_len:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.is_padding_free:
command.append("--is_granite")
if train_args.use_dolomite:
command.append("--use_dolomite")

if train_args.disable_flash_attn:
if train_args.is_padding_free:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)
command.append("--disable_flash_attn")

if train_args.lora:
Expand Down Expand Up @@ -918,7 +916,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
default="SHARD_GRAD_OP",
help="Sharding strategy to be used for FSDP distributed training.",
)
parser.add_argument("--is_granite", action="store_true")
parser.add_argument("--use_dolomite", action="store_true")
parser.add_argument("--lora_r", type=int, default=0) # set to > 0 to activate lora
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.1)
Expand Down Expand Up @@ -1007,7 +1005,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
--save_samples=250000 \
--log_level="INFO" \
--fsdp_sharding_strategy="SHARD_GRAD_OP" \
--is_granite \
--use_dolomite \
--max_batch_len 70000 \
--seed=42
"""
25 changes: 16 additions & 9 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
class TokenDataset(Dataset):
def __init__(self, data_path):
self.data = load_dataset("json", data_files=data_path, split="train")
self.lengths = np.array(
self.data.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=8,
)["len"]
)
if "len" not in self.data.column_names:
self.lengths = np.array(
self.data.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=8,
)["len"]
)
else:
self.lengths = np.array(self.data["len"])

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -87,15 +90,19 @@ def setup_dataloader(
dataset: Dataset,
pad_token_id: int,
num_workers: int = 8,
is_granite=False,
use_dolomite=False,
flash_enabled=True,
max_batch_len=60000,
packing_max_batch_len=60000,
samples_per_gpu=None,
sampler="multipack",
seed=47,
) -> DataLoader:
collate_fn = make_collate_fn(
pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len
pad_token_id,
use_dolomite=use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=max_batch_len,
)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -108,7 +115,7 @@ def setup_dataloader(
num_replicas=world_size,
rank=rank,
seed=seed,
padding=not is_granite,
padding=not flash_enabled,
)
sampler = {"batch_sampler": sampler}
elif sampler == "distributed":
Expand Down
Loading

0 comments on commit 9a37873

Please sign in to comment.