Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GitHub Actions static checks for habana_main #177

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: clang-format

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
clang-format:
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: mypy

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
ruff:
Expand Down Expand Up @@ -50,4 +50,6 @@ jobs:
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml


6 changes: 3 additions & 3 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: ruff

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main

jobs:
ruff:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/yapf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: yapf

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# but only for the habana_main branch
push:
branches:
- main
- habana_main
pull_request:
branches:
- main
- habana_main
jobs:
yapf:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/hpu --config-file pyproject.toml


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
16 changes: 13 additions & 3 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
import torch.nn.functional as F

import vllm.hpu.utils as hpu_utils
from vllm.logger import init_logger

logger = init_logger()
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
HPUFusedRMSNorm = FusedRMSNorm
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')

Expand Down Expand Up @@ -52,8 +62,7 @@ def paged_attention_v1(query,
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = torch.cat(attn_weights, dim=-1)
attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
Expand Down Expand Up @@ -128,7 +137,8 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
Expand Down
12 changes: 1 addition & 11 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@

from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_hpu

logger = init_logger(__name__)
if is_hpu():
try:
from habana_frameworks.torch.hpex.normalization import (FusedRMSNorm as
HPUFusedRMSNorm
)
except ImportError:
logger.warning(
"Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedRMSNorm = None


class RMSNorm(CustomOp):
Expand Down Expand Up @@ -86,6 +75,7 @@ def forward_hpu(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm.hpu.ops import HPUFusedRMSNorm
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
Expand Down
Loading