Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
Browse files Browse the repository at this point in the history
…-ultravox-lora-dec-16
  • Loading branch information
jeejeelee committed Jan 20, 2025
2 parents 575b5dc + df450aa commit 7cb7eba
Show file tree
Hide file tree
Showing 31 changed files with 852 additions and 234 deletions.
9 changes: 6 additions & 3 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RequestFuncInput:
prompt_len: int
output_len: int
model: str
model_name: Optional[str] = None
best_of: int = 1
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
Expand Down Expand Up @@ -78,7 +79,7 @@ async def async_request_tgi(
continue
chunk_bytes = chunk_bytes.decode("utf-8")

#NOTE: Sometimes TGI returns a ping response without
# NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it.
if chunk_bytes.startswith(":"):
continue
Expand Down Expand Up @@ -235,7 +236,8 @@ async def async_request_openai_completions(

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
Expand Down Expand Up @@ -328,7 +330,8 @@ async def async_request_openai_chat_completions(
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model,
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
Expand Down
13 changes: 13 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ async def benchmark(
api_url: str,
base_url: str,
model_id: str,
model_name: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int],
Expand Down Expand Up @@ -553,6 +554,7 @@ async def benchmark(
"Multi-modal content is only supported on 'openai-chat' backend.")
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
Expand All @@ -573,6 +575,7 @@ async def benchmark(
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
Expand Down Expand Up @@ -616,6 +619,7 @@ async def limited_request_func(request_func_input, pbar):
async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
Expand Down Expand Up @@ -780,6 +784,7 @@ def main(args: argparse.Namespace):

backend = args.backend
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode

Expand Down Expand Up @@ -877,6 +882,7 @@ def main(args: argparse.Namespace):
api_url=api_url,
base_url=base_url,
model_id=model_id,
model_name=model_name,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
Expand Down Expand Up @@ -1222,5 +1228,12 @@ def main(args: argparse.Namespace):
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')

parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")

args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
-
- ✅︎
* - `UltravoxModel`
- Ultravox
- T + A<sup>E+</sup>
Expand Down
18 changes: 8 additions & 10 deletions tests/models/decoder_only/vision_language/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def batch_make_image_embeddings(
pixel_values = preprocess_result["pixel_values"]
image_grid_thw = preprocess_result["image_grid_thw"]

# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
Expand All @@ -124,11 +124,10 @@ def batch_make_image_embeddings(
for image_batch in image_batches_:
cur_batch_image_count = len(image_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in image_grid_thw[image_counter:image_counter +
cur_batch_image_count]
])
cur_batch_image_count])

result.append({
"image_embeds":
Expand Down Expand Up @@ -187,7 +186,7 @@ def batch_make_video_embeddings(
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]

# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
Expand All @@ -206,11 +205,10 @@ def batch_make_video_embeddings(
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[video_counter:video_counter +
cur_batch_video_count]
])
cur_batch_video_count])

result.append({
"video_embeds":
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class _HfExamplesInfo:
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
Expand Down
3 changes: 2 additions & 1 deletion tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main
13 changes: 7 additions & 6 deletions tests/weight_loading/test_weight_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
with vllm_runner(
model_name=MODEL_NAME,
revision=REVISION,
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=None if QUANTIZATION == "None" else QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:

output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
Expand Down
80 changes: 58 additions & 22 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,30 @@
logger = init_logger(__name__)


@dataclasses.dataclass
class InductorArtifact:
hash_str: str = ""
file_path: str = ""


class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
(runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
"""

def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: defaultdict = defaultdict(dict)
self.cache: Dict[Optional[int],
Dict[int, InductorArtifact]] = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
Expand All @@ -66,14 +73,25 @@ def deserialize(self, data: str):
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str
for item in list_data:
runtime_shape = item[0]
graph_index = item[1]
hash_str = item[2]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path = "" if len(item) == 3 else item[3]
self.cache[runtime_shape][graph_index] = InductorArtifact(
hash_str=hash_str, file_path=file_path)

def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
for runtime_shape, value in self.cache.items():
for graph_index, inductor_artifact in value.items():
data.append(
(runtime_shape, graph_index, inductor_artifact.hash_str,
inductor_artifact.file_path))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

Expand All @@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def __setitem__(self, key: Tuple[Optional[int], int], value: str):
def __setitem__(self, key: Tuple[Optional[int], int],
value: InductorArtifact):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
Expand Down Expand Up @@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
inductor_artifact = cache_data[(runtime_shape, graph_index)]
hash_str = inductor_artifact.hash_str
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
Expand All @@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
"Inductor cache lookup failed. Please remove"
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa

# Inductor calling convention (function signature):
# f(list) -> tuple
Expand All @@ -224,19 +245,20 @@ def compiled_graph(*args):
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash

inductor_artifact = InductorArtifact()
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
original_load = FxGraphCache.load

def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph

def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug("store the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), out[0])
inductor_artifact.hash_str = out[0]
return out

def _check_can_cache(*args, **kwargs):
Expand All @@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
if not cache_data.disabled:
# compilation cache is enabled, patch several functions

# hijack to get the compiled graph itself
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache.load",
hijack_load))

# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
Expand All @@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

# store the inductor_artifact in the cache
cache_data[(runtime_shape, graph_index)] = inductor_artifact
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug(
"store the %s-th graph for shape %s via hash %s from file %s",
graph_index, str(runtime_shape), inductor_artifact.hash_str,
inductor_artifact.file_path)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
Expand Down
14 changes: 12 additions & 2 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
Expand Down Expand Up @@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [
arg.ndim + dim if dim < 0 else dim for dim in dims
]
torch._dynamo.mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [
tensor.ndim + dim if dim < 0 else dim
for dim in dims
]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(
Expand Down
Loading

0 comments on commit 7cb7eba

Please sign in to comment.