Skip to content

Commit

Permalink
Roll own initialization loop
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth committed Nov 22, 2024
1 parent 4ca2afe commit 34ca6bb
Show file tree
Hide file tree
Showing 12 changed files with 505 additions and 214 deletions.
2 changes: 2 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@

TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
Expand Down
79 changes: 59 additions & 20 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import pickle
import struct
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from typing import List, Optional, Tuple
from unittest.mock import patch

import msgspec
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand All @@ -21,6 +23,13 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on python < 3.11.1,
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = False
if sys.version_info[:3] >= (3, 11, 1):
USE_SCHED_YIELD = True


class ShmRingBuffer:

Expand Down Expand Up @@ -74,7 +83,7 @@ def __init__(self,
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
created object to other processes by serializing it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
Expand Down Expand Up @@ -114,6 +123,10 @@ def __init__(self,
# and we should suppress the error
pass

def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)

def __reduce__(self):
return (
self.__class__,
Expand Down Expand Up @@ -147,13 +160,18 @@ class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)

buffer: Optional[ShmRingBuffer] = None
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None


class MessageQueue:

# Use 4 bytes to store size of each message (we omit this for ZMQ).
# This is needed for decoding the message.
SIZE_PREFIX_FORMAT = '!I' # unsigned int, 4 bytes, network byte order
SIZE_PREFIX_LEN = struct.calcsize(SIZE_PREFIX_FORMAT)

def __init__(
self,
n_reader, # number of all readers
Expand Down Expand Up @@ -228,7 +246,7 @@ def __init__(
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
buffer_handle=self.buffer.handle(),
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
Expand All @@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
context = Context()

if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
Expand Down Expand Up @@ -329,7 +347,10 @@ def acquire_write(self):
# we need to wait until it is read by all readers

# Release the processor to other threads
os.sched_yield()
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(1e-5)

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand Down Expand Up @@ -383,7 +404,10 @@ def acquire_read(self):
# we need to wait until it is written

# Release the processor to other threads
os.sched_yield()
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand All @@ -408,34 +432,49 @@ def acquire_read(self):

def enqueue(self, obj):
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)

encoder = msgspec.msgpack.Encoder()
serialized_obj = encoder.encode(obj)
size_to_write = self.SIZE_PREFIX_LEN + len(serialized_obj)

if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
if size_to_write >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
obj_offset = 1 + self.SIZE_PREFIX_LEN

# Write size prefix
buf[1:obj_offset] = struct.pack(self.SIZE_PREFIX_FORMAT,
len(serialized_obj))

buf[obj_offset:obj_offset +
len(serialized_obj)] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)

def dequeue(self):
def dequeue(self, obj_type):
decoder = msgspec.msgpack.Decoder(obj_type)

if self._is_local_reader:
with self.acquire_read() as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
# pickle format contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
obj_offset = 1 + self.SIZE_PREFIX_LEN
size_bytes = buf[1:obj_offset]
msg_size = struct.unpack(self.SIZE_PREFIX_FORMAT,
size_bytes)[0]

obj = decoder.decode(buf[obj_offset:obj_offset + msg_size])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
obj = decoder.decode(recv)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
obj = decoder.decode(recv)
else:
raise RuntimeError("Only readers can dequeue")
return obj
Expand All @@ -445,7 +484,7 @@ def broadcast_object(self, obj=None):
self.enqueue(obj)
return obj
else:
return self.dequeue()
return self.dequeue(obj)

@staticmethod
def create_from_process_group(pg: ProcessGroup,
Expand Down
12 changes: 7 additions & 5 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,11 +1066,13 @@ def initialize_model_parallel(
group_ranks.append(ranks)

# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="tp")
_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
#TODO: this is not getting cleaned up.
use_message_queue_broadcaster=False,
group_name="tp")

# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
Expand Down
30 changes: 28 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import Mapping
from contextlib import contextmanager
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Type, TypeVar, Union, overload)
Hashable, Iterator, List, Literal, Optional, OrderedDict,
Set, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4

import numpy as np
Expand All @@ -32,6 +33,7 @@
import torch
import torch.types
import yaml
import zmq
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
Expand Down Expand Up @@ -515,6 +517,30 @@ def get_open_zmq_ipc_path() -> str:
return f"ipc://{base_rpc_path}/{uuid4()}"


@contextmanager
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
"""Context manager for use """

ctx = zmq.Context()
try:
socket = ctx.socket(type)

if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")

yield socket

except KeyboardInterrupt:
logger.debug("Worker had Keyboard Interrupt.")

finally:
ctx.destroy(linger=0)


def get_open_port() -> int:
port = envs.VLLM_PORT
if port is not None:
Expand Down
13 changes: 5 additions & 8 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange

logger = init_logger(__name__)


Expand Down Expand Up @@ -508,8 +505,8 @@ class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
mm_inputs: List[MultiModalKwargs]
mm_positions: List[PlaceholderRange]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
Expand Down
24 changes: 24 additions & 0 deletions vllm/v1/core/scheduler_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from enum import Enum, auto
from typing import Optional

import msgspec

from vllm.v1.core.scheduler import SchedulerOutput


#TODO: Move this file
class ExecutorMsgType(Enum):
TOIL = auto()
TERMINATE = auto()


class ExecutorMsg(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
"""A directive from the core process to its worker processes.
Wraps SchedulerOutput with a message type to distinguish between
regular work assignments and termination orders."""
message_type: ExecutorMsgType
payload: Optional[SchedulerOutput]
Loading

0 comments on commit 34ca6bb

Please sign in to comment.