From afffe330716672a36af56d1853e65d9719a62449 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 13 Aug 2024 14:49:02 +0300 Subject: [PATCH] Add fake HPU mode --- vllm/hpu/cache_ops.py | 5 ++++- vllm/hpu/ops.py | 8 ++++++-- vllm/hpu/utils.py | 5 ++++- vllm/utils.py | 25 +++++++++++++++++++++++ vllm/worker/habana_model_runner.py | 32 ++++++++++++++++++++---------- vllm/worker/habana_worker.py | 23 +++++++++++++++------ 6 files changed, 78 insertions(+), 20 deletions(-) diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 14824945aa53a..a69105e18c3bd 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -import habana_frameworks.torch as htorch +from vllm.utils import is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch import torch diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 7a40e6e720259..f2ea8202e0487 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -7,14 +7,18 @@ import os from typing import Optional -import habana_frameworks.torch as htorch +from vllm.utils import is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch + import torch import torch.nn.functional as F import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger -logger = init_logger() +logger = init_logger(__name__) HPUFusedRMSNorm = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index b7b435c50c295..2092eb3b99ad8 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -7,7 +7,10 @@ from functools import wraps -import habana_frameworks.torch as htorch +from vllm.utils import is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch def with_mark_steps(fn): diff --git a/vllm/utils.py b/vllm/utils.py index 8a1bc5de03eb7..ce6c0f621c263 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -207,10 +207,29 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) def is_hpu() -> bool: + return _is_habana_frameworks_installed() or _is_built_for_hpu() + + +@lru_cache(maxsize=None) +def is_fake_hpu() -> bool: + return not _is_habana_frameworks_installed() and _is_built_for_hpu() + + +@lru_cache(maxsize=None) +def _is_habana_frameworks_installed() -> bool: from importlib import util return util.find_spec('habana_frameworks') is not None +@lru_cache(maxsize=None) +def _is_built_for_hpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "gaudi" in version("vllm") + except PackageNotFoundError: + return False + + @lru_cache(maxsize=None) def is_tpu() -> bool: try: @@ -623,18 +642,24 @@ def __init__(self, device=None): @staticmethod def current_device_memory_usage() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory - free_hpu_memory @staticmethod def current_free_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, _ = torch.hpu.mem_get_info() return free_hpu_memory @staticmethod def total_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. _, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index d6a68ebc39eca..6d06ffbc00ba4 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -9,12 +9,17 @@ import math import operator import os -import time +import time from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) -import habana_frameworks.torch as htorch +from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu, + is_pin_memory_available, make_tensor_with_pad) + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch + import torch from vllm.attention import AttentionMetadata, get_attn_backend @@ -31,8 +36,6 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (HabanaMemoryProfiler, format_bytes, - is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -151,7 +154,8 @@ class HpuModelAdapter(): def __init__(self, model, enforce_eager): self.model = model - if not htorch.utils.internal.is_lazy() and not enforce_eager: + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', dynamic=False) @@ -380,7 +384,9 @@ def __init__( if model_config is not None else None) self.device_config = (device_config if device_config is not None else DeviceConfig()) - + if is_fake_hpu(): + device_config.device = torch.device('cpu') + device_config.device_type = 'cpu' self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs @@ -1048,7 +1054,8 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) for i in range(batch_size) ] - torch.hpu.synchronize() + if not is_fake_hpu(): + torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) self.execute_model(inputs, kv_caches) @@ -1220,6 +1227,8 @@ def mem_margin(self, value): def _maybe_wrap_in_hpu_graph(*args, **kwargs): + if is_fake_hpu(): + return HpuModelAdapter(*args, **kwargs) return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter( *args, ** kwargs)) if htorch.utils.internal.is_lazy() else HpuModelAdapter( @@ -1403,7 +1412,8 @@ def execute_model( if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() if self.is_driver_worker: model_event_name = ("model_" f"{'prompt' if is_prompt else 'decode'}_" @@ -1428,7 +1438,8 @@ def execute_model( sampling_metadata.selected_token_indices = None logits = self.model.compute_logits(hidden_states, sampling_metadata) - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] @@ -1444,7 +1455,8 @@ def execute_model( sampling_metadata=sampling_metadata, ) output.outputs = output.outputs[:real_batch_size] - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index f3fdc4dcc63c6..d3df7c026a8d0 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -6,7 +6,11 @@ import os from typing import List, Optional, Set, Tuple -import habana_frameworks.torch as htorch # noqa:F401 +from vllm.utils import HabanaMemoryProfiler, format_bytes, is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch # noqa:F401 + import torch import torch.distributed @@ -21,7 +25,6 @@ from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import HabanaMemoryProfiler, format_bytes from vllm.worker.cache_engine import CacheEngine from vllm.worker.habana_model_runner import HabanaModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -95,6 +98,8 @@ def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") torch.hpu.set_device(self.device) + elif self.device_config.device_type == "cpu": + self.device = torch.device("cpu") else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -126,6 +131,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. + if is_fake_hpu(): + return 128, 0 with HabanaMemoryProfiler() as m: self.model_runner.profile_run() torch.hpu.synchronize() @@ -184,7 +191,8 @@ def initialize_cache(self, num_gpu_blocks: int, with HabanaMemoryProfiler() as m: self._init_cache_engine() - torch.hpu.synchronize() + if not is_fake_hpu(): + torch.hpu.synchronize() msg = ("Initializing cache engine " f"took {m.get_summary_string()}") logger.info(msg) @@ -311,11 +319,12 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + backend = 'hccl' if not is_fake_hpu() else 'gloo' init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, - backend='hccl') + backend=backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -332,15 +341,17 @@ def init_worker_distributed_environment( "distributed_init_method must be set if torch.distributed " "is not already initialized") else: + backend = 'hccl' if not is_fake_hpu() else 'gloo' torch.distributed.init_process_group( - backend="hccl", + backend=backend, world_size=parallel_config.world_size, rank=rank, init_method=distributed_init_method, ) # A small all_reduce for warmup & checking conformance. - dummy_tensor_hpu = torch.ones(1).to('hpu') + device = 'hpu' if not is_fake_hpu() else 'cpu' + dummy_tensor_hpu = torch.ones(1).to(device) torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item() == parallel_config.world_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,