Skip to content

Commit

Permalink
[V1] Fix Compilation config & Enable CUDA graph by default (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#10528)

Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Nov 21, 2024
1 parent 7560ae5 commit f9310cb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 42 deletions.
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ def __post_init__(self):

if self.compilation_config is None:
self.compilation_config = CompilationConfig()
if envs.VLLM_USE_V1:
if envs.VLLM_USE_V1 and not self.model_config.enforce_eager:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
Expand All @@ -2380,6 +2380,7 @@ def __post_init__(self):
self.compilation_config.use_inductor = True
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

current_platform.check_and_update_config(self)

Expand Down
62 changes: 34 additions & 28 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
Expand Down Expand Up @@ -515,7 +516,25 @@ def load_model(self) -> None:
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))

def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
@torch.inference_mode()
def _dummy_run(
self,
model: nn.Module,
num_tokens: int,
kv_caches: List[torch.Tensor],
) -> torch.Tensor:
with set_forward_context(None):
hidden_states = model(
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens])
return hidden_states

def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
Expand All @@ -527,23 +546,17 @@ def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(input_ids=None,
positions=self.positions,
kv_caches=dummy_kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds)

@torch.inference_mode()
def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
self._dummy_run(self.model, self.max_num_tokens)
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
gc.collect()

@torch.inference_mode()
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
Expand All @@ -554,18 +567,11 @@ def capture_model(self) -> None:
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]

with set_forward_context(None):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens],
)
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self._dummy_run(self.model, num_tokens, self.kv_caches)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
Expand Down
39 changes: 26 additions & 13 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,35 +105,48 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

_, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()

free_gpu_memory, _ = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")

# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = _get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
# if self.model_runner.lora_manager:
# self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, 0

def initialize_cache(self, num_gpu_blocks: int) -> None:
Expand Down

0 comments on commit f9310cb

Please sign in to comment.