diff --git a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py index 89598164..b2a7e13e 100644 --- a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py +++ b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,6 +26,7 @@ import argparse import asyncio +import json import pickle import sys import unittest @@ -36,6 +37,7 @@ from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import GuidedDecodingParams from vllm.utils import random_uuid sys.path.append("../../common") @@ -53,14 +55,22 @@ "The future of AI is", ] +GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. "] + SAMPLING_PARAMETERS = {"temperature": 0, "top_p": 1} -async def generate_python_vllm_output(prompt, llm_engine): +async def generate_python_vllm_output( + prompt, + llm_engine, + sampling_params=SamplingParams(**SAMPLING_PARAMETERS), + guided_generation=None, +): request_id = random_uuid() - sampling_params = SamplingParams(**SAMPLING_PARAMETERS) python_vllm_output = None last_output = None + if guided_generation: + sampling_params.guided_decoding = guided_generation async for vllm_output in llm_engine.generate(prompt, sampling_params, request_id): last_output = vllm_output @@ -69,24 +79,28 @@ async def generate_python_vllm_output(prompt, llm_engine): python_vllm_output = [ (prompt + output.text).encode("utf-8") for output in last_output.outputs ] - return python_vllm_output -def prepare_vllm_baseline_outputs(): +def prepare_vllm_baseline_outputs( + export_file="vllm_baseline_output.pkl", prompts=PROMPTS, guided_generation=None +): """ Helper function that starts async vLLM engine and generates output for each - prompt in `PROMPTS`. Saves resulted baselines in `vllm_baseline_output.pkl` + prompt in `prompts`. Saves resulted baselines in `vllm_baseline_output.pkl` for further use. """ llm_engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**VLLM_ENGINE_CONFIG)) python_vllm_output = [] - for i in range(len(PROMPTS)): + for i in range(len(prompts)): python_vllm_output.extend( - asyncio.run(generate_python_vllm_output(PROMPTS[i], llm_engine)) + asyncio.run( + generate_python_vllm_output( + prompts[i], llm_engine, guided_generation=guided_generation + ) + ) ) - - with open("vllm_baseline_output.pkl", "wb") as f: + with open(export_file, "wb") as f: pickle.dump(python_vllm_output, f) return @@ -96,6 +110,9 @@ class VLLMTritonAccuracyTest(TestResultCollector): def setUp(self): self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") self.vllm_model_name = "vllm_opt" + + def test_vllm_model(self): + # Reading and verifying baseline data self.python_vllm_output = [] with open("vllm_baseline_output.pkl", "rb") as f: self.python_vllm_output = pickle.load(f) @@ -116,11 +133,9 @@ def setUp(self): ), ) - def test_vllm_model(self): user_data = UserData() stream = False triton_vllm_output = [] - self.triton_client.start_stream(callback=partial(callback, user_data)) for i in range(len(PROMPTS)): request_data = create_vllm_request( @@ -131,7 +146,7 @@ def test_vllm_model(self): request_id=request_data["request_id"], inputs=request_data["inputs"], outputs=request_data["outputs"], - parameters=SAMPLING_PARAMETERS, + parameters=request_data["parameters"], ) for i in range(len(PROMPTS)): @@ -146,6 +161,63 @@ def test_vllm_model(self): self.triton_client.stop_stream() self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort()) + def test_guided_decoding(self): + # Reading and verifying baseline data + self.python_vllm_output = [] + with open("vllm_guided_baseline_output.pkl", "rb") as f: + self.python_vllm_output = pickle.load(f) + + self.assertNotEqual( + self.python_vllm_output, + [], + "Loaded baseline outputs' list should not be empty", + ) + self.assertIsNotNone( + self.python_vllm_output, "Loaded baseline outputs' list should not be None" + ) + self.assertEqual( + len(self.python_vllm_output), + len(GUIDED_PROMPTS), + "Unexpected number of baseline outputs loaded, expected {}, but got {}".format( + len(GUIDED_PROMPTS), len(self.python_vllm_output) + ), + ) + + user_data = UserData() + stream = False + triton_vllm_output = [] + + self.triton_client.start_stream(callback=partial(callback, user_data)) + sampling_params = SAMPLING_PARAMETERS + guided_decoding_params = { + "choice": ["Positive", "Negative"], + "backend": "outlines", + } + sampling_params["guided_decoding"] = json.dumps(guided_decoding_params) + for i in range(len(GUIDED_PROMPTS)): + request_data = create_vllm_request( + GUIDED_PROMPTS[i], i, stream, sampling_params, self.vllm_model_name + ) + self.triton_client.async_stream_infer( + model_name=self.vllm_model_name, + request_id=request_data["request_id"], + inputs=request_data["inputs"], + outputs=request_data["outputs"], + parameters=request_data["parameters"], + ) + + for i in range(len(GUIDED_PROMPTS)): + result = user_data._completed_requests.get() + self.assertIsNot(type(result), InferenceServerException, str(result)) + + output = result.as_numpy("text_output") + self.assertIsNotNone(output, "`text_output` should not be None") + + triton_vllm_output.extend(output) + + self.triton_client.stop_stream() + self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort()) + def tearDown(self): self.triton_client.close() @@ -159,9 +231,29 @@ def tearDown(self): default=False, help="Generates baseline output for accuracy tests", ) + parser.add_argument( + "--generate-guided-baseline", + action="store_true", + required=False, + default=False, + help="Generates baseline output for accuracy tests", + ) FLAGS = parser.parse_args() if FLAGS.generate_baseline: prepare_vllm_baseline_outputs() exit(0) + if FLAGS.generate_guided_baseline: + guided_decoding_params = { + "choice": ["Positive", "Negative"], + "backend": "outlines", + } + guided_generation = GuidedDecodingParams(**guided_decoding_params) + prepare_vllm_baseline_outputs( + export_file="vllm_guided_baseline_output.pkl", + prompts=GUIDED_PROMPTS, + guided_generation=guided_generation, + ) + exit(0) + unittest.main() diff --git a/ci/L0_backend_vllm/accuracy_test/test.sh b/ci/L0_backend_vllm/accuracy_test/test.sh index b0b1c1b2..75093b6b 100755 --- a/ci/L0_backend_vllm/accuracy_test/test.sh +++ b/ci/L0_backend_vllm/accuracy_test/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -37,7 +37,7 @@ TEST_RESULT_FILE='test_results.txt' CLIENT_PY="./accuracy_test.py" SAMPLE_MODELS_REPO="../../../samples/model_repository" VLLM_ENGINE_LOG="vllm_engine.log" -EXPECTED_NUM_TESTS=1 +EXPECTED_NUM_TESTS=2 rm -rf models && mkdir -p models cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt @@ -50,6 +50,10 @@ set +e # memory issues: https://github.com/vllm-project/vllm/issues/2248 python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$! wait $BASELINE_PID + +python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$! +wait $BASELINE_PID + set -e run_server diff --git a/ci/L0_multi_gpu_vllm/multi_lora/test.sh b/ci/L0_multi_gpu_vllm/multi_lora/test.sh index b561a2d0..bcc52770 100755 --- a/ci/L0_multi_gpu_vllm/multi_lora/test.sh +++ b/ci/L0_multi_gpu_vllm/multi_lora/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -38,6 +38,60 @@ CLIENT_PY="./multi_lora_test.py" DOWNLOAD_PY="./download.py" SAMPLE_MODELS_REPO="../../../samples/model_repository" EXPECTED_NUM_TESTS=2 +GENERATE_ENDPOINT="localhost:8000/v2/models/vllm_llama_multi_lora/generate" +CHECK_FOR_ERROR=true + +make_api_call() { + local endpoint="$1" + local data="$2" + curl -X POST "$endpoint" --data-binary @- <<< "$data" +} + +check_response() { + local response="$1" + local expected_response="$2" + local error_message="$3" + local check_error="${4:-false}" + + if [ -z "$response" ]; then + echo -e "Expected a non-empty response from server" + echo -e "\n***\n*** $error_message \n***" + return 1 + fi + + local response_text=$(echo "$response" | jq '.text_output // empty') + local response_error=$(echo "$response" | jq '.error // empty') + + if [ "$check_error" = true ]; then + if [[ -n "$response_text" ]]; then + echo -e "Server didn't return an error." + echo "$response" + echo -e "\n***\n*** $error_message \n***" + return 1 + elif [[ "$expected_response" != "$response_error" ]]; then + echo -e "Expected error message doesn't match actual response." + echo "Expected: $expected_response." + echo "Received: $response_error" + echo -e "\n***\n*** $error_message\n***" + return 1 + fi + else + if [[ ! -z "$response_error" ]]; then + echo -e "Received an error from server." + echo "$response" + echo -e "\n***\n*** $error_message \n***" + return 1 + elif [[ "$expected_response" != "$response_text" ]]; then + echo "Expected response doesn't match actual" + echo "Expected: $expected_response." + echo "Received: $response_text" + echo -e "\n***\n*** $error_message \n***" + return 1 + fi + fi + + return 0 +} # first we download weights pip install -U huggingface_hub @@ -106,6 +160,39 @@ else RET=1 fi fi + +# Test generate endpoint + LoRA enabled (boolean flag) +EXPECTED_RESPONSE='" I love soccer. I play soccer every day.\nInstruct: Tell me"' +DATA='{ + "text_input": "Instruct: Tell me more about soccer\nOutput:", + "parameters": { + "stream": false, + "temperature": 0, + "top_p":1, + "lora_name": "sheep", + "exclude_input_in_output": true + } +}' +RESPONSE=$(make_api_call "$GENERATE_ENDPOINT" "$DATA") +check_response "$RESPONSE" "$EXPECTED_RESPONSE" "Valid LoRA + Generate Endpoint Test FAILED." || RET=1 + +EXPECTED_RESPONSE="\"LoRA unavailable is not supported, we currently support ['doll', 'sheep']\"" +DATA='{ + "text_input": "Instruct: Tell me more about soccer\nOutput:", + "parameters": { + "stream": false, + "temperature": 0, + "top_p":1, + "lora_name": "unavailable", + "exclude_input_in_output": true + } +}' +RESPONSE=$(make_api_call "$GENERATE_ENDPOINT" "$DATA") +check_response "$RESPONSE" "$EXPECTED_RESPONSE" "Invalid LoRA + Generate Endpoint Test FAILED." $CHECK_FOR_ERROR || RET=1 + +unset EXPECTED_RESPONSE +unset RESPONSE +unset DATA set -e kill $SERVER_PID @@ -151,6 +238,39 @@ else RET=1 fi fi + +# Test generate endpoint + LoRA enabled (str flag) +EXPECTED_RESPONSE='" I think it is a very interesting subject.\n\nInstruct: What do you"' +DATA='{ + "text_input": "Instruct: What do you think of Computer Science?\nOutput:", + "parameters": { + "stream": false, + "temperature": 0, + "top_p":1, + "lora_name": "doll", + "exclude_input_in_output": true + } +}' +RESPONSE=$(make_api_call "$GENERATE_ENDPOINT" "$DATA") +check_response "$RESPONSE" "$EXPECTED_RESPONSE" "Valid LoRA + Generate Endpoint Test FAILED." || RET=1 + +EXPECTED_RESPONSE="\"LoRA unavailable is not supported, we currently support ['doll', 'sheep']\"" +DATA='{ + "text_input": "Instruct: What do you think of Computer Science?\nOutput:", + "parameters": { + "stream": false, + "temperature": 0, + "top_p":1, + "lora_name": "unavailable", + "exclude_input_in_output": true + } +}' +RESPONSE=$(make_api_call "$GENERATE_ENDPOINT" "$DATA") +check_response "$RESPONSE" "$EXPECTED_RESPONSE" "Invalid LoRA + Generate Endpoint Test FAILED." $CHECK_FOR_ERROR || RET=1 + +unset EXPECTED_RESPONSE +unset RESPONSE +unset DATA set -e kill $SERVER_PID @@ -197,6 +317,22 @@ else RET=1 fi fi + +# Test generate endpoint + LoRA disabled (boolean flag) +EXPECTED_RESPONSE='"LoRA feature is not enabled."' +DATA='{ + "text_input": "Instruct: What do you think of Computer Science?\nOutput:", + "parameters": { + "stream": false, + "temperature": 0, + "top_p":1, + "lora_name": "doll", + "exclude_input_in_output": true + } +}' +RESPONSE=$(make_api_call "$GENERATE_ENDPOINT" "$DATA") +check_response "$RESPONSE" "$EXPECTED_RESPONSE" "Disabled LoRA + Generate Endpoint Test FAILED." $CHECK_FOR_ERROR || RET=1 + set -e kill $SERVER_PID @@ -243,6 +379,22 @@ else RET=1 fi fi + +# Test generate endpoint + LoRA disabled (str flag) +EXPECTED_RESPONSE='"LoRA feature is not enabled."' +DATA='{ + "text_input": "Instruct: What do you think of Computer Science?\nOutput:", + "parameters": { + "stream": false, + "temperature": 0, + "top_p":1, + "lora_name": "doll", + "exclude_input_in_output": true + } +}' +RESPONSE=$(make_api_call "$GENERATE_ENDPOINT" "$DATA") +check_response "$RESPONSE" "$EXPECTED_RESPONSE" "Disabled LoRA + Generate Endpoint Test FAILED." $CHECK_FOR_ERROR > $CLIENT_LOG 2>&1 || RET=1 + set -e kill $SERVER_PID diff --git a/src/model.py b/src/model.py index 4c351f14..19ff713e 100644 --- a/src/model.py +++ b/src/model.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -43,10 +43,10 @@ build_async_engine_client_from_engine_args, ) from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from utils.metrics import VllmStatLogger +from utils.vllm_backend_utils import TritonSamplingParams _VLLM_ENGINE_ARGS_FILENAME = "model.json" _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" @@ -430,9 +430,8 @@ async def _generate(self, request): additional_outputs, ) = self._get_input_tensors(request) - sampling_params_dict = self._get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) - sampling_params = SamplingParams(**sampling_params_dict) + sampling_params = TritonSamplingParams.from_dict(parameters, self.logger) + lora_name = sampling_params.lora_name lora_request = None if lora_name is not None: lora_id = str(self.supported_loras.index(lora_name) + 1) @@ -564,8 +563,8 @@ def _get_input_tensors(self, request): ) # parameters / sampling_parameters - # An alternative mechanism to receive serialized parameters as an input tensor, - # because request parameters are not yet supported via BLS. + # An alternative mechanism to receive serialized parameters as an input + # tensor, because request parameters are not yet supported via BLS. sampling_parameters = pb_utils.get_input_tensor_by_name( request, "sampling_parameters" ) @@ -704,33 +703,6 @@ def _create_response( return pb_utils.InferenceResponse(output_tensors=output_tensors) - def _get_sampling_params_dict(self, params_json): - params_dict = json.loads(params_json) - - # Special parsing for the supported sampling parameters - bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"] - for k in bool_keys: - if k in params_dict: - params_dict[k] = bool(params_dict[k]) - - float_keys = [ - "frequency_penalty", - "length_penalty", - "presence_penalty", - "temperature", - "top_p", - ] - for k in float_keys: - if k in params_dict: - params_dict[k] = float(params_dict[k]) - - int_keys = ["best_of", "max_tokens", "min_tokens", "n", "top_k"] - for k in int_keys: - if k in params_dict: - params_dict[k] = int(params_dict[k]) - - return params_dict - def _verify_loras(self, request): # We will check if the requested lora exists here, if not we will send a # response with `LoRA not found` information. In this way we may avoid @@ -743,9 +715,10 @@ def _verify_loras(self, request): ) if parameters_input_tensor: parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - sampling_params_dict = self._get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) + else: + parameters = request.parameters() + lora_name = json.loads(parameters).pop("lora_name", None) if lora_name is not None: if not self.enable_lora: lora_error = pb_utils.TritonError("LoRA feature is not enabled.") diff --git a/src/utils/vllm_backend_utils.py b/src/utils/vllm_backend_utils.py new file mode 100644 index 00000000..8d330fb8 --- /dev/null +++ b/src/utils/vllm_backend_utils.py @@ -0,0 +1,100 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +from typing import Optional + +from vllm.sampling_params import GuidedDecodingParams, SamplingParams + + +class TritonSamplingParams(SamplingParams): + """ + Extended sampling parameters for text generation via + Triton Inference Server and vLLM backend. + + Attributes: + lora_name (Optional[str]): The name of the LoRA (Low-Rank Adaptation) + to use for inference. + """ + + lora_name: Optional[str] = None + + def __repr__(self) -> str: + """ + Returns a string representation of the `TritonSamplingParams` object. + + This method overrides the `__repr__` method of the parent class + to include additional attributes in the string representation. + + Returns: + A string representation of the object. + """ + base = super().__repr__() + return f"{base}, lora_name={self.lora_name}" + + @staticmethod + def from_dict( + params_dict_str: str, logger: "pb_utils.Logger" + ) -> "TritonSamplingParams": + """ + Creates a `TritonSamplingParams` object from a dictionary string. + + This method parses a JSON string containing sampling parameters, + converts the values to appropriate types, and creates a + `TritonSamplingParams` object. + + Args: + params_dict (str): A JSON string containing sampling parameters. + logger (pb_utils.Logger): Triton Inference Server logger object. + + Returns: + TritonSamplingParams: An instance of TritonSamplingParams. + """ + try: + params_dict = json.loads(params_dict_str) + vllm_params_dict = SamplingParams.__annotations__ + type_mapping = { + int: int, + float: float, + bool: bool, + str: str, + Optional[int]: int, + } + for key, value in params_dict.items(): + if key == "guided_decoding": + params_dict[key] = GuidedDecodingParams(**json.loads(value)) + elif key in vllm_params_dict: + vllm_type = vllm_params_dict[key] + if vllm_type in type_mapping: + params_dict[key] = type_mapping[vllm_type](params_dict[key]) + + return TritonSamplingParams(**params_dict) + + except Exception as e: + logger.log_error( + f"[vllm] Was trying to create `TritonSamplingParams`, but got exception: {e}" + ) + return None