Skip to content

Commit

Permalink
[Platform] Do not raise error if _Backend is not found (vllm-project#…
Browse files Browse the repository at this point in the history
…12023)

Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
  • Loading branch information
2 people authored and jikunshang committed Jan 21, 2025
1 parent 8742287 commit 9436992
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 16 deletions.
11 changes: 8 additions & 3 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch):


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
get_attn_backend(16, torch.float16, None, 16, False)
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN"

# when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from vllm.attention.backends.flash_attn import FlashAttentionBackend


class DummyAttentionBackend(FlashAttentionBackend):

@staticmethod
def get_name() -> str:
return "Dummy_Backend"
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"

def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
14 changes: 14 additions & 0 deletions tests/plugins_tests/test_platform_plugins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import get_attn_backend
from vllm.utils import STR_INVALID_VAL


def test_platform_plugins():
# simulate workload by running an example
import runpy
Expand All @@ -14,3 +21,10 @@ def test_platform_plugins():
f"Expected DummyDevice, got {current_platform.device_name}, "
"possibly because current_platform is imported before the plugin"
f" is loaded. The first import:\n{_init_trace}")


def test_oot_attention_backend(monkeypatch):
# ignore the backend env variable if it is set
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "Dummy_Backend"
8 changes: 4 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def __init__(
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
attn_backend = backend_name_to_enum(attn_backend.get_name())
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = attn_backend if attn_backend in {
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA

Expand Down
20 changes: 11 additions & 9 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
logger = init_logger(__name__)


def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None

backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.
return _Backend[backend_name]
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
None


def get_env_variable_attn_backend() -> Optional[_Backend]:
Expand Down

0 comments on commit 9436992

Please sign in to comment.