Skip to content

Commit

Permalink
[Kernel] Enable custome AR on ROCm
Browse files Browse the repository at this point in the history
  • Loading branch information
wenkaidu committed Jun 10, 2024
1 parent 95b3acc commit 190a636
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 41 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

endif()

if(VLLM_GPU_LANG STREQUAL "HIP")
list(APPEND VLLM_EXT_SRC
"csrc/custom_all_reduce.cu")
endif()

define_gpu_extension_target(
_C
DESTINATION vllm
Expand Down
37 changes: 37 additions & 0 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,40 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}

#ifdef USE_ROCM

void free_meta_buffer(void *buffer) {
hipFree(buffer);
}

std::vector<uint8_t> get_meta_buffer_ipc_handle(torch::Tensor inp) {
std::vector<uint8_t> data_handle(sizeof(cudaIpcMemHandle_t), 0);
CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)data_handle.data(), inp.data_ptr()));
return data_handle;
}

torch::Tensor allocate_meta_buffer(int size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void *buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions().dtype(torch::kI8).device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}

std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(cudaDeviceGetPCIBusId((char *)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size()-1); // remove trailing NULL
return bdf;
}

#endif
53 changes: 51 additions & 2 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#else
#include <cuda_bf16.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>

Expand Down Expand Up @@ -29,9 +34,14 @@ constexpr int kMaxBlocks = 64;
struct Signal {
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
};

struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
#ifdef USE_ROCM
struct __align__(16) RankData { const void * ptrs[8]; };
#else
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
#endif

struct __align__(16) RankSignals { volatile Signal* signals[8]; };

Expand Down Expand Up @@ -130,6 +140,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
#ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x]+1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED);
// wait until we got true from all ranks
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED) < flag)
;
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
Expand All @@ -140,6 +164,7 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
while (!self_sg->start[blockIdx.x][threadIdx.x]);
}
__syncthreads();
#endif
}

// This function is meant to be used as the second or the final synchronization
Expand All @@ -148,6 +173,25 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
#ifdef USE_ROCM
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
uint32_t flag = self_sg->_flag[blockIdx.x]+1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
// wait until we got true from all ranks
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) < flag)
;
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
Expand All @@ -164,6 +208,7 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
while (!self_sg->end[blockIdx.x][threadIdx.x]);
}
if constexpr (!final_sync) __syncthreads();
#endif
}

template <typename P, int ngpus, typename A>
Expand Down Expand Up @@ -324,7 +369,11 @@ class CustomAllreduce {
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
Expand Down
25 changes: 23 additions & 2 deletions csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@
#include <vector>

#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h"
#include "nccl.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#include "rccl/rccl.h"
#include "custom_all_reduce_hip.cuh"
#else
#include "nccl.h"
#include "custom_all_reduce.cuh"
#endif

#define MPICHECK(cmd) \
do { \
Expand All @@ -44,7 +51,16 @@
} while (0)

__global__ void dummy_kernel() {
#ifdef USE_ROCM
for (int i = 0; i < 100; i++) {
uint64_t start = wall_clock64();
uint64_t cycles_elapsed;
do { cycles_elapsed = wall_clock64() - start; }
while (cycles_elapsed < 100);
}
#else
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
#endif
}

template <typename T>
Expand Down Expand Up @@ -114,8 +130,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
* registration, they are allocated and registered together in the test for
* convenience.
*/
#ifdef USE_ROCM
CUDACHECK(
hipExtMallocWithFlags((void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal), hipDeviceMallocUncached));
#else
CUDACHECK(
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
#endif
CUDACHECK(
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
Expand Down
5 changes: 4 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

#ifndef USE_ROCM
using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
Expand All @@ -151,4 +150,8 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#ifdef USE_ROCM
torch::Tensor allocate_meta_buffer(int size);
std::vector<uint8_t> get_meta_buffer_ipc_handle(torch::Tensor inp);
std::vector<uint8_t> get_device_bdf(int dev);
#endif
8 changes: 7 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");

#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
Expand All @@ -112,5 +111,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#ifdef USE_ROCM
custom_ar.def("allocate_meta_buffer", &allocate_meta_buffer,
"allocate_meta_buffer");
custom_ar.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle,
"get_meta_buffer_ipc_handle");
custom_ar.def("get_device_bdf", &get_device_bdf,
"get_device_bdf");
#endif
}
9 changes: 2 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,13 +621,8 @@ def _verify_args(self) -> None:
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp' or 'torchrun'.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
if (not self.disable_custom_all_reduce and self.world_size > 1
and self.pipeline_parallel_size > 1):
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
Expand Down
Loading

0 comments on commit 190a636

Please sign in to comment.