diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 723495c700412..12256f921f69d 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -24,6 +24,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") + @pytest.fixture(autouse=True) def v1(run_with_both_engines): # Simple autouse wrapper to run both engines for each test @@ -31,6 +32,7 @@ def v1(run_with_both_engines): # test in a package pass + def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" llm = LLM("facebook/opt-125m") diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 2ff1a1ead99c1..c737cc98b3427 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,12 +1,14 @@ import os -import pickle +import struct +import sys import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import List, Optional +from typing import List, Optional, Tuple from unittest.mock import patch +import msgspec import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -21,6 +23,13 @@ logger = init_logger(__name__) +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on python < 3.11.1, +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = False +if sys.version_info[:3] >= (3, 11, 1): + USE_SCHED_YIELD = True + class ShmRingBuffer: @@ -74,7 +83,7 @@ def __init__(self, NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. During creation, `name` is None and the buffer is created. We can pass the - created object to other processes by pickling it. The other processes will + created object to other processes by serializing it. The other processes will get the name of the shared memory and open it, so that they can access the same shared memory buffer. """# noqa @@ -114,6 +123,10 @@ def __init__(self, # and we should suppress the error pass + def handle(self): + return (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name) + def __reduce__(self): return ( self.__class__, @@ -147,13 +160,18 @@ class Handle: connect_ip: str local_reader_ranks: List[int] = field(default_factory=list) - buffer: Optional[ShmRingBuffer] = None + buffer_handle: Optional[Tuple[int, int, int, str]] = None local_subscribe_port: Optional[int] = None remote_subscribe_port: Optional[int] = None class MessageQueue: + # Use 4 bytes to store size of each message (we omit this for ZMQ). + # This is needed for decoding the message. + SIZE_PREFIX_FORMAT = '!I' # unsigned int, 4 bytes, network byte order + SIZE_PREFIX_LEN = struct.calcsize(SIZE_PREFIX_FORMAT) + def __init__( self, n_reader, # number of all readers @@ -228,7 +246,7 @@ def __init__( self.handle = Handle( connect_ip=connect_ip, local_reader_ranks=local_reader_ranks, - buffer=self.buffer, + buffer_handle=self.buffer.handle(), local_subscribe_port=local_subscribe_port, remote_subscribe_port=remote_subscribe_port, ) @@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": context = Context() if rank in handle.local_reader_ranks: - assert handle.buffer is not None - self.buffer = handle.buffer + assert handle.buffer_handle is not None + self.buffer = ShmRingBuffer(*handle.buffer_handle) self.current_idx = 0 self.local_reader_rank = handle.local_reader_ranks.index(rank) self._is_local_reader = True @@ -329,7 +347,10 @@ def acquire_write(self): # we need to wait until it is read by all readers # Release the processor to other threads - os.sched_yield() + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(1e-5) # if we wait for a long time, we should warn the user if (time.monotonic() - start_time > @@ -383,7 +404,10 @@ def acquire_read(self): # we need to wait until it is written # Release the processor to other threads - os.sched_yield() + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) # if we wait for a long time, we should warn the user if (time.monotonic() - start_time > @@ -408,34 +432,49 @@ def acquire_read(self): def enqueue(self, obj): assert self._is_writer, "Only writers can enqueue" - serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + + encoder = msgspec.msgpack.Encoder() + serialized_obj = encoder.encode(obj) + size_to_write = self.SIZE_PREFIX_LEN + len(serialized_obj) + if self.n_local_reader > 0: - if len(serialized_obj) >= self.buffer.max_chunk_bytes: + if size_to_write >= self.buffer.max_chunk_bytes: with self.acquire_write() as buf: buf[0] = 1 # overflow self.local_socket.send(serialized_obj) else: with self.acquire_write() as buf: buf[0] = 0 # not overflow - buf[1:len(serialized_obj) + 1] = serialized_obj + obj_offset = 1 + self.SIZE_PREFIX_LEN + + # Write size prefix + buf[1:obj_offset] = struct.pack(self.SIZE_PREFIX_FORMAT, + len(serialized_obj)) + + buf[obj_offset:obj_offset + + len(serialized_obj)] = serialized_obj if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self): + def dequeue(self, obj_type): + decoder = msgspec.msgpack.Decoder(obj_type) + if self._is_local_reader: with self.acquire_read() as buf: overflow = buf[0] == 1 if not overflow: - # no need to know the size of serialized object - # pickle format contains the size information internally - # see https://docs.python.org/3/library/pickle.html - obj = pickle.loads(buf[1:]) + obj_offset = 1 + self.SIZE_PREFIX_LEN + size_bytes = buf[1:obj_offset] + msg_size = struct.unpack(self.SIZE_PREFIX_FORMAT, + size_bytes)[0] + + obj = decoder.decode(buf[obj_offset:obj_offset + msg_size]) if overflow: recv = self.local_socket.recv() - obj = pickle.loads(recv) + obj = decoder.decode(recv) elif self._is_remote_reader: recv = self.remote_socket.recv() - obj = pickle.loads(recv) + obj = decoder.decode(recv) else: raise RuntimeError("Only readers can dequeue") return obj @@ -445,7 +484,7 @@ def broadcast_object(self, obj=None): self.enqueue(obj) return obj else: - return self.dequeue() + return self.dequeue(obj) @staticmethod def create_from_process_group(pg: ProcessGroup, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 87ade377266a2..c4d6b2df3a68d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1066,11 +1066,13 @@ def initialize_model_parallel( group_ranks.append(ranks) # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="tp") + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + #TODO: this is not getting cleaned up. + use_message_queue_broadcaster=False, + group_name="tp") # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // diff --git a/vllm/utils.py b/vllm/utils.py index 1b02cbff79f78..016fc75abbff5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -19,11 +19,12 @@ import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from collections.abc import Mapping +from contextlib import contextmanager from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, - Type, TypeVar, Union, overload) + Hashable, Iterator, List, Literal, Optional, OrderedDict, + Set, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -32,6 +33,7 @@ import torch import torch.types import yaml +import zmq from packaging.version import Version from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never @@ -515,6 +517,30 @@ def get_open_zmq_ipc_path() -> str: return f"ipc://{base_rpc_path}/{uuid4()}" +@contextmanager +def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]: + """Context manager for use """ + + ctx = zmq.Context() + try: + socket = ctx.socket(type) + + if type == zmq.constants.PULL: + socket.connect(path) + elif type == zmq.constants.PUSH: + socket.bind(path) + else: + raise ValueError(f"Unknown Socket Type: {type}") + + yield socket + + except KeyboardInterrupt: + logger.debug("Worker had Keyboard Interrupt.") + + finally: + ctx.destroy(linger=0) + + def get_open_port() -> int: port = envs.VLLM_PORT if port is not None: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index c3cf950133564..0fc8ae2250238 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,10 +1,11 @@ from collections import deque from dataclasses import dataclass -from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.base import PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -12,10 +13,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -if TYPE_CHECKING: - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.base import PlaceholderRange - logger = init_logger(__name__) @@ -508,8 +505,8 @@ class NewRequestData: req_id: str prompt_token_ids: List[int] prompt: Optional[str] - mm_inputs: List["MultiModalKwargs"] - mm_positions: List["PlaceholderRange"] + mm_inputs: List[MultiModalKwargs] + mm_positions: List[PlaceholderRange] sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py new file mode 100644 index 0000000000000..9acdf6b9c2bc8 --- /dev/null +++ b/vllm/v1/core/scheduler_output.py @@ -0,0 +1,24 @@ +from enum import Enum, auto +from typing import Optional + +import msgspec + +from vllm.v1.core.scheduler import SchedulerOutput + + +#TODO: Move this file +class ExecutorMsgType(Enum): + TOIL = auto() + TERMINATE = auto() + + +class ExecutorMsg(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False): + """A directive from the core process to its worker processes. + + Wraps SchedulerOutput with a message type to distinguish between + regular work assignments and termination orders.""" + message_type: ExecutorMsgType + payload: Optional[SchedulerOutput] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 35ed131d50de9..70c37ceaa360b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -2,10 +2,9 @@ import queue import threading import time -from contextlib import contextmanager from multiprocessing.process import BaseProcess from multiprocessing.sharedctypes import Synchronized -from typing import Any, Iterator, List, Tuple, Type, Union +from typing import List, Tuple, Type, Union import zmq import zmq.asyncio @@ -14,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext +from vllm.utils import make_zmq_socket from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) @@ -126,6 +126,9 @@ def step(self) -> List[EngineCoreOutput]: scheduler_output, output) return engine_core_outputs + def shutdown(self): + self.model_executor.shutdown() + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -162,32 +165,9 @@ def __init__( daemon=True).start() # Send Readiness signal to EngineClient. - with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket: + with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: ready_socket.send_string(EngineCoreProc.READY_STR) - @contextmanager - def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]: - """Context manager for use """ - - ctx = zmq.Context() - try: - socket = ctx.socket(type) - - if type == zmq.constants.PULL: - socket.connect(path) - elif type == zmq.constants.PUSH: - socket.bind(path) - else: - raise ValueError(f"Unknown Socket Type: {type}") - - yield socket - - except KeyboardInterrupt: - logger.debug("EngineCore had Keyboard Interrupt.") - - finally: - ctx.destroy(linger=0) - @staticmethod def wait_for_startup( proc: BaseProcess, @@ -329,7 +309,7 @@ def process_input_socket(self, input_path: str): decoder_add_req = PickleEncoder() decoder_abort_req = PickleEncoder() - with self.make_socket(input_path, zmq.constants.PULL) as socket: + with self.make_zmq_socket(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) @@ -355,7 +335,7 @@ def process_output_socket(self, output_path: str): # Reuse send buffer. buffer = bytearray() - with self.make_socket(output_path, zmq.constants.PUSH) as socket: + with self.make_zmq_socket(output_path, zmq.constants.PUSH) as socket: while True: engine_core_outputs = self.output_queue.get() outputs = EngineCoreOutputs(outputs=engine_core_outputs) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 09801e20e16ca..fbdbcd9777557 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -95,6 +95,9 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) + def shutdown(self): + self.engine_core.shutdown() + class MPClient(EngineCoreClient): """ diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c71680f1055e0..c02f27ac3f714 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Union from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -17,7 +17,6 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor -from vllm.v1.executor.gpu_executor import GPUExecutor logger = init_logger(__name__) @@ -113,7 +112,7 @@ def _get_executor_cls(cls, vllm_config: VllmConfig): return executor_class def stop_remote_worker_execution_loop(self) -> None: - raise NotImplementedError("TP not implemented yet.") + self.engine_core.shutdown() def get_num_unfinished_requests(self) -> int: return self.detokenizer.get_num_unfinished_requests() diff --git a/vllm/v1/executor/multiproc_gpu_executor.py b/vllm/v1/executor/multiproc_gpu_executor.py index 0b07368cc561c..a6c89a9c0683e 100644 --- a/vllm/v1/executor/multiproc_gpu_executor.py +++ b/vllm/v1/executor/multiproc_gpu_executor.py @@ -1,19 +1,18 @@ import os -from functools import partial -from typing import Any, List, Optional, Tuple +from concurrent.futures import ThreadPoolExecutor +from typing import List, Tuple import torch from vllm.config import VllmConfig from vllm.distributed.device_communicators.shm_broadcast import MessageQueue -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.triton_utils import maybe_set_triton_cache_manager -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, +from vllm.utils import (get_distributed_init_method, get_open_port, get_vllm_instance_id) +from vllm.v1.core.scheduler_output import ExecutorMsg, ExecutorMsgType from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.gpu_worker import MultiprocessingWorker +from vllm.v1.worker.gpu_worker import WorkerProc, WorkerProcHandle logger = init_logger(__name__) @@ -21,6 +20,9 @@ class MultiprocessingGPUExecutor: def __init__(self, vllm_config: VllmConfig) -> None: + # Store early so we can count on using it at shutdown + self._TERMINATE_VALUE = ExecutorMsgType.TERMINATE.value + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -74,100 +76,50 @@ def __init__(self, vllm_config: VllmConfig) -> None: distributed_init_method = get_distributed_init_method( "127.0.0.1", get_open_port()) + # Initialize worker and set up message queues for SchedulerOutputs + # and ModelRunnerOutputs + self.scheduler_output_mq = MessageQueue(world_size, world_size) + scheduler_output_handle = self.scheduler_output_mq.export_handle() + # Create workers - self.workers: List[ProcessWorkerWrapper] = [] - result_handler = ResultHandler() + self.workers: List[WorkerProcHandle] = [] for rank in range(world_size): - worker = ProcessWorkerWrapper( - result_handler, - partial( - self._create_worker, - **dict( - rank=rank, - local_rank=rank, - distributed_init_method=distributed_init_method, - ))) + worker = WorkerProc.make_worker_process(vllm_config, rank, rank, + distributed_init_method, + scheduler_output_handle) self.workers.append(worker) - self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() - self.worker_monitor.start() - - self._run_workers("initialize") - self._run_workers("load_model") - - # Initialize worker and set up message queues for SchedulerOutputs - # and ModelRunnerOutputs - self.scheduler_output_sender = MessageQueue(world_size, world_size) - model_output_receiver_handle = self._run_workers( - "initialize_message_queues", - self.scheduler_output_sender.export_handle())[0] - self.model_output_receiver = MessageQueue.create_from_handle( - model_output_receiver_handle, 0) - - # Message queues are not valid until all readers and writers call - # wait_until_ready() - wait_futures = self._run_workers("finish_message_queue_initialization", - run_async=True) - self.scheduler_output_sender.wait_until_ready() - self.model_output_receiver.wait_until_ready() - for output in wait_futures: - output.get() - - # Flag that's set if workers are waiting in the main execution loop + model_output_mq_handle = self.workers[0].model_output_mq_handle + self.model_output_mq = MessageQueue.create_from_handle( + model_output_mq_handle, 0) self.workers_in_busy_loop = False - def _create_worker( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None - ) -> MultiprocessingWorker: - """Return worker init args for a given rank.""" - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ['NCCL_CUMEM_ENABLE'] = '0' - - if distributed_init_method is None: - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - return MultiprocessingWorker( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - ) - - def _run_workers( - self, - method: str, - *args, - run_async: bool = False, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - run_async: If True the method will be run asynchronously and return - a list of futures rather than blocking on the results. - """ - - worker_outputs = [ - worker.execute_method(method, *args, **kwargs) - for worker in self.workers - ] + def start_workers(self): + for w in self.workers: + w.start_busy_loop() + self.scheduler_output_mq.wait_until_ready() + self.model_output_mq.wait_until_ready() + self.workers_in_busy_loop = True + + def run_on_workers(self, fn: str, *args): + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(getattr(type(w), fn), w, *args) + for w in self.workers + ] + result = [f.result() for f in futures] # Wait for all to complete + return result - if run_async: - return worker_outputs - else: - return [output.get() for output in worker_outputs] + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initialize the KV caches by invoking the underlying worker.""" + self.run_on_workers('initialize_cache', num_gpu_blocks) def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks") + num_blocks = self.run_on_workers('determine_num_available_blocks') # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory @@ -177,29 +129,29 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return num_gpu_blocks, num_cpu_blocks - def initialize_cache(self, num_gpu_blocks: int) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d", num_gpu_blocks) - self._run_workers("initialize_cache", num_gpu_blocks) - self._run_workers("compile_or_warm_up_model") - def execute_model( self, scheduler_output, ) -> ModelRunnerOutput: - # TODO: Find a better way to start this loop if not self.workers_in_busy_loop: - self._run_workers("execute_model_busy_loop", run_async=True) - self.workers_in_busy_loop = True + self.start_workers() - self.scheduler_output_sender.enqueue(scheduler_output) - model_output = self.model_output_receiver.dequeue() + self.scheduler_output_mq.enqueue( + ExecutorMsg(ExecutorMsgType.TOIL.value, scheduler_output)) + model_output = self.model_output_mq.dequeue(ModelRunnerOutput) return model_output + def shutdown(self): + """Properly shut down the executor and its workers""" + termination_msg = ExecutorMsg(self._TERMINATE_VALUE, None) + self.scheduler_output_mq.enqueue(termination_msg) + + # Shutdown the worker processes if needed. + self.run_on_workers('terminate') + + def __del__(self): + self.shutdown() + def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 212d44401744e..faa735cac45e6 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,8 +1,12 @@ from dataclasses import dataclass -from typing import Dict, List, Optional +import enum +from typing import Dict, List, Optional, Tuple +import msgspec import torch +from vllm.distributed.device_communicators.shm_broadcast import Handle + @dataclass class SamplerOutput: @@ -20,10 +24,12 @@ class SamplerOutput: prompt_logprobs: Optional[torch.Tensor] -# ModelRunnerOutput is pickeled and sent to the scheduler process. +# ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use List instead. -@dataclass -class ModelRunnerOutput: +class ModelRunnerOutput(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False): # [num_reqs] req_ids: List[str] @@ -37,3 +43,36 @@ class ModelRunnerOutput: logprob_token_ids_cpu: Optional[torch.Tensor] # [num_reqs, max_num_logprobs + 1] logprobs_cpu: Optional[torch.Tensor] + + +# Below are data structures used for serializing initiailization-related +# data structures to send between workers and the core engine process +class NumBlocksMsg(msgspec.Struct): + num_blocks: Tuple[int, int] + + +class NumGPUBlocks(msgspec.Struct): + num_gpu_blocks: int + + +class ShmHandleMsg(msgspec.Struct): + handle: Handle + + +class WorkerInitRequestType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. + """ + DETERMINE_NUM_BLOCKS = b'\x00' + INIT_CACHE = b'\x01' + BEGIN_MODEL_EXECUTION = b'\x02' + + +class WorkerInitOutputType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. + """ + NUM_BLOCKS = b'\x00' + MODEL_OUTPUT_MSG_QUEUE = b'\x01' diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0ebdada9d51bf..d3aff39089132 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,25 +1,40 @@ """A GPU worker class.""" import gc +import multiprocessing import os +import time +from dataclasses import dataclass +from multiprocessing.process import BaseProcess from typing import TYPE_CHECKING, Optional, Tuple +import msgspec import torch import torch.distributed +import zmq from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size -from vllm.v1.outputs import ModelRunnerOutput +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, + get_open_zmq_ipc_path, make_zmq_socket) +from vllm.v1.core.scheduler_output import ExecutorMsg, ExecutorMsgType +from vllm.v1.outputs import (ModelRunnerOutput, NumBlocksMsg, NumGPUBlocks, + ShmHandleMsg, WorkerInitOutputType, + WorkerInitRequestType) from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) +POLLING_TIMEOUT_MS = 5000 +POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +LOGGING_TIME_S = 5000 + if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -173,8 +188,63 @@ def execute_model( return output -# Wraps Worker for the multiprocessing, multi-gpu case. -class MultiprocessingWorker: +@dataclass +class WorkerProcHandle: + proc: BaseProcess + initialization_input_path: str + initialization_output_path: str + model_output_mq_handle: Optional[Handle] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + with make_zmq_socket(self.initialization_output_path, + zmq.constants.PUSH) as send_socket, \ + make_zmq_socket(self.initialization_input_path, + zmq.constants.PULL) as recv_socket: + + send_socket.send_multipart( + (WorkerInitRequestType.DETERMINE_NUM_BLOCKS.value, )) + type_frame, data_frame = recv_socket.recv_multipart(copy=False) + + request_type = type_frame.buffer + request_data = data_frame.buffer + + if request_type == WorkerInitOutputType.NUM_BLOCKS.value: + decoder = msgspec.msgpack.Decoder(NumBlocksMsg) + num_blocks = decoder.decode(request_data).num_blocks + return num_blocks + else: + raise ValueError(f"Unknown RequestType: {request_type}") + + def initialize_cache(self, num_gpu_blocks: int) -> int: + with make_zmq_socket(self.initialization_output_path, + zmq.constants.PUSH) as socket: + encoder = msgspec.msgpack.Encoder() + msg = encoder.encode(NumGPUBlocks(num_gpu_blocks)) + socket.send_multipart( + (WorkerInitRequestType.INIT_CACHE.value, msg)) + + def start_busy_loop(self) -> None: + with make_zmq_socket(self.initialization_output_path, + zmq.constants.PUSH) as socket: + socket.send_multipart( + (WorkerInitRequestType.BEGIN_MODEL_EXECUTION.value, )) + + def terminate(self) -> None: + self.proc.terminate() + start_time = time.time() + + while time.time() - start_time < 5: + if not self.proc.is_alive(): + return # Process terminated successfully + time.sleep(0.1) # Short sleep to avoid CPU spinning + + self.proc.kill() + + +class WorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + READY_STR = "READY" def __init__( self, @@ -182,31 +252,198 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, + input_shm_handle: Handle, + initialization_input_path: str, + initialization_output_path: str, + ready_path: str, ): + self.rank = rank self.worker = Worker(vllm_config, local_rank, rank, distributed_init_method) - def initialize_message_queues(self, scheduler_output_receiver_handle): # Initialize MessageQueue for receiving SchedulerOutput - # Add 1 rank to account for driver process self.scheduler_output_receiver = MessageQueue.create_from_handle( - scheduler_output_receiver_handle, self.worker.rank) - - # Initialize group coordinator for sending the ModelRunnerOutput - # to the driver process - if self.worker.rank == 0: - self.model_output_sender = MessageQueue(1, 1) - return self.model_output_sender.export_handle() + input_shm_handle, self.worker.rank) + + # Send Readiness signal to EngineCore process. + logger.info("sending ready.") + with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: + ready_socket.send_string(WorkerProc.READY_STR) + + # Worker 0 initializes a message queue for sending the model output + if self.rank == 0: + self.model_output_mq = MessageQueue(1, 1) + output_mq_handle = self.model_output_mq.export_handle() + with make_zmq_socket(initialization_output_path, + zmq.constants.PUSH) as socket: + encoder = msgspec.msgpack.Encoder() + msg = encoder.encode(ShmHandleMsg(output_mq_handle)) + socket.send_multipart( + (WorkerInitOutputType.MODEL_OUTPUT_MSG_QUEUE.value, msg)) else: - self.model_output_sender = None - return None + self.model_output_mq = None - # Message queues are not valid until all readers and writers call - # wait_until_ready() - def finish_message_queue_initialization(self): - self.scheduler_output_receiver.wait_until_ready() - if self.worker.rank == 0: - self.model_output_sender.wait_until_ready() + logger.info("initializing and loading model.") + self.worker.initialize() + self.worker.load_model() + + # TODO: WHY is this needed? + def __del__(self): + if hasattr(self, "model_output_mq"): + del self.model_output_mq + + @staticmethod + def make_worker_process( + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + ) -> WorkerProcHandle: + # The current process might have CUDA context, + # so we need to spawn a new process. + # NOTE(rob): this is a problem for using EngineCoreProc w/ + # LLM, since we need a if __name__ == "__main__" guard. + + # TODO(tms): fix before landing + context = multiprocessing.get_context("fork") + + # ZMQ paths to send back and forth to worker process + # Used for initialization. + initialization_input_path = get_open_zmq_ipc_path() + initialization_output_path = get_open_zmq_ipc_path() + ready_path = get_open_zmq_ipc_path() + + process_kwargs = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "input_shm_handle": input_shm_handle, + "ready_path": ready_path, + "initialization_input_path": initialization_output_path, + "initialization_output_path": initialization_input_path, + } + # Run EngineCore busy loop in background process. + proc = context.Process(target=WorkerProc.run_worker, + kwargs=process_kwargs) + proc.start() + + # Wait for startup + WorkerProc.wait_for_startup(proc, ready_path) + + # Read Shm MessageQueue from rank 0 + if rank == 0: + model_output_mq_handle = WorkerProc.read_model_output_mq_handle( + initialization_input_path) + else: + model_output_mq_handle = None + + return WorkerProcHandle(proc, initialization_input_path, + initialization_output_path, + model_output_mq_handle) + + @staticmethod + def run_worker(*args, **kwargs): + """Launch Worker busy loop in background process.""" + + try: + worker = WorkerProc(*args, **kwargs) + worker.model_initialization_loop( + kwargs["initialization_input_path"], + kwargs["initialization_output_path"]) + + worker.execute_model_busy_loop() + + except KeyboardInterrupt: + logger.debug("Worker interrupted.") + + except BaseException as e: + logger.exception(e) + raise e + finally: + # TODO: Why is this del needed? + del worker + exit(0) + + @staticmethod + def wait_for_startup( + proc: BaseProcess, + ready_path: str, + ) -> None: + """Wait until the Worker is ready.""" + with make_zmq_socket(ready_path, zmq.constants.PULL) as socket: + + # Wait for Worker to send Worker.READY_STR. + while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for WorkerProc to startup.") + + if not proc.is_alive(): + raise RuntimeError("WorkerProc failed to start.") + + message = socket.recv_string() + assert message == WorkerProc.READY_STR + + @staticmethod + def read_model_output_mq_handle(init_input_path: str, ) -> Handle: + with make_zmq_socket(init_input_path, + zmq.constants.PULL) as recv_socket: + type_frame, data_frame = recv_socket.recv_multipart(copy=False) + request_type = type_frame.buffer + request_data = data_frame.buffer + + if (request_type == + WorkerInitOutputType.MODEL_OUTPUT_MSG_QUEUE.value): + decoder = msgspec.msgpack.Decoder(ShmHandleMsg) + handle = decoder.decode(request_data).handle + return handle + else: + raise ValueError(f"Unknown RequestType: {request_type}") + + # Busy loop used for initializing Multiprocessing Workers + def model_initialization_loop(self, init_input_path, init_output_path): + # Msgpack serialization encoding. + encoder = msgspec.msgpack.Encoder() + # Reuse send buffer. + buffer = bytearray() + + with make_zmq_socket(init_output_path, + zmq.constants.PUSH) as send_socket, \ + make_zmq_socket(init_input_path, + zmq.constants.PULL) as recv_socket: + while True: + # (RequestType, RequestData) + thing = recv_socket.recv_multipart(copy=False) + request_type = thing[0].buffer + + # Deserialize the request data. + if (request_type == + WorkerInitRequestType.DETERMINE_NUM_BLOCKS.value): + num_blocks = self.worker.determine_num_available_blocks() + output = NumBlocksMsg(num_blocks) + encoder.encode_into(output, buffer) + send_socket.send_multipart( + (WorkerInitOutputType.NUM_BLOCKS.value, buffer), + copy=False) + elif request_type == WorkerInitRequestType.INIT_CACHE.value: + request_data = thing[1].buffer + decoder = msgspec.msgpack.Decoder(NumGPUBlocks) + num_gpu_blocks = decoder.decode( + request_data).num_gpu_blocks + self.worker.initialize_cache(num_gpu_blocks) + self.worker.compile_or_warm_up_model() + elif (request_type == + WorkerInitRequestType.BEGIN_MODEL_EXECUTION.value): + # Make sure message queues are ready. + self.scheduler_output_receiver.wait_until_ready() + + if self.model_output_mq is not None: + self.model_output_mq.wait_until_ready() + + # Exit initialization loop to begin model execution loop + return + else: + raise ValueError(f"Unknown RequestType: {request_type}") # Main busy loop for Multiprocessing Workers def execute_model_busy_loop(self): @@ -229,29 +466,20 @@ def execute_model_busy_loop(self): ) as p: while True: - scheduler_output = self.scheduler_output_receiver.dequeue() - output = self.worker.execute_model(scheduler_output) - if self.worker.rank == 0: - self.model_output_sender.enqueue(output) + msg = self.scheduler_output_receiver.dequeue(ExecutorMsg) + + if msg.message_type == ExecutorMsgType.TERMINATE: + return + elif msg.message_type == ExecutorMsgType.TOIL: + output = self.worker.execute_model(msg.payload) + if self.worker.rank == 0: + self.model_output_mq.enqueue(output) + else: + raise ValueError( + f"Unknown RequestType: {msg.message_type}") p.step() - # Wrapper methods defined here - def initialize(self): - self.worker.initialize() - - def load_model(self): - self.worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - return self.worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int) -> None: - self.worker.initialize_cache(num_gpu_blocks) - - def compile_or_warm_up_model(self) -> None: - self.worker.compile_or_warm_up_model() - def init_worker_distributed_environment( parallel_config: ParallelConfig,