From 215e8e5519d9f365fe64ae3136006e5fa31c0361 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 21 Jan 2025 19:33:35 +0200 Subject: [PATCH 1/3] [SW-216413] - Fix new executors shutdown and shutdown_inc flow --- vllm/engine/arg_utils.py | 2 +- vllm/executor/mp_distributed_executor.py | 4 ++ vllm/executor/multiproc_hpu_executor.py | 57 ----------------------- vllm/executor/ray_distributed_executor.py | 10 ++++ vllm/executor/uniproc_executor.py | 9 ++++ vllm/worker/hpu_worker.py | 2 +- vllm/worker/worker_base.py | 4 ++ 7 files changed, 29 insertions(+), 59 deletions(-) delete mode 100644 vllm/executor/multiproc_hpu_executor.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 09f89242cbdec..7aba440397790 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -407,7 +407,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'or equal to the number of GPUs available, "mp" will be used to ' 'keep processing on a single host. Otherwise, this will default ' 'to "ray" if Ray is installed and fail otherwise. Note that tpu ' - 'and hpu only support Ray for distributed inference.') + 'only support Ray for distributed inference.') parser.add_argument( '--worker-use-ray', diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index 8ae88e646aad6..8da97df13190c 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -91,8 +91,12 @@ def _init_executor(self) -> None: max_parallel_loading_workers) self.driver_exec_model = make_async(self.driver_worker.execute_model) self.pp_locks: Optional[List[asyncio.Lock]] = None + self.shutdown_workers = True def shutdown(self): + if getattr(self, 'shutdown_workers', False): + self._run_workers("shutdown") + self.shutdown_workers = False if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: worker_monitor.close() diff --git a/vllm/executor/multiproc_hpu_executor.py b/vllm/executor/multiproc_hpu_executor.py deleted file mode 100644 index a82fff956738f..0000000000000 --- a/vllm/executor/multiproc_hpu_executor.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Callable, Optional, Tuple, Type - -import habana_frameworks.torch # noqa: F401 -import torch - -from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) -from vllm.logger import init_logger -from vllm.utils import make_async -from vllm.worker.worker_base import WorkerBase - -logger = init_logger(__name__) - - -class MultiprocessingHPUExecutor(MultiprocessingGPUExecutor): - """Python multiprocessing-based multi-HPU executor""" - - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.scheduler_config.is_multi_step: - module_name = "vllm.worker.multi_step_hpu_worker" - class_name = "MultiStepHPUWorker" - elif self.speculative_config is not None: - module_name = "vllm.spec_decode.spec_decode_worker" - class_name = "create_spec_worker" - else: - module_name = "vllm.worker.hpu_worker" - class_name = "HPUWorker" - return (module_name, class_name, worker_class_fn) - - def _check_executor_parameters(self): - world_size = self.parallel_config.world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - - hpu_device_count = torch.hpu.device_count() - assert tensor_parallel_size <= hpu_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local hpu count ({hpu_device_count})") - - assert world_size <= hpu_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local hpu count ({hpu_device_count})") - - def shutdown_inc(self): - self._run_workers("shutdown_inc") - - def __del__(self): - self.shutdown() - - -class MultiprocessingHPUExecutorAsync(MultiprocessingHPUExecutor, - MultiprocessingGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 2afd99f99b353..d395887049a0d 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -98,7 +98,17 @@ def _init_executor(self) -> None: self.driver_exec_method = make_async( self.driver_worker.execute_method) + self.shutdown_workers = True + self.terminate_ray = True + def shutdown(self) -> None: + if getattr(self, 'shutdown_workers', False): + self._run_workers("shutdown") + self.shutdown_workers = False + if getattr(self, 'terminate_ray', False): + for worker in self.workers: + worker.__ray_terminate__.remote() + self.terminate_ray = False if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() import ray diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index a5c4dcf0ec7f9..540206a15a360 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -39,6 +39,8 @@ def _init_executor(self) -> None: self.collective_rpc("init_device") self.collective_rpc("load_model") + self.shutdown_worker = True + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, @@ -54,6 +56,11 @@ def check_health(self) -> None: # it's running. return + def shutdown(self): + if getattr(self, 'shutdown_worker', False): + self.collective_rpc("shutdown") + self.shutdown_worker = False + UniProcExecutorAsync = UniProcExecutor @@ -112,6 +119,8 @@ def _init_executor(self) -> None: self.collective_rpc("init_device") self.collective_rpc("load_model") + self.shutdown_worker = True + def determine_num_available_blocks(self) -> Tuple[int, int]: """ Determine the number of available KV blocks. diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index a83039054fc78..969971f2e25cd 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -475,7 +475,7 @@ def list_prompt_adapters(self) -> Set[int]: raise NotImplementedError( "Prompt Adapter is not implemented for HPU backend.") - def shutdown_inc(self): + def shutdown(self): self.model_runner.shutdown_inc() @property diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fb9919f7a7b6a..f434c7082bd2b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -96,6 +96,10 @@ def execute_model( ) -> Optional[List[SamplerOutput]]: raise NotImplementedError + def shutdown(self) -> None: + """Shutdown the worker.""" + return + @abstractmethod def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in From 2f7f5c0d6bdc212d8eb7e838b2b0536eb555987d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Kuligowski?= Date: Wed, 22 Jan 2025 10:28:31 +0100 Subject: [PATCH 2/3] Update requirements-hpu.txt --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index ab4b823784bdc..1eace1a23f236 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@01090a8 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e1db53b From 6b16c2b079a9ad3b4e34360aabbdd399a08f7281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Kuligowski?= Date: Wed, 22 Jan 2025 14:23:50 +0100 Subject: [PATCH 3/3] Update requirements-hpu.txt --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 1eace1a23f236..cefea150d7042 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e1db53b +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@d4f37bb