From 1f1542afa915e0975d2b63559424403e5e8aae2c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 21 Jan 2025 15:49:08 +0800 Subject: [PATCH] [Misc]Add BNB quantization for PaliGemmaForConditionalGeneration (#12237) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/paligemma.py | 13 ++++++++++++- vllm/model_executor/models/siglip.py | 14 ++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index f9ad0c67adaba..ed9ae1887259e 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -136,7 +136,18 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index cca42842bc06e..211e5dc80066e 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -344,10 +344,16 @@ def __init__( self.config = config self.activation_fn = get_act_fn(config.hidden_act) - - # For quantization, we require the hidden size to be a multiple of 64 - quantizable = (config.hidden_size % 64 == 0 - and config.intermediate_size % 64 == 0) + # Special handling for BNB quantization + if quant_config and quant_config.get_name() == "bitsandbytes": + quantizable = True + else: + # For other quantization, we require the hidden size to be a + # multiple of 64 + quantizable = ( + config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0 + ) self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size,