From 22578e4ec54aed25a9d88cedea0925a21caf9124 Mon Sep 17 00:00:00 2001 From: Jaideep Rao Date: Tue, 12 Nov 2024 23:31:53 -0500 Subject: [PATCH] feat: add toggle to pick legacy chat tmpl for granite Signed-off-by: Jaideep Rao --- src/instructlab/training/config.py | 3 +++ src/instructlab/training/main_ds.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 51e9752e..7eb66c02 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -156,6 +156,9 @@ class TrainingArgs(BaseModel): os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" ) + # this field determins if ibm_legacy_tmpl should be used instead + use_legacy_sp_tokens: bool = False + # this field specifies the filepath to the training dataset before processing data_path: str ckpt_output_dir: str diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4c04da0f..d2708656 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -691,6 +691,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ check_valid_train_args(train_args) + # switch out generic tmpl for legacy tmpl if requested + if train_args.use_legacy_sp_tokens: + train_args.chat_tmpl_path = os.path.join( + os.path.dirname(__file__), "chat_templates/ibm_legacy_tmpl.py" + ) + if train_args.process_data: dp.main( DataProcessArgs(