Skip to content

Commit

Permalink
added molmo lora
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Vogler <[email protected]>
  • Loading branch information
Matthias Vogler committed Dec 24, 2024
1 parent a491d6f commit 5f4bbf7
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand Down Expand Up @@ -43,7 +43,7 @@
SequenceData)
from vllm.transformers_utils.processor import get_processor

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -1121,9 +1121,19 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
supported_lora_modules = [
"transformer.blocks.22.att_proj",
"transformer.blocks.22.ff_proj",
"transformer.blocks.23.att_proj",
"transformer.blocks.23.ff_proj",
"transformer.blocks.16.att_proj",
"transformer.blocks.16.ff_proj",
"transformer.blocks.8.att_proj",
"transformer.blocks.8.ff_proj",
"transformer.blocks.20.att_proj",
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Optional[LoRAConfig] = None):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
Expand Down Expand Up @@ -1152,6 +1162,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

self.lora_config = lora_config

def _parse_and_validate_image_input(
self,
Expand Down

0 comments on commit 5f4bbf7

Please sign in to comment.