Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/habana_main' into private/kzawor…
Browse files Browse the repository at this point in the history
…a/test_pr
  • Loading branch information
kzawora-intel committed Aug 13, 2024
2 parents e52c0ec + d291910 commit dcc878b
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 26 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: clang-format

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
clang-format:
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: mypy

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
ruff:
Expand Down Expand Up @@ -50,4 +50,6 @@ jobs:
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml
6 changes: 3 additions & 3 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: ruff

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
ruff:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/yapf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: yapf

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main
jobs:
yapf:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
16 changes: 13 additions & 3 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
import torch.nn.functional as F

import vllm.hpu.utils as hpu_utils
from vllm.logger import init_logger

logger = init_logger()
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
HPUFusedRMSNorm = FusedRMSNorm
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')

Expand Down Expand Up @@ -52,8 +62,7 @@ def paged_attention_v1(query,
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = torch.cat(attn_weights, dim=-1)
attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
Expand Down Expand Up @@ -128,7 +137,8 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
Expand Down
12 changes: 1 addition & 11 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@

from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_hpu

logger = init_logger(__name__)
if is_hpu():
try:
from habana_frameworks.torch.hpex.normalization import (FusedRMSNorm as
HPUFusedRMSNorm
)
except ImportError:
logger.warning(
"Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedRMSNorm = None


class RMSNorm(CustomOp):
Expand Down Expand Up @@ -86,6 +75,7 @@ def forward_hpu(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm.hpu.ops import HPUFusedRMSNorm
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
Expand Down

0 comments on commit dcc878b

Please sign in to comment.