Skip to content

Commit

Permalink
Add fake HPU mode
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 13, 2024
1 parent dcc878b commit afffe33
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 20 deletions.
5 changes: 4 additions & 1 deletion vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import habana_frameworks.torch as htorch
from vllm.utils import is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch
import torch


Expand Down
8 changes: 6 additions & 2 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
import os
from typing import Optional

import habana_frameworks.torch as htorch
from vllm.utils import is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch

import torch
import torch.nn.functional as F

import vllm.hpu.utils as hpu_utils
from vllm.logger import init_logger

logger = init_logger()
logger = init_logger(__name__)
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
Expand Down
5 changes: 4 additions & 1 deletion vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from functools import wraps

import habana_frameworks.torch as htorch
from vllm.utils import is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch


def with_mark_steps(fn):
Expand Down
25 changes: 25 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,29 @@ def is_neuron() -> bool:

@lru_cache(maxsize=None)
def is_hpu() -> bool:
return _is_habana_frameworks_installed() or _is_built_for_hpu()


@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return not _is_habana_frameworks_installed() and _is_built_for_hpu()


@lru_cache(maxsize=None)
def _is_habana_frameworks_installed() -> bool:
from importlib import util
return util.find_spec('habana_frameworks') is not None


@lru_cache(maxsize=None)
def _is_built_for_hpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "gaudi" in version("vllm")
except PackageNotFoundError:
return False


@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
Expand Down Expand Up @@ -623,18 +642,24 @@ def __init__(self, device=None):

@staticmethod
def current_device_memory_usage() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory - free_hpu_memory

@staticmethod
def current_free_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory

@staticmethod
def total_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
_, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory
Expand Down
32 changes: 22 additions & 10 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
import math
import operator
import os
import time
import time
from enum import IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)

import habana_frameworks.torch as htorch
from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu,
is_pin_memory_available, make_tensor_with_pad)

if not is_fake_hpu():
import habana_frameworks.torch as htorch

import torch

from vllm.attention import AttentionMetadata, get_attn_backend
Expand All @@ -31,8 +36,6 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (HabanaMemoryProfiler, format_bytes,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -151,7 +154,8 @@ class HpuModelAdapter():

def __init__(self, model, enforce_eager):
self.model = model
if not htorch.utils.internal.is_lazy() and not enforce_eager:
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
Expand Down Expand Up @@ -380,7 +384,9 @@ def __init__(
if model_config is not None else None)
self.device_config = (device_config
if device_config is not None else DeviceConfig())

if is_fake_hpu():
device_config.device = torch.device('cpu')
device_config.device_type = 'cpu'
self.device = self.device_config.device
self.enforce_eager = self.model_config.enforce_eager
self.max_num_seqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -1048,7 +1054,8 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt,
self.create_dummy_seq_group_metadata(i, seq_len, is_prompt)
for i in range(batch_size)
]
torch.hpu.synchronize()
if not is_fake_hpu():
torch.hpu.synchronize()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches)
Expand Down Expand Up @@ -1220,6 +1227,8 @@ def mem_margin(self, value):


def _maybe_wrap_in_hpu_graph(*args, **kwargs):
if is_fake_hpu():
return HpuModelAdapter(*args, **kwargs)
return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(
*args, **
kwargs)) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
Expand Down Expand Up @@ -1403,7 +1412,8 @@ def execute_model(
if multi_modal_input is not None:
execute_model_kwargs.update(multi_modal_input)

htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{'prompt' if is_prompt else 'decode'}_"
Expand All @@ -1428,7 +1438,8 @@ def execute_model(
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
Expand All @@ -1444,7 +1455,8 @@ def execute_model(
sampling_metadata=sampling_metadata,
)
output.outputs = output.outputs[:real_batch_size]
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()

if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
Expand Down
23 changes: 17 additions & 6 deletions vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import os
from typing import List, Optional, Set, Tuple

import habana_frameworks.torch as htorch # noqa:F401
from vllm.utils import HabanaMemoryProfiler, format_bytes, is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch # noqa:F401

import torch
import torch.distributed

Expand All @@ -21,7 +25,6 @@
from vllm.model_executor import set_random_seed
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import HabanaMemoryProfiler, format_bytes
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.habana_model_runner import HabanaModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
Expand Down Expand Up @@ -95,6 +98,8 @@ def init_device(self) -> None:
if self.device_config.device.type == "hpu":
self.device = torch.device("hpu")
torch.hpu.set_device(self.device)
elif self.device_config.device_type == "cpu":
self.device = torch.device("cpu")
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -126,6 +131,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
if is_fake_hpu():
return 128, 0
with HabanaMemoryProfiler() as m:
self.model_runner.profile_run()
torch.hpu.synchronize()
Expand Down Expand Up @@ -184,7 +191,8 @@ def initialize_cache(self, num_gpu_blocks: int,

with HabanaMemoryProfiler() as m:
self._init_cache_engine()
torch.hpu.synchronize()
if not is_fake_hpu():
torch.hpu.synchronize()
msg = ("Initializing cache engine "
f"took {m.get_summary_string()}")
logger.info(msg)
Expand Down Expand Up @@ -311,11 +319,12 @@ def init_worker_distributed_environment(
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
backend = 'hccl' if not is_fake_hpu() else 'gloo'
init_distributed_environment(parallel_config.world_size,
rank,
distributed_init_method,
local_rank,
backend='hccl')
backend=backend)

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand All @@ -332,15 +341,17 @@ def init_worker_distributed_environment(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
backend = 'hccl' if not is_fake_hpu() else 'gloo'
torch.distributed.init_process_group(
backend="hccl",
backend=backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)

# A small all_reduce for warmup & checking conformance.
dummy_tensor_hpu = torch.ones(1).to('hpu')
device = 'hpu' if not is_fake_hpu() else 'cpu'
dummy_tensor_hpu = torch.ones(1).to(device)
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
Expand Down

0 comments on commit afffe33

Please sign in to comment.