From 21d1a50f3b7673f00ea6027eb6cd43a530682c77 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 2 Dec 2024 13:44:50 +0800 Subject: [PATCH 1/6] add quant support to llava projector Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 6 ++--- vllm/model_executor/models/llava.py | 28 ++++++++++++++-------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 0e12bc5691538..7255cda541d8b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1152,9 +1152,9 @@ def _load_weights(self, model_config: ModelConfig, shard_name, weight_name) break - if quant_param_name not in param_dict: - raise ValueError( - f"Parameter {quant_param_name} not found in the model.") + # if quant_param_name not in param_dict: + # raise ValueError( + # f"Parameter {quant_param_name} not found in the model.") if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index db7fa82ceb9b7..c2b61d43d7962 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -13,6 +13,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -63,21 +64,26 @@ class LlavaImageEmbeddingInputs(TypedDict): class LlavaMultiModalProjector(nn.Module): def __init__(self, vision_hidden_size: int, text_hidden_size: int, - projector_hidden_act: str): + projector_hidden_act: str, quant_config: QuantizationConfig=None, + prefix=""): super().__init__() - self.linear_1 = nn.Linear(vision_hidden_size, - text_hidden_size, - bias=True) + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") self.act = get_act_fn(projector_hidden_act) - self.linear_2 = nn.Linear(text_hidden_size, - text_hidden_size, - bias=True) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.linear_1(image_features) + hidden_states, _ = self.linear_1(image_features) hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) return hidden_states @@ -325,7 +331,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act) + projector_hidden_act=config.projector_hidden_act, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, From 96859730c87b2636352b4cb228c7756d597c0aae Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 2 Dec 2024 13:59:41 +0800 Subject: [PATCH 2/6] add weight loading tracker to bnb loader Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 7255cda541d8b..07cfaf2c12aea 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1120,7 +1120,15 @@ def _load_weights(self, model_config: ModelConfig, model_config.revision, pre_quant, load_8bit)) - model.load_weights(qweight_iterator) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") torch.cuda.empty_cache() @@ -1131,6 +1139,11 @@ def _load_weights(self, model_config: ModelConfig, from vllm.model_executor.models.utils import is_pp_missing_parameter for quant_param_name in quant_state_dict: + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. + if quant_param_name not in param_dict: + continue + if is_pp_missing_parameter(quant_param_name, model): continue @@ -1152,10 +1165,6 @@ def _load_weights(self, model_config: ModelConfig, shard_name, weight_name) break - # if quant_param_name not in param_dict: - # raise ValueError( - # f"Parameter {quant_param_name} not found in the model.") - if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} From 8cd3d894759af84c23e991f0565994fea18d9db0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 2 Dec 2024 23:51:01 +0800 Subject: [PATCH 3/6] fix quant_param_name not in param_dict location Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 07cfaf2c12aea..487277cf32d6b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1139,11 +1139,6 @@ def _load_weights(self, model_config: ModelConfig, from vllm.model_executor.models.utils import is_pp_missing_parameter for quant_param_name in quant_state_dict: - # Models like Clip/Siglip may skip some layers in initialization, - # causing unused quant_param_name in state_dict. - if quant_param_name not in param_dict: - continue - if is_pp_missing_parameter(quant_param_name, model): continue @@ -1165,6 +1160,11 @@ def _load_weights(self, model_config: ModelConfig, shard_name, weight_name) break + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. + if quant_param_name not in param_dict: + continue + if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} From 40c44fc22e33c2fcd6e32b9e1a0066154ba00cf1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 2 Dec 2024 23:58:52 +0800 Subject: [PATCH 4/6] address todo Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/llava.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c2b61d43d7962..d875874a851c8 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -60,7 +60,6 @@ class LlavaImageEmbeddingInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs] -# TODO(xwjiang): Run benchmark and decide if TP. class LlavaMultiModalProjector(nn.Module): def __init__(self, vision_hidden_size: int, text_hidden_size: int, From f904d328019b3e85ea319f4a70f8bfa566712563 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 3 Dec 2024 00:00:37 +0800 Subject: [PATCH 5/6] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 5 ++--- vllm/model_executor/models/llava.py | 10 +++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 487277cf32d6b..b4921cc80797f 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1126,9 +1126,8 @@ def _load_weights(self, model_config: ModelConfig, if loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") torch.cuda.empty_cache() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d875874a851c8..931b59e3682d5 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -13,7 +13,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -62,8 +63,11 @@ class LlavaImageEmbeddingInputs(TypedDict): class LlavaMultiModalProjector(nn.Module): - def __init__(self, vision_hidden_size: int, text_hidden_size: int, - projector_hidden_act: str, quant_config: QuantizationConfig=None, + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, prefix=""): super().__init__() From ae080c865a57b686ed2bc3791e298fb4d1653ab3 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 3 Dec 2024 00:10:34 +0800 Subject: [PATCH 6/6] Update vllm/model_executor/models/llava.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/llava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 931b59e3682d5..d375c1c9da2a9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -68,7 +68,7 @@ def __init__(self, text_hidden_size: int, projector_hidden_act: str, quant_config: Optional[QuantizationConfig] = None, - prefix=""): + prefix: str = ""): super().__init__() self.linear_1 = ColumnParallelLinear(vision_hidden_size,