From 8d72bb20fae1a8a9d6ec6dcb2a833a190e1225d3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 4 Nov 2024 08:51:31 -0800 Subject: [PATCH] [4/N] make quant config first-class citizen (#9978) Signed-off-by: youkaichao --- vllm/config.py | 38 ++++++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 34 ++----------------- 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0870eb9f70709..814e00c8785f0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -23,9 +23,13 @@ from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.loader import BaseModelLoader from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) +else: + QuantizationConfig = None logger = init_logger(__name__) @@ -1966,6 +1970,35 @@ class VllmConfig: decoding_config: Optional[DecodingConfig] = None observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None + quant_config: Optional[QuantizationConfig] = None + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import ( + get_quant_config) + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1983,3 +2016,8 @@ def __post_init__(self): if self.prompt_adapter_config: self.prompt_adapter_config.verify_with_model_config( self.model_config) + + if self.quant_config is None and \ + self.model_config is not None and self.load_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 07adf7c01eaaf..5edb951343ae0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, + get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models import (has_inner_state, supports_lora, @@ -93,32 +93,6 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: - """Get the quantization config.""" - if model_config.quantization is not None: - quant_config = get_quant_config(model_config, load_config) - capability_tuple = current_platform.get_device_capability() - - if capability_tuple is not None: - capability = capability_tuple.to_int() - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} " - "is not supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - return quant_config - return None - - def _get_model_initialization_kwargs( model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], @@ -185,7 +159,6 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module: lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config - load_config = vllm_config.load_config model_class, _ = get_model_architecture(model_config) return build_model( @@ -193,7 +166,7 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module: vllm_config, model_config.hf_config, cache_config=cache_config, - quant_config=_get_quantization_config(model_config, load_config), + quant_config=vllm_config.quant_config, lora_config=lora_config, multimodal_config=model_config.multimodal_config, scheduler_config=scheduler_config, @@ -518,8 +491,7 @@ def _load_model_serialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config( - model_config, self.load_config) + quant_config = vllm_config.quant_config extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, model_config.multimodal_config) extra_kwargs["quant_config"] = quant_config