Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support fully transparent sleep mode #11743

Merged
merged 79 commits into from
Jan 22, 2025

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Jan 5, 2025

there is a strong need to put vllm in sleep mode (offload weight, discard kv cache), from rlhf community. see #10714 and #11638

I have implemented the core functionality in https://github.com/vllm-project/vllm_allocator_adaptor , and add the integration in this PR.

Currently, we support:

  • offload model weights
  • discard kv cache

NOTE: we do nothing for cudagraph memory right now.

Because the underlying cumem API is very low-level, this PR is also compatible with cudagraph.

With this PR, when sleeping mode is enabled:

the current vLLM instance can use total_gpu_memory (79.22GiB) x gpu_memory_utilization (0.90) = 71.29GiB
model weights take 2.32GiB; non_torch_memory takes 0.67GiB; PyTorch activation peak memory takes 1.20GiB; the rest of the memory reserved for KV Cache is 67.11GiB.

Free memory before sleep: 8.34 GiB
Free memory after sleep: 78.29 GiB
(70GiB memory released, which is the sum of model weights and KV Cache size)

Why free memory after sleep is not 80 GiB? Because cudagraph memory pool is not released, and there are some other cost such as cuda context.

TODO:

  • move https://github.com/vllm-project/vllm_allocator_adaptor into vllm repo (need someone familiar with CMake. Note that I don't use any pytorch ops in that file, so binding should be easy)
  • expose the interface in API server, similar to profiling endpoints (can be a follow-up PR, and we need to discuss how to expose these endpoints. It should live under some dev-mode, with explicit opt-in)
  • drop existing block manager's content when prefix-caching is enabled (after a second thought, I think this is out of the scope of the current PR. but we do need a way to expose this API interface)
  • add the feature in V1
  • check if any features are incompatible with this sleeping mode. I guess it should just work for cuda platforms, but need to check. If that is true, we can enable it by default. (NOTE: I find some bugs about PyTorch's pluggable allocator, see empty_cache does not work for CUDAPluggableAllocator + MemPool pytorch/pytorch#145168 . Therefore we should not enable it by default.)

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

github-actions bot commented Jan 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

thanks to @cennn who helped me a lot along the way.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

this pr benefits a lot from pytorch/pytorch#131152 and pytorch/pytorch#124807

@youkaichao
Copy link
Member Author

TODO: in distributed inference, there's also NCCL memory to be considered. need to check how much memory that takes, and if we need to do anything to release that part (it might be quite difficult, as NCCL is quite black box)

@youkaichao youkaichao marked this pull request as draft January 18, 2025 12:34
Copy link

mergify bot commented Jan 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @youkaichao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 18, 2025
@mergify mergify bot removed the needs-rebase label Jan 18, 2025
@mergify mergify bot removed the needs-rebase label Jan 20, 2025
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@comaniac
Copy link
Collaborator

comaniac commented Jan 20, 2025

Did a quick pass and overall LGTM. A high level question: would there be a correctness issue with prefix caching?

If prefix caching is enabled, and we sleep with level 2 (discard the weights), there will be correctness issue. That's why I'm asking for help to reset prefix caching state.

Make sense. Will provide the reset capability to prefix caching.

btw meanwhile, please raise an error if sleep mode level 2 is used and prefix caching is enabled in this PR, so that it can be unblocked.

@youkaichao
Copy link
Member Author

btw meanwhile, please raise an error if sleep mode level 2 is used and prefix caching is enabled in this PR, so that it can be unblocked.

@comaniac fixed in b371bf3

Copy link

mergify bot commented Jan 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @youkaichao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 21, 2025
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM. Approve to unblock first but other's comments are welcome.

vllm/config.py Outdated Show resolved Hide resolved
vllm/device_allocator/cumem.py Outdated Show resolved Hide resolved
vllm/device_allocator/cumem.py Outdated Show resolved Hide resolved
vllm/device_allocator/cumem.py Outdated Show resolved Hide resolved
vllm/device_allocator/cumem.py Outdated Show resolved Hide resolved
vllm/device_allocator/cumem.py Outdated Show resolved Hide resolved
vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
vllm/executor/executor_base.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM when I skimmed through the PR. Since this PR is quite isolated, I don't have a concern in merging it.

Will take a closer look a few days later once I have more bandwidth.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

Otherwise LGTM. Approve to unblock first but other's comments are welcome.

@comaniac thanks for the very detailed comments! Looking forward to #12284 being merged after this PR.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao merged commit 68ad4e3 into vllm-project:main Jan 22, 2025
6 of 9 checks passed
@youkaichao youkaichao deleted the cumem branch January 22, 2025 06:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants