Skip to content

Commit

Permalink
[FEATURE] Enables offline /score for embedding models
Browse files Browse the repository at this point in the history
Signed-off-by: Gabriel Marinho <[email protected]>
  • Loading branch information
gmarinho2 committed Jan 16, 2025
1 parent f35ec46 commit 8afdaae
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 37 deletions.
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,11 @@ def is_cross_encoder(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_cross_encoder_model(architectures)

@property
def has_pooling(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_pooling_model(architectures)

@property
def supported_runner_types(self) -> Set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
Expand Down
118 changes: 81 additions & 37 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)

import torch
from tqdm import tqdm
from typing_extensions import deprecated

Expand Down Expand Up @@ -983,6 +984,7 @@ def score(
A list of ``ScoringRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""

runner_type = self.llm_engine.model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.score() is only supported for pooling models."]
Expand All @@ -998,25 +1000,20 @@ def score(

raise ValueError(" ".join(messages))

if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support cross encoding")
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")

tokenizer = self.llm_engine.get_tokenizer()

if isinstance(tokenizer, MistralTokenizer):
if self.llm_engine.model_config.task not in ["embed", "score"]:
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
"Score API is only enabled for `--task embed or score`")

# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()

def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not "
"supported for cross encoding")
"supported for scoring")
elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"])
Expand All @@ -1042,40 +1039,87 @@ def ensure_str(prompt: SingletonPrompt):
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")

if len(text_1) == 1:
text_1 = text_1 * len(text_2)
if self.llm_engine.model_config.is_cross_encoder:
if len(text_1) == 1:
text_1 = text_1 * len(text_2)

input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
input_pairs = [(t1, t2)
for t1, t2 in zip(text_1, text_2)] #output_pairs?

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

parsed_prompts = []
pooling_params = PoolingParams()

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)

self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)

outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]

elif self.llm_engine.model_config.runner_type == "pooling":
text = text_1 + text_2
encoded_text = self.encode(text)
encoded_text_1 = encoded_text[0:len(text_1)]
encoded_text_2 = encoded_text[len(text_1):]

if len(encoded_text_1) == 1:
encoded_text_1 = encoded_text_1 * len(encoded_text_2)

output_pairs = [(t1, t2)
for t1, t2 in zip(encoded_text_1, encoded_text_2)]

scores = []
cosSim = torch.nn.CosineSimilarity(0)

for embed_1, embed_2 in output_pairs:
pairs_score = cosSim(embed_1.outputs.data,
embed_2.outputs.data)

if getattr(tokenizer, "pad_token", None) is None:
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
else:
tokens = embed_1.prompt_token_ids + [
tokenizer.pad_token_type_id
] + embed_2.prompt_token_ids

scores.append(
PoolingRequestOutput(
request_id="unk", #unk ou id incremental?
outputs=pairs_score,
prompt_token_ids=tokens,
finished=True))

items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]

return [ScoringRequestOutput.from_base(item) for item in items]
raise ValueError(
"Your model does not support cross encoding and pooling.")

def start_profile(self) -> None:
self.llm_engine.start_profile()
Expand Down

0 comments on commit 8afdaae

Please sign in to comment.