diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index b0c23fee5b373..f272c8e3acf93 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,10 +6,10 @@ from vllm import _custom_ops as ops from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) + create_kv_caches_with_random, is_hip) NUM_BLOCKS = 1024 * 1024 -PARTITION_SIZE = 256 +PARTITION_SIZE = 512 @torch.inference_mode() @@ -80,9 +80,9 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - if not args.custom_paged_attn: + if is_hip() and not args.custom_paged_attn: global PARTITION_SIZE - PARTITION_SIZE = 512 + PARTITION_SIZE = 1024 num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 46831b506aff3..eb865652791cb 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -31,8 +31,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 - ] if not is_hip() else [64, 80, 96, 112, 128] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize( + "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -137,7 +137,8 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if kv_cache_dtype == "fp8" and head_size % 16: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): pytest.skip() random.seed(seed) torch.random.manual_seed(seed) @@ -208,7 +209,9 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0])) - elif version == "v2": + elif version in ("v2", "rocm"): + if is_hip(): + PARTITION_SIZE = 1024 if version == "v2" else 512 num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -221,32 +224,62 @@ def test_paged_attention( dtype=torch.float32, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, key_cache, - value_cache, num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0])) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._rocm_C.paged_attention, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == 64)) else: raise AssertionError(f"Unknown version: {version}") @@ -330,165 +363,6 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -@pytest.mark.parametrize("version", ["rocm"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not is_hip(), reason="only for rocm") -def test_paged_attention_rocm( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - #context_lens = [8192 for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) - #print('>>> ctx lens', context_lens) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # TODO(charlifu) enable fp8 kv cache - # Using default kv_scale - # kv_scale = 1.0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - PARTITION_SIZE_ROCM = 256 - num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // - PARTITION_SIZE_ROCM) - assert PARTITION_SIZE_ROCM % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - if version == "rocm": - ops.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - ) - else: - raise AssertionError(f"Unknown version: {version}") - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - alibi_slopes, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 2e-4, 1e-5 - if use_alibi: - if dtype == torch.half: - atol, rtol = 5e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) - - # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -496,7 +370,8 @@ def test_paged_attention_rocm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(is_hip(), reason="skip for rocm") +@pytest.mark.skipif(is_hip(), + reason="Xformers backend is not supported on ROCm.") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py deleted file mode 100644 index 65cfbb9d9872e..0000000000000 --- a/tests/kernels/test_attention_custom.py +++ /dev/null @@ -1,239 +0,0 @@ -import random -from typing import Optional, Tuple - -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.utils import is_hip - -from .allclose_default import get_default_atol, get_default_rtol - -MAX_SEQ_LEN = 32 * 1024 -# There may not be enough gpu memory due to large NUM_BLOCKS. -# Reduce NUM_BLOCKS when it happens. -NUM_BLOCKS = 128 * 1024 + 4321 # Arbitrary values for testing -PARTITION_SIZE = 256 -DTYPES = [torch.bfloat16, torch.half] -NUM_GEN_SEQS = [1, 17] # Arbitrary values for testing -NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing - -HEAD_SIZES = [64, 128] -BLOCK_SIZES = [16, 32] -USE_ALIBI = [True, False] -KV_CACHE_DTYPE = ["auto"] -SEEDS = [37] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) -] - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - -def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], -) -> None: - num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] - num_seqs = query.shape[0] - - block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() - for i in range(num_seqs): - q = query[i].unsqueeze(0) - block_table = block_tables[i] - context_len = int(context_lens[i]) - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) - keys.append(k) - - v = value_cache[block_number, :, :, block_offset] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - alibi_bias = None - if alibi_slopes is not None: - # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len).int() - alibi_bias = (position_ids - context_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) - - -@pytest.mark.parametrize("version", ["custom"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_paged_attention( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - #context_lens = [8192 for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) - #print('>>> ctx lens', context_lens) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # Using default kv_scale - k_scale = v_scale = 1.0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) - assert PARTITION_SIZE % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_rocm(output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale) - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - alibi_slopes, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 2e-4, 1e-5 - if use_alibi: - if dtype == torch.half: - atol, rtol = 5e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 7357508751ae1..5563f69469669 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -20,7 +20,7 @@ # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing -PARTITION_SIZE = 512 +PARTITION_SIZE = 512 if not is_hip() else 1024 DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c0e72cfd55b6f..6aa0d51f127ae 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -929,8 +929,8 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: int, - v_scale: int, + k_scale: float, + v_scale: float, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9dd74081390a7..dc419bf4a1718 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -19,7 +19,7 @@ logger = init_logger(__name__) -_PARTITION_SIZE = 256 +_PARTITION_SIZE_ROCM = 512 ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName @@ -549,9 +549,10 @@ def forward( decode_meta.max_decode_seq_len) if use_custom: max_seq_len = decode_meta.max_decode_seq_len - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - assert _PARTITION_SIZE % block_size == 0 + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype, @@ -629,7 +630,6 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, max_seq_len: int) -> bool: # rocm custom page attention not support on navi (gfx1*) return (envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN and not ON_NAVI - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) + and qtype in (torch.half, torch.bfloat16) + and head_size in (64, 128) and block_size in (16, 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5a..540ab2fedd239 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -5,12 +5,13 @@ from vllm import _custom_ops as ops from vllm.triton_utils import HAS_TRITON +from vllm.utils import is_hip if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 +_PARTITION_SIZE = 512 if not is_hip() else 1024 @dataclass