diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ed8c15358830c..daec46760117d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -76,7 +76,9 @@ steps: - tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_cpu_offload - tests/basic_correctness/test_preemption + - tests/basic_correctness/test_cumem.py commands: + - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py diff --git a/CMakeLists.txt b/CMakeLists.txt index f4b9c3ec9c14f..0945905104f32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,6 +181,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") # Define other extension targets # +# +# cumem_allocator extension +# + +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling cumem allocator extension.") + # link against cuda driver library + list(APPEND CUMEM_LIBS cuda) + define_gpu_extension_target( + cumem_allocator + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_CUMEM_EXT_SRC} + LIBRARIES ${CUMEM_LIBS} + USE_SABI 3.8 + WITH_SOABI) +endif() + # # _C extension # diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp new file mode 100644 index 0000000000000..e8555d853b7ac --- /dev/null +++ b/csrc/cumem_allocator.cpp @@ -0,0 +1,310 @@ +// A CUDAPluggableAllocator based on cumem* APIs. +// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* +// need to be unsigned long long +#include + +extern "C" { + +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include + +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; \ + } \ + } while (0) + +// Global references to Python callables +// NOTE: this is borrowed reference, so we don't need to DECREF them. +// This brings the limitation that the allocator needs to be singleton. +static PyObject* g_python_malloc_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; + +// --------------------------------------------------------------------------- +// Helper functions: + +void ensure_context(unsigned long long device) { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + ensure_context(device); + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Allocate memory using cuMemCreate + CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + // std::cout << "create_and_map: device=" << device << ", size=" << size << ", + // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; +} + +void unmap_and_release(unsigned long long device, ssize_t size, + CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << + // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; + ensure_context(device); + CUDA_CHECK(cuMemUnmap(d_mem, size)); + CUDA_CHECK(cuMemRelease(*p_memHandle)); +} + +PyObject* create_tuple_from_c_integers(unsigned long long a, + unsigned long long b, + unsigned long long c, + unsigned long long d) { + // Create a new tuple of size 4 + PyObject* tuple = PyTuple_New(4); + if (!tuple) { + return NULL; // Return NULL on failure + } + + // Convert integers to Python objects and set them in the tuple + PyTuple_SetItem( + tuple, 0, + PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; // Return the created tuple +} + +// --------------------------------------------------------------------------- +// Our exported C functions that call Python: + +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h +void* my_malloc(ssize_t size, int device, CUstream stream) { + ensure_context(device); + + // first allocation, align the size, and reserve an address, and also allocate + // a CUmemGenericAllocationHandle + + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Check if the allocation is supported + size_t granularity; + CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; + + CUdeviceptr d_mem; + CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); + + // allocate the CUmemGenericAllocationHandle + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)malloc( + sizeof(CUmemGenericAllocationHandle)); + + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* arg_tuple = create_tuple_from_c_integers( + (unsigned long long)device, (unsigned long long)alignedSize, + (unsigned long long)d_mem, (unsigned long long)p_memHandle); + + // Call g_python_malloc_callback + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + + if (!py_result) { + PyErr_Print(); + PyGILState_Release(gstate); + return nullptr; + } + + PyGILState_Release(gstate); + + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; +} + +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h +void my_free(void* ptr, ssize_t size, int device, CUstream stream) { + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* py_ptr = + PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, + &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } + + PyGILState_Release(gstate); + + // recv_size == size + // recv_device == device + + // Free memory + + CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); + + // free address and the handle + CUDA_CHECK(cuMemAddressFree(d_mem, size)); + free(p_memHandle); +} + +// --------------------------------------------------------------------------- +// Python extension boilerplate: + +// Python-exposed function: init_module(python_malloc, python_free) +static PyObject* py_init_module(PyObject* self, PyObject* args) { + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; +} + +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, + "Initialize module with python_malloc and python_free callables."}, + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, + "Create and map memory on the device."}, + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, + METH_VARARGS, "Unmap and release memory on the device."}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef cumem_allocator_module = { + PyModuleDef_HEAD_INIT, "cumem_allocator", + "cumem-based allocator for CUDAPluggableAllocator", -1, module_methods}; + +PyMODINIT_FUNC PyInit_cumem_allocator(void) { + // Initialize the module + PyObject* module = PyModule_Create(&cumem_allocator_module); + if (!module) { + return NULL; + } + return module; +} +} // extern "C" diff --git a/setup.py b/setup.py index 978625a069778..8801705f26dad 100644 --- a/setup.py +++ b/setup.py @@ -301,6 +301,7 @@ def run(self) -> None: "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", "vllm/vllm_flash_attn/__init__.py", + "vllm/cumem_allocator.abi3.so", # "vllm/_version.py", # not available in nightly wheels yet ] file_members = filter(lambda x: x.filename in files_to_copy, @@ -594,6 +595,7 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c")) + ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py new file mode 100644 index 0000000000000..53f4ef08f36a2 --- /dev/null +++ b/tests/basic_correctness/test_cumem.py @@ -0,0 +1,112 @@ +import torch + +from vllm import LLM, SamplingParams +from vllm.device_allocator.cumem import CuMemAllocator +from vllm.utils import GiB_bytes + +from ..utils import fork_new_process_for_each_test + + +@fork_new_process_for_each_test +def test_basic_cumem(): + # some tensors from default memory pool + shape = (1024, 1024) + x = torch.empty(shape, device='cuda') + x.zero_() + + # some tensors from custom memory pool + allocator = CuMemAllocator.get_instance() + with allocator.use_memory_pool(): + # custom memory pool + y = torch.empty(shape, device='cuda') + y.zero_() + y += 1 + z = torch.empty(shape, device='cuda') + z.zero_() + z += 2 + + # they can be used together + output = x + y + z + assert torch.allclose(output, torch.ones_like(output) * 3) + + free_bytes = torch.cuda.mem_get_info()[0] + allocator.sleep() + free_bytes_after_sleep = torch.cuda.mem_get_info()[0] + assert free_bytes_after_sleep > free_bytes + allocator.wake_up() + + # they can be used together + output = x + y + z + assert torch.allclose(output, torch.ones_like(output) * 3) + + +@fork_new_process_for_each_test +def test_cumem_with_cudagraph(): + allocator = CuMemAllocator.get_instance() + with allocator.use_memory_pool(): + weight = torch.eye(1024, device='cuda') + with allocator.use_memory_pool(tag="discard"): + cache = torch.empty(1024, 1024, device='cuda') + + def model(x): + out = x @ weight + cache[:out.size(0)].copy_(out) + return out + 1 + + x = torch.empty(128, 1024, device='cuda') + + # warmup + model(x) + + # capture cudagraph + model_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(model_graph): + y = model(x) + + free_bytes = torch.cuda.mem_get_info()[0] + allocator.sleep() + free_bytes_after_sleep = torch.cuda.mem_get_info()[0] + assert free_bytes_after_sleep > free_bytes + allocator.wake_up() + + # after waking up, the content in the weight tensor + # should be restored, but the content in the cache tensor + # should be discarded + + # this operation is also compatible with cudagraph + + x.random_() + model_graph.replay() + + # cache content is as expected + assert torch.allclose(x, cache[:x.size(0)]) + + # output content is as expected + assert torch.allclose(y, x + 1) + + +@fork_new_process_for_each_test +def test_end_to_end(): + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage, + # which is difficult to measure in the test. therefore, we only + # test sleep level 1 here. + llm.sleep(level=1) + + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + # now the memory usage is mostly cudagraph memory pool, + # and it should be less than the model weights (1B model, 2GiB weights) + assert used_bytes < 2 * GiB_bytes + + llm.wake_up() + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/vllm/config.py b/vllm/config.py index b8628db4d2b80..69577505fc9bd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -195,40 +195,43 @@ def compute_hash(self) -> str: factors.append(self.rope_theta) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__(self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: Optional[str] = None) -> None: + def __init__( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None, + enable_sleep_mode: bool = False, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -277,6 +280,12 @@ def __init__(self, self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init + self.enable_sleep_mode = enable_sleep_mode + + from vllm.platforms import current_platform + + if self.enable_sleep_mode and not current_platform.is_cuda(): + raise ValueError("Sleep mode is only supported on CUDA devices.") hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) @@ -348,7 +357,6 @@ def __init__(self, self.is_hybrid = self._init_is_hybrid() self.has_inner_state = self._init_has_inner_state() - from vllm.platforms import current_platform if current_platform.is_neuron(): self.override_neuron_config = override_neuron_config else: diff --git a/vllm/device_allocator/__init__.py b/vllm/device_allocator/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py new file mode 100644 index 0000000000000..a43418dbb3b46 --- /dev/null +++ b/vllm/device_allocator/cumem.py @@ -0,0 +1,254 @@ +# cumem-based pytorch pluggable allocator to implement sleep mode. +# other approaches tried but failed: +# - cuda-python package binding +# - custom libcuda driver ctypes wrapper +# both of them failed because of cuda context mismatch. +# not sure why, they are created from a different context. +# the only successful approach is to call cuda driver API in C. +import dataclasses +from contextlib import contextmanager +from typing import Callable, Dict, Optional, Tuple, Union + +import torch + +from vllm.utils import is_pin_memory_available + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +cumem_available = False +try: + from vllm.cumem_allocator import (init_module, python_create_and_map, + python_unmap_and_release) + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib_name = find_loaded_library("cumem_allocator") + libcudart = CudaRTLibrary() + cumem_available = True +except ModuleNotFoundError: + # rocm platform does not support cumem allocator + init_module = None + python_create_and_map = None + python_unmap_and_release = None + CudaRTLibrary = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = Tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[int], + int], python_free_func: Callable[[int, int], + None] +) -> torch.cuda.memory.CUDAPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.cuda.memory.CUDAPluggableAllocator( + lib_name, 'my_malloc', 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[int], int], + python_free_func: Callable[[int, int], None]) -> None: + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) + with torch.cuda.memory.use_mem_pool(mem_pool): + yield mem_pool + + +class CuMemAllocator: + """ + A singleton class that manages a memory pool for CUDA tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance: "CuMemAllocator" = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CuMemAllocator": + """ + CuMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + assert cumem_available, "cumem allocator is not available" + if CuMemAllocator.instance is None: + CuMemAllocator.instance = CuMemAllocator() + return CuMemAllocator.instance + + def __init__(self): + self.pointer_to_data: Dict[int, AllocationData] = {} + self.current_tag: str = CuMemAllocator.default_tag + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) + return + + def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep( + self, + offload_tags: Optional[Union[Tuple[str, ...], + str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CuMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + + assert isinstance(offload_tags, tuple) + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + def wake_up(self): + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory.""" + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CuMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback): + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4f4c9558d056..ba58614bf8f95 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -197,6 +197,7 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None generation_config: Optional[str] = None + enable_sleep_mode: bool = False def __post_init__(self): if not self.tokenizer: @@ -955,6 +956,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "loaded from model. If set to a folder path, the generation config " "will be loaded from the specified folder path.") + parser.add_argument("--enable-sleep-mode", + action="store_true", + default=False, + help="Enable sleep mode for the engine. " + "(only cuda platform is supported)") + return parser @classmethod @@ -999,7 +1006,9 @@ def create_model_config(self) -> ModelConfig: override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, - generation_config=self.generation_config) + generation_config=self.generation_config, + enable_sleep_mode=self.enable_sleep_mode, + ) def create_load_config(self) -> LoadConfig: return LoadConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dc6c9c1778dd7..5271027d41863 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1818,6 +1818,16 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.model_executor.stop_profile() + def sleep(self, level: int = 1) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.sleep(level=level) + + def wake_up(self) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.wake_up() + def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 27386daa4bbc9..04056f37f851b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1132,6 +1132,29 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() + def sleep(self, level: int = 1): + """ + Put the engine to sleep. The engine should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. + + :param level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. The content of kv cache is + forgotten. Level 1 sleep is good for sleeping and waking up the + engine to run the same model again. The model weights are backed + up in CPU memory. Please make sure there's enough CPU memory to + store the model weights. Level 2 sleep will discard both the model + weights and the kv cache. The content of both the model weights + and kv cache is forgotten. Level 2 sleep is good for sleeping and + waking up the engine to run a different model or update the model, + where previous model weights are not needed. It reduces CPU memory + pressure. + """ + self.llm_engine.sleep(level=level) + + def wake_up(self): + self.llm_engine.wake_up() + # LEGACY def _convert_v1_inputs( self, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 859e105f15d97..069be05eafa50 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -193,6 +193,17 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.collective_rpc("stop_profile") + def sleep(self, level: int = 1): + if self.cache_config.enable_prefix_caching: + # TODO: support sleep with prefix caching + # by resetting the prefix cache state, + # after https://github.com/vllm-project/vllm/pull/12284 + raise ValueError("Cannot sleep when prefix caching is enabled.") + self.collective_rpc("sleep", kwargs=dict(level=level)) + + def wake_up(self): + self.collective_rpc("wake_up") + def save_sharded_state( self, path: str, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index bd40112aea5e8..e69b6c90454c4 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -9,12 +9,14 @@ import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.utils import GiB_bytes from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -77,6 +79,23 @@ def __init__( else: self.profiler = None + def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up() + def init_device(self): if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -110,7 +129,17 @@ def init_device(self): self.model_runner = GPUModelRunner(self.vllm_config, self.device) def load_model(self) -> None: - self.model_runner.load_model() + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag="weights") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.load_model() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -167,7 +196,14 @@ def get_kv_cache_spec(self) -> KVCacheSpec: def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - self.model_runner.initialize_kv_cache(kv_cache_config) + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 29d62ddda3dc0..c95d2a9a69ff3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, @@ -120,6 +121,23 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() + def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up() + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -151,7 +169,17 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - self.model_runner.load_model() + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag="weights") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.load_model() def save_sharded_state( self, @@ -270,7 +298,14 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self._init_cache_engine() + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self._init_cache_engine() self._warm_up_model() def _init_cache_engine(self):