Skip to content

Commit

Permalink
[Misc] Pass attention to impl backend (#12218)
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan authored Jan 20, 2025
1 parent 5f0ec39 commit 86bfb6d
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 78 deletions.
23 changes: 19 additions & 4 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Protocol, Set, Tuple, Type, TypeVar)

import torch

Expand Down Expand Up @@ -223,6 +223,22 @@ def build(self, seq_lens: List[int], query_lens: List[int],
raise NotImplementedError


class AttentionLayer(Protocol):

_k_scale: float
_v_scale: float

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
Expand All @@ -244,13 +260,12 @@ def __init__(
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
12 changes: 6 additions & 6 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
Expand Down Expand Up @@ -358,13 +359,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -401,8 +401,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if prefill_meta := attn_metadata.prefill_metadata:
Expand Down Expand Up @@ -439,8 +439,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,
Expand Down
10 changes: 5 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
Expand Down Expand Up @@ -634,13 +635,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand All @@ -657,7 +657,7 @@ def forward(
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."
Expand Down Expand Up @@ -709,8 +709,8 @@ def forward(
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

(num_prefill_query_tokens, num_prefill_kv_tokens,
Expand Down
16 changes: 8 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
Expand Down Expand Up @@ -792,13 +793,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:

Expand Down Expand Up @@ -826,8 +826,8 @@ def forward(
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
Expand Down Expand Up @@ -886,8 +886,8 @@ def forward(
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
Expand All @@ -897,8 +897,8 @@ def forward(
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)

if prefill_output is None and decode_output is not None:
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
Expand Down Expand Up @@ -152,13 +153,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
18 changes: 9 additions & 9 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
Expand Down Expand Up @@ -171,13 +172,12 @@ def split_kv_cache(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand All @@ -193,7 +193,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand All @@ -210,8 +210,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if attn_metadata.is_prompt:
Expand Down Expand Up @@ -296,8 +296,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
# Run PagedAttention V2.
Expand Down Expand Up @@ -329,8 +329,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
6 changes: 3 additions & 3 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch_xla.experimental.custom_kernel # Required to register custom ops.

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState

Expand Down Expand Up @@ -150,13 +151,12 @@ def __init__(

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand All @@ -173,7 +173,7 @@ def forward(
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
Expand Down
20 changes: 10 additions & 10 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
Expand Down Expand Up @@ -414,13 +415,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -458,8 +458,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

num_prefill_tokens = attn_metadata.num_prefill_tokens
Expand Down Expand Up @@ -567,8 +567,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down Expand Up @@ -613,8 +613,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
Expand All @@ -628,8 +628,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
Loading

0 comments on commit 86bfb6d

Please sign in to comment.