Skip to content

Commit

Permalink
updated paged attention kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
vllmellm committed Dec 27, 2024
1 parent e31e05f commit 0aad2af
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 10 deletions.
6 changes: 3 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 256 * 1024
NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 256


Expand Down Expand Up @@ -176,13 +176,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False)
run_benchmark(num_iters=500, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=1000, profile=False)
latency = run_benchmark(num_iters=10000, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


Expand Down
92 changes: 85 additions & 7 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
#define __HIP__MI300_MI250__
#endif

#if defined(__HIPCC__) && (defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#endif

#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
Expand Down Expand Up @@ -85,6 +90,27 @@ __device__ __forceinline__ void store(T value, T* addr) {
addr[0] = value;
}


template <typename T>
__device__ __forceinline__ T loadnt(T* addr) {
return __builtin_nontemporal_load(addr);
}

__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) {
auto addr_alias = reinterpret_cast<const float*>(addr);
auto dat0 = loadnt(addr_alias);
auto dat1 = loadnt(addr_alias + 1);
auto dat2 = loadnt(addr_alias + 2);
auto dat3 = loadnt(addr_alias + 3);
//auto dat0 = *(addr_alias);
//auto dat1 = *(addr_alias+1);
//auto dat2 = *(addr_alias+2);
//auto dat3 = *(addr_alias+3);
auto res = make_float4(dat0,dat1,dat2,dat3);
return *reinterpret_cast<_B16x8*>(&res);
}


template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
const _B16x4& inpB,
Expand Down Expand Up @@ -345,6 +371,49 @@ __device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) {
}
}

template <typename T>
__device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) {
_B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) {
int32_t tmpf8;
tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false);
tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true);
const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false);
const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true);
union h2cvt {
_Half2 h2[2];
_B16x4 b16x4;
} u;
u.h2[0] = __builtin_amdgcn_cvt_pkrtz(f0[0],f0[1]);
u.h2[1] = __builtin_amdgcn_cvt_pkrtz(f1[0],f1[1]);
return u.b16x4;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
int32_t tmpf8;
tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false);
tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true);
const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false);
const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true);
floatx4 tmpf;
tmpf[0] = f0[0];
tmpf[1] = f0[1];
tmpf[2] = f1[0];
tmpf[3] = f1[1];
for (int i = 0; i < 4; i++) {
union fcvt {
uint32_t i32;
float f32;
} u;
u.f32 = tmpf[i];
ret[i] = uint16_t(u.i32 >> 16);
}
return ret;
} else {
static_assert(false, "unsupported 16b dtype");
}
}



template <typename T>
__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) {
#if 0
Expand Down Expand Up @@ -743,7 +812,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1
//__shared__ _B16x4 shared_logits[NWARPS][TLOOP][16][VTOKENS_PER_LANE/4 + 1];
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4<scalar_t>(dout[token_depth]);
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz<scalar_t>(dout[token_depth]);
} else {
shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz<scalar_t>(dout[token_depth]);
}
}

if (threadIdx.x < GQA_RATIO) {
Expand Down Expand Up @@ -900,8 +973,8 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1
*out_ptr_B16x8 = vout[h];
}
}

}

#endif

#if 0
Expand Down Expand Up @@ -954,6 +1027,15 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1
}
}
#endif
#if 0 //DEBUG ONLY
floatx4 partition_out[VHELOOP];
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
partition_out[vhe_depth] = {0};
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth];
}
}
#endif
#if 0 //DEBUG ONLY
if (laneid < GQA_RATIO) {
auto* exp_sums_ptr = exp_sums + seq_idx * 8 * max_num_partitions + partition_idx;
Expand Down Expand Up @@ -2256,7 +2338,7 @@ void paged_attention_custom_launcher(
// supported by graphing, not the actual max among all the sequences: in that
// case reduction kernel will still run but return immediately

//above optimization is not yet implemented in mfma16 kernel
//below optimization is not yet implemented in mfma16 kernel
//if (max_context_len > PARTITION_SIZE) {
dim3 reduce_grid(num_heads, num_seqs);
dim3 reduce_block(head_size);
Expand Down Expand Up @@ -2324,17 +2406,13 @@ void paged_attention_custom_launcher(
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T);
/*
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t); \
} else { \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
}
*/
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
Expand Down

0 comments on commit 0aad2af

Please sign in to comment.