Skip to content

Commit

Permalink
[torch.compile] transparent compilation with more logging (#12246)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 21, 2025
1 parent a94eee4 commit c81081f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
32 changes: 25 additions & 7 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def configure_post_pass(self):

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

vllm_config = self.vllm_config
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
Expand All @@ -532,7 +533,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config = self.vllm_config
config_hash = vllm_config.compute_hash()

# 2. factors come from the code files that are traced by Dynamo (
Expand All @@ -556,20 +556,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
else:
cache_dir = self.compilation_config.cache_dir
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
hash_key,
)
self.compilation_config.cache_dir = cache_dir

cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
local_cache_dir = os.path.join(
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
self.compilation_config.local_cache_dir = local_cache_dir

disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
cache_dir, disabled=disabled)
local_cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
cache_dir)
local_cache_dir)

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
Expand Down Expand Up @@ -609,6 +615,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.vllm_config, self.graph_pool,
self).run(*example_inputs)

graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
# use `print_readable` because it can include submodules
src = "from __future__ import annotations\nimport torch\n" + \
self.split_gm.print_readable(print_output=False)
src = src.replace("<lambda>", "GraphModule")
with open(graph_path, "w") as f:
f.write(src)

logger.debug("Computation graph saved to %s", graph_path)

self._called = True

if not self.compilation_config.use_cudagraph or \
Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def __call__(self, *args, **kwargs):
f" {dims} for argument {k} with type {type(arg)}.")
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
logger.debug("Start compiling function %s",
self.original_code_object)

# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
Expand Down
22 changes: 22 additions & 0 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.logger import init_logger

logger = init_logger(__name__)


class TorchCompileWrapperWithCustomDispatcher:
Expand Down Expand Up @@ -82,6 +85,25 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
return

self.compiled_codes.append(new_code)
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
if isinstance(local_cache_dir, str):
decompiled_file = os.path.join(local_cache_dir,
"transformed_code.py")
if not os.path.exists(decompiled_file):
try:
# usually the decompilation will succeed for most models,
# as we guarantee a full-graph compilation in Dynamo.
# but there's no 100% guarantee, since decompliation is
# not a reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)

logger.debug("Dynamo transformed code saved to %s",
decompiled_file)
except Exception:
pass

if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,7 @@ def model_post_init(self, __context: Any) -> None:
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
Expand Down

0 comments on commit c81081f

Please sign in to comment.