From e6c4086b1a6347afd8f690cd5391e89340727fb1 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 13 Aug 2024 12:29:53 +0300 Subject: [PATCH] adjust format.sh & vllm.hpu.ops --- format.sh | 1 + vllm/hpu/ops.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/format.sh b/format.sh index 5ad6d6f2938bb..fbfc27a68bb3d 100755 --- a/format.sh +++ b/format.sh @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml +mypy vllm/hpu --config-file pyproject.toml # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 3748eb3544dd1..8ae292b5413aa 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -52,8 +52,7 @@ def paged_attention_v1(query, keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = [torch.matmul(query, k) for k in keys] - attn_weights = torch.cat(attn_weights, dim=-1) + attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) @@ -128,7 +127,8 @@ def prompt_attention( query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) - attn_bias = attn_bias.unsqueeze(2) + if attn_bias is not None: + attn_bias = attn_bias.unsqueeze(2) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias)