Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support fully transparent sleep mode #11743

Merged
merged 79 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
814095e
add code
youkaichao Jan 4, 2025
d559772
add basic tests
youkaichao Jan 5, 2025
5189a29
add basic tests
youkaichao Jan 5, 2025
d6c1bb9
fix tests
youkaichao Jan 5, 2025
d00a99f
fix tests
youkaichao Jan 5, 2025
e18b239
add cudagraph tests
youkaichao Jan 5, 2025
31bc20e
add test code
youkaichao Jan 5, 2025
69262bb
enable sleeping mode for user
youkaichao Jan 5, 2025
88bec78
add end to end experiments
youkaichao Jan 5, 2025
c3d845e
fix
youkaichao Jan 5, 2025
921b848
update
youkaichao Jan 5, 2025
09d624c
add tests
youkaichao Jan 5, 2025
59fbf5c
pin version
youkaichao Jan 5, 2025
39b6fa5
avoid interference
youkaichao Jan 5, 2025
8245114
Merge branch 'main' into cumem
youkaichao Jan 18, 2025
1c3fed0
update
youkaichao Jan 18, 2025
c5b207d
update
youkaichao Jan 18, 2025
873d853
add in executor base
youkaichao Jan 18, 2025
1e5798a
reduce diff
youkaichao Jan 18, 2025
2719274
format
youkaichao Jan 18, 2025
20fbbc3
also support v1
youkaichao Jan 18, 2025
20d6876
fix linter
youkaichao Jan 18, 2025
d8c9874
add csrc
youkaichao Jan 18, 2025
f2539d3
try to update cmake
youkaichao Jan 18, 2025
ae3ddd9
try to update cmake
youkaichao Jan 18, 2025
2f16a8a
try to update cmake
youkaichao Jan 18, 2025
0384305
full extern c
youkaichao Jan 18, 2025
c928912
fix iostream
youkaichao Jan 18, 2025
2287a4f
use cxx
youkaichao Jan 18, 2025
d707238
use abi
youkaichao Jan 18, 2025
25ba886
use abi
youkaichao Jan 18, 2025
ac1beff
use abi
youkaichao Jan 18, 2025
7146925
add so to precompiled list
youkaichao Jan 18, 2025
962ee15
port files
youkaichao Jan 18, 2025
6b783e0
fix dependency
youkaichao Jan 18, 2025
1ff95be
add libs
youkaichao Jan 18, 2025
be845df
fix stream
youkaichao Jan 18, 2025
ae8c52e
comment
youkaichao Jan 18, 2025
de75d23
add comments
youkaichao Jan 18, 2025
b6f227c
update links
youkaichao Jan 19, 2025
d426272
consider rocm
youkaichao Jan 19, 2025
8d37273
use tag
youkaichao Jan 19, 2025
44cf2db
update tests
youkaichao Jan 19, 2025
f9d3983
cmake comments
youkaichao Jan 19, 2025
6b23d17
rename to sleep mode
youkaichao Jan 19, 2025
1b33768
msg
youkaichao Jan 19, 2025
1ad644f
fix
youkaichao Jan 19, 2025
397630e
remove tests
youkaichao Jan 19, 2025
5936493
fix
youkaichao Jan 19, 2025
2827e02
comment
youkaichao Jan 19, 2025
94c4e83
add logging
youkaichao Jan 19, 2025
4c28f27
Merge branch 'main' into cumem
youkaichao Jan 19, 2025
6f48e8a
fix initialize_cache
youkaichao Jan 19, 2025
66ff900
fix load_model
youkaichao Jan 19, 2025
7bee39d
fix
youkaichao Jan 19, 2025
7763332
fix comments
youkaichao Jan 19, 2025
a45dcd9
use ModuleNotFoundError
youkaichao Jan 20, 2025
60a2f50
fix get_current_usage
youkaichao Jan 20, 2025
1d7edcf
add doc string for functions
youkaichao Jan 20, 2025
d172402
add comments
youkaichao Jan 20, 2025
95da432
tuple of str
youkaichao Jan 20, 2025
9c94517
comments
youkaichao Jan 20, 2025
f4cc888
Merge branch 'main' into cumem
youkaichao Jan 20, 2025
4d6177a
fix?
youkaichao Jan 20, 2025
a1c5634
fix?
youkaichao Jan 20, 2025
13b2213
Merge branch 'main' into cumem
youkaichao Jan 20, 2025
b371bf3
disable level 2 with prefix caching
youkaichao Jan 20, 2025
5ff5423
Merge branch 'main' into cumem
youkaichao Jan 22, 2025
cbdbcea
use ValueError
youkaichao Jan 22, 2025
4388f0f
doc string for sleep
youkaichao Jan 22, 2025
d1991e5
polish type assert
youkaichao Jan 22, 2025
daf3169
doc string for use_memory_pool
youkaichao Jan 22, 2025
23ee3ad
polish type assert
youkaichao Jan 22, 2025
a61f473
docstring for sleep
youkaichao Jan 22, 2025
a626d63
error for prefix caching
youkaichao Jan 22, 2025
7414e0c
format
youkaichao Jan 22, 2025
d378a08
format
youkaichao Jan 22, 2025
900c257
use found_line
youkaichao Jan 22, 2025
53bce8a
robust tests
youkaichao Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ steps:
- tests/basic_correctness/test_basic_correctness
- tests/basic_correctness/test_cpu_offload
- tests/basic_correctness/test_preemption
- tests/basic_correctness/test_cumem.py
commands:
- pytest -v -s basic_correctness/test_cumem.py
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
Expand Down
25 changes: 25 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
# Define other extension targets
#

