From 42adcc81a9ac6d3072af31720e2709b0baaba1ff Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Sun, 24 Nov 2024 23:56:20 -0300 Subject: [PATCH] Support Cross encoder models (#10400) Signed-off-by: Max de Bayser Signed-off-by: Max de Bayser Signed-off-by: Flavia Beo Co-authored-by: Flavia Beo Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- .../serving/openai_compatible_server.md | 142 ++++++++++++ examples/openai_cross_encoder_score.py | 58 +++++ tests/conftest.py | 20 ++ tests/entrypoints/openai/test_score.py | 93 ++++++++ .../models/embedding/language/test_scoring.py | 95 ++++++++ tests/models/registry.py | 9 + tests/models/test_registry.py | 23 +- vllm/config.py | 5 + vllm/core/scheduler.py | 1 + vllm/entrypoints/llm.py | 124 +++++++++- vllm/entrypoints/openai/api_server.py | 35 ++- vllm/entrypoints/openai/protocol.py | 36 +++ vllm/entrypoints/openai/serving_score.py | 215 ++++++++++++++++++ vllm/inputs/data.py | 18 ++ vllm/inputs/preprocess.py | 2 + vllm/model_executor/layers/pooler.py | 64 ++++++ vllm/model_executor/models/bert.py | 128 ++++++++++- vllm/model_executor/models/interfaces.py | 36 +++ vllm/model_executor/models/registry.py | 23 +- vllm/model_executor/models/roberta.py | 179 ++++++++++++--- vllm/multimodal/inputs.py | 5 +- vllm/outputs.py | 45 +++- vllm/sequence.py | 9 + vllm/transformers_utils/config.py | 15 ++ vllm/worker/cpu_embedding_model_runner.py | 4 + vllm/worker/cpu_model_runner.py | 13 ++ vllm/worker/embedding_model_runner.py | 7 +- vllm/worker/model_runner.py | 28 +++ 28 files changed, 1370 insertions(+), 62 deletions(-) create mode 100644 examples/openai_cross_encoder_score.py create mode 100644 tests/entrypoints/openai/test_score.py create mode 100644 tests/models/embedding/language/test_scoring.py create mode 100644 vllm/entrypoints/openai/serving_score.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 79d032bf8b211..c39cef85897ed 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -44,6 +44,148 @@ We currently support the following OpenAI APIs: - This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst). - *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.* +## Score API for Cross Encoder Models + +vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). + +A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1. + +### Example of usage for a pair of a string and a list of texts + +In this case, the model will compare the first given text to each of the texts containing the list. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "text_1": "What is the capital of France?", + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693570, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 0.001094818115234375 + ] + }, + { + "index": 1, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + +### Example of usage for a pair of two lists of texts + +In this case, the model will compare the one by one, making pairs by same index correspondent in each list. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": [ + "What is the capital of Brazil?", + "What is the capital of France?" + ], + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 1 + ] + }, + { + "index": 1, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + +### Example of usage for a pair of two strings + +In this case, the model will compare the strings of texts. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": "What is the capital of France?", + "text_2": "The capital of France is Paris." +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + ## Extra Parameters vLLM supports a set of parameters that are not part of the OpenAI API. diff --git a/examples/openai_cross_encoder_score.py b/examples/openai_cross_encoder_score.py new file mode 100644 index 0000000000000..8c32eea5dd252 --- /dev/null +++ b/examples/openai_cross_encoder_score.py @@ -0,0 +1,58 @@ +"""Examples Python client Score for Cross Encoder Models +""" + +import argparse +import json +import pprint + +import requests + + +def post_http_request(prompt: json, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3") + args = parser.parse_args() + api_url = f"http://{args.host}:{args.port}/v1/score" + + model_name = args.model + + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 is string and text_2 is a list:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) + + text_1 = [ + "What is the capital of Brazil?", "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 and text_2 are lists:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) + + text_1 = "What is the capital of Brazil?" + text_2 = "The capital of Brazil is Brasilia." + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 and text_2 are strings:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 0dc1cc6e83c18..29707f975e2a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -265,6 +265,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[..., BatchEncoding] = identity, @@ -282,6 +283,14 @@ def __init__( device="cpu", trust_remote_code=True, ).to(dtype=torch_dtype)) + elif is_cross_encoder: + # Lazy init required for AMD CI + from sentence_transformers import CrossEncoder + self.model = CrossEncoder(model_name, + device="cpu", + trust_remote_code=True) + self.model.model = self.wrap_device(self.model.model)\ + .to(dtype=torch_dtype) else: model_kwargs = model_kwargs if model_kwargs is not None else {} self.model = self.wrap_device( @@ -625,6 +634,9 @@ def generate_encoder_decoder_greedy_logprobs_limit( def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) + def predict(self, prompts: List[List[str]]) -> torch.Tensor: + return self.model.predict(prompts, convert_to_tensor=True) + def __enter__(self): return self @@ -898,6 +910,14 @@ def encode( req_outputs = self.model.encode(inputs) return [req_output.outputs.embedding for req_output in req_outputs] + def score( + self, + text_1: Union[str, List[str]], + text_2: Union[str, List[str]], + ) -> List[List[float]]: + req_outputs = self.model.score(text_1, text_2) + return [req_output.outputs.embedding for req_output in req_outputs] + def __enter__(self): return self diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py new file mode 100644 index 0000000000000..7565ff7192f67 --- /dev/null +++ b/tests/entrypoints/openai/test_score.py @@ -0,0 +1,93 @@ +import pytest +import requests + +from vllm.entrypoints.openai.protocol import ScoreResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-reranker-v2-m3" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_str_text_2_list(server: RemoteOpenAIServer, + model_name: str): + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + assert score.data[0].score[0] <= 0.01 + assert score.data[1].score[0] >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_list_text_2_list(server: RemoteOpenAIServer, + model_name: str): + text_1 = [ + "What is the capital of the United States?", + "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + assert score.data[0].score[0] <= 0.01 + assert score.data[1].score[0] >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_str_text_2_str(server: RemoteOpenAIServer, + model_name: str): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 1 + assert score.data[0].score[0] >= 0.9 diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/embedding/language/test_scoring.py new file mode 100644 index 0000000000000..30fa5ea7b36c0 --- /dev/null +++ b/tests/models/embedding/language/test_scoring.py @@ -0,0 +1,95 @@ +"""Compare the embedding outputs of HF and vLLM models. + +Run `pytest tests/models/embedding/language/test_embedding.py`. +""" +import math + +import pytest + +MODELS = [ + "cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert + "BAAI/bge-reranker-v2-m3", # Roberta +] + +TEXTS_1 = [ + "What is the capital of France?", + "What is the capital of Germany?", +] + +TEXTS_2 = [ + "The capital of France is Paris.", + "The capital of Germany is Berlin.", +] + + +@pytest.fixture(scope="module", params=MODELS) +def model_name(request): + yield request.param + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): + + text_pair = [TEXTS_1[0], TEXTS_2[0]] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict([text_pair]).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) + + assert len(vllm_outputs) == 1 + assert len(hf_outputs) == 1 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): + + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[0], TEXTS_2[1]], + ] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str): + + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[1], TEXTS_2[1]], + ] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3848367b6126c..fa0818c4f0bd1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -135,6 +135,7 @@ class _HfExamplesInfo: "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"), # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), @@ -143,6 +144,13 @@ class _HfExamplesInfo: "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 } +_CROSS_ENCODER_EXAMPLE_MODELS = { + # [Text-only] + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 +} + _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 @@ -195,6 +203,7 @@ class _HfExamplesInfo: _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, + **_CROSS_ENCODER_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, } diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index e462dae3dc688..289ea66b5ebc5 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,7 +6,10 @@ from vllm.model_executor.models import (is_embedding_model, is_text_generation_model, supports_multimodal) -from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS, + _EMBEDDING_MODELS, _MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, @@ -29,22 +32,28 @@ def test_registry_imports(model_arch): model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS) + embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS} assert is_embedding_model(model_cls) is (model_arch - in _EMBEDDING_MODELS) + in embedding_models) assert supports_multimodal(model_cls) is (model_arch in _MULTIMODAL_MODELS) @fork_new_process_for_each_test -@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ - ("LlamaForCausalLM", False, False), - ("MllamaForConditionalGeneration", True, False), - ("LlavaForConditionalGeneration", True, True), +@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ + ("LlamaForCausalLM", False, False, False), + ("MllamaForConditionalGeneration", True, False, False), + ("LlavaForConditionalGeneration", True, True, False), + ("BertForSequenceClassification", False, False, True), + ("RobertaForSequenceClassification", False, False, True), + ("XLMRobertaForSequenceClassification", False, False, True), ]) -def test_registry_is_multimodal(model_arch, is_mm, init_cuda): +def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce + if init_cuda and current_platform.is_cuda_alike(): assert not torch.cuda.is_initialized() diff --git a/vllm/config.py b/vllm/config.py index f163665e2c063..4ea56a14cabba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -712,6 +712,11 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_cross_encoder(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_cross_encoder_model(architectures) + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 841e65c488fc6..530cbdc3a9190 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1357,6 +1357,7 @@ def schedule( encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, + token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c211ec5aee080..e07f4c04abd84 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -20,7 +20,7 @@ apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -817,6 +817,128 @@ def encode( return self.engine_class.validate_outputs(outputs, EmbeddingRequestOutput) + def score( + self, + text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], + text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + /, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates similarity scores for all pairs . + + The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case + the text_1 sentence will be replicated N times to pair with the text_2 + sentences. The input pairs are used to build a list of prompts for the + cross encoder model. This class automatically batches the prompts, + considering the memory constraint. For the best performance, put all + of your texts into a single list and pass it to this method. + + Args: + text_1: can be a single prompt or a list of prompts, in which + case it has to have the same length as the text_2 list + text_2: The texts to pair with the query to form the input + to the LLM. See :class:`~vllm.inputs.PromptType` for + more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.score() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) + + if not self.llm_engine.model_config.is_cross_encoder: + raise ValueError("Your model does not support the cross encoding") + + tokenizer = self.llm_engine.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for cross encoding") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [ensure_str(t) for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [ensure_str(t) for t in text_2] + + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] + pooling_params = PoolingParams() + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + EmbeddingRequestOutput) + def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b0fe061f5db4a..2b1f14b89b1f2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -45,6 +45,7 @@ EmbeddingRequest, EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, + ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) @@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: return request.app.state.openai_serving_embedding +def score(request: Request) -> Optional[OpenAIServingScores]: + return request.app.state.openai_serving_scores + + def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) +@router.post("/v1/score") +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - chat = app.state.openai_serving_chat - err = chat.create_error_response(message=str(exc)) + err = ErrorResponse(message=str(exc), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -565,6 +589,13 @@ def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embedding" else None + state.openai_serving_scores = OpenAIServingScores( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger + ) if (model_config.task == "embedding" \ + and model_config.is_cross_encoder) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f343732174014..ee94a9413f098 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -806,6 +806,27 @@ def to_pooling_params(self): EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] +class ScoreRequest(OpenAIBaseModel): + model: str + text_1: Union[List[str], str] + text_2: Union[List[str], str] + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-chat-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-chat-embedding-pooling-params + + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -876,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel): usage: UsageInfo +class ScoreResponseData(OpenAIBaseModel): + index: int + object: str = "score" + score: Union[List[float], str] + + +class ScoreResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[ScoreResponseData] + usage: UsageInfo + + class FunctionCall(OpenAIBaseModel): name: str arguments: str diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py new file mode 100644 index 0000000000000..156fea6f47982 --- /dev/null +++ b/vllm/entrypoints/openai/serving_score.py @@ -0,0 +1,215 @@ +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, + ScoreResponse, ScoreResponseData, + UsageInfo) +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import EmbeddingRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + + +def request_output_to_score_response( + final_res_batch: List[EmbeddingRequestOutput], request_id: str, + created_time: int, model_name: str) -> ScoreResponse: + data: List[ScoreResponseData] = [] + score = None + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + if final_res is not None: + score = final_res.outputs.embedding + score_data = ScoreResponseData(index=idx, score=score) + data.append(score_data) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ScoreResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], + str]) -> List: + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [t for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [t for t in text_2] + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + return [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + +class OpenAIServingScores(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"score-{random_uuid()}" + created_time = int(time.monotonic()) + truncate_prompt_tokens = request.truncate_prompt_tokens + + request_prompts = [] + engine_prompts = [] + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + + input_pairs = make_pairs(request.text_1, request.text_2) + + for q, t in input_pairs: + request_prompt = f"{q}{tokenizer.sep_token}{t}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch) + + response = request_output_to_score_response( + final_res_batch_checked, request_id, created_time, model_name) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 07ff9faa50f13..fb7dbbebd7b90 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -38,6 +38,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" + token_type_ids: NotRequired[List[int]] + """A list of token type IDs to pass to the cross encoder model.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ DEPRECATED: Optional multi-modal data to pass to the model, @@ -133,6 +136,9 @@ class TokenInputs(TypedDict): prompt_token_ids: List[int] """The token IDs of the prompt.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. @@ -160,6 +166,7 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: List[int], + token_type_ids: Optional[List[int]] = None, prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, @@ -170,6 +177,8 @@ def token_inputs( if prompt is not None: inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data if multi_modal_placeholders is not None: @@ -234,6 +243,15 @@ def prompt_token_ids(self) -> List[int]: assert_never(inputs) + @cached_property + def token_type_ids(self) -> List[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("token_type_ids", []) + + assert_never(inputs) + @cached_property def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 853257c5ad71f..3d606817e90aa 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -305,6 +305,7 @@ def _prompt_to_llm_inputs( tokens_content = parsed["content"] prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") @@ -318,6 +319,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index df1978241340b..f9437b4112ceb 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -3,11 +3,14 @@ import torch import torch.nn as nn +from transformers import PretrainedConfig from vllm.config import PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) class PoolingType(IntEnum): @@ -152,3 +155,64 @@ def forward( ] return PoolerOutput(outputs=pooled_outputs) + + +class CrossEncodingPooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + def __init__( + self, + config: PretrainedConfig, + classifier: nn.Module, + pooler: Optional[nn.Module] = None, + ): + super().__init__() + self.classifier = classifier + self.pooler = pooler + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools sentence pair scores from the hidden_states.""" + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + + if self.pooler is not None: + final_shape_tensor = self.pooler(pooled_data_i) + else: + final_shape_tensor = self.classifier(pooled_data_i) + + pooled_data_lst.append(final_shape_tensor) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + if self.pooler is not None: + # apply classifier once on the full batch if possible + pooled_output = self.classifier(pooled_output) + logits = self.default_activation_function(pooled_output) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in logits + ] + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d8301a36acb01..1fc87bc650d92 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -11,14 +11,18 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, + PoolingType) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) from .utils import maybe_prefix @@ -48,7 +52,9 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() @@ -58,17 +64,34 @@ def forward( # Position embeddings. position_embeddings = self.position_embeddings(position_ids) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[0, :] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + class BertEncoder(nn.Module): def __init__(self, @@ -309,7 +332,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", - embedding_class: type = BertEmbedding): + embedding_class: type = BertEmbedding, + add_pooling_layer: bool = False): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -319,6 +343,7 @@ def __init__(self, cache_config, quant_config, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(config) if add_pooling_layer else None def forward( self, @@ -328,13 +353,17 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids) - + assert hasattr(attn_metadata, "seq_lens_tensor") + hidden_states = self.embeddings( + input_ids=input_ids, + seq_lens=attn_metadata.seq_lens_tensor, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states, kv_caches, attn_metadata) def load_weights(self, weights: Iterable[Tuple[str, @@ -349,7 +378,7 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if self.pooler is None and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -430,3 +459,78 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: pooling_type=PoolingType.CLS, normalize=True, softmax=False) + + +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + add_pooling_layer=True) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = CrossEncodingPooler(config, self.classifier, + self.bert.pooler) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert."):], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.bert(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dcead65115132..4f0c75b2c6a57 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,6 +7,8 @@ from vllm.logger import init_logger from vllm.utils import supports_kw +from .interfaces_base import is_embedding_model + if TYPE_CHECKING: from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.sequence import IntermediateTensors @@ -350,3 +352,37 @@ def is_attention_free( return isinstance(model, _IsAttentionFreeType) return isinstance(model, IsAttentionFree) + + +@runtime_checkable +class SupportsCrossEncoding(Protocol): + """The interface required for all models that support cross encoding.""" + + supports_cross_encoding: ClassVar[Literal[True]] = True + + +@overload +def supports_cross_encoding( + model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: + ... + + +@overload +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: + ... + + +def _supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + + if isinstance(model, type): + return isinstance(model, SupportsCrossEncoding) + + return isinstance(model, SupportsCrossEncoding) + + +def supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + return is_embedding_model(model) and _supports_cross_encoding(model) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 22c2e328bfb65..789ffb4d3bde0 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -21,7 +21,8 @@ from vllm.platforms import current_platform from .interfaces import (has_inner_state, is_attention_free, - supports_multimodal, supports_pp) + supports_cross_encoding, supports_multimodal, + supports_pp) from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) @@ -100,6 +101,7 @@ # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), "RobertaModel": ("roberta", "RobertaEmbeddingModel"), + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), @@ -121,6 +123,14 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, } +_CROSS_ENCODER_MODELS = { + "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "RobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), +} + _MULTIMODAL_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), @@ -159,6 +169,7 @@ _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, + **_CROSS_ENCODER_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, } @@ -193,6 +204,7 @@ class _ModelInfo: is_text_generation_model: bool is_embedding_model: bool + supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool has_inner_state: bool @@ -203,6 +215,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": return _ModelInfo( is_text_generation_model=is_text_generation_model(model), is_embedding_model=is_embedding_model(model), + supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), @@ -415,6 +428,12 @@ def is_embedding_model( ) -> bool: return self.inspect_model_cls(architectures).is_embedding_model + def is_cross_encoder_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_cross_encoding + def is_multimodal_model( self, architectures: Union[str, List[str]], @@ -489,4 +508,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index c1dcdd36ec3de..5a296e311f079 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -6,10 +6,17 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.interfaces import SupportsCrossEncoding +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) class RobertaEmbedding(nn.Module): @@ -39,34 +46,93 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() - - # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) - # TODO: figure out if there is a better way - # to make to make position ids start at padding_idx + 1 + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - position_ids += self.padding_idx + 1 + pos_list = [] + token_list = [] + offset = 0 + for seq_len in seq_lens: + pos_list.append(position_ids[offset:offset + seq_len]) + token_list.append(input_ids[offset:offset + seq_len]) + offset += seq_len + + new_pos_list = [] + for positions, tokens in zip(pos_list, token_list): + # Verify assumption that incoming position are + # always a sequence from 0 to N. + expected_pos = torch.arange(positions.size()[0], + dtype=torch.long, + device=inputs_embeds.device) + assert torch.equal(positions, expected_pos) + new_pos_list.append( + create_position_ids_from_input_ids(tokens, self.padding_idx)) + position_ids = torch.cat(new_pos_list) # Position embeddings. position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) - + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +# Adapted from transformers +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + + incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + + past_key_values_length) * mask + + return incremental_indices.long() + padding_idx + + +# Adapted from transformers +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RobertaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) + x = torch.tanh(x) + x = self.out_proj(x) + return x + + class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -85,6 +151,62 @@ def _build_model(self, prefix=prefix, embedding_class=RobertaEmbedding) + +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.roberta = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + self._pooler = CrossEncodingPooler(config, self.classifier) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + else: + self_weights.append((name, weight)) + + self.roberta.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def forward( self, input_ids: Optional[torch.Tensor], @@ -93,25 +215,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # Verify assumption that position are always a sequence from - # 0 to N. (Actually here we just check 0 and N to simplify). - # This is important to fix the position which are assumed to - # start from padding_idx + 1 instead of 0 in the Roberta models. - assert hasattr(attn_metadata, "seq_lens_tensor") - cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0) - start_pos = torch.cat( - (torch.tensor([0], device=attn_metadata.seq_lens_tensor.device), - cumulative[:-1])) - assert len(torch.nonzero(positions[start_pos])) == 0 - end_pos = cumulative - 1 - last_tokens = attn_metadata.seq_lens_tensor - 1 - assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0 - - return super().forward(input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.roberta(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 8e67a552afe12..640c7c04b8817 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -6,7 +6,7 @@ import torch import torch.types from PIL.Image import Image -from typing_extensions import TypeAlias +from typing_extensions import NotRequired, TypeAlias from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -208,6 +208,9 @@ class MultiModalInputsV2(TypedDict): prompt_token_ids: List[int] """The processed token IDs which includes placeholder tokens.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" diff --git a/vllm/outputs.py b/vllm/outputs.py index 4ae9b377ae693..2d256803edfe8 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -60,7 +60,6 @@ class EmbeddingOutput: embedding: The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide. """ - embedding: List[float] def __repr__(self) -> str: @@ -363,6 +362,50 @@ def __repr__(self): f"finished={self.finished})") +@dataclass +class ScoreOutput: + """The output data of one completion output of a request. + + Args: + score: The score, which is a list of floats. + index: The correspondent text index of the score. + """ + index: int + score: List[float] + + def __repr__(self) -> str: + return (f"ScoreOutput(" + f"score={self.score}), " + f"index={self.index})") + + +class ScoreRequestOutput: + """ + The output data of an score request to the LLM. + + Args: + request_id (str): A unique identifier for the score request. + outputs (score): The embedding results for the given input. + """ + + def __init__(self, request_id: str, outputs: "ScoreOutput"): + self.request_id = request_id + self.outputs = outputs + + def __repr__(self): + """ + Returns a string representation of an ScoreRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the ScoreRequestOutput instance. + """ + return (f"ScoreRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}") + + class RequestOutputFactory: @staticmethod diff --git a/vllm/sequence.py b/vllm/sequence.py index a1cc8fc3b09de..669124319c4f4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -449,6 +449,10 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: return self.inputs.prompt_embeds + @property + def token_type_ids(self) -> List[int]: + return self.inputs.token_type_ids + @property def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.multi_modal_data @@ -687,6 +691,10 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) + @property + def token_type_ids(self) -> Optional[List[int]]: + return self.first_seq.token_type_ids + @property def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data @@ -909,6 +917,7 @@ class SequenceGroupMetadata( default_factory=lambda: SequenceGroupState()) # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. + token_type_ids: Optional[List[int]] = None multi_modal_data: Optional[Any] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 59096753c395d..70d18d40b7aa7 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,6 +9,7 @@ from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) +from torch import nn from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -31,6 +32,7 @@ UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import resolve_obj_by_qualname if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -577,3 +579,16 @@ def try_get_generation_config( return GenerationConfig.from_model_config(config) except OSError: # Not found return None + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + if (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + + function_name = config.sbert_ce_default_activation_function + assert function_name.startswith("torch.nn.modules."), \ + "Loading of activation functions is restricted to " \ + "torch.nn.modules for security reasons" + return resolve_obj_by_qualname(function_name)() + else: + return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py index 978de73df6b70..3954e4c4c8a5b 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_embedding_model_runner.py @@ -50,6 +50,9 @@ def execute_model( ] model_executable = self.model + cross_enc_kwargs = {} + if model_input.token_type_ids is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids execute_model_kwargs = { "input_ids": model_input.input_tokens, @@ -61,6 +64,7 @@ def execute_model( model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), + **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, } diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7cab476d7fca4..b08171d79f002 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_type_ids: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None @@ -54,6 +55,7 @@ def as_broadcastable_tensor_dict( tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -83,6 +85,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -112,6 +115,7 @@ def __init__(self, use_mrope: bool): self.input_tokens: List[int] = [] self.input_positions: Optional[ List[int]] = [] if not self.use_mrope else None + self.token_type_ids: Optional[List[int]] = [] self.seq_lens: List[int] = [] self.query_lens: List[int] = [] self.prefill_block_tables: List[List[int]] = [] @@ -165,6 +169,10 @@ def build(self) -> ModelInputForCPU: if not input_data.use_mrope else input_data.input_mrope_positions, dtype=torch.long, device="cpu") + token_type_ids = torch.tensor(input_data.token_type_ids, + dtype=torch.long, + device="cpu") \ + if input_data.token_type_ids else None # For multi-modal models multi_modal_kwargs = None @@ -178,6 +186,7 @@ def build(self) -> ModelInputForCPU: return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, + token_type_ids=token_type_ids, seq_lens=input_data.seq_lens, query_lens=input_data.query_lens, attn_metadata=attn_metadata, @@ -285,6 +294,7 @@ def _compute_prompt_input_tokens(self, data: ModelInputData, tokens = seq_data.get_token_ids() tokens = tokens[context_len:seq_len] token_positions = range(context_len, seq_len) + token_types = seq_group_metadata.token_type_ids # For encoder-only models, the block_table is None, # and there is no need to initialize the slot_mapping. @@ -301,6 +311,9 @@ def _compute_prompt_input_tokens(self, data: ModelInputData, if data.input_positions is not None: data.input_positions.extend(token_positions) + if data.token_type_ids is not None: + data.token_type_ids.extend(token_types if token_types else []) + # Update fields data.input_tokens.extend(tokens) data.num_prefills += 1 diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 4a55d91e71484..f56805918fd15 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -97,6 +97,10 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + cross_enc_kwargs = {} + if model_input.token_types is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_types + with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, @@ -105,7 +109,8 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device)) + device=self.device), + **cross_enc_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 13301b876217d..1f654a9cce465 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None @@ -200,6 +201,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore + self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore @@ -226,6 +228,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, + token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). @@ -291,6 +294,12 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() + if token_types: + self.token_types = token_types + else: + for seq_id in range(len(self.seq_ids)): + self.token_types[seq_id].clear() + self.mrope_input_positions = None if seq_lens: @@ -354,6 +363,7 @@ def __init__( else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] + self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] @@ -386,6 +396,7 @@ def __post_init__(self): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] + self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs @@ -498,12 +509,15 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.token_types[seq_idx].extend( + token_types if token_types else []) inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: @@ -561,6 +575,8 @@ def _compute_for_prefix_cache_hit( seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][uncomputed_start:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + uncomputed_start:] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len @@ -575,6 +591,8 @@ def _compute_for_prefix_cache_hit( seq_idx][-1:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][-1:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + -1:] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 @@ -803,9 +821,12 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] + token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) if not input_tokens: # This may happen when all prefill requests hit @@ -874,6 +895,12 @@ def build(self) -> ModelInputForGPU: input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory) + + token_types_tensor = async_tensor_h2d(token_types, torch.long, + self.runner.device, + self.runner.pin_memory) \ + if token_types else None + if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( @@ -952,6 +979,7 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, + token_types=token_types_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens,