From ca0d5f31d01a6688ac1b15ad0f6024d27f0dfb23 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 6 Nov 2024 09:27:06 -0500 Subject: [PATCH] Remove ScaledActivation for AWQ (#10057) Signed-off-by: mgoin Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/activation.py | 37 ++----------------- .../layers/quantization/aqlm.py | 3 -- .../model_executor/layers/quantization/awq.py | 3 -- .../layers/quantization/awq_marlin.py | 3 -- .../layers/quantization/base_config.py | 8 ---- .../layers/quantization/bitsandbytes.py | 3 -- .../compressed_tensors/compressed_tensors.py | 3 -- .../layers/quantization/deepspeedfp.py | 3 -- .../layers/quantization/experts_int8.py | 3 -- .../layers/quantization/fbgemm_fp8.py | 3 -- .../model_executor/layers/quantization/fp8.py | 3 -- .../layers/quantization/gguf.py | 3 -- .../layers/quantization/gptq.py | 3 -- .../layers/quantization/gptq_marlin.py | 3 -- .../layers/quantization/gptq_marlin_24.py | 3 -- .../layers/quantization/ipex_quant.py | 6 --- .../layers/quantization/marlin.py | 3 -- .../layers/quantization/modelopt.py | 3 -- .../layers/quantization/neuron_quant.py | 3 -- .../model_executor/layers/quantization/qqq.py | 3 -- .../layers/quantization/tpu_int8.py | 3 -- vllm/model_executor/models/bart.py | 8 ++-- vllm/model_executor/models/bloom.py | 2 +- vllm/model_executor/models/falcon.py | 2 +- vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/mpt.py | 2 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/persimmon.py | 2 +- vllm/model_executor/models/phi.py | 2 +- vllm/model_executor/models/qwen.py | 2 +- vllm/model_executor/models/starcoder2.py | 3 +- 34 files changed, 19 insertions(+), 124 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index e347ca80ff765..34d65ed51ef3f 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -9,7 +9,6 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs from vllm.utils import LazyDict @@ -277,28 +276,14 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): }) -def get_act_fn( - act_fn_name: str, - quant_config: Optional[QuantizationConfig] = None, - intermediate_size: Optional[int] = None, - input_is_parallel: bool = True, - params_dtype: Optional[torch.dtype] = None, -) -> nn.Module: +def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() if act_fn_name not in _ACTIVATION_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") - act_fn = _ACTIVATION_REGISTRY[act_fn_name] - if (quant_config is not None - and act_fn_name in quant_config.get_scaled_act_names()): - if intermediate_size is None: - raise ValueError("intermediate_size must be specified for scaled " - "activation functions.") - return ScaledActivation(act_fn, intermediate_size, input_is_parallel, - params_dtype) - return act_fn + return _ACTIVATION_REGISTRY[act_fn_name] _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ @@ -307,25 +292,11 @@ def get_act_fn( }) -def get_act_and_mul_fn( - act_fn_name: str, - quant_config: Optional[QuantizationConfig] = None, - intermediate_size: Optional[int] = None, - input_is_parallel: bool = True, - params_dtype: Optional[torch.dtype] = None, -) -> nn.Module: +def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" act_fn_name = act_fn_name.lower() if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") - act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] - if (quant_config is not None - and act_fn_name in quant_config.get_scaled_act_names()): - if intermediate_size is None: - raise ValueError("intermediate_size must be specified for scaled " - "activation functions.") - return ScaledActivation(act_fn, intermediate_size, input_is_parallel, - params_dtype) - return act_fn + return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index c88ca340ebcc5..72c89fe2b0e48 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -213,9 +213,6 @@ def get_quant_method(self, layer: torch.nn.Module, return AQLMLinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class AQLMLinearMethod(LinearMethodBase): """Linear method for AQLM. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 38dd1f2e10fcd..d83528e9ec79c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -77,9 +77,6 @@ def get_quant_method(self, layer: torch.nn.Module, return AWQLinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] - def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): return any(module_name in prefix for module_name in modules_to_not_convert) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index ea69bee45f8d9..4d1a837d11585 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -127,9 +127,6 @@ def get_quant_method(self, layer: torch.nn.Module, return AWQMoEMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - @classmethod def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 75fa8249cd3c2..6dfac8aad5358 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -133,11 +133,3 @@ def get_quant_method(self, layer: torch.nn.Module, method. """ raise NotImplementedError - - @abstractmethod - def get_scaled_act_names(self) -> List[str]: - """Returns the activation function names that should be post-scaled. - - For now, this is only used by AWQ. - """ - raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 78965d7b9495c..39965ac9115c2 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -114,9 +114,6 @@ def get_quant_method(self, layer: torch.nn.Module, return BitsAndBytesLinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): # Split the prefix into its dot-separated components diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ecc345f116c37..4f5758a42dbbc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -45,9 +45,6 @@ def __init__(self, def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) - def get_scaled_act_names(self) -> List[str]: - return [] - def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.float16, torch.bfloat16] diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 29484801dc380..36598b3e2990f 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -50,9 +50,6 @@ def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": def get_linear_method(self) -> "DeepSpeedFPLinearMethod": return DeepSpeedFPLinearMethod(self) - def get_scaled_act_names(self) -> List[str]: - return [] - @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.half, torch.bfloat16] diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 116a4ea0aed89..97297970d9317 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -45,9 +45,6 @@ def get_quant_method(self, layer: torch.nn.Module, return ExpertsInt8MoEMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class ExpertsInt8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 825d01d1b3551..7b71e13b50ccc 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -64,9 +64,6 @@ def get_quant_method(self, layer: torch.nn.Module, return FBGEMMFp8LinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class FBGEMMFp8LinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d34579b7099bb..978e727bc7cb3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -92,9 +92,6 @@ def get_quant_method(self, layer: torch.nn.Module, return Fp8KVCacheMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index d73b9f6d92832..24138662eb25c 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -48,9 +48,6 @@ def get_quant_method(self, layer: torch.nn.Module, return GGUFEmbeddingMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 1cfadb4f42ca8..0aa605e62454e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -80,9 +80,6 @@ def get_quant_method(self, layer: torch.nn.Module, return GPTQLinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class ExllamaState(Enum): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b97dd108d6785..1f72e3afbbce5 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -125,9 +125,6 @@ def get_quant_method( return GPTQMarlinMoEMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - @classmethod def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 0971aedba4c3c..07552c0f13348 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -127,9 +127,6 @@ def get_quant_method(self, layer: torch.nn.Module, return GPTQMarlin24LinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class GPTQMarlin24LinearMethod(LinearMethodBase): """Linear method for Marlin24. diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index e54052632e468..43f4502f7455c 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -93,12 +93,6 @@ def get_quant_method(self, layer: torch.nn.Module, return self.quant_method(self) return None - def get_scaled_act_names(self) -> List[str]: - if self.method == "awq": - return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] - else: - return [] - class IPEXAWQLinearMethod(AWQLinearMethod): """AWQ linear method using IPEX for the CPU backend. diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 8f1b5370b4538..20212e672eab0 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -110,9 +110,6 @@ def get_quant_method(self, layer: torch.nn.Module, return MarlinLinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9694f2b8208e2..a1b3eeb43cbee 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -68,9 +68,6 @@ def get_quant_method(self, layer: torch.nn.Module, return ModelOptFp8KVCacheMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index 2624981f6a614..2d5cdfa165775 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -57,9 +57,6 @@ def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: "Neuron Quantization is only supported through" " transformers_neuronx.") - def get_scaled_act_names(self) -> List[str]: - return [] - def get_quantization_config(self): from transformers_neuronx.config import QuantizationConfig return QuantizationConfig(quant_dtype=self.quant_dtype, diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index 5bc3737520865..2ccd082029610 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -112,9 +112,6 @@ def get_quant_method(self, layer: torch.nn.Module, return QQQLinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class QQQLinearMethod(LinearMethodBase): """Linear method for QQQ. diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index be8235b468f68..605c3a38644ac 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -50,9 +50,6 @@ def get_quant_method(self, layer: Module, return TPUInt8LinearMethod(self) return None - def get_scaled_act_names(self) -> List[str]: - return [] - class TPUInt8LinearMethod(LinearMethodBase): """Int8 Linear method for TPU Quant. """ diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 0543ca978b7dd..85de1a8115b8b 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -393,8 +393,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.activation_fn = get_act_fn(config.activation_function, - quant_config) + self.activation_fn = get_act_fn(config.activation_function) ffn_hidden_size = self.embed_dim ffn_intermediate_size = config.encoder_ffn_dim @@ -405,7 +404,7 @@ def __init__( bias=ffn_has_bias, quant_config=quant_config, ) - self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size) + self.act = get_act_fn("gelu") self.fc2 = RowParallelLinear( ffn_intermediate_size, ffn_hidden_size, @@ -473,8 +472,7 @@ def __init__( config=config, cache_config=cache_config, quant_config=quant_config) - self.activation_fn = get_act_fn(config.activation_function, - quant_config) + self.activation_fn = get_act_fn(config.activation_function) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) ''' diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 83ff39a30fbe3..b2c109a21d4cf 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -146,7 +146,7 @@ def __init__( 4 * hidden_size, quant_config=quant_config, ) - self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) + self.gelu_impl = get_act_fn("gelu") self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ad07fc3b3776e..6f8a7a7015c79 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -212,7 +212,7 @@ def __init__( bias=config.bias, skip_bias_add=True, quant_config=quant_config) - self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) + self.act = get_act_fn("gelu") self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) self.dense_4h_to_h = RowParallelLinear( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index a06200c4b7e08..8147037ed2a32 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -123,8 +123,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.c_proj", ) - self.act = get_act_fn(config.activation_function, quant_config, - intermediate_size) + self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 7612ea641d95c..9f44fa76abcba 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -135,8 +135,7 @@ def __init__( bias=True, quant_config=quant_config, ) - self.act = get_act_fn(config.activation_function, quant_config, - intermediate_size) + self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index b28a6081b868f..6fcccdfb112d8 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -130,8 +130,7 @@ def __init__( hidden_size, quant_config=quant_config, ) - self.act = get_act_fn(config.activation_function, quant_config, - intermediate_size) + self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc_in(hidden_states) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 931052c7cccf0..d3f86558ecc7e 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -128,8 +128,7 @@ def __init__( config.hidden_size, quant_config=quant_config, ) - self.act = get_act_fn(config.hidden_act, quant_config, - config.intermediate_size) + self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states): hidden_states, _ = self.dense_h_to_4h(hidden_states) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index fdd8af79b5470..7f0658f4cb2b0 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -153,7 +153,7 @@ def __init__( bias=not config.no_bias, quant_config=quant_config, ) - self.act = get_act_fn("gelu", quant_config, intermediate_size) + self.act = get_act_fn("gelu") self.down_proj = RowParallelLinear( intermediate_size, hidden_size, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 7a76e4a0906db..d140f4237b1ca 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -147,8 +147,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.activation_fn = get_act_fn(config.activation_function, - quant_config, config.ffn_dim) + self.activation_fn = get_act_fn(config.activation_function) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index bd4a9f698bacd..112bf6f3ed1af 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -60,7 +60,7 @@ def __init__(self, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, quant_config=quant_config) - self.act = get_act_fn(config.hidden_act, quant_config) + self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states) -> torch.Tensor: hidden_states, _ = self.dense_h_to_4h(hidden_states) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 492122450b237..d308f4913314c 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -152,7 +152,7 @@ def __init__(self, config.hidden_size, quant_config=quant_config, ) - self.act = get_act_fn(config.hidden_act, quant_config, n_inner) + self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states): hidden_states, _ = self.fc1(hidden_states) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3a0e33e8a3eff..4044ddbbcca3d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -203,7 +203,7 @@ def __init__( intermediate_size, bias=True, quant_config=quant_config) - self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) + self.act_fn = get_act_fn("gelu") self.c_proj = RowParallelLinear( intermediate_size, hidden_size, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index b24c5dadb2b2b..a5e4155fb4d2c 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -139,8 +139,7 @@ def __init__(self, bias=config.use_bias, quant_config=quant_config, ) - self.act = get_act_fn(config.hidden_act, quant_config, - config.intermediate_size) + self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states)