diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index e3e5a7d0fc5a5..9c65059c6b348 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -1,6 +1,10 @@ """ Test the piecewise compilation with a simple model, comparing the output with and without the piecewise compilation. + +This is a tractable model, the weights and computation are specially designed +if the config `tractable_init` is set to True. Otherwise, the weights are +initialized randomly with a fixed seed. """ import os from dataclasses import dataclass @@ -49,6 +53,12 @@ class LlamaConfig: mlp_size: int = 256 vocab_size: int = 128 num_layers: int = 2 + init_value: float = 1.0 + tractable_init: bool = False + random_seed: int = 0 + + def __post_init__(self): + assert self.mlp_size >= self.hidden_size class LlamaMLP(nn.Module): @@ -66,10 +76,23 @@ def __init__(self, config: LlamaConfig) -> None: bias=False, ) - self.gate_up_projection.weight.data.fill_(0.0) - self.down_projection.weight.data.fill_(0.0) + if config.tractable_init: + nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size]) + nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:]) + nn.init.eye_(self.down_projection.weight.data) + else: + nn.init.xavier_normal_(self.gate_up_projection.weight.data, + generator=torch.Generator().manual_seed( + config.random_seed), + gain=0.001) + nn.init.xavier_normal_(self.down_projection.weight.data, + generator=torch.Generator().manual_seed( + config.random_seed), + gain=0.001) def forward(self, x): + # for tractable_init and positive input, this is + # essentially an elementwise-square x = self.gate_up_projection(x) x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( x[:, x.size(1) // 2:]) @@ -84,21 +107,39 @@ def __init__(self, config: LlamaConfig) -> None: self.qkv_projection = nn.Linear( in_features=config.hidden_size, out_features=config.hidden_size * 3, + bias=False, ) self.output_projection = nn.Linear( in_features=config.hidden_size, out_features=config.hidden_size, + bias=False, ) - self.qkv_projection.weight.data.fill_(0.0) - self.output_projection.weight.data.fill_(0.0) + if config.tractable_init: + nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size]) + nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 * + config.hidden_size]) + nn.init.eye_(self.qkv_projection.weight.data[2 * + config.hidden_size:]) + nn.init.eye_(self.output_projection.weight.data) + else: + nn.init.xavier_normal_(self.qkv_projection.weight.data, + generator=torch.Generator().manual_seed( + config.random_seed), + gain=0.001) + nn.init.xavier_normal_(self.output_projection.weight.data, + generator=torch.Generator().manual_seed( + config.random_seed), + gain=0.001) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: + # for tractable_init, this is: + # output = (hidden_states * 3 + positions * 2) qkv = self.qkv_projection(hidden_states) hidden_size = qkv.size(-1) // 3 q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1) @@ -126,20 +167,29 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + For tractable computation: + - if residual is None, the outputs are: + - residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3 + - hidden_states = (residual + 1) ** 2 + - if residual is not None, the outputs are: + - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3 + - hidden_states = (residual + 1) ** 2 + """ # noqa if residual is None: residual = hidden_states - hidden_states = hidden_states / 2 + hidden_states = hidden_states + 1 else: hidden_states = hidden_states + residual residual = hidden_states - hidden_states = hidden_states / 2 + hidden_states = hidden_states + 1 hidden_states = self.self_attention(positions=positions, hidden_states=hidden_states) hidden_states = hidden_states + residual residual = hidden_states - hidden_states = hidden_states / 2 + hidden_states = hidden_states + 1 hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -156,7 +206,8 @@ def __init__(self, config: LlamaConfig) -> None: self.layers = nn.ModuleList( [LlamaDecoderLayer(config) for _ in range(config.num_layers)]) - self.embedding_tokens.weight.data.fill_(0.0) + # this is the initial value of the hidden states + self.embedding_tokens.weight.data.fill_(config.init_value) def forward( self, @@ -170,6 +221,28 @@ def forward( return hidden_states +def tractable_computation(input_ids: torch.Tensor, + positions: torch.Tensor, + config: LlamaConfig, + init_value: float = 1.0) -> torch.Tensor: + hidden_states = torch.ones(input_ids.size(0), + config.hidden_size, + device=input_ids.device, + dtype=input_ids.dtype) * init_value + + # first layer + residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 + hidden_states = (residual + 1)**2 + + # following layers + for _ in range(config.num_layers - 1): + hidden_states = hidden_states + residual + residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 + hidden_states = (residual + 1)**2 + + return hidden_states + + @torch.inference_mode def run_model(llama_config, use_compile: bool, @@ -213,7 +286,15 @@ def run_model(llama_config, del os.environ["VLLM_TORCH_COMPILE_LEVEL"] set_compilation_config(None) - return output.cpu() + output = output.cpu() + + if llama_config.tractable_init: + expected_output = tractable_computation(input_ids[:2], positions[:2], + llama_config).cpu() + + assert torch.allclose(output, expected_output) + else: + return output.cpu() def test_toy_llama(): @@ -222,7 +303,13 @@ def test_toy_llama(): llama_config = LlamaConfig(hidden_size=128, mlp_size=256, vocab_size=128, - num_layers=2) + num_layers=12) + + tractable_config = LlamaConfig(hidden_size=128, + mlp_size=256, + vocab_size=128, + num_layers=2, + tractable_init=True) outputs = [] with compilation_counter.expect( @@ -233,6 +320,8 @@ def test_toy_llama(): num_cudagraph_caputured=0, ): outputs.append(run_model(llama_config, use_compile=False)) + run_model(tractable_config, use_compile=False) + with compilation_counter.expect( num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=1, @@ -242,6 +331,7 @@ def test_toy_llama(): 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append(run_model(llama_config, use_compile=True)) + run_model(tractable_config, use_compile=True) with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -257,6 +347,7 @@ def test_toy_llama(): ): outputs.append( run_model(llama_config, use_compile=True, split_attn=True)) + run_model(tractable_config, use_compile=True, split_attn=True) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 96ddcba467c5b..de32cabbe6d07 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -6,6 +6,7 @@ import torch import torch.fx as fx +import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import weak_ref_tensors @@ -193,6 +194,7 @@ def wrap_inductor(graph, @dataclasses.dataclass class SplitItem: submod_name: str + graph_id: int is_splitting_graph: bool graph: fx.GraphModule @@ -226,9 +228,7 @@ def split_graph(graph: fx.GraphModule, outputs = [] - # sort the names to make sure the order is deterministic names = [name for (name, module) in split_gm.named_modules()] - names.sort() for name in names: if "." in name or name == "": @@ -238,7 +238,11 @@ def split_graph(graph: fx.GraphModule, module = getattr(split_gm, name) graph_id = int(name.replace("submod_", "")) - outputs.append(SplitItem(name, graph_id in split_op_graphs, module)) + outputs.append( + SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) return split_gm, outputs @@ -252,6 +256,11 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): It runs the given graph with fake inputs, and compile some submodules specified by `compile_submod_names` with the given compilation configs. + + NOTE: the order in `compile_submod_names` matters, because + it will be used to determine the order of the compiled piecewise + graphs. The first graph will handle logging, and the last graph + has some special cudagraph output handling. """ def __init__(self, module: torch.fx.GraphModule, @@ -263,7 +272,6 @@ def __init__(self, module: torch.fx.GraphModule, self.compile_submod_names = compile_submod_names self.compilation_configs = compilation_configs self.graph_pool = graph_pool - self.have_seen_first_graph = False def run(self, *args): fake_args = [ @@ -279,6 +287,7 @@ def call_module(self, target: torch.fx.node.Target, output = super().call_module(target, args, kwargs) if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) submod = self.fetch_attr(target) sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) @@ -288,15 +297,14 @@ def call_module(self, target: torch.fx.node.Target, args, self.compilation_configs.inductor_compile_config, runtime_shape=None, - do_logging=not self.have_seen_first_graph, + do_logging=index == 0, use_inductor=self.compilation_configs.use_inductor) self.module.__dict__[target] = PiecewiseBackend( - submod, self.compilation_configs, self.graph_pool, - not self.have_seen_first_graph, sym_shape_indices, + submod, self.compilation_configs, self.graph_pool, index, + len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape) - self.have_seen_first_graph = True compilation_counter.num_piecewise_capturable_graphs_seen += 1 return output @@ -352,8 +360,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: graph, self.compilation_configs.non_cudagraph_ops) from torch._dynamo.utils import lazy_format_graph_code - logger.debug("%s", - lazy_format_graph_code("stiching module", self.split_gm)) + logger.debug("%s", lazy_format_graph_code("before split", self.graph)) + logger.debug("%s", lazy_format_graph_code("after split", + self.split_gm)) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) @@ -385,12 +394,17 @@ class ConcreteSizeEntry: cudagraph: Optional[torch.cuda.CUDAGraph] = None output: Optional[Any] = None + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[List[int]] = None + class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, compilation_configs: CompilationConfig, graph_pool: Any, - is_first_graph: bool, sym_shape_indices: List[int], + piecewise_compile_index: int, total_piecewise_compiles: int, + sym_shape_indices: List[int], compiled_graph_for_general_shape: Callable): """ The backend for piecewise compilation. @@ -408,7 +422,12 @@ def __init__(self, graph: fx.GraphModule, self.graph = graph self.compilation_configs = compilation_configs self.graph_pool = graph_pool - self.is_first_graph = is_first_graph + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) self.compile_sizes: Set[int] = set( self.compilation_configs.compile_sizes) @@ -422,6 +441,8 @@ def __init__(self, graph: fx.GraphModule, self.sym_shape_indices = sym_shape_indices + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + # the entries for different shapes that we need to either # compile or capture cudagraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} @@ -476,14 +497,45 @@ def __call__(self, *args) -> Any: logger.info("Capturing a cudagraph for shape %s", runtime_shape) + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses cudagraph = torch.cuda.CUDAGraph() + + # mind-exploding: carefully manage the reference and memory. with torch.cuda.graph(cudagraph, pool=self.graph_pool): - entry.output = weak_ref_tensors(entry.runnable(*args)) + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph compilation_counter.num_cudagraph_caputured += 1 - entry.cudagraph = cudagraph - return entry.output + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) entry.cudagraph.replay() return entry.output