diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 58682111ee3bd..b138b181cd0cf 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -47,6 +47,9 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +// clean up this line later +#define __HIP__MI300_MI250__ + #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 @@ -102,10 +105,6 @@ __device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { auto dat1 = loadnt(addr_alias + 1); auto dat2 = loadnt(addr_alias + 2); auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); auto res = make_float4(dat0,dat1,dat2,dat3); return *reinterpret_cast<_B16x8*>(&res); } @@ -188,23 +187,7 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { __hip_bfloat16 b; } t16; _B16x4 ret; -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; -#else if constexpr (std::is_same::value) { -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; -#else union h2cvt { __half2 h2[2]; _B16x4 b16x4; @@ -212,7 +195,6 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); return u.b16x4; -#endif } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { @@ -230,7 +212,6 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { } else { static_assert(false, "unsupported 16b dtype"); } -#endif } template @@ -242,27 +223,7 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, __hip_bfloat16 b; } t1, t2, res; _B16x4 ret; -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; -#else if constexpr (std::is_same::value) { -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; -#else union h2cvt { _B16x4 b16x4; __half2 h2[2]; @@ -272,7 +233,6 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, s.h2[0] = u1.h2[0] + u2.h2[0]; s.h2[1] = u1.h2[1] + u2.h2[1]; return s.b16x4; -#endif } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { @@ -284,10 +244,6 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, u2.i32 = uint32_t(inp2[i])<<16; s.f32 = u1.f32 + u2.f32; ret[i] = uint16_t(s.i32>>16); - //t1.u = inp1[i]; - //t2.u = inp2[i]; - //res.b = t1.b + t2.b; - //ret[i] = res.u; } return ret; } else { @@ -416,16 +372,6 @@ __device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) { template __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { -#if 0 - union { - floatx4 f32x4[2]; - vllm::Float8_ f32x8; - _B8x8 b8x8[2]; - } tmpf8; - tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); - //tmpf8.b8x8[0] = input; - //tmpf8.b8x8[1] = input; -#endif union { _B8x8 b8x8; _B8x4 b8x4[2]; @@ -435,8 +381,6 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { for (int i=0; i<2; i++) { ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); } - //ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); - //ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); return ret; } /////////////////////////////////////// @@ -543,31 +487,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } -#if 0 //fetch Q into registers - - const int local_qhead_idx = lane16id % GQA_RATIO; - const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); - const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; - - if (lane16id < GQA_RATIO) { - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - const scalar_t* q_ptr2 = q_ptr + qkhe_depth * QKHE_PER_FETCH; - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - const scalar_t* q_fetch_ptr = q_ptr2 + qkratio * CONTIGUOUS_SCALAR_ELEMS_16B; - const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); - Qlocal[qkhe_depth][qkratio] = *q_fetch_ptr_16B; - } - } - } else { - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - Qlocal[qkhe_depth][qkratio].xy[0] = {0}; - Qlocal[qkhe_depth][qkratio].xy[1] = {0}; - } - } - } -#else //fetch Q in shared + //fetch Q in shared const int local_qhead_idx = 4 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; const int64_t seq_idx64 = static_cast(seq_idx); @@ -599,7 +519,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } -#endif constexpr int KX = 16 / sizeof(cache_t); const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; @@ -675,7 +594,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } floatx4 dout[TLOOP]; -#if 1 //Q stored in registers + + //Q stored in registers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] = {0}; for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { @@ -702,37 +623,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 dout[token_depth] *= scale2; } -#else //Q in shared - _B16x4 tmpQ[QKHELOOP][2]; - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - tmpQ[qkhe_depth][0] = shared_logits[qkhe_depth][rowid][lane16id][0]; - tmpQ[qkhe_depth][1] = shared_logits[qkhe_depth][rowid][lane16id][1]; - } - - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - dout[token_depth] = {0}; - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - //for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - for (int i=0; i<2; i++) { - dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], - tmpQ[qkhe_depth][i], //shared_logits[qkhe_depth][rowid][lane16id][i], - dout[token_depth]); - } - //} - } - dout[token_depth] *= scale; - } -#endif - -#if 0 //DEBUG ONLY qk * scale - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - auto tmp = from_floatx4(dout[token_depth]); - *qkout_write_ptr = tmp; - } -#endif - float qk_max = -FLT_MAX; float exp_sum = 0.0f; @@ -776,17 +666,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 shared_mem[exp_sum_offset] = exp_sum; } -#if 0 //DEBUG ONLY - //scalar_t* qkout_ptr = out + - // seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - //auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - //auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - auto tmp = from_floatx4(dout[token_depth]); - shared_tokens[warpid][token_depth][lane16id][rowid] = tmp; - //*qkout_write_ptr = tmp; - } -#endif __syncthreads(); float partition_qk_max = -FLT_MAX; @@ -809,7 +688,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); //new - //__shared__ _B16x4 shared_logits[NWARPS][TLOOP][16][VTOKENS_PER_LANE/4 + 1]; for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { @@ -828,40 +706,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); -#if 0 //DEBUG ONLY - scalar_t* qkout_ptr = out + - seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - //dout[token_depth] *= inv_sum_scale[warpid]; - //auto tmp = from_floatx4(dout[token_depth]); - auto tmp = shared_tokens[warpid][token_depth][lane16id][rowid]; - *qkout_write_ptr = tmp; - } -#endif -#if 0 - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; - _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); - for (int j=0; j<2; j++) { - _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; - _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - for (int i=0; i<2; i++) { - const int offset = 4*rowid + 2*j + i; - const int offset1 = offset % 4; - const int offset2 = offset / 4; - tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); - } - } - } - } -#endif _B16x4 outelems[VHELOOP]; _B16x4 S_local[VTLOOP][2][2]; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { @@ -897,16 +741,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP const int offset2 = offset / 4; -#if 0 - //if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows - tmp_out = gcn_mfma16x16x16_instr(shared_logits[vtoken_depth][offset2][lane16id][offset1], - Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out); -#else + //if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows tmp_out = gcn_mfma16x16x16_instr(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], tmp_out); -#endif } } } else { @@ -923,8 +762,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], S_local[vtoken_depth][j][i], tmp_out); - //shared_logits[vtoken_depth][offset2][lane16id][offset1], - //tmp_out); } } } @@ -937,7 +774,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 outelems[vhe_depth] = from_floatx4(tmp_out); } -#if 1 __syncthreads(); for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { @@ -975,85 +811,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } -#endif - -#if 0 - //if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - - const int vhe_offset = warpid * 16 + lane16id; - - #pragma unroll - for (int i=0; i<4; i++) { - const int local_head_idx = 4*rowid + i; - if (local_head_idx < GQA_RATIO) { - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; - scalar_t* out_ptr3 = out_ptr2 + vhead_elem; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr3); - *out_ptr_b16 = outelems[vhe_depth][i]; - } - } - } -#endif -#if 0 - //if output format is 16 qheads across 16 lanes, 16 he spread across 4 rows - if (lane16id < GQA_RATIO) { - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - const int local_head_idx = lane16id; - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - const int vhe_offset = warpid * 16 + rowid * 4; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; - scalar_t* out_ptr3 = out_ptr2 + vhead_elem; - _B16x4* out_ptr_B16x4 = reinterpret_cast<_B16x4*>(out_ptr3); - *out_ptr_B16x4 = outelems[vhe_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - floatx4 partition_out[VHELOOP]; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - partition_out[vhe_depth] = {0}; - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - floatx4 partition_out[VHELOOP]; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - partition_out[vhe_depth] = {0}; - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - if (laneid < GQA_RATIO) { - auto* exp_sums_ptr = exp_sums + seq_idx * 8 * max_num_partitions + partition_idx; - floatx4 tmp = {0}; - //for (int t=0; t(from_floatx4(tmp), shared_tokens[warpid][lane4id][lane16id][rowid]); - - float2 tmpf = *reinterpret_cast(&tmp16); - *exp_sums_ptr = laneid%2 == 0 ? tmpf.x : tmpf.y; - } -#endif } ///////////////////////////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 09901d77be79f..745bc9b080751 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,11 +18,12 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -#MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -MAX_SEQ_LEN = 32768 +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 + # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 128*1024+4321 # Arbitrary values for testing +NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 PARTITION_SIZE_ROCM = 256 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} @@ -31,7 +32,8 @@ ] if not current_platform.is_rocm() else [torch.half,torch.bfloat16] NUM_GEN_SEQS = [17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -NUM_HEADS = [(64, 8), (26,2), (16,1), (32,32)] # Arbitrary values for testing +# NUM_HEADS = [(64, 8), (26,2), (16,1), (32,32)] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 @@ -40,10 +42,11 @@ BLOCK_SIZES = [16] USE_ALIBI = [False] +# USE_ALIBI = [True] KV_CACHE_DTYPE = ["auto","fp8"] SEEDS = [0] CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] REF_TENSOR = None @@ -56,15 +59,10 @@ def ref_masked_attention( scale: float, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qkout = torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = scale * qkout + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - #print('>>> ref qkout shape',attn_weights.shape) - #print('>>> ref qkout',attn_weights) - #global REF_TENSOR - #REF_TENSOR = attn_weights out = torch.einsum("hqk,khd->qhd", attn_weights, value) return out @@ -160,8 +158,6 @@ def test_paged_attention( num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) query.uniform_(-scale, scale) - #query = torch.ones_like(query) - query = torch.randn_like(query) assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads @@ -170,11 +166,8 @@ def test_paged_attention( alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - #seq_lens = [MAX_SEQ_LEN for _ in range(num_seqs)] seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) - #max_seq_len = 512 - print('>>>', seq_lens, max_seq_len) seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. @@ -200,7 +193,7 @@ def test_paged_attention( #key_cache = torch.ones_like(key_cache) # Using default kv_scale - k_scale = v_scale = 0.1 + k_scale = v_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) @@ -354,17 +347,7 @@ def test_paged_attention( #bf16 rounding is handled via truncation in new kernel, this increses error if dtype == torch.bfloat16: atol = 1e-3 - #print('>>>tmpout shape', tmp_output.shape) - #print('>>>tmpout', tmp_output.view(8,1,256)) - #global REF_TENSOR - #torch.testing.assert_close(tmp_output.view(8,1,256), REF_TENSOR, atol=atol, rtol=rtol) - - #print('>>> ref out shape', ref_output.shape) - #print('>>> ref out', ref_output) - #print('>>> out shape', output.shape) - #print('>>> out', output) - #print('>>>', exp_sums) - #print('>>>', max_logits) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)