From c81081fece240736e30e0f9a5ed82bb5b483c561 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 19:32:55 +0800 Subject: [PATCH] [torch.compile] transparent compilation with more logging (#12246) Signed-off-by: youkaichao --- vllm/compilation/backends.py | 32 +++++++++++++++++++++++++------- vllm/compilation/decorators.py | 2 ++ vllm/compilation/wrapper.py | 22 ++++++++++++++++++++++ vllm/config.py | 1 + 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 955c25f300512..b9f96c00284b9 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -524,6 +524,7 @@ def configure_post_pass(self): def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + vllm_config = self.vllm_config if not self.compilation_config.cache_dir: # no provided cache dir, generate one based on the known factors # that affects the compilation. if none of the factors change, @@ -532,7 +533,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # 1. factors come from the vllm_config (it mainly summarizes how the # model is created) - vllm_config = self.vllm_config config_hash = vllm_config.compute_hash() # 2. factors come from the code files that are traced by Dynamo ( @@ -556,20 +556,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: hash_key = hashlib.md5( f"{config_hash}_{code_hash}".encode()).hexdigest()[:10] cache_dir = os.path.join( - envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, - f"rank_{vllm_config.parallel_config.rank}") - else: - cache_dir = self.compilation_config.cache_dir + envs.VLLM_CACHE_ROOT, + "torch_compile_cache", + hash_key, + ) + self.compilation_config.cache_dir = cache_dir + + cache_dir = self.compilation_config.cache_dir os.makedirs(cache_dir, exist_ok=True) + local_cache_dir = os.path.join( + cache_dir, f"rank_{vllm_config.parallel_config.rank}") + self.compilation_config.local_cache_dir = local_cache_dir disabled = envs.VLLM_DISABLE_COMPILE_CACHE self.inductor_hash_cache: InductorHashCache = InductorHashCache( - cache_dir, disabled=disabled) + local_cache_dir, disabled=disabled) if disabled: logger.info("vLLM's torch.compile cache is disabled.") else: logger.info("Using cache directory: %s for vLLM's torch.compile", - cache_dir) + local_cache_dir) # when dynamo calls the backend, it means the bytecode # transform and analysis are done @@ -609,6 +615,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.vllm_config, self.graph_pool, self).run(*example_inputs) + graph_path = os.path.join(local_cache_dir, "computation_graph.py") + if not os.path.exists(graph_path): + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # use `print_readable` because it can include submodules + src = "from __future__ import annotations\nimport torch\n" + \ + self.split_gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + logger.debug("Computation graph saved to %s", graph_path) + self._called = True if not self.compilation_config.use_cudagraph or \ diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 38f284794b8db..17eb0592ced6d 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -198,6 +198,8 @@ def __call__(self, *args, **kwargs): f" {dims} for argument {k} with type {type(arg)}.") # here, it is the starting point of the `torch.compile` process start_monitoring_torch_compile(self.vllm_config) + logger.debug("Start compiling function %s", + self.original_code_object) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index e3260a10c02ae..58a8fa76f6ce2 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -9,6 +9,9 @@ import vllm.envs as envs from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger(__name__) class TorchCompileWrapperWithCustomDispatcher: @@ -82,6 +85,25 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): return self.compiled_codes.append(new_code) + local_cache_dir = self.vllm_config.compilation_config.local_cache_dir + if isinstance(local_cache_dir, str): + decompiled_file = os.path.join(local_cache_dir, + "transformed_code.py") + if not os.path.exists(decompiled_file): + try: + # usually the decompilation will succeed for most models, + # as we guarantee a full-graph compilation in Dynamo. + # but there's no 100% guarantee, since decompliation is + # not a reversible process. + import depyf + src = depyf.decompile(new_code) + with open(decompiled_file, "w") as f: + f.write(src) + + logger.debug("Dynamo transformed code saved to %s", + decompiled_file) + except Exception: + pass if self.vllm_config.compilation_config.use_cudagraph and \ "update" in new_code.co_names: diff --git a/vllm/config.py b/vllm/config.py index b0a92b2e21343..b8628db4d2b80 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2785,6 +2785,7 @@ def model_post_init(self, __context: Any) -> None: compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr max_capture_size: int = PrivateAttr + local_cache_dir: str = PrivateAttr # local cache dir for each rank # optimization: # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. # since we know all keys are in a range [0, max_capture_size],