From 5f4bbf7ea2e5cd4f41e47cc501a430d8b4791deb Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Mon, 23 Dec 2024 14:17:59 +0100 Subject: [PATCH] added molmo lora Signed-off-by: Matthias Vogler --- vllm/model_executor/models/molmo.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 9f744b6918818..53b0329a9f4f2 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,7 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -43,7 +43,7 @@ SequenceData) from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -1121,9 +1121,19 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + supported_lora_modules = [ + "transformer.blocks.22.att_proj", + "transformer.blocks.22.ff_proj", + "transformer.blocks.23.att_proj", + "transformer.blocks.23.ff_proj", + "transformer.blocks.16.att_proj", + "transformer.blocks.16.ff_proj", + "transformer.blocks.8.att_proj", + "transformer.blocks.8.ff_proj", + "transformer.blocks.20.att_proj", + ] + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Optional[LoRAConfig] = None): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -1152,6 +1162,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + + self.lora_config = lora_config def _parse_and_validate_image_input( self,