#
# cumem_allocator extension
#

set(VLLM_CUMEM_EXT_SRC
"csrc/cumem_allocator.cpp")

set_gencode_flags_for_srcs(
SRCS "${VLLM_CUMEM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")

if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling cumem allocator extension.")
# link against cuda driver library
list(APPEND CUMEM_LIBS cuda)
define_gpu_extension_target(
cumem_allocator
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_CUMEM_EXT_SRC}
LIBRARIES ${CUMEM_LIBS}
USE_SABI 3.8
WITH_SOABI)
endif()

#
# _C extension
#
Expand Down
310 changes: 310 additions & 0 deletions csrc/cumem_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
// A CUDAPluggableAllocator based on cumem* APIs.
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle*
// need to be unsigned long long
#include <iostream>

extern "C" {

#define PY_SSIZE_T_CLEAN
#include <Python.h>

#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>

#define CUDA_CHECK(condition) \
do { \
CUresult error = condition; \
if (error != 0) { \
char* error_string; \
cuGetErrorString(error, (const char**)&error_string); \
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl; \
} \
} while (0)

// Global references to Python callables
// NOTE: this is borrowed reference, so we don't need to DECREF them.
// This brings the limitation that the allocator needs to be singleton.
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;

// ---------------------------------------------------------------------------
// Helper functions:

void ensure_context(unsigned long long device) {
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
// Ensure device context.
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
}

void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
CUmemGenericAllocationHandle* p_memHandle) {
ensure_context(device);
// Define memory allocation properties
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;

// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));

CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;

CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
}

void unmap_and_release(unsigned long long device, ssize_t size,
CUdeviceptr d_mem,
CUmemGenericAllocationHandle* p_memHandle) {
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
CUDA_CHECK(cuMemUnmap(d_mem, size));
CUDA_CHECK(cuMemRelease(*p_memHandle));
}

PyObject* create_tuple_from_c_integers(unsigned long long a,
unsigned long long b,
unsigned long long c,
unsigned long long d) {
// Create a new tuple of size 4
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL; // Return NULL on failure
}

// Convert integers to Python objects and set them in the tuple
PyTuple_SetItem(
tuple, 0,
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));

// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.

return tuple; // Return the created tuple
}

// ---------------------------------------------------------------------------
// Our exported C functions that call Python:

// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
void* my_malloc(ssize_t size, int device, CUstream stream) {
ensure_context(device);

// first allocation, align the size, and reserve an address, and also allocate
// a CUmemGenericAllocationHandle

// Define memory allocation properties
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;

// Check if the allocation is supported
size_t granularity;
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
CU_MEM_ALLOC_GRANULARITY_MINIMUM));

size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;

CUdeviceptr d_mem;
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));

// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));

if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}

// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();

PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);

// Call g_python_malloc_callback
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);

if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}

PyGILState_Release(gstate);

// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);

return (void*)d_mem;
}

// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "ERROR: g_python_free_callback not set.\n";
return;
}

// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();

PyObject* py_ptr =
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));

PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);

if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}

unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}

PyGILState_Release(gstate);

// recv_size == size
// recv_device == device

// Free memory

CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);

// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, size));
free(p_memHandle);
}

// ---------------------------------------------------------------------------
// Python extension boilerplate:

// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args) {
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;

if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}

if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}

// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;

Py_RETURN_NONE;
}

static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}

unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}

CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;

unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);

Py_RETURN_NONE;
}

static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}

unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}

CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;

create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);

Py_RETURN_NONE;
}

static PyMethodDef module_methods[] = {
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
"Initialize module with python_malloc and python_free callables."},
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
"Create and map memory on the device."},
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
METH_VARARGS, "Unmap and release memory on the device."},
{NULL, NULL, 0, NULL} // sentinel
};

static struct PyModuleDef cumem_allocator_module = {
PyModuleDef_HEAD_INIT, "cumem_allocator",
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods};

PyMODINIT_FUNC PyInit_cumem_allocator(void) {
// Initialize the module
PyObject* module = PyModule_Create(&cumem_allocator_module);
if (!module) {
return NULL;
}
return module;
}
} // extern "C"
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def run(self) -> None:
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/vllm_flash_attn/__init__.py",
"vllm/cumem_allocator.abi3.so",
# "vllm/_version.py", # not available in nightly wheels yet
]
file_members = filter(lambda x: x.filename in files_to_copy,
Expand Down Expand Up @@ -594,6 +595,7 @@ def _read_requirements(filename: str) -> List[str]:
if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))

if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
Expand Down
Loading
Loading