Skip to content

Commit

Permalink
naive mha syntax error fix
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 28, 2024
1 parent c8fb65e commit 385f2cd
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention()
self.attn_fuc = _naive_attention
logger.info("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
Expand Down Expand Up @@ -243,11 +243,14 @@ def forward(
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
output = self.attn_fuc(
query,
key,
value,
attn_metadata.prompt_lens,
self.scale,
query,
key,
value,
attn_metadata.prompt_lens,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.scale,
)
else:
output = self.attn_func(
Expand Down Expand Up @@ -302,17 +305,28 @@ def _naive_attention(
key: torch.Tensor,
value: torch.Tensor,
prompt_lens: List[int],
num_heads: int,
num_kv_heads: int,
head_size: int,
scale: float,
) -> torch.Tensor:
query = query.reshape(-1, num_heads, head_size)
#key = key.view(-1, num_kv_heads, head_size)
#value = value.view(-1, num_kv_heads, head_size)
key = key.reshape(-1, num_heads, head_size)
value = value.reshape(-1, num_heads, head_size)
num_tokens = query.shape[0]
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
out = _naive_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
#query[None, start:end],
#key[None, start:end],
#value[None, start:end],
query[start:end],
key[start:end],
value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
Expand All @@ -332,6 +346,7 @@ def _naive_masked_attention(
value: torch.Tensor,
scale: float,
) -> torch.Tensor:

seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
Expand Down

0 comments on commit 385f2cd

Please sign in to comment.