Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] transparent compilation with more logging #12246

Merged
merged 8 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def configure_post_pass(self):

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

vllm_config = self.vllm_config
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
Expand All @@ -532,7 +533,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config = self.vllm_config
config_hash = vllm_config.compute_hash()

# 2. factors come from the code files that are traced by Dynamo (
Expand All @@ -556,20 +556,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
else:
cache_dir = self.compilation_config.cache_dir
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
hash_key,
)
self.compilation_config.cache_dir = cache_dir

cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
local_cache_dir = os.path.join(
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
self.compilation_config.local_cache_dir = local_cache_dir

disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
cache_dir, disabled=disabled)
local_cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
cache_dir)
local_cache_dir)

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
Expand Down Expand Up @@ -609,6 +615,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.vllm_config, self.graph_pool,
self).run(*example_inputs)

graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
# use `print_readable` because it can include submodules
src = "from __future__ import annotations\nimport torch\n" + \
self.split_gm.print_readable(print_output=False)
src = src.replace("<lambda>", "GraphModule")
with open(graph_path, "w") as f:
f.write(src)

logger.info("Computation graph saved to %s", graph_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not need logger.info here, as most users do not need to be aware of this step. The file structure in local_cache_dir can be explained in the document.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion! moved to debug instead.


self._called = True

if not self.compilation_config.use_cudagraph or \
Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def __call__(self, *args, **kwargs):
f" {dims} for argument {k} with type {type(arg)}.")
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
logger.info("Start compiling function %s",
self.original_code_object)

# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
Expand Down
20 changes: 20 additions & 0 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.logger import init_logger

logger = init_logger(__name__)


class TorchCompileWrapperWithCustomDispatcher:
Expand Down Expand Up @@ -82,6 +85,23 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
return

self.compiled_codes.append(new_code)
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
decompiled_file = os.path.join(local_cache_dir, "transformed_code.py")
if not os.path.exists(decompiled_file):
try:
# usually the decompilation will succeed for most models, as
# we guarantee a full-graph compilation in Dynamo.
# but there's no 100% guarantee, since decompliation is not a
# reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)

logger.info("Dynamo transformed code saved to %s",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in fd753d1

decompiled_file)
except Exception:
pass

if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,7 @@ def model_post_init(self, __context: Any) -> None:
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
Expand Down
Loading