From 065ed53f9380a269f71da8d5fd230db9f817d26f Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 14 Jan 2025 10:37:32 +0800 Subject: [PATCH] Do not raise error if _Backend is not found Signed-off-by: wangxiyuan --- tests/kernels/test_attention_selector.py | 7 ++++--- .../dummy_attention_backend.py | 8 ++++++++ .../vllm_add_dummy_platform/dummy_platform.py | 4 ++++ tests/plugins_tests/test_platform_plugins.py | 14 +++++++++++++ vllm/attention/layer.py | 8 ++++---- vllm/attention/selector.py | 20 ++++++++++--------- 6 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a08c874407e3f..192906645c8fe 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -94,7 +94,8 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): - """Throw an exception if the backend name is invalid.""" + """Ignore the invalid env variable if it is set.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) - with pytest.raises(ValueError): - get_attn_backend(16, torch.float16, None, 16, False) + with patch("vllm.attention.selector.current_platform", CudaPlatform()): + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() == "FLASH_ATTN" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py new file mode 100644 index 0000000000000..5634be3c8d882 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py @@ -0,0 +1,8 @@ +from vllm.attention.backends.flash_attn import FlashAttentionBackend + + +class DummyAttentionBackend(FlashAttentionBackend): + + @staticmethod + def get_name() -> str: + return "Dummy_Backend" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index fde93142f1103..84721d5971ccf 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -3,3 +3,7 @@ class DummyPlatform(CudaPlatform): device_name = "DummyDevice" + + def get_attn_backend_cls(self, backend_name, head_size, dtype, + kv_cache_dtype, block_size, use_v1): + return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index 69698b34c71a3..661aa5f649ab9 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -1,3 +1,10 @@ +import torch + +from tests.kernels.utils import override_backend_env_variable +from vllm.attention.selector import get_attn_backend +from vllm.utils import STR_INVALID_VAL + + def test_platform_plugins(): # simulate workload by running an example import runpy @@ -14,3 +21,10 @@ def test_platform_plugins(): f"Expected DummyDevice, got {current_platform.device_name}, " "possibly because current_platform is imported before the plugin" f" is loaded. The first import:\n{_init_trace}") + + +def test_oot_attention_backend(monkeypatch): + # ignore the backend env variable if it is set + override_backend_env_variable(monkeypatch, STR_INVALID_VAL) + backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) + assert backend.get_name() == "Dummy_Backend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a283e87d84070..9b03fd73fe690 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -190,11 +190,11 @@ def __init__( kv_cache_dtype=None, block_size=16, is_attention_free=False) - attn_backend = backend_name_to_enum(attn_backend.get_name()) - if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - attn_backend = _Backend.XFORMERS + backend = backend_name_to_enum(attn_backend.get_name()) + if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + backend = _Backend.XFORMERS - self.attn_backend = attn_backend if attn_backend in { + self.attn_backend = backend if backend in { _Backend.TORCH_SDPA, _Backend.XFORMERS } else _Backend.TORCH_SDPA diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 0ff007c87b1c9..81ea6eefb5410 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -14,16 +14,18 @@ logger = init_logger(__name__) -def backend_name_to_enum(backend_name: str) -> _Backend: - assert backend_name is not None - - backend_members = _Backend.__members__ - if backend_name not in backend_members: - raise ValueError(f"Invalid attention backend '{backend_name}'. " - f"Available backends: {', '.join(backend_members)} " - "(case-sensitive).") +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. - return _Backend[backend_name] + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None def get_env_variable_attn_backend() -> Optional[_Backend]: