Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SW-216413] - Fix new executors shutdown and shutdown_inc flow #716

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'or equal to the number of GPUs available, "mp" will be used to '
'keep processing on a single host. Otherwise, this will default '
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
'and hpu only support Ray for distributed inference.')
'only support Ray for distributed inference.')

parser.add_argument(
'--worker-use-ray',
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/mp_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ def _init_executor(self) -> None:
max_parallel_loading_workers)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.shutdown_workers = True

def shutdown(self):
if getattr(self, 'shutdown_workers', False):
self._run_workers("shutdown")
self.shutdown_workers = False
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
Expand Down
57 changes: 0 additions & 57 deletions vllm/executor/multiproc_hpu_executor.py

This file was deleted.

10 changes: 10 additions & 0 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,17 @@ def _init_executor(self) -> None:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)

self.shutdown_workers = True
self.terminate_ray = True

def shutdown(self) -> None:
if getattr(self, 'shutdown_workers', False):
self._run_workers("shutdown")
self.shutdown_workers = False
if getattr(self, 'terminate_ray', False):
for worker in self.workers:
worker.__ray_terminate__.remote()
self.terminate_ray = False
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
Expand Down
9 changes: 9 additions & 0 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _init_executor(self) -> None:
self.collective_rpc("init_device")
self.collective_rpc("load_model")

self.shutdown_worker = True

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
Expand All @@ -54,6 +56,11 @@ def check_health(self) -> None:
# it's running.
return

def shutdown(self):
if getattr(self, 'shutdown_worker', False):
self.collective_rpc("shutdown")
self.shutdown_worker = False


UniProcExecutorAsync = UniProcExecutor

Expand Down Expand Up @@ -112,6 +119,8 @@ def _init_executor(self) -> None:
self.collective_rpc("init_device")
self.collective_rpc("load_model")

self.shutdown_worker = True

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks.
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")

def shutdown_inc(self):
def shutdown(self):
self.model_runner.shutdown_inc()

@property
Expand Down
4 changes: 4 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def execute_model(
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError

def shutdown(self) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this method needed here?

This comment was marked as resolved.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to be able to call the worker's shutdown function from the executors.
but the executors are not device specific therefor there should be some generic "shutdown" function to all workers, and I just adjusted the hpu worker shutdown to do something else.

"""Shutdown the worker."""
return

@abstractmethod
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
Expand Down
Loading