From 7bb6d4e63f8ab12daa925307936ebb76ba2904a3 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 5 Nov 2024 01:04:40 +0800 Subject: [PATCH] [Misc]Reduce BNB static variable (#9987) Signed-off-by: Jee Jee Li --- vllm/model_executor/model_loader/loader.py | 40 +++++++++++----------- vllm/model_executor/models/falcon.py | 2 -- vllm/model_executor/models/gemma.py | 3 -- vllm/model_executor/models/gemma2.py | 2 -- vllm/model_executor/models/llama.py | 2 -- vllm/model_executor/models/minicpmv.py | 8 ----- vllm/model_executor/models/mllama.py | 2 -- vllm/model_executor/models/opt.py | 2 -- vllm/model_executor/models/phi.py | 2 -- vllm/model_executor/models/qwen2.py | 3 -- 10 files changed, 20 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 5edb951343ae0..c3e0290f270ae 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -28,7 +28,8 @@ get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( @@ -727,6 +728,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] + # Save the module names that are sharded by column. + self.column_sharded_weights_modules: List[str] = [] # we don't need to quantize the whole model, only the target modules # that are specified in the adapter config file. If the adapter config # file is not provided, we will quantize the default modules. @@ -744,8 +749,6 @@ def __init__(self, load_config: LoadConfig): with open(config_file_path, "r") as f: config = json.load(f) self.target_modules = config["target_modules"] - # Save the module names without sharding. - self.unsharded_weights_modules: List[str] = [] def _get_config_file(self, qlora_adapter: str) -> str: is_local = os.path.isdir(qlora_adapter) @@ -971,9 +974,9 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, for module in self.unsharded_weights_modules): weight_sub_tensor = weight_tensor # Shard by column - elif any(module in weight_name - for module in self.column_parallel_weights_modules): - + elif any( + weight_name.startswith(module) + for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) @@ -1028,20 +1031,17 @@ def _load_weights(self, model_config: ModelConfig, else: self.target_modules = self.default_target_modules - if hasattr(model, 'column_parallel_weights_modules'): - self.column_parallel_weights_modules = \ - model.column_parallel_weights_modules - else: - self.column_parallel_weights_modules = [] - # Some modules like `ReplicatedLinear` should not have their weights - # sharded. The reason for implementing it this way is to avoid new - # static variable in the model implementation. - # TODO: Can we reduce the static variables needed for BNB based on - # model information? - self.unsharded_weights_modules = [ - name for name, module in model.named_modules() - if isinstance(module, (ReplicatedLinear, )) - ] + for name, module in model.named_modules(): + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + if isinstance(module, (ReplicatedLinear, )): + self.unsharded_weights_modules.append(name) + # In TP, these weights are partitioned along the column + # dimension (dim=-1) + elif isinstance(module, (RowParallelLinear, )): + self.column_sharded_weights_modules.append(name) + self.model_type = type(model).__name__ logger.info("Loading weights with BitsAndBytes quantization. " diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 36c85e37783ab..c376347811965 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -401,8 +401,6 @@ class FalconForCausalLM(nn.Module, SupportsPP): ".dense_h_to_4h.", ".dense_4h_to_h.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."] def __init__( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 57b2b43c82f89..029178af61da0 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -350,7 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "gate_up_proj", "down_proj", ] - # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", @@ -361,8 +360,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 693f32160a289..9238ed839c9de 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -390,8 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8a9e5203972be..38a31f420cec9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -464,8 +464,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c1f714bb25680..f90df6b7df036 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -854,10 +854,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): # resampler ".kv_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [ - ".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2." - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -1008,10 +1004,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): # resampler ".kv_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [ - ".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2." - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index a03155ac32a61..d30b9addd09f1 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1062,8 +1062,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): # so we can't add a dot in front of it. "multi_modal_projector." ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 10cca8b56268a..7521ab749e10f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -343,8 +343,6 @@ class OPTForCausalLM(nn.Module, SupportsPP): default_bitsandbytes_target_modules = [ ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".out_proj.", ".fc2."] def __init__( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 497eae4e8905b..4e7935a7636c5 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -274,8 +274,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): default_bitsandbytes_target_modules = [ ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense." ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".fc2.", ".dense."] embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index db7556b3b5f4b..72b286fe6f6d6 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -395,9 +395,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0),