Skip to content

Commit

Permalink
Do not raise error if _Backend is not found
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Jan 14, 2025
1 parent 8a1f938 commit 850bce0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
8 changes: 4 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def __init__(
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
attn_backend = backend_name_to_enum(attn_backend.get_name())
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = attn_backend if attn_backend in {
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA

Expand Down
9 changes: 5 additions & 4 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
logger = init_logger(__name__)


def backend_name_to_enum(backend_name: str) -> _Backend:
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
assert backend_name is not None

backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
logger.warning(
"Invalid attention backend %s. Available backends: {%s} "
"(case-sensitive).", backend_name, ', '.join(backend_members))
return None

return _Backend[backend_name]

Expand Down

0 comments on commit 850bce0

Please sign in to comment.