diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index 9d71e4ecc4a37..a9ab4fc9b621e 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -22,6 +22,7 @@ class RequestFuncInput:
prompt_len: int
output_len: int
model: str
+ model_name: Optional[str] = None
best_of: int = 1
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
@@ -78,7 +79,7 @@ async def async_request_tgi(
continue
chunk_bytes = chunk_bytes.decode("utf-8")
- #NOTE: Sometimes TGI returns a ping response without
+ # NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it.
if chunk_bytes.startswith(":"):
continue
@@ -235,7 +236,8 @@ async def async_request_openai_completions(
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
- "model": request_func_input.model,
+ "model": request_func_input.model_name \
+ if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
@@ -328,7 +330,8 @@ async def async_request_openai_chat_completions(
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
- "model": request_func_input.model,
+ "model": request_func_input.model_name \
+ if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index 4eb0e1f8ac903..53186e10c5452 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -525,6 +525,7 @@ async def benchmark(
api_url: str,
base_url: str,
model_id: str,
+ model_name: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int],
@@ -553,6 +554,7 @@ async def benchmark(
"Multi-modal content is only supported on 'openai-chat' backend.")
test_input = RequestFuncInput(
model=model_id,
+ model_name=model_name,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
@@ -573,6 +575,7 @@ async def benchmark(
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
+ model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
@@ -616,6 +619,7 @@ async def limited_request_func(request_func_input, pbar):
async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id,
+ model_name=model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
@@ -780,6 +784,7 @@ def main(args: argparse.Namespace):
backend = args.backend
model_id = args.model
+ model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
@@ -877,6 +882,7 @@ def main(args: argparse.Namespace):
api_url=api_url,
base_url=base_url,
model_id=model_id,
+ model_name=model_name,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
@@ -1222,5 +1228,12 @@ def main(args: argparse.Namespace):
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
+ parser.add_argument("--served-model-name",
+ type=str,
+ default=None,
+ help="The model name used in the API. "
+ "If not specified, the model name will be the "
+ "same as the ``--model`` argument. ")
+
args = parser.parse_args()
main(args)
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 4a8ba2a01a8de..1ef384987bca3 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
- -
+ - ✅︎
* - `UltravoxModel`
- Ultravox
- T + AE+
diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py
index 16e256e040a74..2fd22f0cc88ec 100644
--- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py
+++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py
@@ -105,7 +105,7 @@ def batch_make_image_embeddings(
pixel_values = preprocess_result["pixel_values"]
image_grid_thw = preprocess_result["image_grid_thw"]
- # pixel values to embeddinds & grid_thws
+ # pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
@@ -124,11 +124,10 @@ def batch_make_image_embeddings(
for image_batch in image_batches_:
cur_batch_image_count = len(image_batch)
merge_size = image_processor.merge_size
- cur_batch_embed_len = sum([
- grid_thw.prod() // merge_size // merge_size
+ cur_batch_embed_len = sum(
+ grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in image_grid_thw[image_counter:image_counter +
- cur_batch_image_count]
- ])
+ cur_batch_image_count])
result.append({
"image_embeds":
@@ -187,7 +186,7 @@ def batch_make_video_embeddings(
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]
- # pixel values to embeddinds & grid_thws
+ # pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
@@ -206,11 +205,10 @@ def batch_make_video_embeddings(
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
- cur_batch_embed_len = sum([
- grid_thw.prod() // merge_size // merge_size
+ cur_batch_embed_len = sum(
+ grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[video_counter:video_counter +
- cur_batch_video_count]
- ])
+ cur_batch_video_count])
result.append({
"video_embeds":
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 938c838617e8b..cb0521cfe80a7 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -69,6 +69,7 @@ class _HfExamplesInfo:
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
+ "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt
index a06956ce18a93..272206d4502e9 100644
--- a/tests/weight_loading/models.txt
+++ b/tests/weight_loading/models.txt
@@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
-hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
\ No newline at end of file
+hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
+None, mgleize/fairseq2-dummy-Llama-3.2-1B, main
\ No newline at end of file
diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py
index 199731bdc21fe..7a3786456d0d6 100644
--- a/tests/weight_loading/test_weight_loading.py
+++ b/tests/weight_loading/test_weight_loading.py
@@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
- with vllm_runner(model_name=MODEL_NAME,
- revision=REVISION,
- dtype=torch.half if QUANTIZATION == "gptq" else "auto",
- quantization=QUANTIZATION,
- max_model_len=MAX_MODEL_LEN,
- tensor_parallel_size=2) as model:
+ with vllm_runner(
+ model_name=MODEL_NAME,
+ revision=REVISION,
+ dtype=torch.half if QUANTIZATION == "gptq" else "auto",
+ quantization=None if QUANTIZATION == "None" else QUANTIZATION,
+ max_model_len=MAX_MODEL_LEN,
+ tensor_parallel_size=2) as model:
output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py
index 157e3f7f39c9c..d7f4dcb7a20fc 100644
--- a/vllm/compilation/backends.py
+++ b/vllm/compilation/backends.py
@@ -25,23 +25,30 @@
logger = init_logger(__name__)
+@dataclasses.dataclass
+class InductorArtifact:
+ hash_str: str = ""
+ file_path: str = ""
+
+
class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
- (runtime_shape, graph_index, hash_str)
+ (runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
- The data is essentially `Dict[Optional[int], Dict[int, str]]`,
+ The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
"""
def __init__(self, cache_dir: str, disabled: bool = False):
- self.cache: defaultdict = defaultdict(dict)
+ self.cache: Dict[Optional[int],
+ Dict[int, InductorArtifact]] = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
@@ -66,14 +73,25 @@ def deserialize(self, data: str):
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
- for runtime_shape, graph_index, hash_str in list_data:
- self.cache[runtime_shape][graph_index] = hash_str
+ for item in list_data:
+ runtime_shape = item[0]
+ graph_index = item[1]
+ hash_str = item[2]
+ # for compatibility of old version,
+ # where we don't have file_path.
+ # NOTE: after running the new code, the file_path
+ # will be updated.
+ file_path = "" if len(item) == 3 else item[3]
+ self.cache[runtime_shape][graph_index] = InductorArtifact(
+ hash_str=hash_str, file_path=file_path)
def serialize(self) -> str:
data = []
- for runtime_shape, graph_index_to_hash_str in self.cache.items():
- for graph_index, hash_str in graph_index_to_hash_str.items():
- data.append((runtime_shape, graph_index, hash_str))
+ for runtime_shape, value in self.cache.items():
+ for graph_index, inductor_artifact in value.items():
+ data.append(
+ (runtime_shape, graph_index, inductor_artifact.hash_str,
+ inductor_artifact.file_path))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)
@@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]
- def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
+ def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]
- def __setitem__(self, key: Tuple[Optional[int], int], value: str):
+ def __setitem__(self, key: Tuple[Optional[int], int],
+ value: InductorArtifact):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
- hash_str = cache_data[(runtime_shape, graph_index)]
+ inductor_artifact = cache_data[(runtime_shape, graph_index)]
+ hash_str = inductor_artifact.hash_str
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
"Inductor cache lookup failed. Please remove"
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)
+ inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
# Inductor calling convention (function signature):
# f(list) -> tuple
@@ -224,19 +245,20 @@ def compiled_graph(*args):
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
- from torch._inductor.codecache import compiled_fx_graph_hash
+
+ inductor_artifact = InductorArtifact()
+ from torch._inductor.codecache import (FxGraphCache,
+ compiled_fx_graph_hash)
+ original_load = FxGraphCache.load
+
+ def hijack_load(*args, **kwargs):
+ inductor_compiled_graph = original_load(*args, **kwargs)
+ inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
+ return inductor_compiled_graph
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
- # store the hash in the cache
- nonlocal cache_data
- cache_data[(runtime_shape, graph_index)] = out[0]
- if graph_index == 0:
- # adds some info logging for the first graph
- logger.info("Cache the graph of shape %s for later use",
- str(runtime_shape))
- logger.debug("store the %s-th graph for shape %s via hash %s",
- graph_index, str(runtime_shape), out[0])
+ inductor_artifact.hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
@@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
if not cache_data.disabled:
# compilation cache is enabled, patch several functions
+ # hijack to get the compiled graph itself
+ stack.enter_context(
+ patch("torch._inductor.codecache.FxGraphCache.load",
+ hijack_load))
+
# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
@@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)
-
+ # store the inductor_artifact in the cache
+ cache_data[(runtime_shape, graph_index)] = inductor_artifact
+ if graph_index == 0:
+ # adds some info logging for the first graph
+ logger.info("Cache the graph of shape %s for later use",
+ str(runtime_shape))
+ logger.debug(
+ "store the %s-th graph for shape %s via hash %s from file %s",
+ graph_index, str(runtime_shape), inductor_artifact.hash_str,
+ inductor_artifact.file_path)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index 10513111ea7f1..38f284794b8db 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- - if it is a single integer, the corresponding dimension of the argument
- will be marked as dynamic.
+ - if it is a single integer (can be negative), the corresponding dimension
+ of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
@@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
+ dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
+ # In case dims is specified with negative indexing
+ dims = [
+ arg.ndim + dim if dim < 0 else dim for dim in dims
+ ]
torch._dynamo.mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
+ # In case dims is specified with negative indexing
+ dims = [
+ tensor.ndim + dim if dim < 0 else dim
+ for dim in dims
+ ]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(
diff --git a/vllm/config.py b/vllm/config.py
index ac5a4c91b1738..4698a05020332 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2862,17 +2862,8 @@ def model_post_init(self, __context: Any) -> None:
"vllm.unified_attention_with_output",
]
else:
- # v0 can use full graph compilation without splitting,
- # splitting is optional.
- # right now we still need it. kv cache shape
- # will be included in the graph if we don't split
- # the graph.
- # TODO: hide kv cache in static forward context
- # so that inductor does not see it.
- self.splitting_ops = [
- "vllm.unified_attention",
- "vllm.unified_attention_with_output",
- ]
+ # v0 uses full graph compilation
+ self.splitting_ops = []
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 4ace03ff1184e..7780e2dfa317d 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -35,6 +35,7 @@ def __init__(
):
self.config = config.kv_transfer_config
+ self.tp_size = config.parallel_config.tensor_parallel_size
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@@ -161,7 +162,7 @@ def send_kv_caches_and_hidden_states(
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
- num_heads = model_config.num_key_value_heads
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
index f10a8fb8e03cf..2d8594cb8aafa 100644
--- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py
+++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
@@ -298,8 +298,11 @@ def __call__(self, input_ids: list[int],
# token_bitmask is a CPU tensor for use with accept_token and
# fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type
+ dtype = scores.dtype
if device_type != "cuda":
- scores = scores.to("cpu").unsqueeze(0)
+ # xgrammar on cpu only supports float32 scores
+ # see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22
+ scores = scores.to("cpu").float().unsqueeze(0)
# Note: In this method, if the tensors have different dimensions
# on CPU device fails, but on GPU it runs without error. Hence the
@@ -307,7 +310,7 @@ def __call__(self, input_ids: list[int],
xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device))
if device_type != "cuda":
- scores = scores.to(device_type).squeeze()
+ scores = scores.to(dtype).to(device_type).squeeze()
return scores
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 00ae64bbe6388..52263e96fb9f9 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -344,11 +344,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
+ # bitsandbytes loads the weights of the specific portion
+ # no need to narrow
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
param_data = param.data
- # bitsandbytes loads the weights of the specific portion
- # no need to narrow here
- if output_dim is not None and not use_bitsandbytes_4bit:
+ if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@@ -546,6 +548,11 @@ def weight_loader(self,
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
+ # bitsandbytes loads the weights of the specific portion
+ # no need to narrow
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
+
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
@@ -554,9 +561,7 @@ def weight_loader(self,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
- # bitsandbytes loads the weights of the specific portion
- # no need to narrow here
- if not use_bitsandbytes_4bit:
+ if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
@@ -941,6 +946,11 @@ def weight_loader(self,
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
+ # bitsandbytes loads the weights of the specific portion
+ # no need to narrow
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
+
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
@@ -964,9 +974,7 @@ def weight_loader(self,
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
- # bitsandbytes loads the weights of the specific portion
- # no need to narrow here
- if not use_bitsandbytes_4bit:
+ if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
@@ -1070,6 +1078,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
+ # bitsandbytes loads the weights of the specific portion
+ # no need to narrow
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -1085,9 +1097,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data
- # bitsandbytes loads the weights of the specific portion
- # no need to narrow here
- if input_dim is not None and not use_bitsandbytes_4bit:
+ if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index 3fcd81a3c4213..d071cfe888f05 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -841,6 +841,37 @@ def get_input_positions(
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
+ llm_positions, mrope_position_delta = \
+ MRotaryEmbedding.get_input_positions_tensor(
+ input_tokens,
+ image_grid_thw,
+ video_grid_thw,
+ image_token_id,
+ video_token_id,
+ vision_start_token_id,
+ vision_end_token_id,
+ spatial_merge_size,
+ context_len,
+ seq_len,
+ )
+
+ return llm_positions.tolist(), mrope_position_delta
+
+ @staticmethod
+ def get_input_positions_tensor(
+ input_tokens: List[int],
+ image_grid_thw: Union[List[List[int]], torch.Tensor],
+ video_grid_thw: Union[List[List[int]], torch.Tensor],
+ image_token_id: int,
+ video_token_id: int,
+ vision_start_token_id: int,
+ vision_end_token_id: int,
+ spatial_merge_size: int,
+ context_len: int = 0,
+ seq_len: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, int]:
+ """Get mrope input positions and delta value."""
+
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
@@ -916,7 +947,7 @@ def get_input_positions(
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
- return llm_positions.tolist(), mrope_position_delta
+ return llm_positions, mrope_position_delta
@staticmethod
def get_next_input_positions(
@@ -930,6 +961,17 @@ def get_next_input_positions(
seq_len + mrope_position_delta)) for _ in range(3)
]
+ @staticmethod
+ def get_next_input_positions_tensor(
+ mrope_position_delta: int,
+ context_len: int,
+ seq_len: int,
+ ) -> torch.Tensor:
+ return torch.arange(
+ mrope_position_delta + context_len,
+ mrope_position_delta + seq_len,
+ ).expand(3, -1)
+
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 9fe0db62435a0..f697c3245f098 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -182,6 +182,9 @@ class Source:
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
+ allow_patterns_overrides: Optional[list[str]] = None
+ """If defined, weights will load exclusively using these patterns."""
+
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
@@ -218,6 +221,7 @@ def _prepare_weights(
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool,
+ allow_patterns_overrides: Optional[list[str]],
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
@@ -249,6 +253,9 @@ def _prepare_weights(
if fall_back_to_pt:
allow_patterns += ["*.pt"]
+ if allow_patterns_overrides is not None:
+ allow_patterns = allow_patterns_overrides
+
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
@@ -298,7 +305,8 @@ def _get_weights_iterator(
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
- source.model_or_path, source.revision, source.fall_back_to_pt)
+ source.model_or_path, source.revision, source.fall_back_to_pt,
+ source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
@@ -340,6 +348,8 @@ def _get_all_weights(
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True),
+ allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
+ None),
)
yield from self._get_weights_iterator(primary_weights)
@@ -353,7 +363,8 @@ def _get_all_weights(
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
- fall_back_to_pt=True)
+ fall_back_to_pt=True,
+ allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
@@ -1105,15 +1116,22 @@ def _load_weights(self, model_config: ModelConfig,
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
- shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
- if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
+ shard_pos = quant_param_name.find(shard_name)
+ can_correct_rename = (shard_pos > 0) and (
+ quant_param_name[shard_pos - 1] == ".")
+ # If the quant_param_name is packed, it won't occur in the
+ # param_dict before renaming.
+ new_quant_param_name = quant_param_name.replace(
+ shard_name, weight_name)
+ need_rename = (quant_param_name not in param_dict) \
+ and (new_quant_param_name in param_dict)
+ if can_correct_rename and need_rename:
shard_index = index
- quant_param_name = quant_param_name.replace(
- shard_name, weight_name)
+ quant_param_name = new_quant_param_name
break
# Models like Clip/Siglip may skip some layers in initialization,
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 7e37ce3086e6b..d5f9b4d19e5ca 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -41,7 +41,7 @@
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
-from .utils import (is_pp_missing_parameter,
+from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -605,9 +605,50 @@ def forward(
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
+ ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: Set[str] = set()
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ if "rotary_pos_emb.inv_freq" in name:
+ continue
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_substr={".word_embeddings": ""}, )
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -660,52 +701,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str,
- torch.Tensor]]) -> Set[str]:
- # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
- merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
- "transformer.vision.linear_proj.merged_proj.weight": {
- "transformer.vision.linear_proj.gate_proj.weight": None,
- "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
- }
- }
-
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded_params: Set[str] = set()
- for name, loaded_weight in weights:
- is_weight_to_be_merge = False
- for _, merged_weight_dict in merged_weights_dict.items():
- if name in merged_weight_dict:
- assert merged_weight_dict[name] is None
- merged_weight_dict[name] = loaded_weight
- is_weight_to_be_merge = True
- if is_weight_to_be_merge:
- continue
- if "rotary_pos_emb.inv_freq" in name:
- continue
- if "word_embeddings" in name:
- name = name.replace(".word_embeddings", "")
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- if is_pp_missing_parameter(name, self):
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
-
- for combined_name, merged_weight_dict in merged_weights_dict.items():
- if combined_name in params_dict:
- param = params_dict[combined_name]
- combined_weight = torch.cat(list(merged_weight_dict.values()),
- dim=0)
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, combined_weight)
- loaded_params.add(combined_name)
- return loaded_params
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class ChatGLM(ChatGLMBaseModel):
@@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
+
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
@@ -777,7 +776,7 @@ def __new__(
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
- if hasattr(config, "visual"):
+ if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
# Initialize LLM
else:
diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py
new file mode 100644
index 0000000000000..b93a68680375d
--- /dev/null
+++ b/vllm/model_executor/models/fairseq2_llama.py
@@ -0,0 +1,151 @@
+# Copyright 2024 The vLLM team.
+# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Llama model for fairseq2 weights."""
+
+from typing import Iterable, Set, Tuple
+
+import torch
+from torch.nn import Parameter
+
+from vllm.config import VllmConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
+from vllm.model_executor.layers.linear import set_weight_attrs
+from vllm.model_executor.models.llama import LlamaForCausalLM
+
+from .utils import AutoWeightsLoader, WeightsMapper
+
+
+class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ # For the model loader to read only the relevant checkpoint files
+ self.allow_patterns_overrides = [
+ # either the full checkpoint
+ "model.pt",
+ # or the tp-sharded checkpoint of the current rank
+ f"model.{self.tp_rank}.pt",
+ ]
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ # fairseq2's serialization adds a wrapper to usual .pt state_dict's:
+ # { "model_key": my_model_name, "my_model_name": state_dict }
+ # which we first need to unpack
+ weights_wrapped = dict(weights)
+ weights = weights_wrapped[
+ weights_wrapped["model_key"]].items() # type: ignore
+
+ # remap keys
+ fs2_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "decoder_frontend.embed.": "model.embed_tokens.",
+ "decoder.": "model.",
+ "final_proj.": "lm_head.",
+ },
+ orig_to_new_substr={
+ ".self_attn_layer_norm.": ".input_layernorm.",
+ ".ffn_layer_norm.": ".post_attention_layernorm.",
+ ".self_attn.output_proj.": ".self_attn.o_proj.",
+ ".ffn.gate_proj.": ".mlp.gate_proj.",
+ ".ffn.inner_proj.": ".mlp.up_proj.",
+ ".ffn.output_proj.": ".mlp.down_proj.",
+ ".layer_norm.": ".norm.",
+ },
+ )
+ weights = fs2_to_vllm_mapper.apply(weights)
+
+ params = dict(self.named_parameters())
+
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."]
+ if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(
+ (self.reshape_fairseq2_weights(name, loaded_weight, params)
+ for name, loaded_weight in weights))
+
+ def flag_sharded_weights(self, params: dict[str, Parameter]):
+ """Sets the `is_sharded_weight` flag to True for all sharded weights"""
+ for name, param in params.items():
+ modules = name.split(".")
+ if "norm" in name and len(param.size()) < 2:
+ # layer norms are not sharded
+ continue
+ elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
+ # for now we repeat embedding layers for compatibility
+ continue
+ else:
+ # all other layers are sharded
+ set_weight_attrs(param, {"is_sharded_weight": True})
+
+ def reshape_fairseq2_weights(
+ self,
+ name: str,
+ loaded_weight: torch.Tensor,
+ params: dict[str, Parameter],
+ ) -> Tuple[str, torch.Tensor]:
+ """Reshape fairseq2's weights."""
+
+ def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
+ attn_in = self.config.head_dim * n_heads
+ # check for a sharded weight on dim 0
+ if attn_in // self.tp_size == w.size()[0]:
+ attn_in //= self.tp_size
+ n_heads //= self.tp_size
+ attn_out = self.config.hidden_size
+ return (w.view(n_heads, attn_in // n_heads // 2, 2,
+ attn_out).transpose(1,
+ 2).reshape(attn_in, attn_out))
+
+ modules = name.split(".")
+
+ # rotary embeds should be sliced
+ if "k_proj" in modules:
+ loaded_weight = permute(loaded_weight,
+ self.config.num_key_value_heads)
+
+ elif "q_proj" in modules:
+ loaded_weight = permute(loaded_weight,
+ self.config.num_attention_heads)
+
+ # We make the loaded weights compatible with both
+ # full checkpoints and tp sharded checkpoints.
+ # Embeddings are repeated to fit the vocab size.
+ # Other weights are flagged for the weight_loader calls.
+ if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
+ # Embeddings are sharded on dim 0
+ dim = 0
+ # In fairseq2, vocab size has to be divisible by tp_size
+ # so we don't worry about padding
+ if self.tp_size > 1 and loaded_weight.shape[
+ dim] < self.config.vocab_size:
+ assert loaded_weight.shape[
+ dim] * self.tp_size == self.config.vocab_size, \
+ "vocab_size should be divisible by tp_size."
+ repeats = [1] * len(loaded_weight.size())
+ repeats[dim] = self.tp_size
+ # repeat to match vocab size and to be easily 'narrow'able
+ loaded_weight = loaded_weight.repeat(repeats)
+ set_weight_attrs(params[name], {"is_sharded_weight": False})
+ # if embeddings are sharded, the rest is too
+ if "embed_tokens" in modules:
+ self.flag_sharded_weights(params)
+
+ return name, loaded_weight
diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py
index 39a5736eb199b..51922e6f2d03d 100644
--- a/vllm/model_executor/models/glm4_vision_encoder.py
+++ b/vllm/model_executor/models/glm4_vision_encoder.py
@@ -42,7 +42,8 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
- images = images.to(self.proj.weight.device)
+ images = images.to(device=self.proj.weight.device,
+ dtype=self.proj.weight.dtype)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index 722fff98d5c19..6cceded43a79d 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -5,9 +5,11 @@
import torch
import torch.nn as nn
+from packaging.version import Version
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
+from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
@@ -716,6 +718,27 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loader.load_weights(weights)
+class MantisProcessingInfo(LlavaProcessingInfo):
+
+ def get_hf_processor(self):
+ hf_config = self.get_hf_config()
+ vision_info = self.get_vision_encoder_info()
+
+ if Version(TRANSFORMERS_VERSION) < Version("4.48"):
+ # BUG: num_additional_image_tokens = 0 but treated as 1,
+ # so we set vision_feature_select_strategy to None to offset this
+ vision_feature_select_strategy = None
+ else:
+ # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
+ vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
+
+ return self.ctx.get_hf_processor(
+ LlavaProcessor,
+ patch_size=vision_info.get_patch_size(),
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply(
@@ -794,7 +817,7 @@ def get_replacement_mantis(item_idx: int):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
- info=LlavaProcessingInfo,
+ info=MantisProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index c9283e0c5ba20..6faa79f65d8de 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
- if input_key == "pixel_values" and "images" not in modalities:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
- if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index d015f60c6d065..82de1c3574090 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -256,7 +256,15 @@ def forward(
return hidden_states, residual
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ })
class Qwen2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
index 0dff9595c6c08..47d56175261e4 100644
--- a/vllm/model_executor/models/qwen2_audio.py
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -36,8 +36,9 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
- NestedTensors)
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalInputsV2, MultiModalKwargs,
+ NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -153,29 +154,24 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, Any],
) -> BatchFeature:
- mm_data = dict(mm_data)
- audios = mm_data.pop("audios", [])
-
- if audios:
- mm_data["audios"] = audios
-
- feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
- mm_kwargs = dict(
- **mm_kwargs,
- sampling_rate=feature_extractor.sampling_rate,
- )
- else:
- # NOTE: WhisperFeatureExtractor cannot handle empty list of audios
- pass
+ # Text-only input not supported in composite processor
+ if not mm_data or not mm_data.get("audios", []):
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
+ mm_kwargs = dict(
+ **mm_kwargs,
+ sampling_rate=feature_extractor.sampling_rate,
+ )
- processed_outputs = super()._call_hf_processor(
+ return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
- return processed_outputs
-
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@@ -192,8 +188,14 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
- hf_config = self.info.get_hf_config()
- placeholder = hf_config.audio_token_index
+ processor = self.info.get_hf_processor()
+
+ # Use getattr with default to be compatible with transformers<4.48
+ audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
+ audio_bos_token = getattr(processor, "audio_bos_token",
+ "<|audio_bos|>")
+ audio_eos_token = getattr(processor, "audio_eos_token",
+ "<|audio_eos|>")
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
@@ -214,12 +216,16 @@ def get_replacement_qwen2_audio(item_idx: int):
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
- return [placeholder] * num_placeholders
+ return "".join([
+ audio_bos_token,
+ audio_token * num_placeholders,
+ audio_eos_token,
+ ])
return [
PromptReplacement(
modality="audio",
- target=[placeholder],
+ target=audio_token,
replacement=get_replacement_qwen2_audio,
)
]
@@ -234,6 +240,26 @@ def _always_apply_prompt_replacements(self) -> bool:
# tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token")
+ def apply(
+ self,
+ prompt: Union[str, list[int]],
+ mm_data: MultiModalDataDict,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> MultiModalInputsV2:
+ result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
+
+ # Only <|AUDIO|> tokens should be considered as placeholders,
+ # so we ignore the audio_bos_token and audio_eos_token
+ result["mm_placeholders"] = {
+ modality: [
+ PlaceholderRange(offset=p["offset"] + 1,
+ length=p["length"] - 2) for p in ps
+ ]
+ for modality, ps in result["mm_placeholders"].items()
+ }
+
+ return result
+
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index d00e5d362c8bc..34d5c8ad089a3 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -67,11 +67,15 @@
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper,
- init_vllm_registered_model, maybe_prefix)
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
+# For profile run
+_MAX_FRAMES_PER_VIDEO = 16
+
# === Vision Inputs === #
@@ -135,7 +139,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict):
- List[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
- (concatenation of all videos' feature tensors).
+ (concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
@@ -611,6 +615,7 @@ def forward(
# adapter
x = self.merger(x)
+
return x
def load_weights(self, weights: Iterable[Tuple[str,
@@ -874,8 +879,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int:
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
-
- num_frames = max(max_total_frames // max(max_videos, 1), 1)
+ num_frames = min(max(max_total_frames // max(max_videos, 1), 1),
+ _MAX_FRAMES_PER_VIDEO)
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
if num_frames > 1 and num_frames % 2 == 1:
@@ -955,13 +960,14 @@ def _get_prompt_replacements(
"image": hf_processor.image_token,
"video": hf_processor.video_token,
}
+
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)
- num_tokens = grid_thw.prod() // merge_length
+ num_tokens = grid_thw.prod().item() // merge_length
return placeholder[modality] * num_tokens
return [
@@ -1047,11 +1053,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2VLConfig = vllm_config.model_config.hf_config
- cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
- assert not cache_config.enable_prefix_caching, \
- "Qwen2-VL currently does not support prefix caching"
self.config = config
self.multimodal_config = multimodal_config
@@ -1173,59 +1176,82 @@ def _parse_and_validate_video_input(
video_embeds=video_embeds,
video_grid_thw=video_grid_thw)
- def _process_image_input(self,
- image_input: Qwen2VLImageInputs) -> torch.Tensor:
+ def _process_image_input(
+ self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
if image_input["type"] == "image_embeds":
- return image_input["image_embeds"].type(self.visual.dtype)
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+
+ # Split concatenated embeddings for each image item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
- pixel_values = image_input["pixel_values"].type(self.visual.dtype)
- image_embeds = self.visual(pixel_values,
- grid_thw=image_input["image_grid_thw"])
- return image_embeds
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
- def _process_video_input(self,
- video_input: Qwen2VLVideoInputs) -> torch.Tensor:
if video_input["type"] == "video_embeds":
- return video_input["video_embeds"].type(self.visual.dtype)
+ video_embeds = video_input["video_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
- pixel_values_videos = video_input["pixel_values_videos"].type(
- self.visual.dtype)
- video_embeds = self.visual(pixel_values_videos,
- grid_thw=video_input["video_grid_thw"])
- return video_embeds
+ # Split concatenated embeddings for each video item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
- def _merge_multimodal_embeddings(
- self,
- input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor,
- multimodal_embeddings: torch.Tensor,
- placeholder_token_id: int,
- ) -> torch.Tensor:
- mask = (input_ids == placeholder_token_id)
- inputs_embeds[mask, :] = multimodal_embeddings
- return inputs_embeds
+ return video_embeds.split(sizes.tolist())
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
+ modalities["videos"] = self._parse_and_validate_video_input(
+ **kwargs)
+
+ return modalities
def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
- image_input = self._parse_and_validate_image_input(**kwargs)
- video_input = self._parse_and_validate_video_input(**kwargs)
- if image_input is None and video_input is None:
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
return None
- # We make a tuple of each embedding with its modality string. This is a
- # temporary workaround for models to handle mixed modalities when
- # get_multimodal_embeddings and get_input_embeddings are called
- # separately.
- # TODO(ywang96): Add support for mixed-modality inference for v1.
- multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
-
- if image_input is not None:
- image_embeds = self._process_image_input(image_input)
- multimodal_embeddings.append((image_embeds, "image"))
- if video_input is not None:
- video_embeds = self._process_video_input(video_input)
- multimodal_embeddings.append((video_embeds, "video"))
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "videos":
+ video_input = modalities["videos"]
+ video_embeddings = self._process_video_input(video_input)
+ multimodal_embeddings += video_embeddings
return multimodal_embeddings
@@ -1237,21 +1263,9 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
- for embeddings, modality in multimodal_embeddings:
- if modality == "image":
- inputs_embeds = self._merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- embeddings,
- placeholder_token_id=self.config.image_token_id,
- )
- if modality == "video":
- inputs_embeds = self._merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- embeddings,
- placeholder_token_id=self.config.video_token_id,
- )
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def forward(
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index a71f7f7029c7d..311f91472783b 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -47,6 +47,7 @@
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
+ "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index b1ac7c92a0be9..33517b0ce2a44 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -138,7 +138,7 @@ def _call_hf_processor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
- if not mm_data:
+ if not mm_data or not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
@@ -147,13 +147,6 @@ def _call_hf_processor(
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
- if not audios:
- return super()._call_hf_processor(
- prompt=prompt,
- mm_data=mm_data,
- mm_kwargs=mm_kwargs,
- )
-
feature_extractor = self.info.get_feature_extractor()
mm_kwargs = dict(
**mm_kwargs,
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index c97acffa1a719..f57dfded0a62f 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -22,10 +22,10 @@
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
-from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
- DbrxConfig, DeepseekVLV2Config,
- EAGLEConfig, ExaoneConfig,
- H2OVLChatConfig,
+from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
+ Cohere2Config, DbrxConfig,
+ DeepseekVLV2Config, EAGLEConfig,
+ ExaoneConfig, H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
@@ -52,6 +52,7 @@
}
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
+ "aria": AriaConfig,
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index f065c56124605..807ef4fbfd0c0 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -1,3 +1,4 @@
+from vllm.transformers_utils.configs.aria import AriaConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@@ -23,6 +24,7 @@
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
+ "AriaConfig",
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",
diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py
index d253da0d96a34..f4b531225b5d0 100644
--- a/vllm/transformers_utils/configs/aria.py
+++ b/vllm/transformers_utils/configs/aria.py
@@ -1,7 +1,32 @@
+# Copyright 2024 Rhymes AI. All rights reserved.
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Mapping
+
+from transformers import PretrainedConfig
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig)
from transformers.models.llama.configuration_llama import LlamaConfig
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"
@@ -45,3 +70,96 @@ def __init__(
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.moe_num_shared_experts = moe_num_shared_experts
+
+
+class AriaConfig(PretrainedConfig):
+ """
+ Configuration class for Aria model.
+ This class handles the configuration for both vision and text components of
+ the Aria model,
+ as well as additional parameters for image token handling and projector
+ mapping.
+
+ Args:
+ vision_config (AriaVisionConfig or dict): Configuration for the vision
+ component.
+ text_config (AriaMoELMConfig or dict): Configuration for the text
+ component.
+ projector_patch_to_query_dict (dict): Mapping of patch sizes to query
+ dimensions.
+ ignore_index (int): Index to ignore in loss calculation.
+ image_token_index (int): Index used to represent image tokens.
+ **kwargs: Additional keyword arguments passed to the parent class.
+ Attributes:
+ model_type (str): Type of the model, set to "aria".
+ is_composition (bool): Whether the model is a composition of multiple
+ components.
+ ignore_index (int): Index to ignore in loss calculation.
+ image_token_index (int): Index used to represent image tokens.
+ projector_patch_to_query_dict (dict): Mapping of patch sizes to query
+ dimensions.
+ vision_config (AriaVisionConfig): Configuration for the vision
+ component.
+ text_config (AriaMoELMConfig): Configuration for the text component.
+ """
+
+ model_type = "aria"
+ is_composition = False
+
+ def __init__(
+ self,
+ vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008
+ text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008
+ projector_patch_to_query_dict: Mapping[int, int] = {
+ 1225: 128,
+ 4900: 256,
+ },
+ ignore_index=-100,
+ image_token_index=32000,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.ignore_index = ignore_index
+ self.image_token_index = image_token_index
+ self.tie_word_embeddings = tie_word_embeddings
+ attn_implementation = kwargs.pop("attn_implementation", None)
+
+ # Set the default attention implementation to flash_attention_2 if not
+ # specified
+ self._attn_implementation = ("flash_attention_2"
+ if attn_implementation is None else
+ attn_implementation)
+
+ # Convert the keys and values of projector_patch_to_query_dict to
+ # integers
+ # This ensures consistency even if they were provided as strings
+ self.projector_patch_to_query_dict = {
+ int(k): int(v)
+ for k, v in projector_patch_to_query_dict.items()
+ }
+
+ if isinstance(vision_config, dict) and "model_type" in vision_config:
+ vision_config = AriaVisionConfig(**vision_config)
+ if attn_implementation is None:
+ vision_attn_implementation = "flash_attention_2"
+ elif attn_implementation == "sdpa":
+ logger.warning("SDPA is not supported for vit, using "
+ "flash_attention_2 instead")
+ vision_attn_implementation = "flash_attention_2"
+ else:
+ vision_attn_implementation = attn_implementation
+ vision_config._attn_implementation = vision_attn_implementation
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict) and "model_type" in text_config:
+ text_attn_implementation = ("sdpa" if attn_implementation is None
+ else attn_implementation)
+ text_config = AriaMoELMConfig(**text_config)
+ text_config._attn_implementation = text_attn_implementation
+
+ self.text_config = text_config
+
+ # This is needed for the static kv cache
+ self.num_hidden_layers = self.text_config.num_hidden_layers
diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py
index a9deee881f41a..841df3994fba2 100644
--- a/vllm/usage/usage_lib.py
+++ b/vllm/usage/usage_lib.py
@@ -27,6 +27,17 @@
_GLOBAL_RUNTIME_DATA: Dict[str, Union[str, int, bool]] = {}
+_USAGE_ENV_VARS_TO_COLLECT = [
+ "VLLM_USE_MODELSCOPE",
+ "VLLM_USE_TRITON_FLASH_ATTN",
+ "VLLM_ATTENTION_BACKEND",
+ "VLLM_USE_FLASHINFER_SAMPLER",
+ "VLLM_PP_LAYER_PARTITION",
+ "VLLM_USE_TRITON_AWQ",
+ "VLLM_USE_V1",
+ "VLLM_ENABLE_V1_MULTIPROCESSING",
+]
+
def set_runtime_usage_data(key: str, value: Union[str, int, bool]) -> None:
"""Set global usage data that will be sent with every usage heartbeat."""
@@ -122,6 +133,7 @@ def __init__(self) -> None:
self.gpu_count: Optional[int] = None
self.gpu_type: Optional[str] = None
self.gpu_memory_per_device: Optional[int] = None
+ self.env_var_json: Optional[str] = None
# vLLM Information
self.model_architecture: Optional[str] = None
@@ -176,6 +188,12 @@ def _report_usage_once(self, model_architecture: str,
self.vllm_version = VLLM_VERSION
self.model_architecture = model_architecture
+ # Environment variables
+ self.env_var_json = json.dumps({
+ env_var: getattr(envs, env_var)
+ for env_var in _USAGE_ENV_VARS_TO_COLLECT
+ })
+
# Metadata
self.log_time = _get_current_timestamp_ns()
self.source = envs.VLLM_USAGE_SOURCE
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 40494e64b22f0..28d8e39053874 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -30,6 +30,9 @@ class CachedRequestState:
num_computed_tokens: int
output_token_ids: List[int]
+ mrope_positions: Optional[torch.Tensor] = None
+ mrope_position_delta: Optional[int] = None
+
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index aa63d9414c296..87a1cd7f9e627 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -14,6 +14,7 @@
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
+from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
@@ -139,6 +140,32 @@ def __init__(
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
+
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.model_config.uses_mrope:
+ # NOTE: `mrope_positions` is implemented as a permuted tensor to
+ # satisfy the following properties to allow `torch.compile` to work
+ # properly:
+ # - shape: (3, )
+ # - stride: (1, 3)
+ # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
+
+ # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
+ # the modality of inputs. For text-only inputs, each dimension has
+ # identical position IDs, making M-RoPE functionally equivalent to
+ # 1D-RoPE.
+ # See page 5 of https://arxiv.org/abs/2409.12191
+ self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
+ dtype=torch.int64,
+ device=self.device)
+ self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory)
+
+ self.mrope_positions = self.mrope_positions.permute((1, 0))
+ self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
+
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
@@ -246,6 +273,35 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
+
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.model_config.uses_mrope:
+ image_grid_thw = []
+ video_grid_thw = []
+ for mm_input in self.requests[req_id].mm_inputs:
+ if mm_input.get("image_grid_thw") is not None:
+ image_grid_thw.extend(
+ mm_input["image_grid_thw"].tolist())
+ if mm_input.get("video_grid_thw") is not None:
+ video_grid_thw.extend(
+ mm_input["video_grid_thw"].tolist())
+
+ hf_config = self.model_config.hf_config
+
+ self.requests[req_id].mrope_positions, \
+ self.requests[req_id].mrope_position_delta = \
+ MRotaryEmbedding.get_input_positions_tensor(
+ self.requests[req_id].prompt_token_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ image_token_id=hf_config.image_token_id,
+ video_token_id=hf_config.video_token_id,
+ vision_start_token_id=hf_config.vision_start_token_id,
+ vision_end_token_id=hf_config.vision_end_token_id,
+ spatial_merge_size=hf_config.vision_config.
+ spatial_merge_size,
+ )
+
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
@@ -313,6 +369,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
arange,
out=positions_np)
+ # Calculate M-RoPE positions.
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.model_config.uses_mrope:
+ self._calc_mrope_positions(scheduler_output)
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -359,8 +420,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
- self.positions[:total_num_scheduled_tokens].copy_(
- self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
+ if self.model_config.uses_mrope:
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
+ self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True)
+ else:
+ # Common case (1D positions)
+ self.positions[:total_num_scheduled_tokens].copy_(
+ self.positions_cpu[:total_num_scheduled_tokens],
+ non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
@@ -472,6 +541,61 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
+ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
+ mrope_pos_ptr = 0
+ num_reqs = self.input_batch.num_reqs
+ for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
+ assert req_id is not None
+
+ req = self.requests[req_id]
+ assert req.mrope_positions is not None
+
+ num_computed_tokens = \
+ self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = \
+ scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = len(req.prompt_token_ids)
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0,
+ num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(
+ 0, num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's mrope_positions are pre-computed
+ dst_start = mrope_pos_ptr
+ dst_end = mrope_pos_ptr + prompt_part_len
+ src_start = num_computed_tokens
+ src_end = num_computed_tokens + prompt_part_len
+
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
+ req.mrope_positions[:,src_start:src_end]
+
+ mrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's mrope_positions on-the-fly
+ dst_start = mrope_pos_ptr
+ dst_end = mrope_pos_ptr + completion_part_len
+
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
+ MRotaryEmbedding.get_next_input_positions_tensor(
+ req.mrope_position_delta,
+ context_len=num_computed_tokens +
+ prompt_part_len,
+ seq_len=num_computed_tokens +
+ prompt_part_len +
+ completion_part_len,
+ )
+
+ mrope_pos_ptr += completion_part_len
+
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
@@ -618,9 +742,12 @@ def execute_model(
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
+ positions = self.mrope_positions[:, :num_input_tokens] \
+ if self.model_config.uses_mrope \
+ else self.positions[:num_input_tokens]
hidden_states = self.model(
input_ids=input_ids,
- positions=self.positions[:num_input_tokens],
+ positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,
@@ -707,9 +834,12 @@ def _dummy_run(
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
with set_forward_context(None, self.vllm_config):
+ positions = self.mrope_positions[:, :num_tokens] \
+ if self.model_config.uses_mrope \
+ else self.positions[:num_tokens]
hidden_states = model(
input_ids=input_ids,
- positions=self.positions[:num_tokens],
+ positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,