From 5ece4ca5254c7c3d5de4f5827c5346f13271894b Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Mon, 13 May 2024 17:57:35 +0000 Subject: [PATCH] address review comments and code cleanup --- xformers/benchmarks/benchmark_attn_decoding.py | 5 +++-- xformers/ops/fmha/triton_splitk.py | 8 ++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/xformers/benchmarks/benchmark_attn_decoding.py b/xformers/benchmarks/benchmark_attn_decoding.py index 1ba5100662..de528f3145 100644 --- a/xformers/benchmarks/benchmark_attn_decoding.py +++ b/xformers/benchmarks/benchmark_attn_decoding.py @@ -35,7 +35,7 @@ # for hkv in (1, 2) ] + [ # dict(B=i, Mq=1, Mkv=8448, Hq=8, Hkv=1, K=128, attn_bias_type=xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask) for i in [2, 4, 8, 16, 32, 64] - dict(B=i, Mq=1, Mkv=4097, Hq=8, Hkv=2, K=128, attn_bias_type=None) for i in [2, 4, 8, 16, 32, 64, 128] + dict(B=i, Mq=1, Mkv=4097, Hq=8, Hkv=1, K=128, attn_bias_type=None) for i in [2, 4, 8, 16, 32, 64, 128] ] @@ -347,7 +347,8 @@ def test_flash_attention_decoder(name, case): assert name in ["ck-decoder", "ck_splitK", "ck", "triton_splitK", "triton_int4KV"] decoder_output,ctx = decoder.OP.apply(baseline.get_inputs(), False) s = decoder_output.shape - decoder_output = decoder_output.reshape([s[0], s[1], -1, s[4]]) + if name in ["ck-decoder", "ck_splitK"]: + decoder_output = decoder_output.reshape([s[0], s[1], -1, s[4]]) decoder_output = decoder_output.transpose(2, 1).contiguous() torch.testing.assert_close(decoder_output, baseline_out, atol=1e-3, rtol=0) diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index a510f222e4..436dc60bf3 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -1168,7 +1168,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: @classmethod def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int: """Heuristic for the number of splits""" - print(f"B = {B}, G = {G}, H = {H}, Mk = {Mk}") bh = max(B * H, 1) # NOTE: Handle B*h=0 case split_k = max(Mk + bh - 1, 1024) // bh if torch.version.hip: @@ -1186,9 +1185,6 @@ def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int: if chunk_size < split_size: split_k += 1 - # split_size = (split_size + max_chunk_size - 1) // max_chunk_size * max_chunk_size - # split_k = (Mk + split_size - 1) // split_size - split_k_upper_bound = 512 else: max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 @@ -1341,8 +1337,8 @@ def grid(META): split_size = (Mk + split_k - 1) // split_k - # align split_size to the multiple of 64 - split_size = (split_size + 63) // 64 * 64 + # # align split_size to the multiple of 64 + # split_size = (split_size + cls.BLOCK_N) // cls.BLOCK_N * cls.BLOCK_N print(f"split_k = {split_k}, split_size = {split_size}, num_tiles = {B * G * H * split_k}") use_seq_len = seq_len is not None