Skip to content

Commit

Permalink
collective rpc function signature sanity
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Dec 10, 2024
1 parent ab6bf27 commit 819b229
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
8 changes: 6 additions & 2 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -40,5 +40,9 @@ def check_health(self) -> None:
raise NotImplementedError

@abstractmethod
def collective_rpc(self, method: str, *args, **kwargs) -> [Optional]:
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
raise NotImplementedError
35 changes: 21 additions & 14 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing.process import BaseProcess
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import zmq

Expand Down Expand Up @@ -81,7 +81,7 @@ def initialize(self, num_gpu_blocks: int) -> None:
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_num_available_blocks(self) -> Tuple[int, int]:
Expand All @@ -99,23 +99,30 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

return num_gpu_blocks, num_cpu_blocks

def collective_rpc(self, method: str, *args, **kwargs) -> [Optional]:
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
"""
Execute an RPC call on workers.
Args:
method: Name of the Worker method to execute
*args: Arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Special kwargs:
- timeout: Optional[float] - Max time to wait for execution
Not passed into worker method
method: Name of the worker method to execute
timeout: Maximum time in seconds to wait for execution. Rases a
TimeoutError on timeout. None means wait indefinitely.
args: Positional arguments to pass to the worker method
kwargs: Keyword arguments to pass to the worker method
Returns:
List of results from each worker
"""
self.rpc_broadcast_mq.enqueue((method, args, kwargs))

timeout = kwargs.pop('timeout', None) # Default None if not present
start_time = time.monotonic()
kwargs = kwargs or {}

try:
self.rpc_broadcast_mq.enqueue((method, args, kwargs))

responses = [None] * self.world_size
for w in self.workers:
dequeue_timeout = timeout - (time.monotonic() - start_time()
Expand Down Expand Up @@ -143,11 +150,11 @@ def execute_model(
scheduler_output,
) -> ModelRunnerOutput:
model_output = self.collective_rpc("execute_model",
scheduler_output)[0]
args=(scheduler_output, ))[0]
return model_output

def profile(self, is_start=True):
self.collective_rpc("profile", is_start)
self.collective_rpc("profile", args=(is_start, ))
return

def _ensure_worker_termination(self):
Expand Down

0 comments on commit 819b229

Please sign in to comment.