diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index ccb9f68124b5e..a2646aaa7493b 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -359,8 +359,11 @@ CUDAGraph::~CUDAGraph() { // hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2. // We need to ensure all async opreations finish before deleting the object. #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) - AT_CUDA_CHECK(cudaSetDevice(capture_dev_)); - AT_CUDA_CHECK(cudaDeviceSynchronize()); + if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id + { + AT_CUDA_CHECK(cudaSetDevice(capture_dev_)); + AT_CUDA_CHECK(cudaDeviceSynchronize()); + } #endif } diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index 804067560a6ea..dbeb976692881 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -75,7 +75,9 @@ struct TORCH_CUDA_CPP_API CUDAGraph { // in a capture to run on the same device, but this is a limitation of CUDAGraph, // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device // captures if needed. - int capture_dev_; + // init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor + static constexpr int UNDEFINED_DEVICE = -1; + int capture_dev_ = UNDEFINED_DEVICE; // RNG state trackers at::Tensor seed_extragraph_;