Skip to content

Commit

Permalink
Kernel-level profiling draft (see NVIDIA#1210 (comment))
Browse files Browse the repository at this point in the history
  • Loading branch information
dfyz committed May 21, 2024
1 parent fee4a6b commit 8990132
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/device/all_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ namespace {
template<typename T, typename RedOp, typename Proto>
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
const int tid = threadIdx.x;

if (tid == 0) {
ncclShmem.kernelType = 1;
}
__syncthreads();

const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
const int nChannels = args->nChannels;
Expand Down
14 changes: 13 additions & 1 deletion src/device/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
#include "op128.h"
#include "network/unpack/unpack_defs.h"

#include <cuda/std/chrono>

__device__ __forceinline__ uint64_t getDeviceTimeNs() {
return cuda::std::chrono::high_resolution_clock::now().time_since_epoch().count();
}

#define COLL_UNROLL (ncclCollUnroll())

typedef void(*ncclDevFuncPtr_t)();
Expand All @@ -38,6 +44,8 @@ struct ncclShmemData {
alignas(16) union {
unpackShmem unpack;
} devicePlugin;
alignas(16) uint64_t startTiming;
alignas(16) int kernelType;
};
static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "shmem.work needs to be 16B aligned");

Expand Down Expand Up @@ -162,7 +170,11 @@ __device__ void ncclKernelMain(struct ncclDevComm* comm, uint64_t channelMask, s
__syncthreads(); // publish ncclShmem.channelId
int channelId = ncclShmem.channelId;
/* set abort flag to 0 */
if (tid == 0) ncclShmem.aborted = 0;
if (tid == 0) {
ncclShmem.aborted = 0;
ncclShmem.startTiming = getDeviceTimeNs();
ncclShmem.kernelType = 0;
}

if (true) {
void *dst, *src;
Expand Down
36 changes: 36 additions & 0 deletions src/device/mixed_precision_reduce_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ class MixedPrecisionReduceScatterPrims {
postPeer<Send, Recv>(sliceSize > 0);
offset += sliceSize;
++slice;
if (shouldSaveTimings()) {
timingIndex += TIMINGS_COUNT;
}
} while (slice < SlicePerChunk && offset < nelem);
}

Expand All @@ -265,6 +268,9 @@ class MixedPrecisionReduceScatterPrims {
postPeer<Send, Recv>(sliceSize > 0);
offset += sliceSize;
++slice;
if (shouldSaveTimings()) {
timingIndex += TIMINGS_COUNT;
}
}
}

Expand All @@ -282,6 +288,7 @@ class MixedPrecisionReduceScatterPrims {
, stepSize(stepSize)
, nranks(nranks)
, size(size)
, timingIndex(-1)
{
constexpr int ThreadsPerSync = 8;

Expand All @@ -291,16 +298,20 @@ class MixedPrecisionReduceScatterPrims {
// Preserve the indexes from `prims_simple.h` for simplicity
if (tid == 0) {
flags |= RoleWaitRecv;
timingIndex = 1;
} else if (tid == 1) {
flags |= RoleInput;
} else if (tid == ThreadsPerSync) {
flags |= RoleWaitSend;
timingIndex = 1;
} else if (tid == ThreadsPerSync + 1) {
flags |= RoleOutput;
} else if (tid == nthreads - 2 * ThreadsPerSync) {
flags |= RolePostRecv;
timingIndex = 1;
} else if (tid == nthreads - ThreadsPerSync) {
flags |= RolePostSend;
timingIndex = 1;
}

loadConn(ring->prev, RolePostRecv, RoleWaitRecv, true /*isRecv*/);
Expand All @@ -323,12 +334,19 @@ class MixedPrecisionReduceScatterPrims {
if (flags & (RolePostSend | RolePostRecv)) {
auto& group = ncclShmem.groups[0];
((flags & RolePostSend) ? group.sendConns : group.recvConns)[0]->step = step;
if (flags & RolePostSend && shouldSaveTimings()) {
ncclShmem.channel.stepTimings[0] = (2ULL << 32) | min(timingIndex, 1 + MAXSTEPTIMINGS*TIMINGS_COUNT);
}
}

barrier();
}

private:
__forceinline__ __device__ bool shouldSaveTimings() {
return ncclShmem.channel.stepTimings != nullptr && timingIndex != -1;
}

__device__ void barrier() {
flags |= ThreadsSynced;
asm volatile("bar.sync %0, %1;" :: "r"(15), "r"(nthreads) : "memory");
Expand Down Expand Up @@ -359,6 +377,15 @@ class MixedPrecisionReduceScatterPrims {
}
}

if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
const int isSend = (flags & RoleWaitSend) != 0;
ncclShmem.channel.stepTimings[timingIndex + isSend] = step;
if ((Recv && !Send) ? (flags & RoleWaitRecv) : (flags & RoleWaitSend)) {
ncclShmem.channel.stepTimings[timingIndex + 2] = nelts*sizeof(T);
}
ncclShmem.channel.stepTimings[timingIndex + isSend + 3] = getDeviceTimeNs() - ncclShmem.startTiming;
}

if (isSendNotRecv && (flags & SizesFifoEnabled)) {
connSizesFifoPtr[step % NCCL_STEPS] = nelts * sizeof(T);
}
Expand All @@ -382,11 +409,18 @@ class MixedPrecisionReduceScatterPrims {
template <int Send, int Recv>
__device__ void postPeer(bool dataStored) {
if (flags & (Recv * RolePostRecv | Send * RolePostSend)) {
const int isSend = (flags & RolePostSend) != 0;
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT && isSend) {
ncclShmem.channel.stepTimings[timingIndex + 7] = getDeviceTimeNs() - ncclShmem.startTiming;
}
step += StepPerSlice;
if (Send && (flags & RolePostSend) && dataStored) {
fence_acq_rel_sys();
}
st_relaxed_sys_global(connStepPtr, step);
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
ncclShmem.channel.stepTimings[timingIndex + isSend + 5] = getDeviceTimeNs() - ncclShmem.startTiming;
}
}
}

Expand Down Expand Up @@ -452,6 +486,8 @@ class MixedPrecisionReduceScatterPrims {
int volatile *connSizesFifoPtr;
uint64_t *connStepPtr;
uint64_t connStepCache;

int timingIndex;
};


Expand Down
65 changes: 59 additions & 6 deletions src/device/prims_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Primitives<
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
void* mhandle;
void* netDeviceHandle;
int timingIndex;

// Don't use barrier 0 as it's used by the final sync
__device__ void barrier() {
Expand Down Expand Up @@ -125,6 +126,10 @@ class Primitives<
return ld_volatile_global(ptr);
}

__forceinline__ __device__ bool shouldSaveTimings() {
return ncclShmem.kernelType != 0 && ncclShmem.channel.stepTimings != nullptr && timingIndex != -1;
}

template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
__device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) {
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
Expand All @@ -141,8 +146,18 @@ class Primitives<
}

if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
if (isSendNotRecv && (flags & SizesFifoEnabled))
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
const int isSend = (flags & RoleWaitSend) != 0;
ncclShmem.channel.stepTimings[timingIndex + isSend] = step;
if ((Recv && !Send) ? (flags & RoleWaitRecv) : (flags & RoleWaitSend)) {
ncclShmem.channel.stepTimings[timingIndex + 2] = nelts*sizeof(T);
}
ncclShmem.channel.stepTimings[timingIndex + isSend + 3] = getDeviceTimeNs() - ncclShmem.startTiming;
}

