Skip to content

Commit

Permalink
Merge branch 'main' into openai-tests-structured-outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Dec 19, 2024
2 parents 1ca8f58 + e461c26 commit d1b540c
Show file tree
Hide file tree
Showing 65 changed files with 1,822 additions and 1,743 deletions.
39 changes: 30 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

#
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
# For Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand All @@ -290,12 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 sparse or quantized models on "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building cutlass_c3x as no compatible archs found "
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()

Expand Down Expand Up @@ -329,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

#
# 2:4 Sparse Kernels

# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper.")
else()
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()


#
# Machete kernels
Expand Down
102 changes: 89 additions & 13 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import json
import random
import time
from typing import List, Optional
from functools import cache
from typing import Dict, List, Optional, Tuple

import torch
import uvloop
Expand All @@ -17,8 +18,11 @@
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


Expand All @@ -28,15 +32,17 @@ class SampleRequest:
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
lora_request: Optional LoRARequest specifying the LoRA to use.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None
lora_request: Optional[LoRARequest] = None


def _get_prompt_for_image_model(question: str, *, model: str) -> str:
Expand All @@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")


@cache
def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path)


lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}


def get_random_lora_request(
args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
global lora_tokenizer_cache
lora_id = random.randint(1, args.max_loras)
lora_request = LoRARequest(lora_name=str(lora_id),
lora_int_id=lora_id,
lora_path=lora_path_on_disk(args.lora_path))
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
return lora_request, lora_tokenizer_cache[lora_id]


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:

dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
Expand All @@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for data in dataset:
for data in tqdm(dataset,
total=len(filtered_dataset),
desc="sampling requests"):
if len(filtered_dataset) == num_requests:
break

Expand All @@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer

# Tokenize the prompts and completions.
prompt_token_ids = tokenizer(prompt).input_ids
completion_token_ids = tokenizer(completion).input_ids
prompt_token_ids = request_tokenizer(prompt).input_ids
completion_token_ids = request_tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
Expand All @@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))
multi_modal_data=multi_modal_data,
lora_request=lora_request))

return filtered_dataset

Expand Down Expand Up @@ -146,14 +184,21 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests: Optional[List[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]

use_beam_search = False

if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
Expand Down Expand Up @@ -185,6 +230,7 @@ async def run_vllm_async(
# Add the requests to the engine.
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
lora_requests: List[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
Expand All @@ -197,11 +243,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests.append(request.lora_request)

generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
Expand Down Expand Up @@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
vocab_size = tokenizer.vocab_size
requests = []
for _ in range(args.num_prompts):

request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer

# Synthesize a prompt with the given input length.
candidate_ids = [
random.randint(0, vocab_size - 1)
Expand All @@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = tokenizer.decode(candidate_ids)
tokenized_len = len(tokenizer.encode(candidate_prompt))
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(request_tokenizer.encode(candidate_prompt))

if tokenized_len == args.input_len:
break
Expand All @@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len))
expected_output_len=args.output_len,
lora_request=lora_request))
else:
requests = sample_requests(tokenizer, args)

Expand Down Expand Up @@ -422,6 +482,14 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand All @@ -431,6 +499,8 @@ def main(args: argparse.Namespace):
assert args.output_len is not None
else:
assert args.input_len is None
if args.enable_lora:
assert args.lora_path is not None

if args.backend == "vllm":
if args.hf_max_batch_size is not None:
Expand All @@ -440,6 +510,9 @@ def main(args: argparse.Namespace):
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
Expand All @@ -452,4 +525,7 @@ def main(args: argparse.Namespace):
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
main(args)
2 changes: 2 additions & 0 deletions csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"

/*
Expand Down
2 changes: 2 additions & 0 deletions csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"

/*
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);

bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
Expand Down
6 changes: 3 additions & 3 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
using namespace cute;

/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
Expand Down
Loading

0 comments on commit d1b540c

Please sign in to comment.