Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Refactor benchmark_throughput.py #9779

Merged
merged 10 commits into from
Nov 4, 2024
73 changes: 49 additions & 24 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import random
import time
from typing import List, Optional, Tuple
from typing import List, Optional

import torch
import uvloop
Expand All @@ -15,16 +15,35 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.

Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
) -> List[SampleRequest]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -41,7 +60,7 @@ def sample_requests(
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
Expand All @@ -60,31 +79,34 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))

return filtered_dataset


def run_vllm(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

use_beam_search = False
Expand All @@ -94,11 +116,11 @@ def run_vllm(
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
Expand All @@ -112,7 +134,7 @@ def run_vllm(


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
Expand All @@ -123,17 +145,17 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

generators = []
Expand All @@ -149,7 +171,7 @@ async def run_vllm_async(


def run_hf(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
Expand Down Expand Up @@ -207,14 +229,14 @@ def run_hf(


def run_mii(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]

start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
Expand Down Expand Up @@ -270,9 +292,10 @@ def main(args: argparse.Namespace):
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand All @@ -299,7 +322,9 @@ def main(args: argparse.Namespace):
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len",
type=int,
default=None,
Expand Down