From fee4a6b1c155b8ada489dc39740d51502df2a68b Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Thu, 8 Feb 2024 21:17:25 +0000 Subject: [PATCH] Properly track aborts and add a test for that --- src/device/mixed_precision_reduce_scatter.h | 18 ++++++++++++++++-- tests/test_rs.cu | 16 ++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/device/mixed_precision_reduce_scatter.h b/src/device/mixed_precision_reduce_scatter.h index 5c54579c3..e06daa566 100644 --- a/src/device/mixed_precision_reduce_scatter.h +++ b/src/device/mixed_precision_reduce_scatter.h @@ -204,6 +204,7 @@ class MixedPrecisionReduceScatterPrims { RoleWaitSend = 0x08, RolePostSend = 0x10, RolePostRecv = 0x20, + Aborted = 0x40, OffsFifoEnabled = 0x80, SizesFifoEnabled = 0x100, ThreadsSynced = 0x800; @@ -235,7 +236,8 @@ class MixedPrecisionReduceScatterPrims { waitPeer(sliceSize); subBarrier(); - // For simplicity, don't track aborts here. + // `prims_simple.h` tries to avoid doing unnecessary reduceCopy() if we are already aborted, + // but we don't really mind doing some extra work. auto& group = ncclShmem.groups[0]; reduceCopyMixedPrecision< ncclCollUnroll(), RedOp, T, TInput, @@ -340,9 +342,21 @@ class MixedPrecisionReduceScatterPrims { __device__ __forceinline__ void waitPeer(int nelts) { const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; if (flags & (Recv * RoleWaitRecv | Send * RoleWaitSend)) { + int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { - // For simplicity, don't track aborts here. connStepCache = ld_volatile_global(connStepPtr); + // Check for kernel abort. + spins++; + if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { + if (*ncclShmem.comm.abortFlag) { + flags |= Aborted; + ncclShmem.aborted = 1; + } + spins = 0; + } + if (flags & Aborted) { + break; + } } if (isSendNotRecv && (flags & SizesFifoEnabled)) { diff --git a/tests/test_rs.cu b/tests/test_rs.cu index 2918384d0..5c937cef1 100644 --- a/tests/test_rs.cu +++ b/tests/test_rs.cu @@ -1,5 +1,6 @@ // nvcc -O2 -I /usr/local/mpi/include -I ~/nccl/build/include -L /usr/local/mpi/lib -L ~/nccl/build/lib -gencode=arch=compute_80,code=sm_80 -o test_rs test_rs.cu -lnccl -lmpi #include +#include #include #include #include @@ -47,7 +48,7 @@ constexpr size_t WARMUP_ITERS = 10; constexpr size_t ITERS = 100; std::vector<__nv_bfloat16> DoReduceScatter(__nv_bfloat16* sendbuff, size_t elemCount, size_t recvCount, ncclComm_t comm, cudaStream_t stream, - ReduceScatterMode mode, int rank, float* elapsedMs) { + ReduceScatterMode mode, int rank, float* elapsedMs, std::optional abortMode) { __nv_bfloat16* recvbuff = sendbuff + rank * recvCount; float* tmpbuff = {}; @@ -93,6 +94,12 @@ std::vector<__nv_bfloat16> DoReduceScatter(__nv_bfloat16* sendbuff, size_t elemC CudaOrDie(cudaGetLastError(), "fp32 -> bf16 conversion"); } + if (abortMode.has_value() && mode == *abortMode) { + NcclOrDie(ncclCommAbort(comm), "abort NCCL communicator"); + std::cerr << "Aborted NCCL communicator on rank " << rank << std::endl; + exit(42); + } + CudaOrDie(cudaStreamSynchronize(stream), "synchronize GPU"); std::vector<__nv_bfloat16> result(recvCount); @@ -161,6 +168,11 @@ int main(int argc, char** argv) { CudaOrDie(cudaStreamCreate(&stream), "create GPU stream"); CudaOrDie(cudaMalloc(&sendbuff, elemCount * sizeof(__nv_bfloat16)), "allocate GPU buffer"); + std::optional abortMode; + if (const auto* abortModeEnv = getenv("ABORT_MODE")) { + abortMode = static_cast(std::stoi(abortModeEnv)); + } + for (size_t rawMode = 0; rawMode < static_cast(ReduceScatterMode::ModeCount); ++rawMode) { const auto mode = static_cast(rawMode); @@ -173,7 +185,7 @@ int main(int argc, char** argv) { cudaMemcpy(sendbuff, gradients.data(), elemCount * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice), "copy to GPU" ); - const auto result = DoReduceScatter(sendbuff, elemCount, recvCount, comm, stream, mode, rank, &elapsedMs); + const auto result = DoReduceScatter(sendbuff, elemCount, recvCount, comm, stream, mode, rank, &elapsedMs, abortMode); if (iter < WARMUP_ITERS) { continue;