diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 9d71e4ecc4a37..a9ab4fc9b621e 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -22,6 +22,7 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + model_name: Optional[str] = None best_of: int = 1 logprobs: Optional[int] = None extra_body: Optional[dict] = None @@ -78,7 +79,7 @@ async def async_request_tgi( continue chunk_bytes = chunk_bytes.decode("utf-8") - #NOTE: Sometimes TGI returns a ping response without + # NOTE: Sometimes TGI returns a ping response without # any data, we should skip it. if chunk_bytes.startswith(":"): continue @@ -235,7 +236,8 @@ async def async_request_openai_completions( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "best_of": request_func_input.best_of, @@ -328,7 +330,8 @@ async def async_request_openai_chat_completions( if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "messages": [ { "role": "user", diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 4eb0e1f8ac903..53186e10c5452 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -525,6 +525,7 @@ async def benchmark( api_url: str, base_url: str, model_id: str, + model_name: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], logprobs: Optional[int], @@ -553,6 +554,7 @@ async def benchmark( "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, + model_name=model_name, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, @@ -573,6 +575,7 @@ async def benchmark( if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, @@ -616,6 +619,7 @@ async def limited_request_func(request_func_input, pbar): async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -780,6 +784,7 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model + model_name = args.served_model_name tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_mode = args.tokenizer_mode @@ -877,6 +882,7 @@ def main(args: argparse.Namespace): api_url=api_url, base_url=base_url, model_id=model_id, + model_name=model_name, tokenizer=tokenizer, input_requests=input_requests, logprobs=args.logprobs, @@ -1222,5 +1228,12 @@ def main(args: argparse.Namespace): 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + args = parser.parse_args() main(args) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 4a8ba2a01a8de..1ef384987bca3 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ - `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ - - + - ✅︎ * - `UltravoxModel` - Ultravox - T + AE+ diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 16e256e040a74..2fd22f0cc88ec 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -105,7 +105,7 @@ def batch_make_image_embeddings( pixel_values = preprocess_result["pixel_values"] image_grid_thw = preprocess_result["image_grid_thw"] - # pixel values to embeddinds & grid_thws + # pixel values to embeddings & grid_thws with torch.no_grad(): visual = llm.llm_engine.model_executor.driver_worker. \ model_runner.model.visual @@ -124,11 +124,10 @@ def batch_make_image_embeddings( for image_batch in image_batches_: cur_batch_image_count = len(image_batch) merge_size = image_processor.merge_size - cur_batch_embed_len = sum([ - grid_thw.prod() // merge_size // merge_size + cur_batch_embed_len = sum( + grid_thw.prod(-1) // merge_size // merge_size for grid_thw in image_grid_thw[image_counter:image_counter + - cur_batch_image_count] - ]) + cur_batch_image_count]) result.append({ "image_embeds": @@ -187,7 +186,7 @@ def batch_make_video_embeddings( pixel_values = preprocess_result["pixel_values_videos"] video_grid_thw = preprocess_result["video_grid_thw"] - # pixel values to embeddinds & grid_thws + # pixel values to embeddings & grid_thws with torch.no_grad(): visual = llm.llm_engine.model_executor.driver_worker.\ model_runner.model.visual @@ -206,11 +205,10 @@ def batch_make_video_embeddings( for video_batch in video_batches_: cur_batch_video_count = len(video_batch) merge_size = image_processor.merge_size - cur_batch_embed_len = sum([ - grid_thw.prod() // merge_size // merge_size + cur_batch_embed_len = sum( + grid_thw.prod(-1) // merge_size // merge_size for grid_thw in video_grid_thw[video_counter:video_counter + - cur_batch_video_count] - ]) + cur_batch_video_count]) result.append({ "video_embeds": diff --git a/tests/models/registry.py b/tests/models/registry.py index 938c838617e8b..cb0521cfe80a7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -69,6 +69,7 @@ class _HfExamplesInfo: "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 trust_remote_code=True), "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 + "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index a06956ce18a93..272206d4502e9 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main qqq, HandH1998/QQQ-Llama-3-8b-g128, main qqq, HandH1998/QQQ-Llama-3-8b, main -hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main \ No newline at end of file +hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main +None, mgleize/fairseq2-dummy-Llama-3.2-1B, main \ No newline at end of file diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index 199731bdc21fe..7a3786456d0d6 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner): """ Test parameter weight loading with tp>1. """ - with vllm_runner(model_name=MODEL_NAME, - revision=REVISION, - dtype=torch.half if QUANTIZATION == "gptq" else "auto", - quantization=QUANTIZATION, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=2) as model: + with vllm_runner( + model_name=MODEL_NAME, + revision=REVISION, + dtype=torch.half if QUANTIZATION == "gptq" else "auto", + quantization=None if QUANTIZATION == "None" else QUANTIZATION, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=2) as model: output = model.generate_greedy("Hello world!", max_tokens=20) print(output) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 157e3f7f39c9c..d7f4dcb7a20fc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -25,23 +25,30 @@ logger = init_logger(__name__) +@dataclasses.dataclass +class InductorArtifact: + hash_str: str = "" + file_path: str = "" + + class InductorHashCache: """ Disk format: a Python list of tuples, each tuple is - (runtime_shape, graph_index, hash_str) + (runtime_shape, graph_index, hash_str, file_path) We use list of tuple for readability. In-memory format: a defaultdict of dict, where the key is runtime_shape, and the value is a dict of graph_index to hash_str. - The data is essentially `Dict[Optional[int], Dict[int, str]]`, + The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`, we don't use json here because json doesn't support int as key. TODO: better off-the-shelf solution to serialize the data? """ def __init__(self, cache_dir: str, disabled: bool = False): - self.cache: defaultdict = defaultdict(dict) + self.cache: Dict[Optional[int], + Dict[int, InductorArtifact]] = defaultdict(dict) self.disabled = disabled self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, @@ -66,14 +73,25 @@ def deserialize(self, data: str): # because it is a safe way to parse Python literals. # do not use eval(), it is unsafe. list_data = ast.literal_eval(data) - for runtime_shape, graph_index, hash_str in list_data: - self.cache[runtime_shape][graph_index] = hash_str + for item in list_data: + runtime_shape = item[0] + graph_index = item[1] + hash_str = item[2] + # for compatibility of old version, + # where we don't have file_path. + # NOTE: after running the new code, the file_path + # will be updated. + file_path = "" if len(item) == 3 else item[3] + self.cache[runtime_shape][graph_index] = InductorArtifact( + hash_str=hash_str, file_path=file_path) def serialize(self) -> str: data = [] - for runtime_shape, graph_index_to_hash_str in self.cache.items(): - for graph_index, hash_str in graph_index_to_hash_str.items(): - data.append((runtime_shape, graph_index, hash_str)) + for runtime_shape, value in self.cache.items(): + for graph_index, inductor_artifact in value.items(): + data.append( + (runtime_shape, graph_index, inductor_artifact.hash_str, + inductor_artifact.file_path)) printer = pprint.PrettyPrinter(indent=4) return printer.pformat(data) @@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool: return runtime_shape in self.cache and graph_index in self.cache[ runtime_shape] - def __getitem__(self, key: Tuple[Optional[int], int]) -> str: + def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact: if self.disabled: raise KeyError("cannot read from disabled cache") runtime_shape, graph_index = key return self.cache[runtime_shape][graph_index] - def __setitem__(self, key: Tuple[Optional[int], int], value: str): + def __setitem__(self, key: Tuple[Optional[int], int], + value: InductorArtifact): # setitem for disabled cache is fine, because we # don't actually write to the disk runtime_shape, graph_index = key @@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule, if (runtime_shape, graph_index) in cache_data: # we compiled this graph before # so we can directly lookup the compiled graph via hash - hash_str = cache_data[(runtime_shape, graph_index)] + inductor_artifact = cache_data[(runtime_shape, graph_index)] + hash_str = inductor_artifact.hash_str if graph_index == 0: # adds some info logging for the first graph logger.info( @@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule, "Inductor cache lookup failed. Please remove" f"the cache file {cache_data.cache_file_path} and try again." # noqa ) + inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa # Inductor calling convention (function signature): # f(list) -> tuple @@ -224,19 +245,20 @@ def compiled_graph(*args): # the assumption is that we don't have nested Inductor compilation. # compiled_fx_graph_hash will only be called once, and we can hook # it to get the hash of the compiled graph directly. - from torch._inductor.codecache import compiled_fx_graph_hash + + inductor_artifact = InductorArtifact() + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + original_load = FxGraphCache.load + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + return inductor_compiled_graph def hijack_compiled_fx_graph_hash(*args, **kwargs): out = compiled_fx_graph_hash(*args, **kwargs) - # store the hash in the cache - nonlocal cache_data - cache_data[(runtime_shape, graph_index)] = out[0] - if graph_index == 0: - # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug("store the %s-th graph for shape %s via hash %s", - graph_index, str(runtime_shape), out[0]) + inductor_artifact.hash_str = out[0] return out def _check_can_cache(*args, **kwargs): @@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv: if not cache_data.disabled: # compilation cache is enabled, patch several functions + # hijack to get the compiled graph itself + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache.load", + hijack_load)) + # for hijacking the hash of the compiled graph stack.enter_context( patch("torch._inductor.codecache.compiled_fx_graph_hash", @@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv: compiled_graph = compile_fx(graph, example_inputs, config_patches=current_config) - + # store the inductor_artifact in the cache + cache_data[(runtime_shape, graph_index)] = inductor_artifact + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s via hash %s from file %s", + graph_index, str(runtime_shape), inductor_artifact.hash_str, + inductor_artifact.file_path) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 10513111ea7f1..38f284794b8db 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): During runtime, when we actually mark dimensions of tensors, it depends on the value of arguments: - - if it is a single integer, the corresponding dimension of the argument - will be marked as dynamic. + - if it is a single integer (can be negative), the corresponding dimension + of the argument will be marked as dynamic. - if it is `None`, ignored. - if it is `IntermediateTensors`, all the tensors in the intermediate tensors will be marked as dynamic. @@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs): for k, dims in dynamic_arg_dims.items(): arg = bound_args.arguments.get(k) if arg is not None: + dims = [dims] if isinstance(dims, int) else dims if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] torch._dynamo.mark_dynamic(arg, dims) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] torch._dynamo.mark_dynamic(tensor, dims) else: raise ValueError( diff --git a/vllm/config.py b/vllm/config.py index ac5a4c91b1738..4698a05020332 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2862,17 +2862,8 @@ def model_post_init(self, __context: Any) -> None: "vllm.unified_attention_with_output", ] else: - # v0 can use full graph compilation without splitting, - # splitting is optional. - # right now we still need it. kv cache shape - # will be included in the graph if we don't split - # the graph. - # TODO: hide kv cache in static forward context - # so that inductor does not see it. - self.splitting_ops = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ] + # v0 uses full graph compilation + self.splitting_ops = [] for k, v in self.inductor_passes.items(): if not isinstance(v, str): diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 4ace03ff1184e..7780e2dfa317d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -35,6 +35,7 @@ def __init__( ): self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -161,7 +162,7 @@ def send_kv_caches_and_hidden_states( end_layer = model_executable.model.end_layer model_config = model_executable.model.config - num_heads = model_config.num_key_value_heads + num_heads = int(model_config.num_key_value_heads / self.tp_size) hidden_size = model_config.hidden_size num_attention_heads = model_config.num_attention_heads head_size = int(hidden_size / num_attention_heads) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index f10a8fb8e03cf..2d8594cb8aafa 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -298,8 +298,11 @@ def __call__(self, input_ids: list[int], # token_bitmask is a CPU tensor for use with accept_token and # fill_next_token_bitmask so we move it to the device of scores device_type = scores.device.type + dtype = scores.dtype if device_type != "cuda": - scores = scores.to("cpu").unsqueeze(0) + # xgrammar on cpu only supports float32 scores + # see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22 + scores = scores.to("cpu").float().unsqueeze(0) # Note: In this method, if the tensors have different dimensions # on CPU device fails, but on GPU it runs without error. Hence the @@ -307,7 +310,7 @@ def __call__(self, input_ids: list[int], xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) if device_type != "cuda": - scores = scores.to(device_type).squeeze() + scores = scores.to(dtype).to(device_type).squeeze() return scores diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 00ae64bbe6388..52263e96fb9f9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -344,11 +344,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit param_data = param.data - # bitsandbytes loads the weights of the specific portion - # no need to narrow here - if output_dim is not None and not use_bitsandbytes_4bit: + if output_dim is not None and not is_sharded_weight: shard_size = param_data.shape[output_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, @@ -546,6 +548,11 @@ def weight_loader(self, use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + if use_bitsandbytes_4bit: shard_size = loaded_weight.shape[output_dim] shard_offset = loaded_weight.shape[output_dim] * \ @@ -554,9 +561,7 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size - # bitsandbytes loads the weights of the specific portion - # no need to narrow here - if not use_bitsandbytes_4bit: + if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for AQLM codebooks. @@ -941,6 +946,11 @@ def weight_loader(self, use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), @@ -964,9 +974,7 @@ def weight_loader(self, shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - # bitsandbytes loads the weights of the specific portion - # no need to narrow here - if not use_bitsandbytes_4bit: + if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -1070,6 +1078,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1085,9 +1097,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - # bitsandbytes loads the weights of the specific portion - # no need to narrow here - if input_dim is not None and not use_bitsandbytes_4bit: + if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 3fcd81a3c4213..d071cfe888f05 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -841,6 +841,37 @@ def get_input_positions( ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" + llm_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + input_tokens, + image_grid_thw, + video_grid_thw, + image_token_id, + video_token_id, + vision_start_token_id, + vision_end_token_id, + spatial_merge_size, + context_len, + seq_len, + ) + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_input_positions_tensor( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() if isinstance(video_grid_thw, torch.Tensor): @@ -916,7 +947,7 @@ def get_input_positions( len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] - return llm_positions.tolist(), mrope_position_delta + return llm_positions, mrope_position_delta @staticmethod def get_next_input_positions( @@ -930,6 +961,17 @@ def get_next_input_positions( seq_len + mrope_position_delta)) for _ in range(3) ] + @staticmethod + def get_next_input_positions_tensor( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> torch.Tensor: + return torch.arange( + mrope_position_delta + context_len, + mrope_position_delta + seq_len, + ).expand(3, -1) + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 9fe0db62435a0..f697c3245f098 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -182,6 +182,9 @@ class Source: fall_back_to_pt: bool = True """Whether .pt weights can be used.""" + allow_patterns_overrides: Optional[list[str]] = None + """If defined, weights will load exclusively using these patterns.""" + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -218,6 +221,7 @@ def _prepare_weights( model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], ) -> Tuple[str, List[str], bool]: """Prepare weights for the model. @@ -249,6 +253,9 @@ def _prepare_weights( if fall_back_to_pt: allow_patterns += ["*.pt"] + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + if not is_local: hf_folder = download_weights_from_hf( model_name_or_path, @@ -298,7 +305,8 @@ def _get_weights_iterator( ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - source.model_or_path, source.revision, source.fall_back_to_pt) + source.model_or_path, source.revision, source.fall_back_to_pt, + source.allow_patterns_overrides) if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False @@ -340,6 +348,8 @@ def _get_all_weights( prefix="", fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", + None), ) yield from self._get_weights_iterator(primary_weights) @@ -353,7 +363,8 @@ def _get_all_weights( def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision, - fall_back_to_pt=True) + fall_back_to_pt=True, + allow_patterns_overrides=None) def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config @@ -1105,15 +1116,22 @@ def _load_weights(self, model_config: ModelConfig, weight_name, index, ) in self.modules_mapping.inverse_packed_mapping.items(): - shard_pos = quant_param_name.find(shard_name) # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # from being incorrectly identified as being present in # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight - if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".": + shard_pos = quant_param_name.find(shard_name) + can_correct_rename = (shard_pos > 0) and ( + quant_param_name[shard_pos - 1] == ".") + # If the quant_param_name is packed, it won't occur in the + # param_dict before renaming. + new_quant_param_name = quant_param_name.replace( + shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) \ + and (new_quant_param_name in param_dict) + if can_correct_rename and need_rename: shard_index = index - quant_param_name = quant_param_name.replace( - shard_name, weight_name) + quant_param_name = new_quant_param_name break # Models like Clip/Siglip may skip some layers in initialization, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 7e37ce3086e6b..d5f9b4d19e5ca 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -41,7 +41,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -605,9 +605,50 @@ def forward( return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), + ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "rotary_pos_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".word_embeddings": ""}, ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -660,52 +701,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - # Merge two ColumnParallelLinear into one MergedColumnParallelLinear - merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = { - "transformer.vision.linear_proj.merged_proj.weight": { - "transformer.vision.linear_proj.gate_proj.weight": None, - "transformer.vision.linear_proj.dense_h_to_4h.weight": None, - } - } - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - is_weight_to_be_merge = False - for _, merged_weight_dict in merged_weights_dict.items(): - if name in merged_weight_dict: - assert merged_weight_dict[name] is None - merged_weight_dict[name] = loaded_weight - is_weight_to_be_merge = True - if is_weight_to_be_merge: - continue - if "rotary_pos_emb.inv_freq" in name: - continue - if "word_embeddings" in name: - name = name.replace(".word_embeddings", "") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - for combined_name, merged_weight_dict in merged_weights_dict.items(): - if combined_name in params_dict: - param = params_dict[combined_name] - combined_weight = torch.cat(list(merged_weight_dict.values()), - dim=0) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, combined_weight) - loaded_params.add(combined_name) - return loaded_params + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class ChatGLM(ChatGLMBaseModel): @@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel): class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal): + packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], @@ -777,7 +776,7 @@ def __new__( ) -> None: config = vllm_config.model_config.hf_config # Initialize VL - if hasattr(config, "visual"): + if hasattr(config, "vision_config"): return ChatGLMV(vllm_config=vllm_config, prefix=prefix) # Initialize LLM else: diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py new file mode 100644 index 0000000000000..b93a68680375d --- /dev/null +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -0,0 +1,151 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama model for fairseq2 weights.""" + +from typing import Iterable, Set, Tuple + +import torch +from torch.nn import Parameter + +from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import set_weight_attrs +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import AutoWeightsLoader, WeightsMapper + + +class Fairseq2LlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + # For the model loader to read only the relevant checkpoint files + self.allow_patterns_overrides = [ + # either the full checkpoint + "model.pt", + # or the tp-sharded checkpoint of the current rank + f"model.{self.tp_rank}.pt", + ] + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + # fairseq2's serialization adds a wrapper to usual .pt state_dict's: + # { "model_key": my_model_name, "my_model_name": state_dict } + # which we first need to unpack + weights_wrapped = dict(weights) + weights = weights_wrapped[ + weights_wrapped["model_key"]].items() # type: ignore + + # remap keys + fs2_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder_frontend.embed.": "model.embed_tokens.", + "decoder.": "model.", + "final_proj.": "lm_head.", + }, + orig_to_new_substr={ + ".self_attn_layer_norm.": ".input_layernorm.", + ".ffn_layer_norm.": ".post_attention_layernorm.", + ".self_attn.output_proj.": ".self_attn.o_proj.", + ".ffn.gate_proj.": ".mlp.gate_proj.", + ".ffn.inner_proj.": ".mlp.up_proj.", + ".ffn.output_proj.": ".mlp.down_proj.", + ".layer_norm.": ".norm.", + }, + ) + weights = fs2_to_vllm_mapper.apply(weights) + + params = dict(self.named_parameters()) + + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights( + (self.reshape_fairseq2_weights(name, loaded_weight, params) + for name, loaded_weight in weights)) + + def flag_sharded_weights(self, params: dict[str, Parameter]): + """Sets the `is_sharded_weight` flag to True for all sharded weights""" + for name, param in params.items(): + modules = name.split(".") + if "norm" in name and len(param.size()) < 2: + # layer norms are not sharded + continue + elif any(emb in modules for emb in ["embed_tokens", "lm_head"]): + # for now we repeat embedding layers for compatibility + continue + else: + # all other layers are sharded + set_weight_attrs(param, {"is_sharded_weight": True}) + + def reshape_fairseq2_weights( + self, + name: str, + loaded_weight: torch.Tensor, + params: dict[str, Parameter], + ) -> Tuple[str, torch.Tensor]: + """Reshape fairseq2's weights.""" + + def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: + attn_in = self.config.head_dim * n_heads + # check for a sharded weight on dim 0 + if attn_in // self.tp_size == w.size()[0]: + attn_in //= self.tp_size + n_heads //= self.tp_size + attn_out = self.config.hidden_size + return (w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, + 2).reshape(attn_in, attn_out)) + + modules = name.split(".") + + # rotary embeds should be sliced + if "k_proj" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + + elif "q_proj" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + # We make the loaded weights compatible with both + # full checkpoints and tp sharded checkpoints. + # Embeddings are repeated to fit the vocab size. + # Other weights are flagged for the weight_loader calls. + if any(emb in modules for emb in ["embed_tokens", "lm_head"]): + # Embeddings are sharded on dim 0 + dim = 0 + # In fairseq2, vocab size has to be divisible by tp_size + # so we don't worry about padding + if self.tp_size > 1 and loaded_weight.shape[ + dim] < self.config.vocab_size: + assert loaded_weight.shape[ + dim] * self.tp_size == self.config.vocab_size, \ + "vocab_size should be divisible by tp_size." + repeats = [1] * len(loaded_weight.size()) + repeats[dim] = self.tp_size + # repeat to match vocab size and to be easily 'narrow'able + loaded_weight = loaded_weight.repeat(repeats) + set_weight_attrs(params[name], {"is_sharded_weight": False}) + # if embeddings are sharded, the rest is too + if "embed_tokens" in modules: + self.flag_sharded_weights(params) + + return name, loaded_weight diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 39a5736eb199b..51922e6f2d03d 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -42,7 +42,8 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: torch.Tensor Transformed tensor with shape (B, L, D) """ - images = images.to(self.proj.weight.device) + images = images.to(device=self.proj.weight.device, + dtype=self.proj.weight.dtype) x = self.proj(images) x = x.flatten(2).transpose(1, 2) cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 722fff98d5c19..6cceded43a79d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -5,9 +5,11 @@ import torch import torch.nn as nn +from packaging.version import Version from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, SiglipVisionConfig) +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor @@ -716,6 +718,27 @@ def load_weights(self, weights: Iterable[Tuple[str, return loader.load_weights(weights) +class MantisProcessingInfo(LlavaProcessingInfo): + + def get_hf_processor(self): + hf_config = self.get_hf_config() + vision_info = self.get_vision_encoder_info() + + if Version(TRANSFORMERS_VERSION) < Version("4.48"): + # BUG: num_additional_image_tokens = 0 but treated as 1, + # so we set vision_feature_select_strategy to None to offset this + vision_feature_select_strategy = None + else: + # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150 + vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501 + + return self.ctx.get_hf_processor( + LlavaProcessor, + patch_size=vision_info.get_patch_size(), + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + class MantisMultiModalProcessor(LlavaMultiModalProcessor): def apply( @@ -794,7 +817,7 @@ def get_replacement_mantis(item_idx: int): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, - info=LlavaProcessingInfo, + info=MantisProcessingInfo, dummy_inputs=LlavaDummyInputsBuilder) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c9283e0c5ba20..6faa79f65d8de 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key == "pixel_values" and "images" not in modalities: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: modalities["images"] = self._parse_and_validate_image_input( **kwargs) - if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501 + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: modalities["videos"] = self._parse_and_validate_video_input( **kwargs) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index d015f60c6d065..82de1c3574090 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -256,7 +256,15 @@ def forward( return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) class Qwen2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 0dff9595c6c08..47d56175261e4 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -36,8 +36,9 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -153,29 +154,24 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, Any], ) -> BatchFeature: - mm_data = dict(mm_data) - audios = mm_data.pop("audios", []) - - if audios: - mm_data["audios"] = audios - - feature_extractor = self.info.get_feature_extractor(**mm_kwargs) - mm_kwargs = dict( - **mm_kwargs, - sampling_rate=feature_extractor.sampling_rate, - ) - else: - # NOTE: WhisperFeatureExtractor cannot handle empty list of audios - pass + # Text-only input not supported in composite processor + if not mm_data or not mm_data.get("audios", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) - processed_outputs = super()._call_hf_processor( + return super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, ) - return processed_outputs - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -192,8 +188,14 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.info.get_hf_config() - placeholder = hf_config.audio_token_index + processor = self.info.get_hf_processor() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + audio_bos_token = getattr(processor, "audio_bos_token", + "<|audio_bos|>") + audio_eos_token = getattr(processor, "audio_eos_token", + "<|audio_eos|>") feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") if feature_attention_mask is None: @@ -214,12 +216,16 @@ def get_replacement_qwen2_audio(item_idx: int): f"The audio {audio} (len={len(audio)}) is too short " "to be represented inside the model") - return [placeholder] * num_placeholders + return "".join([ + audio_bos_token, + audio_token * num_placeholders, + audio_eos_token, + ]) return [ PromptReplacement( modality="audio", - target=[placeholder], + target=audio_token, replacement=get_replacement_qwen2_audio, ) ] @@ -234,6 +240,26 @@ def _always_apply_prompt_replacements(self) -> bool: # tokens than the number of audio items) return not hasattr(self.info.get_hf_processor(), "audio_token") + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) + + # Only <|AUDIO|> tokens should be considered as placeholders, + # so we ignore the audio_bos_token and audio_eos_token + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"] + 1, + length=p["length"] - 2) for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result + @MULTIMODAL_REGISTRY.register_processor( Qwen2AudioMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d00e5d362c8bc..34d5c8ad089a3 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,11 +67,15 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vit_attn_backend logger = init_logger(__name__) +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + # === Vision Inputs === # @@ -135,7 +139,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): - List[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). + (concatenation of all videos' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on @@ -611,6 +615,7 @@ def forward( # adapter x = self.merger(x) + return x def load_weights(self, weights: Iterable[Tuple[str, @@ -874,8 +879,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - - num_frames = max(max_total_frames // max(max_videos, 1), 1) + num_frames = min(max(max_total_frames // max(max_videos, 1), 1), + _MAX_FRAMES_PER_VIDEO) # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 if num_frames > 1 and num_frames % 2 == 1: @@ -955,13 +960,14 @@ def _get_prompt_replacements( "image": hf_processor.image_token, "video": hf_processor.video_token, } + merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) - num_tokens = grid_thw.prod() // merge_length + num_tokens = grid_thw.prod().item() // merge_length return placeholder[modality] * num_tokens return [ @@ -1047,11 +1053,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2VLConfig = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - assert not cache_config.enable_prefix_caching, \ - "Qwen2-VL currently does not support prefix caching" self.config = config self.multimodal_config = multimodal_config @@ -1173,59 +1176,82 @@ def _parse_and_validate_video_input( video_embeds=video_embeds, video_grid_thw=video_grid_thw) - def _process_image_input(self, - image_input: Qwen2VLImageInputs) -> torch.Tensor: + def _process_image_input( + self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + if image_input["type"] == "image_embeds": - return image_input["image_embeds"].type(self.visual.dtype) + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, - grid_thw=image_input["image_grid_thw"]) - return image_embeds + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 - def _process_video_input(self, - video_input: Qwen2VLVideoInputs) -> torch.Tensor: if video_input["type"] == "video_embeds": - return video_input["video_embeds"].type(self.visual.dtype) + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=video_input["video_grid_thw"]) - return video_embeds + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size - def _merge_multimodal_embeddings( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: torch.Tensor, - placeholder_token_id: int, - ) -> torch.Tensor: - mask = (input_ids == placeholder_token_id) - inputs_embeds[mask, :] = multimodal_embeddings - return inputs_embeds + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities def get_multimodal_embeddings( self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - if image_input is None and video_input is None: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: return None - # We make a tuple of each embedding with its modality string. This is a - # temporary workaround for models to handle mixed modalities when - # get_multimodal_embeddings and get_input_embeddings are called - # separately. - # TODO(ywang96): Add support for mixed-modality inference for v1. - multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - multimodal_embeddings.append((image_embeds, "image")) - if video_input is not None: - video_embeds = self._process_video_input(video_input) - multimodal_embeddings.append((video_embeds, "video")) + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -1237,21 +1263,9 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - for embeddings, modality in multimodal_embeddings: - if modality == "image": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.image_token_id, - ) - if modality == "video": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.video_token_id, - ) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a71f7f7029c7d..311f91472783b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -47,6 +47,7 @@ "DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b1ac7c92a0be9..33517b0ce2a44 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -138,7 +138,7 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor - if not mm_data: + if not mm_data or not mm_data.get("audios", []): prompt_ids = self.info.get_tokenizer().encode(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -147,13 +147,6 @@ def _call_hf_processor( audios = mm_data.pop("audios", []) assert isinstance(audios, list) - if not audios: - return super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - ) - feature_extractor = self.info.get_feature_extractor() mm_kwargs = dict( **mm_kwargs, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c97acffa1a719..f57dfded0a62f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -22,10 +22,10 @@ from vllm.logger import init_logger # yapf conflicts with isort for this block # yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, DeepseekVLV2Config, - EAGLEConfig, ExaoneConfig, - H2OVLChatConfig, +from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig, + Cohere2Config, DbrxConfig, + DeepseekVLV2Config, EAGLEConfig, + ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,6 +52,7 @@ } _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + "aria": AriaConfig, "chatglm": ChatGLMConfig, "cohere2": Cohere2Config, "dbrx": DbrxConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f065c56124605..807ef4fbfd0c0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,3 +1,4 @@ +from vllm.transformers_utils.configs.aria import AriaConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -23,6 +24,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ + "AriaConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py index d253da0d96a34..f4b531225b5d0 100644 --- a/vllm/transformers_utils/configs/aria.py +++ b/vllm/transformers_utils/configs/aria.py @@ -1,7 +1,32 @@ +# Copyright 2024 Rhymes AI. All rights reserved. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Mapping + +from transformers import PretrainedConfig from transformers.models.idefics2.configuration_idefics2 import ( Idefics2VisionConfig) from transformers.models.llama.configuration_llama import LlamaConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + class AriaVisionConfig(Idefics2VisionConfig): model_type = "aria_vision_model" @@ -45,3 +70,96 @@ def __init__( self.moe_num_experts = moe_num_experts self.moe_topk = moe_topk self.moe_num_shared_experts = moe_num_shared_experts + + +class AriaConfig(PretrainedConfig): + """ + Configuration class for Aria model. + This class handles the configuration for both vision and text components of + the Aria model, + as well as additional parameters for image token handling and projector + mapping. + + Args: + vision_config (AriaVisionConfig or dict): Configuration for the vision + component. + text_config (AriaMoELMConfig or dict): Configuration for the text + component. + projector_patch_to_query_dict (dict): Mapping of patch sizes to query + dimensions. + ignore_index (int): Index to ignore in loss calculation. + image_token_index (int): Index used to represent image tokens. + **kwargs: Additional keyword arguments passed to the parent class. + Attributes: + model_type (str): Type of the model, set to "aria". + is_composition (bool): Whether the model is a composition of multiple + components. + ignore_index (int): Index to ignore in loss calculation. + image_token_index (int): Index used to represent image tokens. + projector_patch_to_query_dict (dict): Mapping of patch sizes to query + dimensions. + vision_config (AriaVisionConfig): Configuration for the vision + component. + text_config (AriaMoELMConfig): Configuration for the text component. + """ + + model_type = "aria" + is_composition = False + + def __init__( + self, + vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008 + text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008 + projector_patch_to_query_dict: Mapping[int, int] = { + 1225: 128, + 4900: 256, + }, + ignore_index=-100, + image_token_index=32000, + tie_word_embeddings=False, + **kwargs, + ): + super().__init__(**kwargs) + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.tie_word_embeddings = tie_word_embeddings + attn_implementation = kwargs.pop("attn_implementation", None) + + # Set the default attention implementation to flash_attention_2 if not + # specified + self._attn_implementation = ("flash_attention_2" + if attn_implementation is None else + attn_implementation) + + # Convert the keys and values of projector_patch_to_query_dict to + # integers + # This ensures consistency even if they were provided as strings + self.projector_patch_to_query_dict = { + int(k): int(v) + for k, v in projector_patch_to_query_dict.items() + } + + if isinstance(vision_config, dict) and "model_type" in vision_config: + vision_config = AriaVisionConfig(**vision_config) + if attn_implementation is None: + vision_attn_implementation = "flash_attention_2" + elif attn_implementation == "sdpa": + logger.warning("SDPA is not supported for vit, using " + "flash_attention_2 instead") + vision_attn_implementation = "flash_attention_2" + else: + vision_attn_implementation = attn_implementation + vision_config._attn_implementation = vision_attn_implementation + + self.vision_config = vision_config + + if isinstance(text_config, dict) and "model_type" in text_config: + text_attn_implementation = ("sdpa" if attn_implementation is None + else attn_implementation) + text_config = AriaMoELMConfig(**text_config) + text_config._attn_implementation = text_attn_implementation + + self.text_config = text_config + + # This is needed for the static kv cache + self.num_hidden_layers = self.text_config.num_hidden_layers diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index a9deee881f41a..841df3994fba2 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -27,6 +27,17 @@ _GLOBAL_RUNTIME_DATA: Dict[str, Union[str, int, bool]] = {} +_USAGE_ENV_VARS_TO_COLLECT = [ + "VLLM_USE_MODELSCOPE", + "VLLM_USE_TRITON_FLASH_ATTN", + "VLLM_ATTENTION_BACKEND", + "VLLM_USE_FLASHINFER_SAMPLER", + "VLLM_PP_LAYER_PARTITION", + "VLLM_USE_TRITON_AWQ", + "VLLM_USE_V1", + "VLLM_ENABLE_V1_MULTIPROCESSING", +] + def set_runtime_usage_data(key: str, value: Union[str, int, bool]) -> None: """Set global usage data that will be sent with every usage heartbeat.""" @@ -122,6 +133,7 @@ def __init__(self) -> None: self.gpu_count: Optional[int] = None self.gpu_type: Optional[str] = None self.gpu_memory_per_device: Optional[int] = None + self.env_var_json: Optional[str] = None # vLLM Information self.model_architecture: Optional[str] = None @@ -176,6 +188,12 @@ def _report_usage_once(self, model_architecture: str, self.vllm_version = VLLM_VERSION self.model_architecture = model_architecture + # Environment variables + self.env_var_json = json.dumps({ + env_var: getattr(envs, env_var) + for env_var in _USAGE_ENV_VARS_TO_COLLECT + }) + # Metadata self.log_time = _get_current_timestamp_ns() self.source = envs.VLLM_USAGE_SOURCE diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f0..28d8e39053874 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -30,6 +30,9 @@ class CachedRequestState: num_computed_tokens: int output_token_ids: List[int] + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa63d9414c296..87a1cd7f9e627 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -14,6 +14,7 @@ from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType @@ -139,6 +140,32 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + # NOTE: `mrope_positions` is implemented as a permuted tensor to + # satisfy the following properties to allow `torch.compile` to work + # properly: + # - shape: (3, ) + # - stride: (1, 3) + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = torch.zeros((self.max_num_tokens, 3), + dtype=torch.int64, + device=self.device) + self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + self.mrope_positions = self.mrope_positions.permute((1, 0)) + self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0)) + self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -246,6 +273,35 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + ) + req_ids_to_add.append(req_id) # Update the cached states of the resumed requests. @@ -313,6 +369,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): arange, out=positions_np) + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + self._calc_mrope_positions(scheduler_output) + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -359,8 +420,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.model_config.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( @@ -472,6 +541,61 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + num_reqs = self.input_batch.num_reqs + for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + MRotaryEmbedding.get_next_input_positions_tensor( + req.mrope_position_delta, + context_len=num_computed_tokens + + prompt_part_len, + seq_len=num_computed_tokens + + prompt_part_len + + completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + def _prepare_sampling( self, scheduler_output: "SchedulerOutput", @@ -618,9 +742,12 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): + positions = self.mrope_positions[:, :num_input_tokens] \ + if self.model_config.uses_mrope \ + else self.positions[:num_input_tokens] hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=positions, kv_caches=self.kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, @@ -707,9 +834,12 @@ def _dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None with set_forward_context(None, self.vllm_config): + positions = self.mrope_positions[:, :num_tokens] \ + if self.model_config.uses_mrope \ + else self.positions[:num_tokens] hidden_states = model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=positions, kv_caches=kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds,