-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathlora_script.py
482 lines (432 loc) · 18.9 KB
/
lora_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Based off the lora script here: https://github.com/artidoro/qlora/blob/main/qlora.py
import copy
import json
import os
import deepspeed
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
import logging
import torch
import transformers
from torch.nn.utils.rnn import pad_sequence
import argparse
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
set_seed,
Seq2SeqTrainer,
)
from datasets import load_dataset
import deepspeed.comm as dist
from deepspeed.linear import LoRAConfig, QuantizationConfig
logger = logging.getLogger(__name__)
IGNORE_INDEX = -100
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default="meta-llama/Meta-Llama-3.1-405B"
)
tokenizer_name_or_path: Optional[str] = field(
default="meta-llama/Meta-Llama-3.1-405B"
)
trust_remote_code: Optional[bool] = field(
default=False,
metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
)
auth_token: Optional[str] = field(
default=None,
metadata={"help": "Enables using Huggingface auth token from Git Credentials."}
)
@dataclass
class DataArguments:
eval_dataset_size: int = field(
default=1024, metadata={"help": "Size of validation dataset."}
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
source_max_len: int = field(
default=1024,
metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
)
target_max_len: int = field(
default=256,
metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
)
dataset: str = field(
default='alpaca',
metadata={"help": "Which dataset to finetune on. See datamodule for options."}
)
dataset_format: Optional[str] = field(
default=None,
metadata={"help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]"}
)
@dataclass
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
cache_dir: Optional[str] = field(
default=None
)
train_on_source: Optional[bool] = field(
default=False,
metadata={"help": "Whether to train on the input in addition to the target text."}
)
mmlu_split: Optional[str] = field(
default='eval',
metadata={"help": "The MMLU split to run on"}
)
mmlu_dataset: Optional[str] = field(
default='mmlu-fs',
metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."}
)
do_mmlu_eval: Optional[bool] = field(
default=False,
metadata={"help": "Whether to run the MMLU evaluation."}
)
max_mmlu_samples: Optional[int] = field(
default=None,
metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."}
)
mmlu_source_max_len: int = field(
default=2048,
metadata={"help": "Maximum source sequence length for mmlu."}
)
full_finetune: bool = field(
default=False,
metadata={"help": "Finetune the entire model without adapters."}
)
lora_r: int = field(
default=64,
metadata={"help": "Lora R dimension."}
)
lora_alpha: float = field(
default=16,
metadata={"help": " Lora alpha."}
)
quantize: bool = field(
default=False,
metadata={"help": "quantize frozen base weights or not."}
)
bits: int = field(
default=8,
metadata={"help": "How many bits to use for quantization."}
)
base_weight_sharding: bool = field(
default=False,
metadata={"help": "Shard base weights with DP world size, similar to ZeRO-3."}
)
offload: bool = field(
default=False,
metadata={"help": "Offload the base weights to CPU."}
)
offload_ratio: float = field(
default=0.0,
metadata={"help": "Fraction of base weights to offload to CPU."}
)
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
gradient_accumulation_steps: int = field(default=1, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'})
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed
learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'})
remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
max_grad_norm: float = field(default=1.0, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
activation_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'})
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
save_strategy: str = field(default='epoch', metadata={"help": 'When to save checkpoints'})
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
def get_accelerate_model(args):
if args.full_finetune: assert args.bits in [16, 32]
print(f'loading base model {args.model_name_or_path}...')
compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
if not args.full_finetune:
base_weight_shards = dist.get_world_size() if args.base_weight_sharding else 1
lora_config = LoRAConfig(
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
base_weight_sharding=base_weight_shards,
offload=args.offload,
offload_ratio=args.offload_ratio,
target_mods=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
)
else:
lora_config = None
if args.quantize:
assert args.bits == 8, "currently deepspeed only supports fp8 for llama"
quantization_config = QuantizationConfig(q_bits=args.bits)
else:
quantization_config = None
# Tokenizer
print(f"Loading tokenizer {args.tokenizer_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name_or_path,
padding_side="right",
use_fast=False, # Fast tokenizer giving issues.
token=args.auth_token
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id=tokenizer.eos_token_id
tokenizer.pad_token=tokenizer.eos_token
print(f'loading base model {args.model_name_or_path}...')
with deepspeed.linear.Init(lora_config=lora_config, quant_config=quantization_config):
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
#attn_implementation="flash_attention_2",
token=args.auth_token
)
print("created model")
model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.enable_input_require_grads()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
return model, tokenizer
def print_trainable_parameters(args, model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for param in model.parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || "
f"all params: {all_param} || "
f"trainable: {100 * trainable_params / all_param}"
)
@dataclass
class DataCollatorForCausalLM(object):
tokenizer: transformers.PreTrainedTokenizer
source_max_len: int
target_max_len: int
train_on_source: bool
predict_with_generate: bool
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# Extract elements
sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances]
targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
# Tokenize
tokenized_sources_with_prompt = self.tokenizer(
sources,
max_length=self.source_max_len,
truncation=True,
add_special_tokens=False,
padding='max_length',
)
tokenized_targets = self.tokenizer(
targets,
max_length=self.target_max_len,
truncation=True,
add_special_tokens=False,
padding='max_length',
)
# Build the input and labels for causal LM
input_ids = []
labels = []
for tokenized_source, tokenized_target in zip(
tokenized_sources_with_prompt['input_ids'],
tokenized_targets['input_ids']
):
if not self.predict_with_generate:
data = torch.tensor(tokenized_source + tokenized_target)
#if len(data) < 1280:
# import pdb; pdb.set_trace()
input_ids.append(data)
if not self.train_on_source:
labels.append(
torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
)
else:
labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
else:
input_ids.append(torch.tensor(tokenized_source))
# Apply padding
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
data_dict = {
'input_ids': input_ids,
'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
}
if labels is not None:
data_dict['labels'] = labels
return data_dict
ALPACA_PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: "
),
}
def extract_alpaca_dataset(example):
if example.get("input", "") != "":
prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
else:
prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
return {'input': prompt_format.format(**example)}
def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict:
"""
Make dataset and collator for supervised fine-tuning.
Datasets are expected to have the following columns: { `input`, `output` }
Available datasets to be selected with `dataset` argument:
- alpaca, 52002 examples
- alpaca cleaned, 51942 examples
- chip2 (OIG), 210289 examples
- self-instruct, 82612 examples
- hh-rlhf (Anthropic), 160800 examples
- longform, 23.7k examples
- oasst1 (OpenAssistant) primary message tree only, 9,846 examples
Coming soon:
- unnatural instructions core, 66010 examples
- unnatural instructions full, 240670 examples
- alpaca-gpt4, 52002 examples
- unnatural-instructions-gpt4, 9000 examples
- supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
- flan (FLAN v2), up to 20M examples available
- vicuna
"""
def load_data(dataset_name):
if dataset_name == 'alpaca':
return load_dataset("tatsu-lab/alpaca")
elif dataset_name == 'alpaca-clean':
return load_dataset("yahma/alpaca-cleaned")
elif dataset_name == 'chip2':
return load_dataset("laion/OIG", data_files='unified_chip2.jsonl')
elif dataset_name == 'self-instruct':
return load_dataset("yizhongw/self_instruct", name='self_instruct')
elif dataset_name == 'hh-rlhf':
return load_dataset("Anthropic/hh-rlhf")
elif dataset_name == 'longform':
return load_dataset("akoksal/LongForm")
elif dataset_name == 'oasst1':
return load_dataset("timdettmers/openassistant-guanaco")
elif dataset_name == 'vicuna':
raise NotImplementedError("Vicuna data was not released.")
else:
NotImplementedError("Add your dataset implementation here.")
def format_dataset(dataset, dataset_format):
if (
dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or
(dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean'])
):
dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction'])
elif dataset_format == 'chip2' or (dataset_format is None and args.dataset == 'chip2'):
dataset = dataset.map(lambda x: {
'input': x['text'].split('\n<bot>: ')[0].replace('<human>: ', ''),
'output': x['text'].split('\n<bot>: ')[1],
})
elif dataset_format == 'self-instruct' or (dataset_format is None and args.dataset == 'self-instruct'):
for old, new in [["prompt", "input"], ["completion", "output"]]:
dataset = dataset.rename_column(old, new)
elif dataset_format == 'hh-rlhf' or (dataset_format is None and args.dataset == 'hh-rlhf'):
dataset = dataset.map(lambda x: {
'input': '',
'output': x['chosen']
})
elif dataset_format == 'oasst1' or (dataset_format is None and args.dataset == 'oasst1'):
dataset = dataset.map(lambda x: {
'input': '',
'output': x['text'],
})
elif dataset_format == 'input-output':
# leave as is
pass
else:
dataset = dataset.map(lambda x: {
'input': x['prompt'],
'output': x['completion'],
})
# Remove unused columns.
dataset = dataset.remove_columns(
[col for col in dataset.column_names['train'] if col not in ['input', 'output']]
)
return dataset
def apply_chat_template(prompt: str) -> str:
return tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False)
# Load dataset.
dataset = load_data(args.dataset)
dataset = format_dataset(dataset, args.dataset_format)
dataset = dataset.map(lambda example: {"input": apply_chat_template(example["input"])})
print("!! sample values from the data !!")
print(list(dataset['train'].select(range(5))))
if args.do_train:
train_dataset = dataset['train']
if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples:
train_dataset = train_dataset.select(range(args.max_train_samples))
if args.group_by_length:
train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
data_collator = DataCollatorForCausalLM(
tokenizer=tokenizer,
source_max_len=args.source_max_len,
target_max_len=args.target_max_len,
train_on_source=args.train_on_source,
predict_with_generate=args.predict_with_generate,
)
return dict(
train_dataset=train_dataset if args.do_train else None,
data_collator=data_collator
)
def train():
dist.init_distributed()
torch.autograd.set_detect_anomaly(True)
hfparser = transformers.HfArgumentParser((
ModelArguments, DataArguments, TrainingArguments
))
model_args, data_args, training_args, extra_args = \
hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
args = argparse.Namespace(
**vars(model_args), **vars(data_args), **vars(training_args)
)
print(args)
model, tokenizer = get_accelerate_model(args)
model.config.use_cache = False
print('loaded model')
set_seed(args.seed)
data_module = make_data_module(tokenizer=tokenizer, args=args)
trainer = Seq2SeqTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**{k:v for k,v in data_module.items() if k != 'predict_dataset'},
)
trainer.model.train()
print_trainable_parameters(args, model)
all_metrics = {"run_name": args.run_name}
# Training
if args.do_train:
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
all_metrics.update(metrics)
with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
fout.write(json.dumps(all_metrics))
if __name__ == "__main__":
train()