-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] transparent compilation with more logging #12246
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
44a6d6d
log transformed bytecode
youkaichao 9e77f98
log computation graph
youkaichao 12dc138
reduce logging
youkaichao 95893aa
add start log
youkaichao c25cd1d
polish comments
youkaichao ed404d9
fix no cache dir case
youkaichao f4c434d
use debug
youkaichao fd753d1
use debug
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,9 @@ | |
|
||
import vllm.envs as envs | ||
from vllm.config import CompilationLevel, get_current_vllm_config | ||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class TorchCompileWrapperWithCustomDispatcher: | ||
|
@@ -82,6 +85,23 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): | |
return | ||
|
||
self.compiled_codes.append(new_code) | ||
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir | ||
decompiled_file = os.path.join(local_cache_dir, "transformed_code.py") | ||
if not os.path.exists(decompiled_file): | ||
try: | ||
# usually the decompilation will succeed for most models, as | ||
# we guarantee a full-graph compilation in Dynamo. | ||
# but there's no 100% guarantee, since decompliation is not a | ||
# reversible process. | ||
import depyf | ||
src = depyf.decompile(new_code) | ||
with open(decompiled_file, "w") as f: | ||
f.write(src) | ||
|
||
logger.info("Dynamo transformed code saved to %s", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in fd753d1 |
||
decompiled_file) | ||
except Exception: | ||
pass | ||
|
||
if self.vllm_config.compilation_config.use_cudagraph and \ | ||
"update" in new_code.co_names: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do not need logger.info here, as most users do not need to be aware of this step. The file structure in
local_cache_dir
can be explained in the document.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the suggestion! moved to
debug
instead.