Skip to content

Commit

Permalink
Remove the dependency on CUDA driver (#7224)
Browse files Browse the repository at this point in the history
* Remove cuda dependency

* Remove unused include

* Add comments
  • Loading branch information
krishung5 authored May 15, 2024
1 parent e787476 commit 58d3396
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ if(${TRITON_ENABLE_GPU})
main
PRIVATE
CUDA::cudart
-lcuda
)
endif() # TRITON_ENABLE_GPU

Expand Down
32 changes: 30 additions & 2 deletions src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,42 @@ OpenCudaIPCRegion(
return nullptr;
}

// Using `cudaGetDriverEntryPoint` from CUDA runtime API to get CUDA driver
// entry point. This approach is used to avoid linking against CUDA driver
// library so that when Triton is built with GPU support, it can still be run on
// CPU-only environments.
TRITONSERVER_Error*
GetCudaDriverEntryPoint(const char* name, void** func_ptr)
{
cudaError_t err = cudaGetDriverEntryPoint(name, func_ptr, cudaEnableDefault);
if (err != cudaSuccess) {
LOG_ERROR << "Failed to get CUDA driver entry point for " << name << ": "
<< cudaGetErrorString(err);
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string("Failed to get CUDA driver entry point").c_str());
}
return nullptr;
}

TRITONSERVER_Error*
GetCudaSharedMemoryRegionSize(CUdeviceptr data_ptr, size_t& shm_region_size)
{
void* cu_mem_get_address_range = nullptr;
void* cu_get_error_string = nullptr;
RETURN_IF_ERR(GetCudaDriverEntryPoint(
"cuMemGetAddressRange", &cu_mem_get_address_range));
RETURN_IF_ERR(
GetCudaDriverEntryPoint("cuGetErrorString", &cu_get_error_string));

CUdeviceptr* base = nullptr;
CUresult result = cuMemGetAddressRange(base, &shm_region_size, data_ptr);
CUresult result = ((
CUresult(*)(CUdeviceptr*, size_t*, CUdeviceptr))cu_mem_get_address_range)(
base, &shm_region_size, data_ptr);
if (result != CUDA_SUCCESS) {
const char* errorString;
if (cuGetErrorString(result, &errorString) != CUDA_SUCCESS) {
if (((CUresult(*)(CUresult, const char**))cu_get_error_string)(
result, &errorString) != CUDA_SUCCESS) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA error string");
}
Expand Down

0 comments on commit 58d3396

Please sign in to comment.