Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral #12303

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class HPUPagedAttentionMetadata:
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def __post_init__(self) -> None:
raise ValueError(f"worker-use-ray can't be used with "
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
ray_only_devices = ["tpu", "hpu"]
ray_only_devices = ["tpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices
and self.world_size > 1):
Expand Down
18 changes: 17 additions & 1 deletion vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

import torch

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context, run_method

Expand Down Expand Up @@ -284,7 +286,21 @@ def set_multiprocessing_worker_envs(parallel_config):
process before worker processes are created"""

_check_multiproc_method()

if (current_platform.is_hpu()
and parallel_config.distributed_executor_backend == 'mp'
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) is not None:
logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork might "
"cause application hangs on exit. Using "
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
"as it was explicitly requested.")
else:
logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork might "
"cause application hangs on exit. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"To override that behavior, please set "
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,26 @@ def forward_cpu(
num_expert_group,
)

def forward_hpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
**kwargs,
):
assert not use_grouped_topk, "use_grouped_topk must be False on HPU"
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, "topk_group is not supported on HPU"
if layer is not None:
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k)

def forward_tpu(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -281,6 +301,9 @@ def __init__(
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -495,4 +496,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if current_platform.is_hpu():
torch.hpu.synchronize()
return loaded_params
5 changes: 5 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,11 @@ def reset(self):
self._index = 0


@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'


@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
Expand Down
Loading
Loading