Skip to content

Commit

Permalink
address review comments and code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
scxiao committed May 13, 2024
1 parent 4381bb7 commit 5ece4ca
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
5 changes: 3 additions & 2 deletions xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# for hkv in (1, 2)
] + [
# dict(B=i, Mq=1, Mkv=8448, Hq=8, Hkv=1, K=128, attn_bias_type=xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask) for i in [2, 4, 8, 16, 32, 64]
dict(B=i, Mq=1, Mkv=4097, Hq=8, Hkv=2, K=128, attn_bias_type=None) for i in [2, 4, 8, 16, 32, 64, 128]
dict(B=i, Mq=1, Mkv=4097, Hq=8, Hkv=1, K=128, attn_bias_type=None) for i in [2, 4, 8, 16, 32, 64, 128]
]


Expand Down Expand Up @@ -347,7 +347,8 @@ def test_flash_attention_decoder(name, case):
assert name in ["ck-decoder", "ck_splitK", "ck", "triton_splitK", "triton_int4KV"]
decoder_output,ctx = decoder.OP.apply(baseline.get_inputs(), False)
s = decoder_output.shape
decoder_output = decoder_output.reshape([s[0], s[1], -1, s[4]])
if name in ["ck-decoder", "ck_splitK"]:
decoder_output = decoder_output.reshape([s[0], s[1], -1, s[4]])
decoder_output = decoder_output.transpose(2, 1).contiguous()

torch.testing.assert_close(decoder_output, baseline_out, atol=1e-3, rtol=0)
Expand Down
8 changes: 2 additions & 6 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
@classmethod
def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
print(f"B = {B}, G = {G}, H = {H}, Mk = {Mk}")
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = max(Mk + bh - 1, 1024) // bh
if torch.version.hip:
Expand All @@ -1186,9 +1185,6 @@ def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int:
if chunk_size < split_size:
split_k += 1

# split_size = (split_size + max_chunk_size - 1) // max_chunk_size * max_chunk_size
# split_k = (Mk + split_size - 1) // split_size

split_k_upper_bound = 512
else:
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
Expand Down Expand Up @@ -1341,8 +1337,8 @@ def grid(META):

split_size = (Mk + split_k - 1) // split_k

# align split_size to the multiple of 64
split_size = (split_size + 63) // 64 * 64
# # align split_size to the multiple of 64
# split_size = (split_size + cls.BLOCK_N) // cls.BLOCK_N * cls.BLOCK_N
print(f"split_k = {split_k}, split_size = {split_size}, num_tiles = {B * G * H * split_k}")

use_seq_len = seq_len is not None
Expand Down

0 comments on commit 5ece4ca

Please sign in to comment.