diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 189ffe4672d8a..813de437b611e 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,7 +9,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 256 * 1024 +NUM_BLOCKS = 1024 * 1024 PARTITION_SIZE = 256 @@ -176,13 +176,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: # Warmup. print("Warming up...") run_benchmark = run_cuda_benchmark - run_benchmark(num_iters=3, profile=False) + run_benchmark(num_iters=500, profile=False) # Benchmark. if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=1000, profile=False) + latency = run_benchmark(num_iters=10000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 9d347053aac7b..58682111ee3bd 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -29,6 +29,11 @@ #define __HIP__MI300_MI250__ #endif +#if defined(__HIPCC__) && (defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -85,6 +90,27 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + //auto dat0 = *(addr_alias); + //auto dat1 = *(addr_alias+1); + //auto dat2 = *(addr_alias+2); + //auto dat3 = *(addr_alias+3); + auto res = make_float4(dat0,dat1,dat2,dat3); + return *reinterpret_cast<_B16x8*>(&res); +} + + template __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, const _B16x4& inpB, @@ -345,6 +371,49 @@ __device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { } } +template +__device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + int32_t tmpf8; + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(f0[0],f0[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(f1[0],f1[1]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + int32_t tmpf8; + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); + floatx4 tmpf; + tmpf[0] = f0[0]; + tmpf[1] = f0[1]; + tmpf[2] = f1[0]; + tmpf[3] = f1[1]; + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = tmpf[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + + + template __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { #if 0 @@ -743,7 +812,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 //__shared__ _B16x4 shared_logits[NWARPS][TLOOP][16][VTOKENS_PER_LANE/4 + 1]; for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; - shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4(dout[token_depth]); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); + } } if (threadIdx.x < GQA_RATIO) { @@ -900,8 +973,8 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 *out_ptr_B16x8 = vout[h]; } } - } + #endif #if 0 @@ -954,6 +1027,15 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } #endif +#if 0 //DEBUG ONLY + floatx4 partition_out[VHELOOP]; + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + partition_out[vhe_depth] = {0}; + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth]; + } + } +#endif #if 0 //DEBUG ONLY if (laneid < GQA_RATIO) { auto* exp_sums_ptr = exp_sums + seq_idx * 8 * max_num_partitions + partition_idx; @@ -2256,7 +2338,7 @@ void paged_attention_custom_launcher( // supported by graphing, not the actual max among all the sequences: in that // case reduction kernel will still run but return immediately - //above optimization is not yet implemented in mfma16 kernel + //below optimization is not yet implemented in mfma16 kernel //if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); @@ -2324,9 +2406,6 @@ void paged_attention_custom_launcher( CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ } #else - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ - CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); -/* #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if (fp8_out_scale) { \ CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ @@ -2334,7 +2413,6 @@ void paged_attention_custom_launcher( } else { \ CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ } - */ #endif #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \