Skip to content

Commit

Permalink
[release/2.2] cudagraph explicit sync only after capture_begin() (#1492)
Browse files Browse the repository at this point in the history
* cudagraph explicit sync only after capture_begin

* use 'capture_dev_=-1' as not initialized value

* use named constant instead of magic '-1' value
  • Loading branch information
dnikolaev-amd authored Jul 29, 2024
1 parent 23381c9 commit eb433b9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 5 additions & 2 deletions aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,11 @@ CUDAGraph::~CUDAGraph() {
// hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2.
// We need to ensure all async opreations finish before deleting the object.
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
AT_CUDA_CHECK(cudaSetDevice(capture_dev_));
AT_CUDA_CHECK(cudaDeviceSynchronize());
if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id
{
AT_CUDA_CHECK(cudaSetDevice(capture_dev_));
AT_CUDA_CHECK(cudaDeviceSynchronize());
}
#endif
}

Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/cuda/CUDAGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
// captures if needed.
int capture_dev_;
// init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor
static constexpr int UNDEFINED_DEVICE = -1;
int capture_dev_ = UNDEFINED_DEVICE;

// RNG state trackers
at::Tensor seed_extragraph_;
Expand Down

0 comments on commit eb433b9

Please sign in to comment.