diff --git a/vllm/config.py b/vllm/config.py index b88b0decaa1ca..df8e37a13285d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -278,20 +278,17 @@ def _init_pooler_config( pooling_returned_token_ids: Optional[List[int]] = None ) -> Optional["PoolerConfig"]: if self.task == "embedding": - pooling_type_ = pooling_type - normalize_ = pooling_norm pooling_config = get_pooling_config(self.model, self.revision) if pooling_config is not None: - pooling_type_ = pooling_config["pooling_type"] - normalize_ = pooling_config["normalize"] - # override if user specifies pooling_type and/or pooling_norm - if pooling_type is not None: - pooling_type_ = pooling_type - if pooling_norm is not None: - normalize_ = pooling_norm + # override if user does not + # specifies pooling_type and/or pooling_norm + if pooling_type is None: + pooling_type = pooling_config["pooling_type"] + if pooling_norm is None: + pooling_norm = pooling_config["normalize"] return PoolerConfig( - pooling_type=pooling_type_, - pooling_norm=normalize_, + pooling_type=pooling_type, + pooling_norm=pooling_norm, pooling_softmax=pooling_softmax, pooling_step_tag_id=pooling_step_tag_id, pooling_returned_token_ids=pooling_returned_token_ids) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 253b345c3f6ea..231a32919effb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -235,9 +235,9 @@ def get_config( return config -def get_hf_file_to_dict(file_name, - model, - revision, +def get_hf_file_to_dict(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main', token: Optional[str] = None): """ Downloads a file from the Hugging Face Hub and returns @@ -273,13 +273,13 @@ def get_hf_file_to_dict(file_name, file_path = Path(hf_hub_file) with open(file_path, "r") as file: - config_dict = json.load(file) - - return config_dict + return json.load(file) return None -def get_pooling_config(model, revision='main', token: Optional[str] = None): +def get_pooling_config(model: str, + revision: Optional[str] = 'main', + token: Optional[str] = None): """ This function gets the pooling and normalize config from the model - only applies to @@ -326,7 +326,7 @@ def get_pooling_config(model, revision='main', token: Optional[str] = None): return None -def get_pooling_config_name(pooling_name): +def get_pooling_config_name(pooling_name: str) -> Union[str, None]: if "pooling_mode_" in pooling_name: pooling_name = pooling_name.replace("pooling_mode_", "") @@ -345,10 +345,11 @@ def get_pooling_config_name(pooling_name): except NotImplementedError as e: logger.debug("Pooling type not supported", e) return None + return None -def get_sentence_transformer_tokenizer_config(model, - revision='main', +def get_sentence_transformer_tokenizer_config(model: str, + revision: Optional[str] = 'main', token: Optional[str] = None): """ Returns the tokenization configuration dictionary for a @@ -448,8 +449,8 @@ def _reduce_modelconfig(mc: ModelConfig): exc_info=e) -def load_params_config(model, - revision, +def load_params_config(model: Union[str, Path], + revision: Optional[str], token: Optional[str] = None) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format