diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index ab372038676e6..6103fb34445ad 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -26,6 +26,7 @@ def main(args: argparse.Namespace): max_num_seqs=args.batch_size, max_num_batched_tokens=args.batch_size * args.input_len, trust_remote_code=args.trust_remote_code, + use_cuda_graph=args.use_cuda_graph ) sampling_params = SamplingParams( @@ -68,7 +69,8 @@ def run_to_completion(profile: bool = False): for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): latencies.append(run_to_completion(profile=False)) print(latencies) - print(f'Avg latency: {np.mean(latencies)} seconds') + if torch.distributed.get_rank() == 0: + print(f'Avg latency: {np.mean(latencies)} seconds') del llm ray.shutdown() @@ -90,6 +92,12 @@ def run_to_completion(profile: bool = False): help='Number of iterations to run.') parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') + parser.add_argument( + '--use-cuda-graph', + action="store_true", + default=False, + help="Whether to use CUDA graph for inference", + ) args = parser.parse_args() main(args) #time.sleep(30) #add sleep for profiling diff --git a/docker/fa_arch.patch b/docker/fa_arch.patch index db2a95ae7f19d..f584860211edc 100644 --- a/docker/fa_arch.patch +++ b/docker/fa_arch.patch @@ -6,7 +6,7 @@ index 7396633..2ca5d55 100644 [ "-O3", "-std=c++20", -- "--offload-arch=gfx941", +- "--offload-arch=gfx942", "-DNDEBUG", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", diff --git a/vllm/config.py b/vllm/config.py index 2e8d58411181c..019575a200e1f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -30,7 +30,10 @@ class ModelConfig: dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. + use_cuda_graph: Whether to capture and replay CUDA graphs. seed: Random seed for reproducibility. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. """ def __init__( @@ -44,6 +47,8 @@ def __init__( use_dummy_weights: bool, dtype: str, seed: int, + max_model_len: Optional[int] = None, + use_cuda_graph: Optional[bool] = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -53,10 +58,13 @@ def __init__( self.use_np_weights = use_np_weights self.use_dummy_weights = use_dummy_weights self.seed = seed + self.use_cuda_graph = use_cuda_graph self.hf_config = get_config(model, trust_remote_code) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_tokenizer_mode() + self.max_model_len = _get_and_verify_max_len(self.hf_config, + max_model_len) def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -117,6 +125,8 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int: return total_num_attention_heads // parallel_config.tensor_parallel_size def get_max_model_len(self) -> int: + if self.max_model_len is not None: + return self.max_model_len max_model_len = float("inf") possible_keys = [ # OPT @@ -214,7 +224,7 @@ def __init__( self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: - self.worker_use_ray = True + self.worker_use_ray = False self._verify_args() def _verify_args(self) -> None: @@ -295,3 +305,58 @@ def _get_and_verify_dtype( f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"{compute_capability[0]}.{compute_capability[1]}.") return torch_dtype + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(hf_config, key, None) + if max_len_key is not None: + derived_max_model_len = min(derived_max_model_len, max_len_key) + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + f"{possible_keys}. Assuming the model's maximum length is " + f"{default_max_len}.") + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if max_model_len is None: + max_model_len = derived_max_model_len + elif max_model_len > derived_max_model_len: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater than " + f"the derived max_model_len ({max_len_key}={derived_max_model_len}" + " in model's config.json). This may lead to incorrect model " + "outputs or CUDA errors. Make sure the value is correct and " + "within the model context size.") + return int(max_model_len) \ No newline at end of file diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6587638b5e1c1..3922f0bb03893 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -19,15 +19,17 @@ class EngineArgs: use_dummy_weights: bool = False dtype: str = 'auto' seed: int = 0 + max_model_len: Optional[int] = None worker_use_ray: bool = False pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 block_size: int = 16 swap_space: int = 4 # GiB - gpu_memory_utilization: float = 0.9 + gpu_memory_utilization: float = 0.6 max_num_batched_tokens: int = 2560 max_num_seqs: int = 256 disable_log_stats: bool = False + use_cuda_graph: Optional[bool] = False def __post_init__(self): if self.tokenizer is None: @@ -84,6 +86,11 @@ def add_cli_args( 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') # Parallel arguments + parser.add_argument('--max-model-len', + type=int, + default=None, + help='model context length. If unspecified, ' + 'will be automatically derived from the model.') parser.add_argument('--worker-use-ray', action='store_true', help='use Ray for distributed serving, will be ' @@ -130,6 +137,10 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + parser.add_argument('--use-cuda-graph', + action="store_true", + default=False, + help="Whether to use CUDA graph for inference") return parser @classmethod @@ -148,7 +159,7 @@ def create_engine_configs( self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.use_np_weights, self.use_dummy_weights, self.dtype, - self.seed) + self.seed, self.max_model_len, self.use_cuda_graph) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 908d01d959fd8..6c88d1f90675e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -118,8 +118,8 @@ def _init_workers(self, distributed_init_method: str): # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") + #assert self.parallel_config.world_size == 1, ( + # "Ray is required if parallel_config.world_size > 1.") self.workers: List[Worker] = [] worker = Worker( @@ -223,6 +223,21 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def warm_up_cuda_graph(self) -> None: + """Compiles a CUDA graph with batch size of max sequence num.""" + if not self.model_config.use_cuda_graph: + return + + for i in range(self.scheduler_config.max_num_seqs): + self.add_request(str(i), "a", + SamplingParams(temperature=0.0, max_tokens=2)) + + self.step() + self.step() + + for i in range(self.scheduler_config.max_num_seqs): + self.abort_request(str(i)) + def add_request( self, request_id: str, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c9ab685255038..640e6e54db1fb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -49,6 +49,7 @@ def __init__( tensor_parallel_size: int = 1, dtype: str = "auto", seed: int = 0, + use_cuda_graph: Optional[bool] = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -61,6 +62,7 @@ def __init__( tensor_parallel_size=tensor_parallel_size, dtype=dtype, seed=seed, + use_cuda_graph=use_cuda_graph, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 6ffb2ad150144..e0179e7197cb1 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import torch #from xformers.ops import AttentionBias @@ -29,6 +29,7 @@ def __init__( context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, + use_cuda_graph: Optional[bool] = False, ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data @@ -37,6 +38,7 @@ def __init__( self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables + self.use_cuda_graph = use_cuda_graph self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) @@ -67,4 +69,5 @@ def __repr__(self) -> str: f'max_context_len={self.max_context_len}), ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' f'block_tables={self.block_tables}), ' - f'slot_mapping={self.slot_mapping}') + f'slot_mapping={self.slot_mapping})' + f"use_cuda_graph={self.use_cuda_graph})") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c59d4d58fc7ba..ba12638d24e7c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,6 +27,9 @@ """ from typing import Dict, List, Optional, Tuple +import math +import time + import torch from torch import nn from transformers import LlamaConfig @@ -264,9 +267,11 @@ def forward( input_metadata, cache_event, ) + #hidden_states = residual + hidden_states #hidden_states = self.norm(hidden_states) hidden_states, residual = self.norm(residual,hidden_states) + return hidden_states @@ -283,6 +288,85 @@ def __init__(self, config): gather_output=False, perform_initialization=False) self.sampler = Sampler(config.vocab_size) + self._cuda_graph: Dict[int, torch.cuda.CUDAGraph] = {} + self._compiled_tensors: Dict[int, Tuple[torch.Tensor, + torch.Tensor, ], ] = {} + self._compiled_logits: Dict[int, torch.Tensor] = {} + self._compiled_input_metadata: Dict[int, InputMetadata] = {} + self._forward_time = 0 + + def _compile_for_batch_size( + self, + batch_size: int, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ): + max_batch_size = max(self._cuda_graph.keys(), default=None) + pool = (None if batch_size == max_batch_size or max_batch_size is None + else self._cuda_graph[max_batch_size].pool() + ) # reusing memory pool + print(">>> Entering _compile_for_batch_size") + self._cuda_graph[batch_size] = torch.cuda.CUDAGraph() + + # The following fields are used in model forward pass + # input_metadata.block_tables, # shape[1] hardcoded to model_config.max_model_len + # input_metadata.context_lens, + # input_metadata.slot_mapping, + # input_metadata.max_context_len, # hardcoded to model_config.max_model_len + + self._compiled_input_metadata[batch_size] = InputMetadata( + input_metadata.seq_groups, + input_metadata.seq_data, + input_metadata.prompt_lens, + input_metadata.slot_mapping[:batch_size].clone(), + input_metadata.context_lens[:batch_size].clone(), + input_metadata.max_context_len, + input_metadata.block_tables[:batch_size].clone(), + input_metadata.use_cuda_graph, + ) + + self._compiled_tensors[batch_size] = tuple([ + input_ids[:batch_size].clone(), + positions[:batch_size].clone(), + ]) + + with torch.cuda.graph(self._cuda_graph[batch_size], pool=pool): + self._compiled_logits[batch_size] = self.model.forward( + *self._compiled_tensors[batch_size], + kv_caches=kv_caches, + input_metadata=self._compiled_input_metadata[batch_size], + cache_events=None, + ) + print(">>> Exiting _compile_for_batch_size") + + def compile_and_call_model( + self, + batch_size: int, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + if batch_size not in self._cuda_graph: + self._compile_for_batch_size(batch_size, input_ids, positions, + kv_caches, input_metadata, + cache_events) + torch.cuda.synchronize() + self._compiled_tensors[batch_size][0].copy_(input_ids) + self._compiled_tensors[batch_size][1].copy_(positions) + self._compiled_input_metadata[batch_size].block_tables.copy_( + input_metadata.block_tables) + self._compiled_input_metadata[batch_size].context_lens.copy_( + input_metadata.context_lens) + self._compiled_input_metadata[batch_size].slot_mapping.copy_( + input_metadata.slot_mapping) + self._cuda_graph[batch_size].replay() + + return self._compiled_logits[batch_size] def forward( self, @@ -292,10 +376,29 @@ def forward( input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> Dict[int, SequenceOutputs]: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + global forward_time + start = time.time() + batch_size = input_metadata.block_tables.shape[0] + if input_metadata.num_prompt_tokens > 0: + forward_time = 0 + if input_metadata.num_prompt_tokens > 0 or not input_metadata.use_cuda_graph: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata, cache_events) + else: + # TODO: support cache_events + assert cache_events is None, "cache_events not supported yet" + hidden_states = self.compile_and_call_model( + batch_size, input_ids, positions, kv_caches, input_metadata, + cache_events) + torch.cuda.synchronize() + forward_time += time.time() - start + if input_metadata.num_prompt_tokens > 0: + self._sample_time = 0 + + start = time.time() next_tokens = self.sampler(self.lm_head.weight, hidden_states, input_metadata) + self._sample_time += time.time() - start return next_tokens _column_parallel_weights = [ diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9175da444bed3..a8bdc7c3a95d3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -232,11 +232,31 @@ def _prepare_inputs( positions_tensor = torch.cuda.LongTensor(input_positions) slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) context_lens_tensor = torch.cuda.IntTensor(context_lens) - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) - for block_table in generation_block_tables - ] - block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) + # padded_block_tables = [ + # _pad_to_max(block_table, max_num_blocks_per_seq) + # for block_table in generation_block_tables + # ] + if self.model_config.use_cuda_graph: + # Make size of block_tables and max_context_len static + # TODO: how to dynamically change shared memory size for kernels with different max_context_len? + block_tables_tensor = torch.zeros( + (len(generation_block_tables), + self.model_config.max_model_len), + dtype=torch.int, + device="cuda") + for i, block_table in enumerate(generation_block_tables): + tensor = torch.tensor(block_table) + block_tables_tensor[i, :len(block_table)] = tensor + + max_context_len = self.model_config.max_model_len + else: + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) + for block_table in generation_block_tables + ] + block_tables_tensor = torch.tensor(padded_block_tables, + dtype=torch.int, + device="cuda") seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: @@ -250,6 +270,7 @@ def _prepare_inputs( context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, + use_cuda_graph=self.model_config.use_cuda_graph, ) return tokens_tensor, positions_tensor, input_metadata @@ -323,8 +344,9 @@ def _init_distributed_environment( torch.distributed.init_process_group( backend="nccl", world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, + #rank=rank, + #init_method=distributed_init_method, + init_method="env://", #distributed_init_method, ) # pg_options=options, #)