From 686c96ecf583b1e6ab40f5c0936d67167a19e91c Mon Sep 17 00:00:00 2001 From: Joe Ruether Date: Tue, 21 Jan 2025 16:19:08 -0600 Subject: [PATCH] Handle weight aliases --- mergekit/scripts/extract_lora.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index 69c010bb..8c739e70 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -15,6 +15,7 @@ from mergekit.card import generate_card_lora from mergekit.common import ModelReference from mergekit.io import LazyTensorLoader +from mergekit.architecture import get_architecture_info def low_rank_decomposition( @@ -230,6 +231,8 @@ def extract_lora( :return: A tuple containing LoRA weights dictionary and ranks dictionary. """ + base_architecture = get_architecture_info(base_model_ref.config()) + base_loader = LazyTensorLoader( base_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle) ) @@ -241,8 +244,18 @@ def extract_lora( ranks = {} for module_type, module_name in tqdm(module_details): - base_weight = base_loader.get_tensor(f"{module_name}.weight") - finetuned_weight = finetuned_loader.get_tensor(f"{module_name}.weight") + + module_weight_name = f"{module_name}.weight" + + aliases = [] + for weight_info in base_architecture.all_weights(base_model_ref.config()): + all_names = [weight_info.name] + list(weight_info.aliases or []) + list(weight_info.tied_names or []) + if module_weight_name in all_names: + aliases = all_names + break + + base_weight = base_loader.get_tensor(module_weight_name, aliases=aliases) + finetuned_weight = finetuned_loader.get_tensor(module_weight_name, aliases=aliases) if module_type == "to_save": lora_weights[