if (isSendNotRecv && (flags & SizesFifoEnabled)) {
connSizesFifoPtr[step%NCCL_STEPS] = nelts*sizeof(T);
}

void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst)
: (ncclShmem.groups[group].srcs + Src);
Expand Down Expand Up @@ -178,9 +193,16 @@ class Primitives<
template<int Recv, int Send>
inline __device__ void postPeer(bool dataStored) {
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
const int isSend = (flags & RolePostSend) != 0;
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT && isSend) {
ncclShmem.channel.stepTimings[timingIndex + 7] = getDeviceTimeNs() - ncclShmem.startTiming;
}
step += StepPerSlice;
if (Send && (flags & RolePostSend) && dataStored) fence_acq_rel_sys();
st_relaxed_sys_global(connStepPtr, step);
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
ncclShmem.channel.stepTimings[timingIndex + isSend + 5] = getDeviceTimeNs() - ncclShmem.startTiming;
}
}
}

Expand Down Expand Up @@ -280,6 +302,9 @@ class Primitives<
postPeer<Recv, Send>(0 < sliceSize);
offset += sliceSize;
slice += 1;
if (shouldSaveTimings()) {
timingIndex += TIMINGS_COUNT;
}
} while (slice < SlicePerChunk && offset < nelem);
}

Expand All @@ -298,6 +323,9 @@ class Primitives<
postPeer<Recv, Send>(0 < sliceSize);
offset += sliceSize;
slice += 1;
if (shouldSaveTimings()) {
timingIndex += TIMINGS_COUNT;
}
}
}

Expand Down Expand Up @@ -471,7 +499,8 @@ class Primitives<
uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr, int stepSize_=0
):
tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group),
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_) {
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_),
timingIndex(-1) {

// For send operations, we need an extra warp to overlap the threadfence and the copy
this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0);
Expand All @@ -489,15 +518,35 @@ class Primitives<
index = tid % ThreadPerSync;
flags = 0;
if (g == 0) {
if (index < nrecv) flags |= RoleWaitRecv;
if (index < nrecv) {
flags |= RoleWaitRecv;
if (index == 0) {
timingIndex = 1;
}
}
if (index == nrecv) flags |= RoleInput;
} else if (g == 1) {
if (index < nsend) flags |= RoleWaitSend;
if (index < nsend) {
flags |= RoleWaitSend;
if (index == 0) {
timingIndex = 1;
}
}
if (index == nsend) flags |= RoleOutput;
} else if (g == ng - 2) {
if (index < nrecv) flags |= RolePostRecv;
if (index < nrecv) {
flags |= RolePostRecv;
if (index == 0) {
timingIndex = 1;
}
}
} else if (g == ng - 1) {
if (index < nsend) flags |= RolePostSend;
if (index < nsend) {
flags |= RolePostSend;
if (index == 0) {
timingIndex = 1;
}
}
}

int peer = 0;
Expand Down Expand Up @@ -532,6 +581,10 @@ class Primitives<
if (flags & (RolePostSend|RolePostRecv)) {
auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns;
conns[index]->step = step;

if (flags & RolePostSend && shouldSaveTimings()) {
ncclShmem.channel.stepTimings[0] = (static_cast<uint64_t>(ncclShmem.kernelType) << 32) | min(timingIndex, 1 + MAXSTEPTIMINGS*TIMINGS_COUNT);
}
}

if ((flags & (AnyNetDeviceUnpack)) && (flags & (RoleWaitRecv))) {
Expand Down
6 changes: 6 additions & 0 deletions src/device/reduce_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ namespace {
template<typename T, typename RedOp, typename Proto>
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
const int tid = threadIdx.x;

if (tid == 0) {
ncclShmem.kernelType = 2;
}
__syncthreads();

const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
const int nChannels = args->nChannels;
Expand Down
Loading

0 comments on commit 8990132

Please sign in to comment.