Skip to content

Commit

Permalink
Properly track aborts and add a test for that
Browse files Browse the repository at this point in the history
  • Loading branch information
dfyz committed May 21, 2024
1 parent eadad54 commit fee4a6b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
18 changes: 16 additions & 2 deletions src/device/mixed_precision_reduce_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class MixedPrecisionReduceScatterPrims {
RoleWaitSend = 0x08,
RolePostSend = 0x10,
RolePostRecv = 0x20,
Aborted = 0x40,
OffsFifoEnabled = 0x80,
SizesFifoEnabled = 0x100,
ThreadsSynced = 0x800;
Expand Down Expand Up @@ -235,7 +236,8 @@ class MixedPrecisionReduceScatterPrims {
waitPeer<Send, Recv>(sliceSize);
subBarrier();

// For simplicity, don't track aborts here.
// `prims_simple.h` tries to avoid doing unnecessary reduceCopy() if we are already aborted,
// but we don't really mind doing some extra work.
auto& group = ncclShmem.groups[0];
reduceCopyMixedPrecision<
ncclCollUnroll(), RedOp, T, TInput,
Expand Down Expand Up @@ -340,9 +342,21 @@ class MixedPrecisionReduceScatterPrims {
__device__ __forceinline__ void waitPeer(int nelts) {
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
if (flags & (Recv * RoleWaitRecv | Send * RoleWaitSend)) {
int spins = 0;
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
// For simplicity, don't track aborts here.
connStepCache = ld_volatile_global(connStepPtr);
// Check for kernel abort.
spins++;
if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
if (*ncclShmem.comm.abortFlag) {
flags |= Aborted;
ncclShmem.aborted = 1;
}
spins = 0;
}
if (flags & Aborted) {
break;
}
}

if (isSendNotRecv && (flags & SizesFifoEnabled)) {
Expand Down
16 changes: 14 additions & 2 deletions tests/test_rs.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// nvcc -O2 -I /usr/local/mpi/include -I ~/nccl/build/include -L /usr/local/mpi/lib -L ~/nccl/build/lib -gencode=arch=compute_80,code=sm_80 -o test_rs test_rs.cu -lnccl -lmpi
#include <iostream>
#include <optional>
#include <random>
#include <string>
#include <vector>
Expand Down Expand Up @@ -47,7 +48,7 @@ constexpr size_t WARMUP_ITERS = 10;
constexpr size_t ITERS = 100;

std::vector<__nv_bfloat16> DoReduceScatter(__nv_bfloat16* sendbuff, size_t elemCount, size_t recvCount, ncclComm_t comm, cudaStream_t stream,
ReduceScatterMode mode, int rank, float* elapsedMs) {
ReduceScatterMode mode, int rank, float* elapsedMs, std::optional<ReduceScatterMode> abortMode) {
__nv_bfloat16* recvbuff = sendbuff + rank * recvCount;
float* tmpbuff = {};

Expand Down Expand Up @@ -93,6 +94,12 @@ std::vector<__nv_bfloat16> DoReduceScatter(__nv_bfloat16* sendbuff, size_t elemC
CudaOrDie(cudaGetLastError(), "fp32 -> bf16 conversion");
}

if (abortMode.has_value() && mode == *abortMode) {
NcclOrDie(ncclCommAbort(comm), "abort NCCL communicator");
std::cerr << "Aborted NCCL communicator on rank " << rank << std::endl;
exit(42);
}

CudaOrDie(cudaStreamSynchronize(stream), "synchronize GPU");

std::vector<__nv_bfloat16> result(recvCount);
Expand Down Expand Up @@ -161,6 +168,11 @@ int main(int argc, char** argv) {
CudaOrDie(cudaStreamCreate(&stream), "create GPU stream");
CudaOrDie(cudaMalloc(&sendbuff, elemCount * sizeof(__nv_bfloat16)), "allocate GPU buffer");

std::optional<ReduceScatterMode> abortMode;
if (const auto* abortModeEnv = getenv("ABORT_MODE")) {
abortMode = static_cast<ReduceScatterMode>(std::stoi(abortModeEnv));
}

for (size_t rawMode = 0; rawMode < static_cast<size_t>(ReduceScatterMode::ModeCount); ++rawMode) {
const auto mode = static_cast<ReduceScatterMode>(rawMode);

Expand All @@ -173,7 +185,7 @@ int main(int argc, char** argv) {
cudaMemcpy(sendbuff, gradients.data(), elemCount * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice),
"copy to GPU"
);
const auto result = DoReduceScatter(sendbuff, elemCount, recvCount, comm, stream, mode, rank, &elapsedMs);
const auto result = DoReduceScatter(sendbuff, elemCount, recvCount, comm, stream, mode, rank, &elapsedMs, abortMode);

if (iter < WARMUP_ITERS) {
continue;
Expand Down

0 comments on commit fee4a6b

Please sign in to comment.