Skip to content

Commit

Permalink
Add output for Attention Backend
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Jan 13, 2025
1 parent d14e98d commit 443e427
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 5 deletions.
4 changes: 4 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def copy_blocks(
) -> None:
pass

@staticmethod
def use_output():
pass


def test_model_runner_input():
sampling_metadata = SamplingMetadata(
Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ def copy_blocks(
) -> None:
raise NotImplementedError

@staticmethod
@abstractmethod
def use_output() -> bool:
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
raise NotImplementedError

def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def use_output():
return True


@dataclass
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def copy_blocks(

ops.copy_blocks(key_caches, value_caches, src_to_dists)

@staticmethod
def use_output():
return True


@dataclass
class FlashAttentionMetadata(AttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

@staticmethod
def use_output():
return False


class FlashInferState(AttentionState):

Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def copy_blocks(
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def use_output():
return False


@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def copy_blocks(
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

@staticmethod
def use_output():
return False


@dataclass
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def copy_blocks(
copy_cache_block(key_cache, key_cache, src, dst)
copy_cache_block(value_cache, value_cache, src, dst)

@staticmethod
def use_output():
return False


@dataclass
class OpenVINOAttentionMetadata:
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def copy_blocks(
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]

@staticmethod
def use_output():
return False


@dataclass
class PallasMetadata(AttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def copy_blocks(
) -> None:
return

@staticmethod
def use_output():
return False


@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def use_output():
return False


@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def use_output():
return False


@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def use_output():
return False


@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
Expand Down
6 changes: 1 addition & 5 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,7 @@ def __init__(
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
self.use_output = self.backend == _Backend.FLASH_ATTN or \
self.backend == _Backend.FLASH_ATTN_VLLM_V1
self.use_output = attn_backend.use_output()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def get_kv_cache_shape(
def use_cascade_attention(*args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)

@staticmethod
def use_output():
return True


@dataclass
class FlashAttentionMetadata:
Expand Down

0 comments on commit 443e427

Please sign in to comment.