From b7843c93060c7684c37f6117261ab7a1e0df6d05 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 23 Dec 2024 22:17:46 +0000 Subject: [PATCH] update --- vllm/v1/engine/__init__.py | 6 +++--- vllm/v1/engine/async_llm.py | 22 ++++++++++------------ vllm/v1/engine/core.py | 17 ++++++----------- vllm/v1/engine/detokenizer.py | 26 ++++++++++++++------------ vllm/v1/engine/llm_engine.py | 6 +++--- vllm/v1/utils.py | 19 ++++++++----------- 6 files changed, 44 insertions(+), 52 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index a99f8a617fd8f..0e104118c9ff9 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -63,10 +63,10 @@ class EngineCoreOutputs( outputs: List[EngineCoreOutput] -class EngineRequestType(enum.Enum): +class EngineRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets without separate encoding step. - """ + """ FROM_ENGINE_CORE = b'\x00' - FROM_ENGINE = b'\x01' + FROM_ENGINE = b'\x01' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4e791e8f06565..7a682f79e7972 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -20,7 +20,8 @@ import zmq.asyncio import pickle -from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Type, Union +from typing import (Any, AsyncGenerator, Dict, List, Mapping, Optional, Type, + Union) from vllm.config import ModelConfig, VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -93,15 +94,14 @@ def __init__( to_detokenizer_path = get_open_zmq_ipc_path() to_engine_core_path = get_open_zmq_ipc_path() to_llm_engine_path = get_open_zmq_ipc_path() - # Detokenizer IPC. self.ctx = zmq.asyncio.Context(io_threads=2) - self.from_detokenizer = make_zmq_socket( - self.ctx, to_llm_engine_path, zmq.PULL) - self.to_detokenizer = make_zmq_socket( - self.ctx, to_detokenizer_path, zmq.PUSH) - + self.from_detokenizer = make_zmq_socket(self.ctx, to_llm_engine_path, + zmq.PULL) + self.to_detokenizer = make_zmq_socket(self.ctx, to_detokenizer_path, + zmq.PUSH) + # Detokenizer (background process). self.detokenizer_client = MPDetokenizerClient( output_path=to_llm_engine_path, @@ -162,7 +162,7 @@ def shutdown(self): if ctx := getattr(self, "ctx", None): ctx.destroy(linger=0) - + if output_handler := getattr(self, "output_hander", None): output_handler.cancel() @@ -278,13 +278,12 @@ async def generate( yield out # Client request cancellation is handled through calling - # task.cancel() on generate(). Calling self.abort() forwards the + # task.cancel() on generate(). Calling self.abort() forwards the # cancellation to the EngineCore and Detokenizer. except asyncio.CancelledError: await self.abort(request_id) raise - async def output_handler_loop(self): """Background loop: pulls from Detokenizer and push to Queues.""" @@ -300,7 +299,6 @@ async def output_handler_loop(self): # are still flowing, so we just ignore. if out.request_id in self.rid_to_queue: self.rid_to_queue[out.request_id].put_nowait(out) - async def abort(self, request_id: str): """Abort request if the client cancels the request.""" @@ -314,7 +312,7 @@ async def abort(self, request_id: str): if self.log_requests: logger.info("Aborted %s.", request_id) - + async def send_to_detokenizer(self, object: Any): """Send object to Detokenizer with a FROM_ENGINE flag.""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e4339222f7539..a06fef170e8d5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,7 +11,6 @@ from multiprocessing.connection import Connection from vllm.config import CacheConfig, VllmConfig -from vllm.executor.multiproc_worker_utils import get_mp_context from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -108,8 +107,8 @@ def add_request(self, request: EngineRequest): def abort_requests(self, request_ids: List[str]): """Abort requests from the scheduler.""" - # TODO: The scheduler doesn't really need to know the - # specific finish reason, TBD whether we propagate that + # TODO: The scheduler doesn't really need to know the + # specific finish reason, TBD whether we propagate that # (i.e. client-aborted vs stop criteria met). self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) @@ -150,7 +149,7 @@ def __init__( super().__init__(vllm_config, executor_class, usage_context) # Background Threads and Queues for IO. These enable us to - # overlap ZMQ IO with GPU since they release the GIL and + # overlap ZMQ IO with GPU since they release the GIL and # some serialization/deserialization with the model forward. # Threads handle Socket <-> Queues and busy_loop uses Queues. self.input_queue: queue.Queue[EngineRequestUnion] = queue.Queue() @@ -165,7 +164,6 @@ def __init__( # Send Readiness signal to EngineClient. ready_pipe.send({"status": "READY"}) - @staticmethod def run_engine_core(*args, **kwargs): """Launch EngineCore busy loop in background process.""" @@ -226,7 +224,7 @@ def run_busy_loop(self): except BaseException: raise - # 2) Handle any new inputs. + # 2) Handle any new client requests (Abort or Add). while not self.input_queue.empty(): req = self.input_queue.get_nowait() self._handle_client_request(req) @@ -295,11 +293,8 @@ def process_output_socket(self, output_path: str): class MPEngineCoreClient(MPBackgroundProcess): """Client for multi-proc EngineCore.""" - def __init__(self, - input_path: str, - output_path: str, - vllm_config: VllmConfig, - executor_class: Type[Executor], + def __init__(self, input_path: str, output_path: str, + vllm_config: VllmConfig, executor_class: Type[Executor], usage_context: UsageContext): super().__init__() diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 870c4a7501a36..384e1a69170af 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -5,7 +5,7 @@ import signal from dataclasses import dataclass from multiprocessing.connection import Connection -from typing import Dict, Iterable, List, Optional, Tuple,Union +from typing import Dict, Iterable, List, Optional, Tuple, Union from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger @@ -23,6 +23,7 @@ POLLING_TIMEOUT_MS = 5000 + @dataclass class IncrementalDetokenizer: @@ -90,7 +91,8 @@ def from_new_request( # NOTE(Nick): could we take ownership of it though? token_ids=request.prompt_token_ids.copy(), stop=stops, - include_stop_str_in_output=sampling_params.include_stop_str_in_output, + include_stop_str_in_output=sampling_params. + include_stop_str_in_output, prefix_offset=prefix_offset, read_offset=read_offset, skip_special_tokens=sampling_params.skip_special_tokens, @@ -247,7 +249,8 @@ def add_request( self.request_states[request.request_id] = request_state def step( - self, encore_core_outputs: EngineCoreOutputs, + self, + encore_core_outputs: EngineCoreOutputs, ) -> Tuple[List[RequestOutput], List[str]]: """Update state and make RequestOutputs for the LLMEngine.""" @@ -283,6 +286,7 @@ def step( # Return to EngineClient. return request_outputs, requests_to_abort + class DetokenizerProc(Detokenizer): """ZMQ-wrapper for running Detokenizer in background process.""" @@ -304,7 +308,6 @@ def __init__( # Send Readiness signal to DetokenizerClient. ready_pipe.send({"status": "READY"}) - @staticmethod def run_detokenizer(*args, **kwargs): """Launch Detokenizer busy loop in background process.""" @@ -336,7 +339,7 @@ def signal_handler(signum, frame): except Exception: traceback = get_exception_traceback() - logger.error(f"Detokenizer hit an exception: {traceback}") + logger.error("Detokenizer hit an exception: %s", traceback) parent_process.send_signal(signal.SIGQUIT) finally: @@ -344,7 +347,7 @@ def signal_handler(signum, frame): detokenizer = None def _handle_from_llm_engine( - self, + self, request_bytes: bytes, to_engine_core: zmq.Socket, ) -> None: @@ -361,7 +364,7 @@ def _handle_from_llm_engine( # Forward to EngineCore. to_engine_core.send(request_bytes) - + def _handle_from_engine_core( self, output_bytes: bytes, @@ -382,9 +385,7 @@ def _handle_from_engine_core( # Abort requests that finished due to stop strings. if len(requests_to_abort) > 0: - to_engine_core.send_pyobj( - EngineAbortRequest(requests_to_abort)) - + to_engine_core.send_pyobj(EngineAbortRequest(requests_to_abort)) def run_busy_loop(self): """Core busy loop of the Detokenizer.""" @@ -395,7 +396,8 @@ def run_busy_loop(self): try: input_socket = make_zmq_socket(ctx, self.input_path, zmq.PULL) to_llm_engine = make_zmq_socket(ctx, self.output_path, zmq.PUSH) - to_engine_core = make_zmq_socket(ctx, self.to_engine_core_path, zmq.PUSH) + to_engine_core = make_zmq_socket(ctx, self.to_engine_core_path, + zmq.PUSH) while True: (msg_type, msg_bytes) = input_socket.recv_multipart() @@ -423,7 +425,7 @@ def run_busy_loop(self): class MPDetokenizerClient(MPBackgroundProcess): """Client for multi-proc Detokenizer.""" - + def __init__(self, input_path: str, output_path: str, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index caef661320cb2..be660a4023b30 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -86,7 +86,7 @@ def __init__( executor_class=executor_class, usage_context=usage_context, ) - + else: # Detokenizer (in process). self.detokenizer = Detokenizer( @@ -190,7 +190,7 @@ def add_request( self.engine_core.add_request(engine_request) def step(self) -> List[RequestOutput]: - + if self.multiprocess_mode: # Get next output from the Detokenizer. return self.detokenizer_client.output_socket.recv_pyobj() @@ -203,7 +203,7 @@ def step(self) -> List[RequestOutput]: # Abort any requests that hit a stop string. if requests_to_abort: self.abort_request(requests_to_abort) - + return request_outputs def get_model_config(self): diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 508474ea53f57..fde4601361256 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -4,7 +4,7 @@ from multiprocessing.process import BaseProcess from collections.abc import Sequence from contextlib import contextmanager -from typing import (Any, Generic, Dict, Iterator, List, Optional, TypeVar, +from typing import (Any, Generic, Dict, Iterator, List, Optional, TypeVar, Union, Callable, overload) import zmq @@ -179,23 +179,20 @@ def wait_for_startup( context = get_mp_context() reader, writer = context.Pipe(duplex=False) - assert ("ready_pipe" not in process_kwargs and - "input_path" not in process_kwargs and - "output_path" not in process_kwargs) + assert ("ready_pipe" not in process_kwargs + and "input_path" not in process_kwargs + and "output_path" not in process_kwargs) process_kwargs["ready_pipe"] = writer process_kwargs["input_path"] = input_path process_kwargs["output_path"] = output_path # Run Detokenizer busy loop in background process. - proc = context.Process(target=target_fn, - kwargs=process_kwargs) + proc = context.Process(target=target_fn, kwargs=process_kwargs) proc.start() - + # Wait for startup. if reader.recv()["status"] != "READY": - raise RuntimeError( - f"{process_name} initalization failed. " - "See root cause above." - ) + raise RuntimeError(f"{process_name} initalization failed. " + "See root cause above.") return BackgroundProcHandle(proc, input_path, output_path)