Skip to content

Commit

Permalink
adjust format.sh & vllm.hpu.ops
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 13, 2024
1 parent 300476a commit e6c4086
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
6 changes: 3 additions & 3 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def paged_attention_v1(query,
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = torch.cat(attn_weights, dim=-1)
attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
Expand Down Expand Up @@ -128,7 +127,8 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
Expand Down

0 comments on commit e6c4086

Please sign in to comment.