diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 51e9752e..7eb66c02 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -156,6 +156,9 @@ class TrainingArgs(BaseModel): os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" ) + # this field determins if ibm_legacy_tmpl should be used instead + use_legacy_sp_tokens: bool = False + # this field specifies the filepath to the training dataset before processing data_path: str ckpt_output_dir: str diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4c04da0f..d2708656 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -691,6 +691,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ check_valid_train_args(train_args) + # switch out generic tmpl for legacy tmpl if requested + if train_args.use_legacy_sp_tokens: + train_args.chat_tmpl_path = os.path.join( + os.path.dirname(__file__), "chat_templates/ibm_legacy_tmpl.py" + ) + if train_args.process_data: dp.main( DataProcessArgs(