Skip to content

Commit

Permalink
added Molmo Lora Support
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Vogler <[email protected]>
  • Loading branch information
Matthias Vogler committed Dec 24, 2024
1 parent 5f4bbf7 commit 1ca7a96
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,17 +1122,18 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"att_proj": ["att_proj"],
"attn_out": ["attn_out"],
"ff_proj": ["ff_proj"],
"ff_out": ["ff_out"],
}
supported_lora_modules = [
"transformer.blocks.22.att_proj",
"transformer.blocks.22.ff_proj",
"transformer.blocks.23.att_proj",
"transformer.blocks.23.ff_proj",
"transformer.blocks.16.att_proj",
"transformer.blocks.16.ff_proj",
"transformer.blocks.8.att_proj",
"transformer.blocks.8.ff_proj",
"transformer.blocks.20.att_proj",
"att_proj",
"ff_proj",
]
embedding_modules = {}
embedding_padding_modules = {}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Optional[LoRAConfig] = None):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -1164,6 +1165,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Op
self.model.make_empty_intermediate_tensors)

self.lora_config = lora_config


def _parse_and_validate_image_input(
self,
Expand Down

0 comments on commit 1ca7a96

Please sign in to comment.