forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Implement disagg prefill by StatelessProcessGroup (vllm-projec…
…t#10502) This PR provides initial support for single-node disaggregated prefill in 1P1D scenario. Signed-off-by: KuntaiDu <[email protected]> Co-authored-by: ApostaC <[email protected]> Co-authored-by: YaoJiayi <[email protected]> Signed-off-by: Andrew Feldman <[email protected]>
- Loading branch information
1 parent
bcdb5b8
commit 88f7f57
Showing
33 changed files
with
2,525 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
144 changes: 144 additions & 0 deletions
144
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#!/bin/bash | ||
|
||
# benchmark the overhead of disaggregated prefill. | ||
# methodology: | ||
# - send all request to prefill vLLM instance. It will buffer KV cache. | ||
# - then send all request to decode instance. | ||
# - The TTFT of decode instance is the overhead. | ||
|
||
set -ex | ||
|
||
kill_gpu_processes() { | ||
# kill all processes on GPU. | ||
pkill -f pt_main_thread | ||
sleep 10 | ||
|
||
# remove vllm config file | ||
rm -rf ~/.config/vllm | ||
|
||
# Print the GPU memory usage | ||
# so that we know if all GPU processes are killed. | ||
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) | ||
# The memory usage should be 0 MB. | ||
echo "GPU 0 Memory Usage: $gpu_memory_usage MB" | ||
} | ||
|
||
wait_for_server() { | ||
# wait for vllm server to start | ||
# return 1 if vllm server crashes | ||
local port=$1 | ||
timeout 1200 bash -c " | ||
until curl -s localhost:${port}/v1/completions > /dev/null; do | ||
sleep 1 | ||
done" && return 0 || return 1 | ||
} | ||
|
||
|
||
benchmark() { | ||
|
||
export VLLM_LOGGING_LEVEL=DEBUG | ||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') | ||
|
||
# compare chunked prefill with disaggregated prefill | ||
|
||
results_folder="./results" | ||
model="meta-llama/Meta-Llama-3.1-8B-Instruct" | ||
dataset_name="sonnet" | ||
dataset_path="../sonnet_4x.txt" | ||
num_prompts=10 | ||
qps=$1 | ||
prefix_len=50 | ||
input_len=2048 | ||
output_len=$2 | ||
|
||
|
||
CUDA_VISIBLE_DEVICES=0 python3 \ | ||
-m vllm.entrypoints.openai.api_server \ | ||
--model meta-llama/Meta-Llama-3.1-8B-Instruct \ | ||
--port 8100 \ | ||
--max-model-len 10000 \ | ||
--gpu-memory-utilization 0.6 \ | ||
--kv-transfer-config \ | ||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & | ||
|
||
|
||
CUDA_VISIBLE_DEVICES=1 python3 \ | ||
-m vllm.entrypoints.openai.api_server \ | ||
--model meta-llama/Meta-Llama-3.1-8B-Instruct \ | ||
--port 8200 \ | ||
--max-model-len 10000 \ | ||
--gpu-memory-utilization 0.6 \ | ||
--kv-transfer-config \ | ||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & | ||
|
||
wait_for_server 8100 | ||
wait_for_server 8200 | ||
|
||
# let the prefill instance finish prefill | ||
python3 ../benchmark_serving.py \ | ||
--backend vllm \ | ||
--model $model \ | ||
--dataset-name $dataset_name \ | ||
--dataset-path $dataset_path \ | ||
--sonnet-input-len $input_len \ | ||
--sonnet-output-len "$output_len" \ | ||
--sonnet-prefix-len $prefix_len \ | ||
--num-prompts $num_prompts \ | ||
--port 8100 \ | ||
--save-result \ | ||
--result-dir $results_folder \ | ||
--result-filename disagg_prefill_2xtp4.json \ | ||
--request-rate "inf" | ||
|
||
|
||
# send the request to decode. | ||
# The TTFT of this command will be the overhead of disagg prefill impl. | ||
python3 ../benchmark_serving.py \ | ||
--backend vllm \ | ||
--model $model \ | ||
--dataset-name $dataset_name \ | ||
--dataset-path $dataset_path \ | ||
--sonnet-input-len $input_len \ | ||
--sonnet-output-len "$output_len" \ | ||
--sonnet-prefix-len $prefix_len \ | ||
--num-prompts $num_prompts \ | ||
--port 8200 \ | ||
--save-result \ | ||
--result-dir $results_folder \ | ||
--result-filename disagg_prefill_2xtp4.json \ | ||
--request-rate "$qps" | ||
kill_gpu_processes | ||
|
||
} | ||
|
||
|
||
main() { | ||
|
||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl) | ||
(which jq) || (apt-get -y install jq) | ||
(which socat) || (apt-get -y install socat) | ||
|
||
pip install quart httpx | ||
|
||
cd "$(dirname "$0")" | ||
|
||
cd .. | ||
# create sonnet-4x.txt | ||
echo "" > sonnet_4x.txt | ||
for _ in {1..4} | ||
do | ||
cat sonnet.txt >> sonnet_4x.txt | ||
done | ||
cd disagg_benchmarks | ||
|
||
rm -rf results | ||
mkdir results | ||
|
||
default_qps=1 | ||
default_output_len=1 | ||
benchmark $default_qps $default_output_len | ||
|
||
} | ||
|
||
|
||
main "$@" |
164 changes: 164 additions & 0 deletions
164
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
#!/bin/bash | ||
|
||
# Requirement: 8x H100 GPUs. | ||
|
||
|
||
# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV | ||
# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests | ||
# Resource: 8x H100 | ||
# Approaches: | ||
# 1. Chunked prefill: 1 vllm instance with tp=8 | ||
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 | ||
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance | ||
# Prefilling instance: max_output_token=1 | ||
# Decoding instance: force the input tokens be the same across requests to bypass prefilling | ||
|
||
set -ex | ||
|
||
kill_gpu_processes() { | ||
# kill all processes on GPU. | ||
pgrep pt_main_thread | xargs -r kill -9 | ||
pgrep python3 | xargs -r kill -9 | ||
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done | ||
sleep 1 | ||
} | ||
|
||
wait_for_server() { | ||
# wait for vllm server to start | ||
# return 1 if vllm server crashes | ||
local port=$1 | ||
timeout 1200 bash -c " | ||
until curl -s localhost:${port}/v1/completions > /dev/null; do | ||
sleep 1 | ||
done" && return 0 || return 1 | ||
} | ||
|
||
|
||
launch_chunked_prefill() { | ||
model="meta-llama/Meta-Llama-3.1-8B-Instruct" | ||
# disagg prefill | ||
CUDA_VISIBLE_DEVICES=0 python3 \ | ||
-m vllm.entrypoints.openai.api_server \ | ||
--model $model \ | ||
--port 8100 \ | ||
--max-model-len 10000 \ | ||
--enable-chunked-prefill \ | ||
--gpu-memory-utilization 0.6 & | ||
CUDA_VISIBLE_DEVICES=1 python3 \ | ||
-m vllm.entrypoints.openai.api_server \ | ||
--model $model \ | ||
--port 8200 \ | ||
--max-model-len 10000 \ | ||
--enable-chunked-prefill \ | ||
--gpu-memory-utilization 0.6 & | ||
wait_for_server 8100 | ||
wait_for_server 8200 | ||
python3 round_robin_proxy.py & | ||
sleep 1 | ||
} | ||
|
||
|
||
launch_disagg_prefill() { | ||
model="meta-llama/Meta-Llama-3.1-8B-Instruct" | ||
# disagg prefill | ||
CUDA_VISIBLE_DEVICES=0 python3 \ | ||
-m vllm.entrypoints.openai.api_server \ | ||
--model $model \ | ||
--port 8100 \ | ||
--max-model-len 10000 \ | ||
--gpu-memory-utilization 0.6 \ | ||
--kv-transfer-config \ | ||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & | ||
|
||
CUDA_VISIBLE_DEVICES=1 python3 \ | ||
-m vllm.entrypoints.openai.api_server \ | ||
--model $model \ | ||
--port 8200 \ | ||
--max-model-len 10000 \ | ||
--gpu-memory-utilization 0.6 \ | ||
--kv-transfer-config \ | ||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & | ||
|
||
wait_for_server 8100 | ||
wait_for_server 8200 | ||
python3 disagg_prefill_proxy_server.py & | ||
sleep 1 | ||
} | ||
|
||
|
||
benchmark() { | ||
results_folder="./results" | ||
model="meta-llama/Meta-Llama-3.1-8B-Instruct" | ||
dataset_name="sonnet" | ||
dataset_path="../sonnet_4x.txt" | ||
num_prompts=100 | ||
qps=$1 | ||
prefix_len=50 | ||
input_len=1024 | ||
output_len=$2 | ||
tag=$3 | ||
|
||
python3 ../benchmark_serving.py \ | ||
--backend vllm \ | ||
--model $model \ | ||
--dataset-name $dataset_name \ | ||
--dataset-path $dataset_path \ | ||
--sonnet-input-len $input_len \ | ||
--sonnet-output-len "$output_len" \ | ||
--sonnet-prefix-len $prefix_len \ | ||
--num-prompts $num_prompts \ | ||
--port 8000 \ | ||
--save-result \ | ||
--result-dir $results_folder \ | ||
--result-filename "$tag"-qps-"$qps".json \ | ||
--request-rate "$qps" | ||
|
||
sleep 2 | ||
|
||
} | ||
|
||
|
||
main() { | ||
|
||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl) | ||
(which jq) || (apt-get -y install jq) | ||
(which socat) || (apt-get -y install socat) | ||
|
||
pip install quart httpx matplotlib aiohttp | ||
|
||
cd "$(dirname "$0")" | ||
|
||
cd .. | ||
# create sonnet-4x.txt so that we can sample 2048 tokens for input | ||
echo "" > sonnet_4x.txt | ||
for _ in {1..4} | ||
do | ||
cat sonnet.txt >> sonnet_4x.txt | ||
done | ||
cd disagg_benchmarks | ||
|
||
rm -rf results | ||
mkdir results | ||
|
||
default_output_len=6 | ||
|
||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') | ||
|
||
launch_chunked_prefill | ||
for qps in 2 4 6 8; do | ||
benchmark $qps $default_output_len chunked_prefill | ||
done | ||
kill_gpu_processes | ||
|
||
launch_disagg_prefill | ||
for qps in 2 4 6 8; do | ||
benchmark $qps $default_output_len disagg_prefill | ||
done | ||
kill_gpu_processes | ||
|
||
python3 visualize_benchmark_results.py | ||
|
||
} | ||
|
||
|
||
main "$@" |
61 changes: 61 additions & 0 deletions
61
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
|
||
import aiohttp | ||
from quart import Quart, make_response, request | ||
|
||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) | ||
|
||
app = Quart(__name__) | ||
|
||
|
||
async def forward_request(url, data): | ||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
headers = { | ||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" | ||
} | ||
async with session.post(url=url, json=data, | ||
headers=headers) as response: | ||
if response.status == 200: | ||
# if response.headers.get('Transfer-Encoding') == 'chunked': | ||
if True: | ||
async for chunk_bytes in response.content.iter_chunked( | ||
1024): | ||
yield chunk_bytes | ||
else: | ||
content = await response.read() | ||
yield content | ||
|
||
|
||
@app.route('/v1/completions', methods=['POST']) | ||
async def handle_request(): | ||
try: | ||
original_request_data = await request.get_json() | ||
|
||
prefill_request = original_request_data.copy() | ||
# change max_tokens = 1 to let it only do prefill | ||
prefill_request['max_tokens'] = 1 | ||
|
||
# finish prefill | ||
async for _ in forward_request('http://localhost:8100/v1/completions', | ||
prefill_request): | ||
continue | ||
|
||
# return decode | ||
generator = forward_request('http://localhost:8200/v1/completions', | ||
original_request_data) | ||
response = await make_response(generator) | ||
response.timeout = None | ||
|
||
return response | ||
|
||
except Exception as e: | ||
import sys | ||
import traceback | ||
exc_info = sys.exc_info() | ||
print("Error occurred in disagg prefill proxy server") | ||
print(e) | ||
print("".join(traceback.format_exception(*exc_info))) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(port=8000) |
Oops, something went wrong.