diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c1b10b3cf8f58..990548c247822 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -414,6 +414,7 @@ def main(args: argparse.Namespace): for request in requests) total_output_tokens = sum(request.expected_output_len for request in requests) + total_input_tokens = total_num_tokens - total_output_tokens if is_multi_modal: print("\033[91mWARNING\033[0m: Multi-modal request detected. The " "following metrics are not accurate because image tokens are not" @@ -421,7 +422,9 @@ def main(args: argparse.Namespace): # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + f"{total_output_tokens / elapsed_time:.2f} output tokens/s, " + f"{total_input_tokens / len(requests)} input tokens/req, " + f"{(total_output_tokens) / len(requests)} output tokens/req, ") # Output JSON results if specified if args.output_json: diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index a61ec63a365b5..5c9bfa02a5b0f 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -8,7 +8,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine import EngineRequest from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core import EngineCore @@ -22,8 +22,8 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids -def make_request() -> EngineCoreRequest: - return EngineCoreRequest( +def make_request() -> EngineRequest: + return EngineRequest( request_id=uuid.uuid4(), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 2f1cbec607a91..20db30e8b1223 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -10,7 +10,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine import EngineRequest from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import EngineCoreClient @@ -24,8 +24,8 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids -def make_request(params: SamplingParams) -> EngineCoreRequest: - return EngineCoreRequest( +def make_request(params: SamplingParams) -> EngineRequest: + return EngineRequest( request_id=str(uuid.uuid4()), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 95da1c6e7b9bf..daefbff7e5178 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -21,7 +21,7 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -119,6 +119,8 @@ async def run_server(args: Namespace, logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + set_ulimit() + app = await init_app(args, llm_engine) assert engine is not None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3e50613a73dd3..f92fa8235835d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -68,7 +68,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, - is_valid_ipv6_address) + is_valid_ipv6_address, kill_process_tree, set_ulimit) from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -585,12 +585,18 @@ async def authentication(request: Request, call_next): status_code=401) return await call_next(request) - @app.middleware("http") - async def add_request_id(request: Request, call_next): - request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex - response = await call_next(request) - response.headers["X-Request-Id"] = request_id - return response + if args.enable_request_id_headers: + logger.warning( + "CAUTION: Enabling X-Request-Id headers in the API Server. " + "This can harm performance at high QPS.") + + @app.middleware("http") + async def add_request_id(request: Request, call_next): + request_id = request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + response = await call_next(request) + response.headers["X-Request-Id"] = request_id + return response for middleware in args.middleware: module_path, object_name = middleware.rsplit(".", 1) @@ -721,12 +727,22 @@ async def run_server(args, **uvicorn_kwargs) -> None: sock_addr = (args.host or "", args.port) sock = create_server_socket(sock_addr) + # workaround to ensure user has enough fds available for uvicorn + ipc + set_ulimit() + def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing raise KeyboardInterrupt("terminated") signal.signal(signal.SIGTERM, signal_handler) + # The child processes will send SIGQUIT to this process when + # any error happens. This process then clean up the whole tree. + def sigquit_handler(signum, frame): + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + async with build_async_engine_client(args) as engine_client: app = build_app(args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 24c206a1261f2..908f8c3532c9e 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -196,7 +196,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action="store_true", help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") - + parser.add_argument( + "--enable-request-id-headers", + action="store_true", + help="If specified, API server will add X-Request-Id header to " + "responses. Caution: this hurts performance at high QPS.") parser.add_argument( "--enable-auto-tool-choice", action="store_true", diff --git a/vllm/utils.py b/vllm/utils.py index 49e532540d7ee..55c2b5191021d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -10,6 +10,7 @@ import inspect import ipaddress import os +import resource import signal import socket import subprocess @@ -17,6 +18,7 @@ import tempfile import threading import time +import traceback import uuid import warnings import weakref @@ -1613,6 +1615,28 @@ def resolve_obj_by_qualname(qualname: str) -> Any: return getattr(module, obj_name) +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, + (target_soft_limit, current_hard)) + except ValueError as e: + logger.warning( + "Found ulimit of %s and failed to automatically increase" + "with error %s. This can cause fd limit errors like" + "`OSError: [Errno 24] Too many open files`. Consider " + "increasing with ulimit -n", current_soft, e) + + +def get_exception_traceback(): + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str + + def kill_process_tree(pid: int): """ Kills all descendant processes of the given pid by sending SIGKILL. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index cc0c7ea23469a..0e104118c9ff9 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -6,35 +6,14 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sampling_params import SamplingParams @dataclass -class DetokenizerRequest: - +class EngineRequest: request_id: str prompt: Optional[str] prompt_token_ids: List[int] - skip_special_tokens: bool - spaces_between_special_tokens: bool - output_kind: RequestOutputKind - - stop: List[str] - include_stop_str_in_output: bool - - -@dataclass -class EngineCoreRequest: - - # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, - # but this object is currently not playing well with msgspec - # due to circular imports and typing we have in data.py - - request_id: str - #NOTE(Nick): I don't think we need to pass prompt here since it should - # always be tokenized? - prompt: Optional[str] - prompt_token_ids: List[int] mm_inputs: Optional[List[Optional[MultiModalKwargs]]] mm_hashes: Optional[List[str]] mm_placeholders: Optional[MultiModalPlaceholderDict] @@ -44,6 +23,20 @@ class EngineCoreRequest: lora_request: Optional[LoRARequest] +@dataclass +class EngineAbortRequest: + request_ids: List[str] + + +@dataclass +class EngineProfileRequest: + is_start: bool + + +EngineRequestUnion = Union[EngineRequest, EngineAbortRequest, + EngineProfileRequest] + + class EngineCoreOutput( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -70,19 +63,10 @@ class EngineCoreOutputs( outputs: List[EngineCoreOutput] -@dataclass -class EngineCoreProfile: - is_start: bool - - -class EngineCoreRequestType(enum.Enum): +class EngineRequestType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. """ - Request types defined as hex byte strings, so it can be sent over sockets - without separate encoding step. - """ - ADD = b'\x00' - ABORT = b'\x01' - PROFILE = b'\x02' - - -EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]] + FROM_ENGINE_CORE = b'\x00' + FROM_ENGINE = b'\x01' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index cfdbea8004c35..d82f278e36744 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,5 +1,27 @@ +# Copyright 2033-2024 The vLLM team. +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Inspired by https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/tokenizer_manager.py + import asyncio -from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union +import pickle +from typing import (Any, AsyncGenerator, Dict, List, Mapping, Optional, Type, + Union) + +import zmq +import zmq.asyncio from vllm.config import ModelConfig, VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -9,18 +31,20 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine.async_stream import AsyncStream -from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.engine.detokenizer import Detokenizer +from vllm.utils import get_open_zmq_ipc_path +from vllm.v1.engine import EngineAbortRequest, EngineRequestType +from vllm.v1.engine.core import MPEngineCoreClient +from vllm.v1.engine.detokenizer import MPDetokenizerClient from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import make_zmq_socket logger = init_logger(__name__) @@ -46,6 +70,9 @@ def __init__( self.stat_loggers = stat_loggers self.model_config = vllm_config.model_config + # RequestId -> OutputQueue. + self.rid_to_queue: Dict[str, asyncio.Queue[RequestOutput]] = {} + # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, @@ -54,12 +81,7 @@ def __init__( lora_config=vllm_config.lora_config) self.tokenizer.ping() - # Request streams (map of request_id -> AsyncStream). - self.request_streams: Dict[str, AsyncStream] = {} - # List of cancelled request ids to be aborted. - self.client_aborted_requests: List[str] = [] - - # Processor (converts Inputs --> EngineCoreRequests). + # Processor (in process). self.processor = Processor( model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, @@ -68,21 +90,36 @@ def __init__( input_registry=input_registry, ) - # Detokenizer (converts EngineCoreOutputs --> RequestOutput). - self.detokenizer = Detokenizer( + # IPC paths. + to_detokenizer_path = get_open_zmq_ipc_path() + to_engine_core_path = get_open_zmq_ipc_path() + to_llm_engine_path = get_open_zmq_ipc_path() + + # Detokenizer IPC. + self.ctx = zmq.asyncio.Context(io_threads=2) + self.from_detokenizer = make_zmq_socket(self.ctx, to_llm_engine_path, + zmq.constants.PULL) + self.to_detokenizer = make_zmq_socket(self.ctx, to_detokenizer_path, + zmq.constants.PUSH) + + # Detokenizer (background process). + self.detokenizer_client = MPDetokenizerClient( + output_path=to_llm_engine_path, + input_path=to_detokenizer_path, + to_engine_core_path=to_engine_core_path, tokenizer_name=vllm_config.model_config.tokenizer, tokenizer_mode=vllm_config.model_config.tokenizer_mode, trust_remote_code=vllm_config.model_config.trust_remote_code, revision=vllm_config.model_config.tokenizer_revision, ) - # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_client( + # EngineCore (background process). + self.engine_core_client = MPEngineCoreClient( + input_path=to_engine_core_path, + output_path=to_detokenizer_path, vllm_config=vllm_config, executor_class=executor_class, usage_context=usage_context, - multiprocess_mode=True, - asyncio_mode=True, ) self.output_handler: Optional[asyncio.Task] = None @@ -123,11 +160,17 @@ def from_engine_args( def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" - if engine_core := getattr(self, "engine_core", None): - engine_core.shutdown() + if ctx := getattr(self, "ctx", None): + ctx.destroy(linger=0) + + if output_handler := getattr(self, "output_hander", None): + output_handler.cancel() - if handler := getattr(self, "output_handler", None): - handler.cancel() + if engine_core_client := getattr(self, "engine_core_client", None): + engine_core_client.shutdown() + + if detokenizer_client := getattr(self, "detokenizer_client", None): + detokenizer_client.shutdown() @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: @@ -153,28 +196,23 @@ async def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + ) -> asyncio.Queue[RequestOutput]: """Add new request to the AsyncLLM.""" - if self.detokenizer.is_request_active(request_id): - raise ValueError(f"Request {request_id} already exists.") - - # 1) Create a new AsyncStream for the request. - stream = self._add_request_to_streams(request_id) - - # 2) Convert input --> DetokenizerRequest / EngineCoreRequest. - detokenizer_req, engine_core_req = self.processor.process_inputs( + # 1) Convert Input --> EngineRequest (Tokenize, MM, etc). + engine_request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, trace_headers, prompt_adapter_request, priority) - # 3) Add the request to Detokenizer (this process). - self.detokenizer.add_request(detokenizer_req) + # 2) Create Queue (output_handler() pushes, generate() pulls). + self.rid_to_queue[request_id] = asyncio.Queue() - # 4) Add the EngineCoreRequest to EngineCore (separate process). - await self.engine_core.add_request_async(engine_core_req) + # 3) Send to Detokenizer (which forwards to EngineCore). + # Note: we forward the request rather than sending to each + # process separately to avoid race conditions in Detokenizer). + await self._send_to_detokenizer(engine_request) - # 5) Return the generator. - return stream.generator() + return self.rid_to_queue[request_id] # TODO: we should support multiple prompts in one call, as you # can do with LLM.generate. So that for multi-prompt completion @@ -193,27 +231,27 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: """ Main function called by the API server to kick off a request - * 1) Making an AsyncStream corresponding to the Request. - # 2) Processing the Input. - * 3) Adding the Request to the Detokenizer. - * 4) Adding the Request to the EngineCore (separate process). + * 1) Make an output queue for the Request. + * 2) Processing the Input (e.g. Tokenizer, MM). + * 3) Adding the Request to Detokenizer + EngineCore. - A separate output_handler loop runs in a background AsyncIO task, - pulling outputs from EngineCore and putting them into the - per-request AsyncStream. + The output_handler() loop runs in a background task, pulling + from Detokenizer and pushing to the per request queue. - The caller of generate() iterates the returned AsyncGenerator, - returning the RequestOutput back to the caller. + The generate() pulls from the per request queue and yields + to the caller which iterates the AsyncGenerator. """ - # We start the output_handler on the first call to generate() so that - # we can call __init__ before the event loop starts, which enables us - # to handle startup failure gracefully in the OpenAI server. - if self.output_handler is None: - self.output_handler = asyncio.create_task( - self._run_output_handler()) - - async for output in await self.add_request( + try: + # Start output_handler on first request. + if not self.output_handler: + loop = asyncio.get_event_loop() + self.output_handler = loop.create_task( + self.output_handler_loop()) + + # Add to Detokenizer and EngineCore and makes queue + # to which the output_handler will push RequestOutputs. + q = await self.add_request( request_id, prompt, sampling_params, @@ -221,109 +259,65 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, - ): - yield output - - def _finish_stream(self, request_id: str): - stream = self.request_streams.pop(request_id, None) - if stream is not None: - stream.finish() - - def _add_request_to_streams( - self, - request_id: str, - ) -> AsyncStream: - - if request_id in self.request_streams: - raise ValueError(f"Request id {request_id} already running.") - - # Avoid streams having circular ref to parent AsyncLLM object. - aborted_reqs = self.client_aborted_requests - stream = AsyncStream(request_id, aborted_reqs.append) - self.request_streams[request_id] = stream - - if self.log_requests: - logger.info("Added request %s.", request_id) - - return stream - - async def _process_cancellations(self) -> None: - """ - Process requests cancelled from user disconnecting. - - When a client disconnects, AsyncStream._cancel() is called. - We passed a callback to AsyncStream(), which appends to - self.client_aborted_requests. - - As a result, if any requests are canceled from the user side - the request_id will show up in self.client_aborted_requests. - """ - - # Avoid streams having circular ref to parent AsyncLLM object. - if not self.client_aborted_requests: - return - reqs_to_abort = self.client_aborted_requests.copy() - self.client_aborted_requests.clear() - - # Remove from Detokenizer. - self.detokenizer.abort_requests(reqs_to_abort) - - # Remove from RequestStreams. - for request_id in reqs_to_abort: - if self.log_requests: - logger.info("User-cancelled request %s.", request_id) - self._finish_stream(request_id) - - # Remove from EngineCore. - await self.engine_core.abort_requests_async(reqs_to_abort) - - def _process_request_outputs(self, request_outputs: List[RequestOutput]): - """Process outputs by putting them into per-request AsyncStreams.""" - - for request_output in request_outputs: - request_id = request_output.request_id - assert request_id in self.request_streams + ) - # Each request in the API server pulls from the per-request stream. - stream = self.request_streams.get(request_id) - if stream is not None: - stream.put(request_output) - - # If finished, remove from the tracker. - if request_output.finished: - if self.log_requests: - logger.info("Finished request %s.", request_id) - self._finish_stream(request_id) - - async def _run_output_handler(self): - """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" - - try: + # The output_handler task pushes items into the queue. + # This task pulls from the queue and yields to caller. while True: - # 1) Pull EngineCoreOutput from the EngineCore. - outputs = await self.engine_core.get_output_async() - - # 2) Detokenize based on the output. - request_outputs, reqs_to_abort = self.detokenizer.step(outputs) + # Note: drain queue without await if possible (avoids + # task switching under load which helps performance). + out = q.get_nowait() if q.qsize() > 0 else await q.get() + + # Note: both Detokenizer and EngineCore handle their + # own request cleanup based on finished. + if out.finished: + del self.rid_to_queue[request_id] + yield out + break + + yield out + + # Client request cancellation is handled through calling + # task.cancel() on generate(). Calling self.abort() forwards the + # cancellation to the EngineCore and Detokenizer. + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def output_handler_loop(self): + """Background loop: pulls from Detokenizer and push to Queues.""" + + while True: + # Note: use socket directly to avoid calling await multiple + # times, which causes too much task switching at high QPS. + outputs: List[RequestOutput] = [] + outputs = await self.from_detokenizer.recv_pyobj() + + for out in outputs: + # Note: it is possible that a request was aborted + # due to client cancellation while EngineCoreOutputs + # are still flowing, so we just ignore. + if out.request_id in self.rid_to_queue: + self.rid_to_queue[out.request_id].put_nowait(out) + + async def abort(self, request_id: str): + """Abort request if the client cancels the request.""" + + # Send abort to Detokenizer (which will fwd to EngineCore). + await self._send_to_detokenizer(EngineAbortRequest([request_id])) + + # Remove from request output queues. + if request_id in self.rid_to_queue: + del self.rid_to_queue[request_id] - # 3) Put the RequestOutputs into the per-request AsyncStreams. - self._process_request_outputs(request_outputs) - - # 4) Abort any requests that finished due to stop strings. - await self.engine_core.abort_requests_async(reqs_to_abort) - - # 5) Abort any requests due to client cancellations. - await self._process_cancellations() - - except BaseException as e: - logger.error(e) - raise e + if self.log_requests: + logger.info("Aborted %s.", request_id) - # TODO: can we eliminate these? + async def _send_to_detokenizer(self, obj: Any): + """Send object to Detokenizer with a FROM_ENGINE flag.""" - async def abort(self, request_id: str) -> None: - # Note: Who Calls this? I dont think this is actually used. - raise ValueError("Not Supported on V1 yet.") + msg = (EngineRequestType.FROM_ENGINE.value, pickle.dumps(obj)) + await self.to_detokenizer.send_multipart(msg, copy=False) def encode( self, @@ -349,8 +343,7 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - assert lora_request is None - return self.detokenizer.tokenizer + return self.tokenizer.get_lora_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: return False @@ -366,10 +359,10 @@ async def check_health(self) -> None: logger.debug("Called check_health.") async def start_profile(self) -> None: - await self.engine_core.profile_async(True) + await self.engine_core_client.profile_async(True) async def stop_profile(self) -> None: - await self.engine_core.profile_async(False) + await self.engine_core_client.profile_async(False) @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py deleted file mode 100644 index 35449238c3259..0000000000000 --- a/vllm/v1/engine/async_stream.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from typing import Any, AsyncGenerator, Callable, Optional, Type, Union - -from vllm.outputs import PoolingRequestOutput, RequestOutput - - -class AsyncStream: - """A stream of RequestOutputs or PoolingRequestOutputs for a request - that can be iterated over asynchronously via an async generator.""" - - STOP_ITERATION = Exception() # Sentinel - - def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: - self.request_id = request_id - self._cancel = cancel - self._queue: asyncio.Queue = asyncio.Queue() - self._finished = False - - def put(self, item: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: - if not self._finished: - self._queue.put_nowait(item) - - def finish( - self, - exception: Optional[Union[BaseException, Type[BaseException]]] = None, - ) -> None: - if not self._finished: - self._finished = True - self._queue.put_nowait(exception if self._is_raisable(exception) - else AsyncStream.STOP_ITERATION) - - async def generator( - self - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: - finished = False - try: - while True: - result = await self._queue.get() - if self._is_raisable(result): - finished = True - if result == AsyncStream.STOP_ITERATION: - return - raise result - yield result - finally: - self._finished = True - if not finished: - self._cancel(self.request_id) - - @staticmethod - def _is_raisable(value: Any): - return isinstance(value, BaseException) or \ - (isinstance(value, type) and \ - issubclass(value, BaseException)) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 497d5db5b4c99..e4c587f1d7eb1 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,38 +1,37 @@ -import pickle import queue import signal import threading import time -from dataclasses import dataclass -from multiprocessing.process import BaseProcess +from multiprocessing.connection import Connection from typing import List, Tuple, Type +import psutil import zmq import zmq.asyncio from msgspec import msgpack from vllm.config import CacheConfig, VllmConfig -from vllm.executor.multiproc_worker_utils import get_mp_context from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.usage.usage_lib import UsageContext +from vllm.utils import get_exception_traceback from vllm.v1.core.scheduler import Scheduler -from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, - EngineCoreProfile, EngineCoreRequest, - EngineCoreRequestType, EngineCoreRequestUnion) +from vllm.v1.engine import (EngineAbortRequest, EngineCoreOutput, + EngineCoreOutputs, EngineProfileRequest, + EngineRequest, EngineRequestType, + EngineRequestUnion) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import PickleEncoder -from vllm.v1.utils import make_zmq_socket +from vllm.v1.utils import MPBackgroundProcess, zmq_socket_ctx from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -LOGGING_TIME_S = 5000 +LOGGING_TIME_S = 5 class EngineCore: @@ -89,7 +88,7 @@ def _initialize_kv_caches(self, "warmup model) took %.2f seconds"), elapsed) return num_gpu_blocks, num_cpu_blocks - def add_request(self, request: EngineCoreRequest): + def add_request(self, request: EngineRequest): """Add request to the scheduler.""" if request.mm_hashes is not None: @@ -103,7 +102,6 @@ def add_request(self, request: EngineCoreRequest): request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) - self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): @@ -134,14 +132,6 @@ def profile(self, is_start: bool = True): self.model_executor.profile(is_start) -@dataclass -class EngineCoreProcHandle: - proc: BaseProcess - ready_path: str - input_path: str - output_path: str - - class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -154,16 +144,15 @@ def __init__( usage_context: UsageContext, input_path: str, output_path: str, - ready_path: str, + ready_pipe: Connection, ): super().__init__(vllm_config, executor_class, usage_context) # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() + # overlap ZMQ IO with GPU since they release the GIL and + # some serialization/deserialization with the model forward. + # Threads handle Socket <-> Queues and busy_loop uses Queues. + self.input_queue: queue.Queue[EngineRequestUnion] = queue.Queue() self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue() threading.Thread(target=self.process_input_socket, args=(input_path, ), @@ -173,68 +162,7 @@ def __init__( daemon=True).start() # Send Readiness signal to EngineClient. - with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: - ready_socket.send_string(EngineCoreProc.READY_STR) - - @staticmethod - def wait_for_startup( - proc: BaseProcess, - ready_path: str, - ) -> None: - """Wait until the EngineCore is ready.""" - - try: - sync_ctx = zmq.Context() # type: ignore[attr-defined] - socket = sync_ctx.socket(zmq.constants.PULL) - socket.connect(ready_path) - - # Wait for EngineCore to send EngineCoreProc.READY_STR. - while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for EngineCoreProc to startup.") - - if not proc.is_alive(): - raise RuntimeError("EngineCoreProc failed to start.") - - message = socket.recv_string() - assert message == EngineCoreProc.READY_STR - - except BaseException as e: - logger.exception(e) - raise e - - finally: - sync_ctx.destroy(linger=0) - - @staticmethod - def make_engine_core_process( - vllm_config: VllmConfig, - executor_class: Type[Executor], - usage_context: UsageContext, - input_path: str, - output_path: str, - ready_path: str, - ) -> EngineCoreProcHandle: - context = get_mp_context() - - process_kwargs = { - "input_path": input_path, - "output_path": output_path, - "ready_path": ready_path, - "vllm_config": vllm_config, - "executor_class": executor_class, - "usage_context": usage_context, - } - # Run EngineCore busy loop in background process. - proc = context.Process(target=EngineCoreProc.run_engine_core, - kwargs=process_kwargs) - proc.start() - - # Wait for startup - EngineCoreProc.wait_for_startup(proc, ready_path) - return EngineCoreProcHandle(proc=proc, - ready_path=ready_path, - input_path=input_path, - output_path=output_path) + ready_pipe.send({"status": "READY"}) @staticmethod def run_engine_core(*args, **kwargs): @@ -258,6 +186,8 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + parent_process = psutil.Process().parent() + engine_core = None try: engine_core = EngineCoreProc(*args, **kwargs) @@ -266,9 +196,10 @@ def signal_handler(signum, frame): except SystemExit: logger.debug("EngineCore interrupted.") - except BaseException as e: - logger.exception(e) - raise e + except Exception: + traceback = get_exception_traceback() + logger.error("EngineCore hit an exception: %s", traceback) + parent_process.send_signal(signal.SIGQUIT) finally: if engine_core is not None: @@ -302,7 +233,8 @@ def run_busy_loop(self): outputs = self.step() # 4) Put EngineCoreOutputs into the output queue. - self.output_queue.put_nowait(outputs) + if len(outputs) > 0: + self.output_queue.put_nowait(outputs) self._log_stats() @@ -320,43 +252,25 @@ def _log_stats(self): self._last_logging_time = now - def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: - """Handle EngineCoreRequest or EngineCoreABORT from Client.""" + def _handle_client_request(self, request: EngineRequestUnion) -> None: + """Handle EngineRequest or EngineCoreABORT from Client.""" - if isinstance(request, EngineCoreRequest): + if isinstance(request, EngineRequest): self.add_request(request) - elif isinstance(request, EngineCoreProfile): + elif isinstance(request, EngineProfileRequest): self.model_executor.profile(request.is_start) + elif isinstance(request, EngineAbortRequest): + self.abort_requests(request.request_ids) else: - # TODO: make an EngineCoreAbort wrapper - assert isinstance(request, list) - self.abort_requests(request) + raise ValueError("Unknown request type: {request}") def process_input_socket(self, input_path: str): """Input socket IO thread.""" - # Msgpack serialization decoding. - decoder_add_req = PickleEncoder() - decoder_abort_req = PickleEncoder() - - with make_zmq_socket(input_path, zmq.constants.PULL) as socket: + with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: - # (RequestType, RequestData) - type_frame, data_frame = socket.recv_multipart(copy=False) - request_type = type_frame.buffer - request_data = data_frame.buffer - - # Deserialize the request data. - if request_type == EngineCoreRequestType.ADD.value: - request = decoder_add_req.decode(request_data) - elif request_type == EngineCoreRequestType.ABORT.value: - request = decoder_abort_req.decode(request_data) - elif request_type == EngineCoreRequestType.PROFILE.value: - request = pickle.loads(request_data) - else: - raise ValueError(f"Unknown RequestType: {request_type}") - # Push to input queue for core busy loop. + request = socket.recv_pyobj() self.input_queue.put_nowait(request) def process_output_socket(self, output_path: str): @@ -367,9 +281,36 @@ def process_output_socket(self, output_path: str): # Reuse send buffer. buffer = bytearray() - with make_zmq_socket(output_path, zmq.constants.PUSH) as socket: + with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: while True: engine_core_outputs = self.output_queue.get() outputs = EngineCoreOutputs(outputs=engine_core_outputs) encoder.encode_into(outputs, buffer) - socket.send_multipart((buffer, ), copy=False) + msg = (EngineRequestType.FROM_ENGINE_CORE.value, buffer) + socket.send_multipart(msg, copy=False) + + +class MPEngineCoreClient(MPBackgroundProcess): + """Client for multi-proc EngineCore.""" + + def __init__(self, input_path: str, output_path: str, + vllm_config: VllmConfig, executor_class: Type[Executor], + usage_context: UsageContext): + + super().__init__() + + self.proc_handle = MPBackgroundProcess.wait_for_startup( + input_path=input_path, + output_path=output_path, + process_name="EngineCore", + target_fn=EngineCoreProc.run_engine_core, + process_kwargs={ + "vllm_config": vllm_config, + "executor_class": executor_class, + "usage_context": usage_context, + }, + ) + + async def profile_async(self, is_start: bool = True): + # TODO: enable this. + pass diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py deleted file mode 100644 index d56fcbdb1e7c4..0000000000000 --- a/vllm/v1/engine/core_client.py +++ /dev/null @@ -1,253 +0,0 @@ -import os -import weakref -from typing import List, Optional - -import msgspec -import zmq -import zmq.asyncio - -from vllm.logger import init_logger -from vllm.utils import get_open_zmq_ipc_path, kill_process_tree -from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, - EngineCoreProfile, EngineCoreRequest, - EngineCoreRequestType, EngineCoreRequestUnion) -from vllm.v1.engine.core import (EngineCore, EngineCoreProc, - EngineCoreProcHandle) -from vllm.v1.serial_utils import PickleEncoder - -logger = init_logger(__name__) - - -class EngineCoreClient: - """ - EngineCoreClient: subclasses handle different methods for pushing - and pulling from the EngineCore for asyncio / multiprocessing. - - Subclasses: - * InprocClient: In process EngineCore (for V0-style LLMEngine use) - * SyncMPClient: ZMQ + background proc EngineCore (for LLM) - * AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM) - """ - - @staticmethod - def make_client( - *args, - multiprocess_mode: bool, - asyncio_mode: bool, - **kwargs, - ) -> "EngineCoreClient": - - # TODO: support this for debugging purposes. - if asyncio_mode and not multiprocess_mode: - raise NotImplementedError( - "Running EngineCore in asyncio without multiprocessing " - "is not currently supported.") - - if multiprocess_mode and asyncio_mode: - return AsyncMPClient(*args, **kwargs) - - if multiprocess_mode and not asyncio_mode: - return SyncMPClient(*args, **kwargs) - - return InprocClient(*args, **kwargs) - - def shutdown(self): - pass - - def get_output(self) -> List[EngineCoreOutput]: - raise NotImplementedError - - def add_request(self, request: EngineCoreRequest) -> None: - raise NotImplementedError - - def profile(self, is_start: bool = True) -> None: - raise NotImplementedError - - def abort_requests(self, request_ids: List[str]) -> None: - raise NotImplementedError - - async def get_output_async(self) -> List[EngineCoreOutput]: - raise NotImplementedError - - async def add_request_async(self, request: EngineCoreRequest) -> None: - raise NotImplementedError - - async def profile_async(self, is_start: bool = True) -> None: - raise NotImplementedError - - async def abort_requests_async(self, request_ids: List[str]) -> None: - raise NotImplementedError - - -class InprocClient(EngineCoreClient): - """ - InprocClient: client for in-process EngineCore. Intended - for use in LLMEngine for V0-style add_request() and step() - EngineCore setup in this process (no busy loop). - - * pushes EngineCoreRequest directly into the EngineCore - * pulls EngineCoreOutputs by stepping the EngineCore - - TODO: support asyncio-mode for debugging. - """ - - def __init__(self, *args, **kwargs): - self.engine_core = EngineCore(*args, **kwargs) - - def get_output(self) -> List[EngineCoreOutput]: - return self.engine_core.step() - - def add_request(self, request: EngineCoreRequest) -> None: - self.engine_core.add_request(request) - - def abort_requests(self, request_ids: List[str]) -> None: - self.engine_core.abort_requests(request_ids) - - def shutdown(self): - self.engine_core.shutdown() - - def __del__(self): - self.shutdown() - - def profile(self, is_start: bool = True) -> None: - self.engine_core.profile(is_start) - - -class MPClient(EngineCoreClient): - """ - MPClient: base client for multi-proc EngineCore. - EngineCore runs in a background process busy loop, getting - new EngineCoreRequests and returning EngineCoreOutputs - - * pushes EngineCoreRequests via input_socket - * pulls EngineCoreOutputs via output_socket - - * AsyncMPClient subclass for AsyncLLM usage - * SyncMPClient subclass for LLM usage - """ - - def __init__( - self, - *args, - asyncio_mode: bool, - **kwargs, - ): - # Serialization setup. - self.encoder = PickleEncoder() - self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) - - # ZMQ setup. - if asyncio_mode: - self.ctx = zmq.asyncio.Context() - else: - self.ctx = zmq.Context() # type: ignore[attr-defined] - - # Path for IPC. - ready_path = get_open_zmq_ipc_path() - output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - - # Get output (EngineCoreOutput) from EngineCore. - self.output_socket = self.ctx.socket(zmq.constants.PULL) - self.output_socket.connect(output_path) - - # Send input (EngineCoreRequest) to EngineCore. - self.input_socket = self.ctx.socket(zmq.constants.PUSH) - self.input_socket.bind(input_path) - - # Start EngineCore in background process. - self.proc_handle: Optional[EngineCoreProcHandle] - self.proc_handle = EngineCoreProc.make_engine_core_process( - *args, - input_path= - input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords - output_path=output_path, # type: ignore[misc] - ready_path=ready_path, # type: ignore[misc] - **kwargs, - ) - self._finalizer = weakref.finalize(self, self.shutdown) - - def shutdown(self): - # Shut down the zmq context. - self.ctx.destroy(linger=0) - - if hasattr(self, "proc_handle") and self.proc_handle: - # Shutdown the process if needed. - if self.proc_handle.proc.is_alive(): - self.proc_handle.proc.terminate() - self.proc_handle.proc.join(5) - - if self.proc_handle.proc.is_alive(): - kill_process_tree(self.proc_handle.proc.pid) - - # Remove zmq ipc socket files - ipc_sockets = [ - self.proc_handle.ready_path, self.proc_handle.output_path, - self.proc_handle.input_path - ] - for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) - self.proc_handle = None - - -class SyncMPClient(MPClient): - """Synchronous client for multi-proc EngineCore.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, asyncio_mode=False, **kwargs) - - def get_output(self) -> List[EngineCoreOutput]: - - (frame, ) = self.output_socket.recv_multipart(copy=False) - engine_core_outputs = self.decoder.decode(frame.buffer).outputs - return engine_core_outputs - - def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: - - # (RequestType, SerializedRequest) - msg = (request_type.value, self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) - - def add_request(self, request: EngineCoreRequest) -> None: - self._send_input(EngineCoreRequestType.ADD, request) - - def abort_requests(self, request_ids: List[str]) -> None: - self._send_input(EngineCoreRequestType.ABORT, request_ids) - - def profile(self, is_start: bool = True) -> None: - self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) - - -class AsyncMPClient(MPClient): - """Asyncio-compatible client for multi-proc EngineCore.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, asyncio_mode=True, **kwargs) - - async def get_output_async(self) -> List[EngineCoreOutput]: - - frames = await self.output_socket.recv_multipart(copy=False) - engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs - - return engine_core_outputs - - async def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: - - msg = (request_type.value, self.encoder.encode(request)) - await self.input_socket.send_multipart(msg, copy=False) - - async def add_request_async(self, request: EngineCoreRequest) -> None: - await self._send_input(EngineCoreRequestType.ADD, request) - - async def abort_requests_async(self, request_ids: List[str]) -> None: - if len(request_ids) > 0: - await self._send_input(EngineCoreRequestType.ABORT, request_ids) - - async def profile_async(self, is_start: bool = True) -> None: - await self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 02f34e2b54dd5..2d8724e687448 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,6 +1,13 @@ +import pickle +import signal from dataclasses import dataclass +from multiprocessing.connection import Connection from typing import Dict, Iterable, List, Optional, Tuple, Union +import msgspec +import psutil +import zmq + from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -8,10 +15,16 @@ from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput +from vllm.utils import get_exception_traceback +from vllm.v1.engine import (EngineAbortRequest, EngineCoreOutput, + EngineCoreOutputs, EngineRequest, + EngineRequestType) +from vllm.v1.utils import MPBackgroundProcess, make_zmq_socket logger = init_logger(__name__) +POLLING_TIMEOUT_MS = 5000 + @dataclass class IncrementalDetokenizer: @@ -55,19 +68,20 @@ def output_token_ids(self) -> List[int]: def from_new_request( cls, tokenizer: AnyTokenizer, - request: DetokenizerRequest, + request: EngineRequest, ) -> "IncrementalDetokenizer": + sampling_params = request.sampling_params tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, - skip_special_tokens=request.skip_special_tokens, + skip_special_tokens=sampling_params.skip_special_tokens, ) - stops = request.stop + stops = request.sampling_params.stop # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stops and not request.include_stop_str_in_output: + if stops and not sampling_params.include_stop_str_in_output: stop_buffer_length = max(len(s) for s in stops) - 1 else: stop_buffer_length = 0 @@ -79,13 +93,14 @@ def from_new_request( # NOTE(Nick): could we take ownership of it though? token_ids=request.prompt_token_ids.copy(), stop=stops, - include_stop_str_in_output=request.include_stop_str_in_output, + include_stop_str_in_output=sampling_params. + include_stop_str_in_output, prefix_offset=prefix_offset, read_offset=read_offset, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=request. + skip_special_tokens=sampling_params.skip_special_tokens, + spaces_between_special_tokens=sampling_params. spaces_between_special_tokens, - output_kind=request.output_kind, + output_kind=sampling_params.output_kind, request_id=request.request_id, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, @@ -145,8 +160,6 @@ def add_tokens( finish_reason = "stop" # TODO: use constant stop_reason = stop_str - # TODO: handle stop_token_ids here too? - # 3) Update the RequestOutput object with the new text. finished = bool(finish_reason) if self.output_kind == RequestOutputKind.FINAL_ONLY \ @@ -227,7 +240,7 @@ def abort_requests( def add_request( self, - request: DetokenizerRequest, + request: EngineRequest, ): """Add new request to the Detokenizer.""" @@ -238,9 +251,10 @@ def add_request( self.request_states[request.request_id] = request_state def step( - self, encore_core_outputs: List[EngineCoreOutput] + self, + encore_core_outputs: List[EngineCoreOutput], ) -> Tuple[List[RequestOutput], List[str]]: - """Update state and request the RequestOutputs to the LLMEngine.""" + """Update state and make RequestOutputs for the LLMEngine.""" request_outputs: List[RequestOutput] = [] requests_to_abort: List[str] = [] @@ -265,8 +279,179 @@ def step( # Free completed requests. if request_output.finished: self.request_states.pop(request_id) + # If Request finished but EngineCore not finished, + # this was caused by a stop string + we need to send + # an abort signal to the EngineCore. if not engine_core_output.finished: requests_to_abort.append(request_id) # Return to EngineClient. return request_outputs, requests_to_abort + + +class DetokenizerProc(Detokenizer): + """ZMQ-wrapper for running Detokenizer in background process.""" + + def __init__( + self, + *args, + input_path: str, + output_path: str, + to_engine_core_path: str, + ready_pipe: Connection, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.input_path = input_path + self.output_path = output_path + self.to_engine_core_path = to_engine_core_path + + # Send Readiness signal to DetokenizerClient. + ready_pipe.send({"status": "READY"}) + + @staticmethod + def run_detokenizer(*args, **kwargs): + """Launch Detokenizer busy loop in background process.""" + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + parent_process = psutil.Process().parent() + + detokenizer = None + try: + detokenizer = DetokenizerProc(*args, **kwargs) + detokenizer.run_busy_loop() + + except SystemExit: + logger.debug("Detokenizer interrupted.") + + except Exception: + traceback = get_exception_traceback() + logger.error("Detokenizer hit an exception: %s", traceback) + parent_process.send_signal(signal.SIGQUIT) + + finally: + if detokenizer is not None: + detokenizer = None + + def _handle_from_llm_engine( + self, + request_bytes: bytes, + to_engine_core: zmq.Socket, # type: ignore[name-defined] + ) -> None: + """Handle EngineRequest from the LLMEngine.""" + + req = pickle.loads(request_bytes) + + if isinstance(req, EngineRequest): + self.add_request(req) + elif isinstance(req, EngineAbortRequest): + self.abort_requests(req.request_ids) + else: + raise ValueError(f"Unknown type: {req}") + + # Forward to EngineCore. + to_engine_core.send(request_bytes) + + def _handle_from_engine_core( + self, + output_bytes: bytes, + to_engine_core: zmq.Socket, # type: ignore[name-defined] + to_llm_engine: zmq.Socket, # type: ignore[name-defined] + decoder: msgspec.msgpack.Decoder, + ) -> None: + """Handle Outputs from the EngineCore.""" + + # Deserialize the EngineOutput (use msgpack for performance). + outputs: List[EngineCoreOutput] = decoder.decode(output_bytes).outputs + + # Detokenize. + request_outputs, requests_to_abort = self.step(outputs) + + # Send request outputs back to LLMEngine. + to_llm_engine.send_pyobj(request_outputs) + + # Abort requests that finished due to stop strings. + if len(requests_to_abort) > 0: + to_engine_core.send_pyobj(EngineAbortRequest(requests_to_abort)) + + def run_busy_loop(self): + """Core busy loop of the Detokenizer.""" + + decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + + ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined] + try: + input_socket = make_zmq_socket(ctx, self.input_path, + zmq.constants.PULL) + to_llm_engine = make_zmq_socket(ctx, self.output_path, + zmq.constants.PUSH) + to_engine_core = make_zmq_socket(ctx, self.to_engine_core_path, + zmq.constants.PUSH) + + while True: + (msg_type, msg_bytes) = input_socket.recv_multipart() + + # Handle message from LLMEngine (Abort or New Request). + if msg_type == EngineRequestType.FROM_ENGINE.value: + self._handle_from_llm_engine(msg_bytes, to_engine_core) + + # Handle message from EngineCore (EngineCoreOutputs). + elif msg_type == EngineRequestType.FROM_ENGINE_CORE.value: + self._handle_from_engine_core( + output_bytes=msg_bytes, + to_engine_core=to_engine_core, + to_llm_engine=to_llm_engine, + decoder=decoder, + ) + else: + raise ValueError(f"Unknown Message Type: {msg_type}") + + except KeyboardInterrupt: + logger.debug("Got Keyboard Interrupt.") + + finally: + ctx.destroy(linger=0) + + +class MPDetokenizerClient(MPBackgroundProcess): + """Client for multi-proc Detokenizer.""" + + def __init__(self, + input_path: str, + output_path: str, + to_engine_core_path: str, + tokenizer_name: str, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + revision: Optional[str] = None): + + super().__init__() + + self.proc_handle = MPBackgroundProcess.wait_for_startup( + input_path=input_path, + output_path=output_path, + process_name="Detokenizer", + target_fn=DetokenizerProc.run_detokenizer, + process_kwargs={ + "to_engine_core_path": to_engine_core_path, + "tokenizer_name": tokenizer_name, + "tokenizer_mode": tokenizer_mode, + "trust_remote_code": trust_remote_code, + "revision": revision, + }, + ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 9ad51575b3cc3..536fdb28717b4 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,5 +1,7 @@ -from typing import Dict, List, Mapping, Optional, Type, Union +import pickle +from typing import Any, Dict, List, Mapping, Optional, Set, Type, Union +import zmq from typing_extensions import TypeVar from vllm.config import VllmConfig @@ -17,11 +19,14 @@ from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.engine.detokenizer import Detokenizer +from vllm.utils import get_open_zmq_ipc_path +from vllm.v1.engine import EngineRequestType +from vllm.v1.engine.core import EngineCore, MPEngineCoreClient +from vllm.v1.engine.detokenizer import Detokenizer, MPDetokenizerClient from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.executor.ray_utils import initialize_ray_cluster +from vllm.v1.utils import make_zmq_socket logger = init_logger(__name__) @@ -44,7 +49,7 @@ def __init__( multiprocess_mode: bool = False, ) -> None: - # TODO: Can we avoid this? + self.multiprocess_mode = multiprocess_mode self.model_config = vllm_config.model_config # Tokenizer (+ ensure liveness if running in another process). @@ -63,22 +68,59 @@ def __init__( input_registry=input_registry, mm_registry=mm_registry) - # Detokenizer (converts EngineCoreOutputs --> RequestOutput) - self.detokenizer = Detokenizer( - tokenizer_name=vllm_config.model_config.tokenizer, - tokenizer_mode=vllm_config.model_config.tokenizer_mode, - trust_remote_code=vllm_config.model_config.trust_remote_code, - revision=vllm_config.model_config.tokenizer_revision, - ) - - # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) - self.engine_core = EngineCoreClient.make_client( - vllm_config, - executor_class, - usage_context, - multiprocess_mode=multiprocess_mode, - asyncio_mode=False, - ) + if self.multiprocess_mode: + # Keep track of active requests. + self.running_requests: Set[str] = set() + + # IPC paths. + to_detokenizer_path = get_open_zmq_ipc_path() + to_engine_core_path = get_open_zmq_ipc_path() + to_llm_engine_path = get_open_zmq_ipc_path() + + # Detokenizer IPC. + self.ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined] + self.from_detokenizer = make_zmq_socket(self.ctx, + to_llm_engine_path, + zmq.constants.PULL) + self.to_detokenizer = make_zmq_socket(self.ctx, + to_detokenizer_path, + zmq.constants.PUSH) + + # Detokenizer (background process). + self.detokenizer_client = MPDetokenizerClient( + output_path=to_llm_engine_path, + input_path=to_detokenizer_path, + to_engine_core_path=to_engine_core_path, + tokenizer_name=vllm_config.model_config.tokenizer, + tokenizer_mode=vllm_config.model_config.tokenizer_mode, + trust_remote_code=vllm_config.model_config.trust_remote_code, + revision=vllm_config.model_config.tokenizer_revision, + ) + + # EngineCore (background process). + self.engine_core_client = MPEngineCoreClient( + input_path=to_engine_core_path, + output_path=to_detokenizer_path, + vllm_config=vllm_config, + executor_class=executor_class, + usage_context=usage_context, + ) + + else: + # Detokenizer (in process). + self.detokenizer = Detokenizer( + tokenizer_name=vllm_config.model_config.tokenizer, + tokenizer_mode=vllm_config.model_config.tokenizer_mode, + trust_remote_code=vllm_config.model_config.trust_remote_code, + revision=vllm_config.model_config.tokenizer_revision, + ) + + # EngineCore (in process). + self.engine_core = EngineCore( + vllm_config=vllm_config, + executor_class=executor_class, + usage_context=usage_context, + ) @classmethod def from_engine_args( @@ -126,10 +168,13 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: return executor_class def get_num_unfinished_requests(self) -> int: - return self.detokenizer.get_num_unfinished_requests() + if self.multiprocess_mode: + return len(self.running_requests) + else: + return self.detokenizer.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: - return self.detokenizer.has_unfinished_requests() + return self.get_num_unfinished_requests() > 0 @classmethod def validate_outputs(cls, outputs, output_type): @@ -138,6 +183,7 @@ def validate_outputs(cls, outputs, output_type): def abort_request(self, request_ids: List[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" + assert not self.multiprocess_mode self.engine_core.abort_requests(request_ids) self.detokenizer.abort_requests(request_ids) @@ -153,33 +199,52 @@ def add_request( priority: int = 0, ) -> None: - # 1) Process raw inputs into the request. - detokenizer_req, engine_core_req = self.processor.process_inputs( + # Process raw inputs into the request. + engine_request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, trace_headers, prompt_adapter_request, priority) - # 2) Add the request to Detokenizer. - self.detokenizer.add_request(detokenizer_req) - - # 3) Add the request to EngineCore. - self.engine_core.add_request(engine_core_req) + # Add processed input to system. + if self.multiprocess_mode: + assert engine_request.request_id not in self.running_requests + self.running_requests.add(engine_request.request_id) + # Send to Detokenizer (which forwards to EngineCore). + # Note: we forward the message rather than sending + # to each process separately to avoid race conditions. + self._send_to_detokenizer(engine_request) + else: + # Add directly to Detokenizer and EngineCore. + self.detokenizer.add_request(engine_request) + self.engine_core.add_request(engine_request) def step(self) -> List[RequestOutput]: + if self.multiprocess_mode: + # Get next output from the Detokenizer. + request_outputs: List[ + RequestOutput] = self.from_detokenizer.recv_pyobj() - # 1) Get EngineCoreOutput from the EngineCore. - engine_core_outputs = self.engine_core.get_output() + # Removed finished requests from the state tracker. + for out in request_outputs: + if out.finished: + self.running_requests.remove(out.request_id) - # 2) Detokenizer the EngineCoreOutput. - request_outputs, requests_to_abort = self.detokenizer.step( - engine_core_outputs) + else: + # Step EngineCore and Detokenizer. + engine_core_outputs = self.engine_core.step() + request_outputs, requests_to_abort = self.detokenizer.step( + engine_core_outputs) - # 3) Abort requests that finished due to stopping criteria. - if requests_to_abort: - self.abort_request(requests_to_abort) + # Abort any requests that hit a stop string. + if requests_to_abort: + self.abort_request(requests_to_abort) return request_outputs - # TODO(rob): Can we get rid of these? + def _send_to_detokenizer(self, object: Any): + """Send object to Detokenizer with a FROM_ENGINE flag.""" + + msg = (EngineRequestType.FROM_ENGINE.value, pickle.dumps(object)) + self.to_detokenizer.send_multipart(msg, copy=False) def get_model_config(self): return self.model_config @@ -210,5 +275,8 @@ def __del__(self): self.shutdown() def shutdown(self): - if engine_core := getattr(self, "engine_core", None): - engine_core.shutdown() + if engine_core_client := getattr(self, "engine_core_client", None): + engine_core_client.shutdown() + + if detokenizer_client := getattr(self, "detokenizer_client", None): + detokenizer_client.shutdown() diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6ee8732bc902c..72d4a1ecf4511 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,5 +1,5 @@ import time -from typing import Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, @@ -13,7 +13,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest +from vllm.v1.engine import EngineRequest from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient @@ -62,7 +62,7 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> Tuple[DetokenizerRequest, EngineCoreRequest]: + ) -> EngineRequest: # TODO(woosuk): Support pooling models. # TODO(woosuk): Check max_logprobs @@ -123,20 +123,8 @@ def process_inputs( decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) - # Make Request for Detokenizer. - detokenizer_request = DetokenizerRequest( - request_id, - decoder_inputs.prompt, - decoder_inputs.prompt_token_ids, - sampling_params.skip_special_tokens, - sampling_params.spaces_between_special_tokens, - sampling_params.output_kind, - sampling_params.stop, - sampling_params.include_stop_str_in_output, - ) - # Make Request for EngineCore. - engine_core_request = EngineCoreRequest( + engine_request = EngineRequest( request_id, decoder_inputs.prompt, decoder_inputs.prompt_token_ids, @@ -149,7 +137,7 @@ def process_inputs( lora_request, ) - return detokenizer_request, engine_core_request + return engine_request def _validate_model_inputs(self, inputs: ProcessorInputs): if is_encoder_decoder_inputs(inputs): diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 128101aa6956d..78509b9cc6a08 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -23,7 +23,7 @@ get_open_zmq_ipc_path) from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import make_zmq_socket +from vllm.v1.utils import zmq_socket_ctx from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -250,7 +250,7 @@ def __init__( worker_response_mq_handle = self.worker_response_mq.export_handle() # Send Readiness signal to EngineCore process. - with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: + with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: payload = pickle.dumps(worker_response_mq_handle, protocol=pickle.HIGHEST_PROTOCOL) ready_socket.send_string(WorkerProc.READY_STR) @@ -352,7 +352,7 @@ def wait_for_startup( ready_path: str, ) -> Optional[Handle]: """Wait until the Worker is ready.""" - with make_zmq_socket(ready_path, zmq.constants.PULL) as socket: + with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket: # Wait for Worker to send READY. while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f4783ae366ef0..c12ff3511e352 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -6,7 +6,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics -from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine import EngineRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: @@ -67,7 +67,7 @@ def __init__( self._kv_block_hashes: List[BlockHashType] = [] @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + def from_engine_core_request(cls, request: EngineRequest) -> "Request": return cls( request_id=request.request_id, inputs=token_inputs( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index e802c6439b740..49b0cf19fd851 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,11 +1,18 @@ +import os +import weakref from collections.abc import Sequence from contextlib import contextmanager -from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, - overload) +from dataclasses import dataclass +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, Iterator, List, Optional, + TypeVar, Union, overload) import zmq +import zmq.asyncio +from vllm.executor.multiproc_worker_utils import get_mp_context from vllm.logger import init_logger +from vllm.utils import kill_process_tree logger = init_logger(__name__) @@ -77,27 +84,123 @@ def __len__(self): return len(self._x) -@contextmanager def make_zmq_socket( + ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] + path: str, + type: Any, +) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + + import psutil + mem = psutil.virtual_memory() + + socket = ctx.socket(type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) # 0.5GB in bytes + else: + buf_size = -1 # Use system default buffer size + + if type == zmq.constants.PULL: + socket.setsockopt(zmq.constants.RCVHWM, 0) + socket.setsockopt(zmq.constants.RCVBUF, buf_size) + socket.bind(path) + elif type == zmq.constants.PUSH: + socket.setsockopt(zmq.constants.SNDHWM, 0) + socket.setsockopt(zmq.constants.SNDBUF, buf_size) + socket.connect(path) + else: + raise ValueError(f"Unknown Socket Type: {type}") + + return socket + + +@contextmanager +def zmq_socket_ctx( path: str, type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined] """Context manager for a ZMQ socket""" - ctx = zmq.Context() # type: ignore[attr-defined] + ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined] try: - socket = ctx.socket(type) - - if type == zmq.constants.PULL: - socket.connect(path) - elif type == zmq.constants.PUSH: - socket.bind(path) - else: - raise ValueError(f"Unknown Socket Type: {type}") - - yield socket + yield make_zmq_socket(ctx, path, type) except KeyboardInterrupt: - logger.debug("Worker had Keyboard Interrupt.") + logger.debug("Got Keyboard Interrupt.") finally: ctx.destroy(linger=0) + + +@dataclass +class BackgroundProcHandle: + proc: BaseProcess + input_path: str + output_path: str + + def shutdown(self): + # Shutdown the process if needed. + if self.proc.is_alive(): + self.proc.terminate() + self.proc.join(5) + + if self.proc.is_alive(): + kill_process_tree(self.proc.pid) + + # Remove zmq ipc socket files + ipc_sockets = [self.output_path, self.input_path] + for ipc_socket in ipc_sockets: + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file) + + +class MPBackgroundProcess: + + def __init__(self): + self.proc_handle: Optional[BackgroundProcHandle] + self._finalizer = weakref.finalize(self, self.shutdown) + + def __del__(self): + self.shutdown() + + def shutdown(self): + if hasattr(self, "proc_handle") and self.proc_handle: + self.proc_handle.shutdown() + self.proc_handle = None + + @staticmethod + def wait_for_startup( + input_path: str, + output_path: str, + process_name: str, + target_fn: Callable, + process_kwargs: Dict[Any, Any], + ) -> BackgroundProcHandle: + context = get_mp_context() + reader, writer = context.Pipe(duplex=False) + + assert ("ready_pipe" not in process_kwargs + and "input_path" not in process_kwargs + and "output_path" not in process_kwargs) + process_kwargs["ready_pipe"] = writer + process_kwargs["input_path"] = input_path + process_kwargs["output_path"] = output_path + + # Run Detokenizer busy loop in background process. + proc = context.Process(target=target_fn, kwargs=process_kwargs) + proc.start() + + # Wait for startup. + if reader.recv()["status"] != "READY": + raise RuntimeError(f"{process_name} initialization failed. " + "See root cause above.") + + return BackgroundProcHandle(proc, input_path, output_path)