From d37339ef93b152222686e0d1de4ade9119e5575f Mon Sep 17 00:00:00 2001 From: Gabriel Marinho Date: Tue, 14 Jan 2025 14:58:28 -0300 Subject: [PATCH] [FEATURE] Enables /score endpoint for embedding models Signed-off-by: Gabriel Marinho --- vllm/config.py | 5 ++ vllm/entrypoints/llm.py | 121 ++++++++++++++++++++++++++++------------ 2 files changed, 89 insertions(+), 37 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 59b509d5a961e..6c2472df973d2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -929,6 +929,11 @@ def is_cross_encoder(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_cross_encoder_model(architectures) + @property + def has_pooling(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_pooling_model(architectures) + @property def supported_runner_types(self) -> Set[RunnerType]: return {_TASK_RUNNER[task] for task in self.supported_tasks} diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index acb4db85632a8..d00c0ec33b06a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -4,6 +4,7 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, overload) +import torch from tqdm import tqdm from typing_extensions import deprecated @@ -998,25 +999,16 @@ def score( raise ValueError(" ".join(messages)) - if not self.llm_engine.model_config.is_cross_encoder: - raise ValueError("Your model does not support cross encoding") - if self.llm_engine.model_config.task != "score": - raise ValueError("Score API is only enabled for `--task score`") - - tokenizer = self.llm_engine.get_tokenizer() - - if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") - # the tokenizer for models such as # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing # lists of tokens to the `text` and `text_pair` kwargs + tokenizer = self.llm_engine.get_tokenizer() + def ensure_str(prompt: SingletonPrompt): if isinstance(prompt, dict): if "multi_modal_data" in prompt: raise ValueError("Multi-modal prompt is not " - "supported for cross encoding") + "supported for scoring") elif "prompt_token_ids" in prompt: prompt = tokenizer.decode( cast(TokensPrompt, prompt)["prompt_token_ids"]) @@ -1045,37 +1037,92 @@ def ensure_str(prompt: SingletonPrompt): if len(text_1) == 1: text_1 = text_1 * len(text_2) - input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] - pooling_params = PoolingParams() + if self.llm_engine.model_config.task != "score": + raise ValueError("Score API is only enabled for `--task score`") - tokenization_kwargs: Dict[str, Any] = {} - if truncate_prompt_tokens is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens + if self.llm_engine.model_config.is_cross_encoder: + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] - parsed_prompts = [] + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") - for q, t in input_pairs: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) + pooling_params = PoolingParams() - self._validate_and_add_requests( - prompts=parsed_prompts, - params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) - outputs = self._run_engine(use_tqdm=use_tqdm) - items = self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + outputs = self._run_engine(use_tqdm=use_tqdm) + items = self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) - return [ScoringRequestOutput.from_base(item) for item in items] + return [ScoringRequestOutput.from_base(item) for item in items] + + elif self.llm_engine.model_config.has_pooling: + text = text_1 + text_2 + encoded_text = self.encode(text) + + encoded_text_1 = encoded_text[0:len(text_1)] + encoded_text_2 = encoded_text[len(text_1):] + + input_pairs_2 = [(t1, t2) + for t1, t2 in zip(encoded_text_1, encoded_text_2)] + + scores = [] + cosSim = torch.nn.CosineSimilarity(0) + + if tokenizer.pad_token is None: + for token_1, token_2 in input_pairs_2: + pairs_score = cosSim(token_1.outputs.data, + token_2.outputs.data) + tokens = token_2.prompt_token_ids + token_2.prompt_token_ids + + scores.append( + PoolingRequestOutput(request_id="unk", + outputs=pairs_score, + prompt_token_ids=tokens, + finished=True)) + else: + for token_1, token_2 in input_pairs_2: + pairs_score = cosSim(token_1.outputs.data, + token_2.outputs.data) + tokens = token_1.prompt_token_ids + [ + tokenizer.pad_token_type_id + ] + token_2.prompt_token_ids + + scores.append( + PoolingRequestOutput(request_id="unk", + outputs=pairs_score, + prompt_token_ids=tokens, + finished=True)) + + items = self.engine_class.validate_outputs(scores, + PoolingRequestOutput) + + return [ScoringRequestOutput.from_base(item) for item in items] + + raise ValueError( + "Your model does not support cross encoding and or pooling.") def start_profile(self) -> None: self.llm_engine.start_profile()