From 58d3396dd01e2dd68dd2acbf101a83e2f5e93bbd Mon Sep 17 00:00:00 2001 From: Kris Hung Date: Wed, 15 May 2024 15:01:30 -0700 Subject: [PATCH] Remove the dependency on CUDA driver (#7224) * Remove cuda dependency * Remove unused include * Add comments --- src/CMakeLists.txt | 1 - src/shared_memory_manager.cc | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 53c8add989..783275d8d7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -180,7 +180,6 @@ if(${TRITON_ENABLE_GPU}) main PRIVATE CUDA::cudart - -lcuda ) endif() # TRITON_ENABLE_GPU diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 8101a2e236..1f4a77e887 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -266,14 +266,42 @@ OpenCudaIPCRegion( return nullptr; } +// Using `cudaGetDriverEntryPoint` from CUDA runtime API to get CUDA driver +// entry point. This approach is used to avoid linking against CUDA driver +// library so that when Triton is built with GPU support, it can still be run on +// CPU-only environments. +TRITONSERVER_Error* +GetCudaDriverEntryPoint(const char* name, void** func_ptr) +{ + cudaError_t err = cudaGetDriverEntryPoint(name, func_ptr, cudaEnableDefault); + if (err != cudaSuccess) { + LOG_ERROR << "Failed to get CUDA driver entry point for " << name << ": " + << cudaGetErrorString(err); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("Failed to get CUDA driver entry point").c_str()); + } + return nullptr; +} + TRITONSERVER_Error* GetCudaSharedMemoryRegionSize(CUdeviceptr data_ptr, size_t& shm_region_size) { + void* cu_mem_get_address_range = nullptr; + void* cu_get_error_string = nullptr; + RETURN_IF_ERR(GetCudaDriverEntryPoint( + "cuMemGetAddressRange", &cu_mem_get_address_range)); + RETURN_IF_ERR( + GetCudaDriverEntryPoint("cuGetErrorString", &cu_get_error_string)); + CUdeviceptr* base = nullptr; - CUresult result = cuMemGetAddressRange(base, &shm_region_size, data_ptr); + CUresult result = (( + CUresult(*)(CUdeviceptr*, size_t*, CUdeviceptr))cu_mem_get_address_range)( + base, &shm_region_size, data_ptr); if (result != CUDA_SUCCESS) { const char* errorString; - if (cuGetErrorString(result, &errorString) != CUDA_SUCCESS) { + if (((CUresult(*)(CUresult, const char**))cu_get_error_string)( + result, &errorString) != CUDA_SUCCESS) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA error string"); }