Skip to content

Commit

Permalink
[Misc]Add BNB quantization for PaliGemmaForConditionalGeneration (#12237
Browse files Browse the repository at this point in the history
)

Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Jan 21, 2025
1 parent 9691255 commit 1f1542a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
13 changes: 12 additions & 1 deletion vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,18 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,16 @@ def __init__(

self.config = config
self.activation_fn = get_act_fn(config.hidden_act)

# For quantization, we require the hidden size to be a multiple of 64
quantizable = (config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0)
# Special handling for BNB quantization
if quant_config and quant_config.get_name() == "bitsandbytes":
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (
config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
Expand Down

0 comments on commit 1f1542a

Please sign in to comment.