Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Dec 23, 2024
1 parent 12df407 commit b7843c9
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 52 deletions.
6 changes: 3 additions & 3 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ class EngineCoreOutputs(
outputs: List[EngineCoreOutput]


class EngineRequestType(enum.Enum):
class EngineRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
"""
FROM_ENGINE_CORE = b'\x00'
FROM_ENGINE = b'\x01'
FROM_ENGINE = b'\x01'
22 changes: 10 additions & 12 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import zmq.asyncio
import pickle

from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
from typing import (Any, AsyncGenerator, Dict, List, Mapping, Optional, Type,
Union)

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -93,15 +94,14 @@ def __init__(
to_detokenizer_path = get_open_zmq_ipc_path()
to_engine_core_path = get_open_zmq_ipc_path()
to_llm_engine_path = get_open_zmq_ipc_path()


# Detokenizer IPC.
self.ctx = zmq.asyncio.Context(io_threads=2)
self.from_detokenizer = make_zmq_socket(
self.ctx, to_llm_engine_path, zmq.PULL)
self.to_detokenizer = make_zmq_socket(
self.ctx, to_detokenizer_path, zmq.PUSH)
self.from_detokenizer = make_zmq_socket(self.ctx, to_llm_engine_path,
zmq.PULL)
self.to_detokenizer = make_zmq_socket(self.ctx, to_detokenizer_path,
zmq.PUSH)

# Detokenizer (background process).
self.detokenizer_client = MPDetokenizerClient(
output_path=to_llm_engine_path,
Expand Down Expand Up @@ -162,7 +162,7 @@ def shutdown(self):

if ctx := getattr(self, "ctx", None):
ctx.destroy(linger=0)

if output_handler := getattr(self, "output_hander", None):
output_handler.cancel()

Expand Down Expand Up @@ -278,13 +278,12 @@ async def generate(
yield out

# Client request cancellation is handled through calling
# task.cancel() on generate(). Calling self.abort() forwards the
# task.cancel() on generate(). Calling self.abort() forwards the
# cancellation to the EngineCore and Detokenizer.
except asyncio.CancelledError:
await self.abort(request_id)
raise


async def output_handler_loop(self):
"""Background loop: pulls from Detokenizer and push to Queues."""

Expand All @@ -300,7 +299,6 @@ async def output_handler_loop(self):
# are still flowing, so we just ignore.
if out.request_id in self.rid_to_queue:
self.rid_to_queue[out.request_id].put_nowait(out)


async def abort(self, request_id: str):
"""Abort request if the client cancels the request."""
Expand All @@ -314,7 +312,7 @@ async def abort(self, request_id: str):

if self.log_requests:
logger.info("Aborted %s.", request_id)

async def send_to_detokenizer(self, object: Any):
"""Send object to Detokenizer with a FROM_ENGINE flag."""

Expand Down
17 changes: 6 additions & 11 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from multiprocessing.connection import Connection

from vllm.config import CacheConfig, VllmConfig
from vllm.executor.multiproc_worker_utils import get_mp_context
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
Expand Down Expand Up @@ -108,8 +107,8 @@ def add_request(self, request: EngineRequest):
def abort_requests(self, request_ids: List[str]):
"""Abort requests from the scheduler."""

# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
Expand Down Expand Up @@ -150,7 +149,7 @@ def __init__(
super().__init__(vllm_config, executor_class, usage_context)

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ IO with GPU since they release the GIL and
# overlap ZMQ IO with GPU since they release the GIL and
# some serialization/deserialization with the model forward.
# Threads handle Socket <-> Queues and busy_loop uses Queues.
self.input_queue: queue.Queue[EngineRequestUnion] = queue.Queue()
Expand All @@ -165,7 +164,6 @@ def __init__(
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})


@staticmethod
def run_engine_core(*args, **kwargs):
"""Launch EngineCore busy loop in background process."""
Expand Down Expand Up @@ -226,7 +224,7 @@ def run_busy_loop(self):
except BaseException:
raise

# 2) Handle any new inputs.
# 2) Handle any new client requests (Abort or Add).
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(req)
Expand Down Expand Up @@ -295,11 +293,8 @@ def process_output_socket(self, output_path: str):
class MPEngineCoreClient(MPBackgroundProcess):
"""Client for multi-proc EngineCore."""

def __init__(self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
executor_class: Type[Executor],
def __init__(self, input_path: str, output_path: str,
vllm_config: VllmConfig, executor_class: Type[Executor],
usage_context: UsageContext):

super().__init__()
Expand Down
26 changes: 14 additions & 12 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import signal
from dataclasses import dataclass
from multiprocessing.connection import Connection
from typing import Dict, Iterable, List, Optional, Tuple,Union
from typing import Dict, Iterable, List, Optional, Tuple, Union

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
Expand All @@ -23,6 +23,7 @@

POLLING_TIMEOUT_MS = 5000


@dataclass
class IncrementalDetokenizer:

Expand Down Expand Up @@ -90,7 +91,8 @@ def from_new_request(
# NOTE(Nick): could we take ownership of it though?
token_ids=request.prompt_token_ids.copy(),
stop=stops,
include_stop_str_in_output=sampling_params.include_stop_str_in_output,
include_stop_str_in_output=sampling_params.
include_stop_str_in_output,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=sampling_params.skip_special_tokens,
Expand Down Expand Up @@ -247,7 +249,8 @@ def add_request(
self.request_states[request.request_id] = request_state

def step(
self, encore_core_outputs: EngineCoreOutputs,
self,
encore_core_outputs: EngineCoreOutputs,
) -> Tuple[List[RequestOutput], List[str]]:
"""Update state and make RequestOutputs for the LLMEngine."""

Expand Down Expand Up @@ -283,6 +286,7 @@ def step(
# Return to EngineClient.
return request_outputs, requests_to_abort


class DetokenizerProc(Detokenizer):
"""ZMQ-wrapper for running Detokenizer in background process."""

Expand All @@ -304,7 +308,6 @@ def __init__(
# Send Readiness signal to DetokenizerClient.
ready_pipe.send({"status": "READY"})


@staticmethod
def run_detokenizer(*args, **kwargs):
"""Launch Detokenizer busy loop in background process."""
Expand Down Expand Up @@ -336,15 +339,15 @@ def signal_handler(signum, frame):

except Exception:
traceback = get_exception_traceback()
logger.error(f"Detokenizer hit an exception: {traceback}")
logger.error("Detokenizer hit an exception: %s", traceback)
parent_process.send_signal(signal.SIGQUIT)

finally:
if detokenizer is not None:
detokenizer = None

def _handle_from_llm_engine(
self,
self,
request_bytes: bytes,
to_engine_core: zmq.Socket,
) -> None:
Expand All @@ -361,7 +364,7 @@ def _handle_from_llm_engine(

# Forward to EngineCore.
to_engine_core.send(request_bytes)

def _handle_from_engine_core(
self,
output_bytes: bytes,
Expand All @@ -382,9 +385,7 @@ def _handle_from_engine_core(

# Abort requests that finished due to stop strings.
if len(requests_to_abort) > 0:
to_engine_core.send_pyobj(
EngineAbortRequest(requests_to_abort))

to_engine_core.send_pyobj(EngineAbortRequest(requests_to_abort))

def run_busy_loop(self):
"""Core busy loop of the Detokenizer."""
Expand All @@ -395,7 +396,8 @@ def run_busy_loop(self):
try:
input_socket = make_zmq_socket(ctx, self.input_path, zmq.PULL)
to_llm_engine = make_zmq_socket(ctx, self.output_path, zmq.PUSH)
to_engine_core = make_zmq_socket(ctx, self.to_engine_core_path, zmq.PUSH)
to_engine_core = make_zmq_socket(ctx, self.to_engine_core_path,
zmq.PUSH)
while True:
(msg_type, msg_bytes) = input_socket.recv_multipart()

Expand Down Expand Up @@ -423,7 +425,7 @@ def run_busy_loop(self):

class MPDetokenizerClient(MPBackgroundProcess):
"""Client for multi-proc Detokenizer."""

def __init__(self,
input_path: str,
output_path: str,
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
executor_class=executor_class,
usage_context=usage_context,
)

else:
# Detokenizer (in process).
self.detokenizer = Detokenizer(
Expand Down Expand Up @@ -190,7 +190,7 @@ def add_request(
self.engine_core.add_request(engine_request)

def step(self) -> List[RequestOutput]:

if self.multiprocess_mode:
# Get next output from the Detokenizer.
return self.detokenizer_client.output_socket.recv_pyobj()

Check failure on line 196 in vllm/v1/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

"MPDetokenizerClient" has no attribute "output_socket" [attr-defined]

Check failure on line 196 in vllm/v1/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

"MPDetokenizerClient" has no attribute "output_socket" [attr-defined]

Check failure on line 196 in vllm/v1/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

"MPDetokenizerClient" has no attribute "output_socket" [attr-defined]

Check failure on line 196 in vllm/v1/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

"MPDetokenizerClient" has no attribute "output_socket" [attr-defined]
Expand All @@ -203,7 +203,7 @@ def step(self) -> List[RequestOutput]:
# Abort any requests that hit a stop string.
if requests_to_abort:
self.abort_request(requests_to_abort)

return request_outputs

def get_model_config(self):
Expand Down
19 changes: 8 additions & 11 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from multiprocessing.process import BaseProcess
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Generic, Dict, Iterator, List, Optional, TypeVar,
from typing import (Any, Generic, Dict, Iterator, List, Optional, TypeVar,
Union, Callable, overload)

import zmq
Expand Down Expand Up @@ -179,23 +179,20 @@ def wait_for_startup(
context = get_mp_context()
reader, writer = context.Pipe(duplex=False)

assert ("ready_pipe" not in process_kwargs and
"input_path" not in process_kwargs and
"output_path" not in process_kwargs)
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path

# Run Detokenizer busy loop in background process.
proc = context.Process(target=target_fn,
kwargs=process_kwargs)
proc = context.Process(target=target_fn, kwargs=process_kwargs)
proc.start()

# Wait for startup.
if reader.recv()["status"] != "READY":
raise RuntimeError(
f"{process_name} initalization failed. "
"See root cause above."
)
raise RuntimeError(f"{process_name} initalization failed. "
"See root cause above.")

return BackgroundProcHandle(proc, input_path, output_path)

Check failure on line 198 in vllm/v1/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible return value type (got "BackgroundProcHandle", expected "MPBackgroundProcess") [return-value]

Check failure on line 198 in vllm/v1/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible return value type (got "BackgroundProcHandle", expected "MPBackgroundProcess") [return-value]

Check failure on line 198 in vllm/v1/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible return value type (got "BackgroundProcHandle", expected "MPBackgroundProcess") [return-value]

Check failure on line 198 in vllm/v1/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible return value type (got "BackgroundProcHandle", expected "MPBackgroundProcess") [return-value]

0 comments on commit b7843c9

Please sign in to comment.