Skip to content

Commit

Permalink
[fix] Sampling Parameters related improvements (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv authored Jan 9, 2025
1 parent d061556 commit 80dd037
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 52 deletions.
118 changes: 105 additions & 13 deletions ci/L0_backend_vllm/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -26,6 +26,7 @@

import argparse
import asyncio
import json
import pickle
import sys
import unittest
Expand All @@ -36,6 +37,7 @@
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import GuidedDecodingParams
from vllm.utils import random_uuid

sys.path.append("../../common")
Expand All @@ -53,14 +55,22 @@
"The future of AI is",
]

GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. "]

SAMPLING_PARAMETERS = {"temperature": 0, "top_p": 1}


async def generate_python_vllm_output(prompt, llm_engine):
async def generate_python_vllm_output(
prompt,
llm_engine,
sampling_params=SamplingParams(**SAMPLING_PARAMETERS),
guided_generation=None,
):
request_id = random_uuid()
sampling_params = SamplingParams(**SAMPLING_PARAMETERS)
python_vllm_output = None
last_output = None
if guided_generation:
sampling_params.guided_decoding = guided_generation

async for vllm_output in llm_engine.generate(prompt, sampling_params, request_id):
last_output = vllm_output
Expand All @@ -69,24 +79,28 @@ async def generate_python_vllm_output(prompt, llm_engine):
python_vllm_output = [
(prompt + output.text).encode("utf-8") for output in last_output.outputs
]

return python_vllm_output


def prepare_vllm_baseline_outputs():
def prepare_vllm_baseline_outputs(
export_file="vllm_baseline_output.pkl", prompts=PROMPTS, guided_generation=None
):
"""
Helper function that starts async vLLM engine and generates output for each
prompt in `PROMPTS`. Saves resulted baselines in `vllm_baseline_output.pkl`
prompt in `prompts`. Saves resulted baselines in `vllm_baseline_output.pkl`
for further use.
"""
llm_engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**VLLM_ENGINE_CONFIG))
python_vllm_output = []
for i in range(len(PROMPTS)):
for i in range(len(prompts)):
python_vllm_output.extend(
asyncio.run(generate_python_vllm_output(PROMPTS[i], llm_engine))
asyncio.run(
generate_python_vllm_output(
prompts[i], llm_engine, guided_generation=guided_generation
)
)
)

with open("vllm_baseline_output.pkl", "wb") as f:
with open(export_file, "wb") as f:
pickle.dump(python_vllm_output, f)

return
Expand All @@ -96,6 +110,9 @@ class VLLMTritonAccuracyTest(TestResultCollector):
def setUp(self):
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
self.vllm_model_name = "vllm_opt"

def test_vllm_model(self):
# Reading and verifying baseline data
self.python_vllm_output = []
with open("vllm_baseline_output.pkl", "rb") as f:
self.python_vllm_output = pickle.load(f)
Expand All @@ -116,11 +133,9 @@ def setUp(self):
),
)

def test_vllm_model(self):
user_data = UserData()
stream = False
triton_vllm_output = []

self.triton_client.start_stream(callback=partial(callback, user_data))
for i in range(len(PROMPTS)):
request_data = create_vllm_request(
Expand All @@ -131,7 +146,7 @@ def test_vllm_model(self):
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
parameters=SAMPLING_PARAMETERS,
parameters=request_data["parameters"],
)

for i in range(len(PROMPTS)):
Expand All @@ -146,6 +161,63 @@ def test_vllm_model(self):
self.triton_client.stop_stream()
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())

def test_guided_decoding(self):
# Reading and verifying baseline data
self.python_vllm_output = []
with open("vllm_guided_baseline_output.pkl", "rb") as f:
self.python_vllm_output = pickle.load(f)

self.assertNotEqual(
self.python_vllm_output,
[],
"Loaded baseline outputs' list should not be empty",
)
self.assertIsNotNone(
self.python_vllm_output, "Loaded baseline outputs' list should not be None"
)
self.assertEqual(
len(self.python_vllm_output),
len(GUIDED_PROMPTS),
"Unexpected number of baseline outputs loaded, expected {}, but got {}".format(
len(GUIDED_PROMPTS), len(self.python_vllm_output)
),
)

user_data = UserData()
stream = False
triton_vllm_output = []

self.triton_client.start_stream(callback=partial(callback, user_data))
sampling_params = SAMPLING_PARAMETERS
guided_decoding_params = {
"choice": ["Positive", "Negative"],
"backend": "outlines",
}
sampling_params["guided_decoding"] = json.dumps(guided_decoding_params)
for i in range(len(GUIDED_PROMPTS)):
request_data = create_vllm_request(
GUIDED_PROMPTS[i], i, stream, sampling_params, self.vllm_model_name
)
self.triton_client.async_stream_infer(
model_name=self.vllm_model_name,
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
parameters=request_data["parameters"],
)

for i in range(len(GUIDED_PROMPTS)):
result = user_data._completed_requests.get()
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")

triton_vllm_output.extend(output)

self.triton_client.stop_stream()
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())

def tearDown(self):
self.triton_client.close()

Expand All @@ -159,9 +231,29 @@ def tearDown(self):
default=False,
help="Generates baseline output for accuracy tests",
)
parser.add_argument(
"--generate-guided-baseline",
action="store_true",
required=False,
default=False,
help="Generates baseline output for accuracy tests",
)
FLAGS = parser.parse_args()
if FLAGS.generate_baseline:
prepare_vllm_baseline_outputs()
exit(0)

if FLAGS.generate_guided_baseline:
guided_decoding_params = {
"choice": ["Positive", "Negative"],
"backend": "outlines",
}
guided_generation = GuidedDecodingParams(**guided_decoding_params)
prepare_vllm_baseline_outputs(
export_file="vllm_guided_baseline_output.pkl",
prompts=GUIDED_PROMPTS,
guided_generation=guided_generation,
)
exit(0)

unittest.main()
8 changes: 6 additions & 2 deletions ci/L0_backend_vllm/accuracy_test/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -37,7 +37,7 @@ TEST_RESULT_FILE='test_results.txt'
CLIENT_PY="./accuracy_test.py"
SAMPLE_MODELS_REPO="../../../samples/model_repository"
VLLM_ENGINE_LOG="vllm_engine.log"
EXPECTED_NUM_TESTS=1
EXPECTED_NUM_TESTS=2

rm -rf models && mkdir -p models
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
Expand All @@ -50,6 +50,10 @@ set +e
# memory issues: https://github.com/vllm-project/vllm/issues/2248
python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
wait $BASELINE_PID

python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
wait $BASELINE_PID

set -e

run_server
Expand Down
Loading

0 comments on commit 80dd037

Please sign in to comment.