-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Support fully transparent sleep mode (#11743)
Signed-off-by: youkaichao <[email protected]>
- Loading branch information
1 parent
4004f14
commit 68ad4e3
Showing
14 changed files
with
877 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
// A CUDAPluggableAllocator based on cumem* APIs. | ||
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* | ||
// need to be unsigned long long | ||
#include <iostream> | ||
|
||
extern "C" { | ||
|
||
#define PY_SSIZE_T_CLEAN | ||
#include <Python.h> | ||
|
||
#include <sys/types.h> | ||
#include <cuda_runtime_api.h> | ||
#include <cuda.h> | ||
|
||
#define CUDA_CHECK(condition) \ | ||
do { \ | ||
CUresult error = condition; \ | ||
if (error != 0) { \ | ||
char* error_string; \ | ||
cuGetErrorString(error, (const char**)&error_string); \ | ||
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ | ||
<< __LINE__ << std::endl; \ | ||
} \ | ||
} while (0) | ||
|
||
// Global references to Python callables | ||
// NOTE: this is borrowed reference, so we don't need to DECREF them. | ||
// This brings the limitation that the allocator needs to be singleton. | ||
static PyObject* g_python_malloc_callback = nullptr; | ||
static PyObject* g_python_free_callback = nullptr; | ||
|
||
// --------------------------------------------------------------------------- | ||
// Helper functions: | ||
|
||
void ensure_context(unsigned long long device) { | ||
CUcontext pctx; | ||
CUDA_CHECK(cuCtxGetCurrent(&pctx)); | ||
if (!pctx) { | ||
// Ensure device context. | ||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); | ||
CUDA_CHECK(cuCtxSetCurrent(pctx)); | ||
} | ||
} | ||
|
||
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, | ||
CUmemGenericAllocationHandle* p_memHandle) { | ||
ensure_context(device); | ||
// Define memory allocation properties | ||
CUmemAllocationProp prop = {}; | ||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; | ||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
prop.location.id = device; | ||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; | ||
|
||
// Allocate memory using cuMemCreate | ||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); | ||
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); | ||
|
||
CUmemAccessDesc accessDesc = {}; | ||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
accessDesc.location.id = device; | ||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; | ||
|
||
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); | ||
// std::cout << "create_and_map: device=" << device << ", size=" << size << ", | ||
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; | ||
} | ||
|
||
void unmap_and_release(unsigned long long device, ssize_t size, | ||
CUdeviceptr d_mem, | ||
CUmemGenericAllocationHandle* p_memHandle) { | ||
// std::cout << "unmap_and_release: device=" << device << ", size=" << size << | ||
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; | ||
ensure_context(device); | ||
CUDA_CHECK(cuMemUnmap(d_mem, size)); | ||
CUDA_CHECK(cuMemRelease(*p_memHandle)); | ||
} | ||
|
||
PyObject* create_tuple_from_c_integers(unsigned long long a, | ||
unsigned long long b, | ||
unsigned long long c, | ||
unsigned long long d) { | ||
// Create a new tuple of size 4 | ||
PyObject* tuple = PyTuple_New(4); | ||
if (!tuple) { | ||
return NULL; // Return NULL on failure | ||
} | ||
|
||
// Convert integers to Python objects and set them in the tuple | ||
PyTuple_SetItem( | ||
tuple, 0, | ||
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong | ||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); | ||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); | ||
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); | ||
|
||
// Note: PyTuple_SetItem "steals" a reference to each object, | ||
// so we do not need to Py_DECREF the PyLong objects explicitly. | ||
|
||
return tuple; // Return the created tuple | ||
} | ||
|
||
// --------------------------------------------------------------------------- | ||
// Our exported C functions that call Python: | ||
|
||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h | ||
void* my_malloc(ssize_t size, int device, CUstream stream) { | ||
ensure_context(device); | ||
|
||
// first allocation, align the size, and reserve an address, and also allocate | ||
// a CUmemGenericAllocationHandle | ||
|
||
// Define memory allocation properties | ||
CUmemAllocationProp prop = {}; | ||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; | ||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
prop.location.id = device; | ||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; | ||
|
||
// Check if the allocation is supported | ||
size_t granularity; | ||
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, | ||
CU_MEM_ALLOC_GRANULARITY_MINIMUM)); | ||
|
||
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; | ||
|
||
CUdeviceptr d_mem; | ||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); | ||
|
||
// allocate the CUmemGenericAllocationHandle | ||
CUmemGenericAllocationHandle* p_memHandle = | ||
(CUmemGenericAllocationHandle*)malloc( | ||
sizeof(CUmemGenericAllocationHandle)); | ||
|
||
if (!g_python_malloc_callback) { | ||
std::cerr << "ERROR: g_python_malloc_callback not set.\n"; | ||
return nullptr; | ||
} | ||
|
||
// Acquire GIL (not in stable ABI officially, but often works) | ||
PyGILState_STATE gstate = PyGILState_Ensure(); | ||
|
||
PyObject* arg_tuple = create_tuple_from_c_integers( | ||
(unsigned long long)device, (unsigned long long)alignedSize, | ||
(unsigned long long)d_mem, (unsigned long long)p_memHandle); | ||
|
||
// Call g_python_malloc_callback | ||
PyObject* py_result = | ||
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); | ||
Py_DECREF(arg_tuple); | ||
|
||
if (!py_result) { | ||
PyErr_Print(); | ||
PyGILState_Release(gstate); | ||
return nullptr; | ||
} | ||
|
||
PyGILState_Release(gstate); | ||
|
||
// do the final mapping | ||
create_and_map(device, alignedSize, d_mem, p_memHandle); | ||
|
||
return (void*)d_mem; | ||
} | ||
|
||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h | ||
void my_free(void* ptr, ssize_t size, int device, CUstream stream) { | ||
// get memory handle from the pointer | ||
if (!g_python_free_callback) { | ||
std::cerr << "ERROR: g_python_free_callback not set.\n"; | ||
return; | ||
} | ||
|
||
// Acquire GIL (not in stable ABI officially, but often works) | ||
PyGILState_STATE gstate = PyGILState_Ensure(); | ||
|
||
PyObject* py_ptr = | ||
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr)); | ||
|
||
PyObject* py_result = | ||
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); | ||
|
||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { | ||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); | ||
return; | ||
} | ||
|
||
unsigned long long recv_device, recv_size; | ||
unsigned long long recv_d_mem, recv_p_memHandle; | ||
// Unpack the tuple into four C integers | ||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, | ||
&recv_d_mem, &recv_p_memHandle)) { | ||
// PyArg_ParseTuple sets an error if it fails | ||
return; | ||
} | ||
|
||
PyGILState_Release(gstate); | ||
|
||
// recv_size == size | ||
// recv_device == device | ||
|
||
// Free memory | ||
|
||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; | ||
CUmemGenericAllocationHandle* p_memHandle = | ||
(CUmemGenericAllocationHandle*)recv_p_memHandle; | ||
unmap_and_release(device, size, d_mem, p_memHandle); | ||
|
||
// free address and the handle | ||
CUDA_CHECK(cuMemAddressFree(d_mem, size)); | ||
free(p_memHandle); | ||
} | ||
|
||
// --------------------------------------------------------------------------- | ||
// Python extension boilerplate: | ||
|
||
// Python-exposed function: init_module(python_malloc, python_free) | ||
static PyObject* py_init_module(PyObject* self, PyObject* args) { | ||
PyObject* malloc_callback = nullptr; | ||
PyObject* free_callback = nullptr; | ||
|
||
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { | ||
return nullptr; | ||
} | ||
|
||
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { | ||
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); | ||
return nullptr; | ||
} | ||
|
||
// Save the Python callables | ||
// This module does not handle GC of these objects, so they must be kept alive | ||
// outside of this module. | ||
g_python_malloc_callback = malloc_callback; | ||
g_python_free_callback = free_callback; | ||
|
||
Py_RETURN_NONE; | ||
} | ||
|
||
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { | ||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { | ||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); | ||
return nullptr; | ||
} | ||
|
||
unsigned long long recv_device, recv_size; | ||
unsigned long long recv_d_mem, recv_p_memHandle; | ||
// Unpack the tuple into four C integers | ||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, | ||
&recv_p_memHandle)) { | ||
// PyArg_ParseTuple sets an error if it fails | ||
return nullptr; | ||
} | ||
|
||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; | ||
CUmemGenericAllocationHandle* p_memHandle = | ||
(CUmemGenericAllocationHandle*)recv_p_memHandle; | ||
|
||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); | ||
|
||
Py_RETURN_NONE; | ||
} | ||
|
||
static PyObject* python_create_and_map(PyObject* self, PyObject* args) { | ||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { | ||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); | ||
return nullptr; | ||
} | ||
|
||
unsigned long long recv_device, recv_size; | ||
unsigned long long recv_d_mem, recv_p_memHandle; | ||
// Unpack the tuple into four C integers | ||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, | ||
&recv_p_memHandle)) { | ||
// PyArg_ParseTuple sets an error if it fails | ||
return nullptr; | ||
} | ||
|
||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; | ||
CUmemGenericAllocationHandle* p_memHandle = | ||
(CUmemGenericAllocationHandle*)recv_p_memHandle; | ||
|
||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); | ||
|
||
Py_RETURN_NONE; | ||
} | ||
|
||
static PyMethodDef module_methods[] = { | ||
{"init_module", (PyCFunction)py_init_module, METH_VARARGS, | ||
"Initialize module with python_malloc and python_free callables."}, | ||
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, | ||
"Create and map memory on the device."}, | ||
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release, | ||
METH_VARARGS, "Unmap and release memory on the device."}, | ||
{NULL, NULL, 0, NULL} // sentinel | ||
}; | ||
|
||
static struct PyModuleDef cumem_allocator_module = { | ||
PyModuleDef_HEAD_INIT, "cumem_allocator", | ||
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods}; | ||
|
||
PyMODINIT_FUNC PyInit_cumem_allocator(void) { | ||
// Initialize the module | ||
PyObject* module = PyModule_Create(&cumem_allocator_module); | ||
if (!module) { | ||
return NULL; | ||
} | ||
return module; | ||
} | ||
} // extern "C" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.