Skip to content

Commit

Permalink
[FEATURE] Enables /score endpoint for embedding models
Browse files Browse the repository at this point in the history
Signed-off-by: Gabriel Marinho <[email protected]>
  • Loading branch information
gmarinho2 committed Jan 14, 2025
1 parent f35ec46 commit d37339e
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 37 deletions.
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,11 @@ def is_cross_encoder(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_cross_encoder_model(architectures)

@property
def has_pooling(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_pooling_model(architectures)

@property
def supported_runner_types(self) -> Set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
Expand Down
121 changes: 84 additions & 37 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)

import torch
from tqdm import tqdm
from typing_extensions import deprecated

Expand Down Expand Up @@ -998,25 +999,16 @@ def score(

raise ValueError(" ".join(messages))

if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support cross encoding")
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")

tokenizer = self.llm_engine.get_tokenizer()

if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()

def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not "
"supported for cross encoding")
"supported for scoring")
elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"])
Expand Down Expand Up @@ -1045,37 +1037,92 @@ def ensure_str(prompt: SingletonPrompt):
if len(text_1) == 1:
text_1 = text_1 * len(text_2)

input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
if self.llm_engine.model_config.is_cross_encoder:
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

parsed_prompts = []
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
pooling_params = PoolingParams()

self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)

self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)

return [ScoringRequestOutput.from_base(item) for item in items]
return [ScoringRequestOutput.from_base(item) for item in items]

elif self.llm_engine.model_config.has_pooling:
text = text_1 + text_2
encoded_text = self.encode(text)

encoded_text_1 = encoded_text[0:len(text_1)]
encoded_text_2 = encoded_text[len(text_1):]

input_pairs_2 = [(t1, t2)
for t1, t2 in zip(encoded_text_1, encoded_text_2)]

scores = []
cosSim = torch.nn.CosineSimilarity(0)

if tokenizer.pad_token is None:

Check failure on line 1094 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "MistralTokenizer" of "Union[Any, MistralTokenizer]" has no attribute "pad_token" [union-attr]

Check failure on line 1094 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "MistralTokenizer" of "Any | MistralTokenizer" has no attribute "pad_token" [union-attr]

Check failure on line 1094 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "MistralTokenizer" of "Any | MistralTokenizer" has no attribute "pad_token" [union-attr]

Check failure on line 1094 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "MistralTokenizer" of "Any | MistralTokenizer" has no attribute "pad_token" [union-attr]
for token_1, token_2 in input_pairs_2:
pairs_score = cosSim(token_1.outputs.data,
token_2.outputs.data)
tokens = token_2.prompt_token_ids + token_2.prompt_token_ids

scores.append(
PoolingRequestOutput(request_id="unk",
outputs=pairs_score,
prompt_token_ids=tokens,
finished=True))
else:
for token_1, token_2 in input_pairs_2:
pairs_score = cosSim(token_1.outputs.data,
token_2.outputs.data)
tokens = token_1.prompt_token_ids + [
tokenizer.pad_token_type_id

Check failure on line 1110 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "MistralTokenizer" of "Union[Any, MistralTokenizer]" has no attribute "pad_token_type_id" [union-attr]

Check failure on line 1110 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "MistralTokenizer" of "Any | MistralTokenizer" has no attribute "pad_token_type_id" [union-attr]

Check failure on line 1110 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "MistralTokenizer" of "Any | MistralTokenizer" has no attribute "pad_token_type_id" [union-attr]

Check failure on line 1110 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "MistralTokenizer" of "Any | MistralTokenizer" has no attribute "pad_token_type_id" [union-attr]
] + token_2.prompt_token_ids

scores.append(
PoolingRequestOutput(request_id="unk",
outputs=pairs_score,
prompt_token_ids=tokens,
finished=True))

items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput)

return [ScoringRequestOutput.from_base(item) for item in items]

raise ValueError(
"Your model does not support cross encoding and or pooling.")

def start_profile(self) -> None:
self.llm_engine.start_profile()
Expand Down

0 comments on commit d37339e

Please sign in to comment.