Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Multiprocessing Tensor Parallel Support for v1 #9856

Merged
merged 68 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5ad9c60
initial v1 tp support
tlrmchlsmth Nov 22, 2024
49869fa
V1 TP with zmq-based boostrapping
tlrmchlsmth Nov 22, 2024
71e08aa
improve check for USE_SCHED_YIELD
tlrmchlsmth Nov 22, 2024
4930246
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 22, 2024
3ea0cae
fixup
tlrmchlsmth Nov 22, 2024
d4b55ae
workers must be daemonic
tlrmchlsmth Nov 22, 2024
feeed73
We can now terminate properly
tlrmchlsmth Nov 22, 2024
e3c9c5c
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 25, 2024
254714d
fixes from merge
tlrmchlsmth Nov 25, 2024
10a627e
Fixup termination
tlrmchlsmth Nov 25, 2024
d95c01e
Appease mypy
tlrmchlsmth Nov 25, 2024
c08bae4
Allow shm_broadcast to enqueue by eithr pickle or msgpack
tlrmchlsmth Nov 25, 2024
2392755
Switch back to pickle for shm_broadcast serialization
tlrmchlsmth Nov 26, 2024
bf3705c
Finish msgpack -> pickle
tlrmchlsmth Nov 26, 2024
d4ea706
wrap sched_yield and time.sleep in a fn
tlrmchlsmth Nov 26, 2024
2174a5b
Review comments
tlrmchlsmth Nov 26, 2024
25270ab
Rename executors to uniproc and multiproc
tlrmchlsmth Nov 26, 2024
9322db5
more review comments
tlrmchlsmth Nov 26, 2024
b5bac31
format
tlrmchlsmth Nov 26, 2024
c4fcfce
hacky hacky hacky cleanup
tlrmchlsmth Nov 26, 2024
bedd593
Fix spawn vs fork issue using approach from #8823
tlrmchlsmth Nov 26, 2024
c03ef6d
skip non-distributed tests in test_basic_correctness to see what happens
tlrmchlsmth Nov 26, 2024
8d9d557
fix async_llm
tlrmchlsmth Nov 27, 2024
5f3a570
format
tlrmchlsmth Nov 27, 2024
b59babc
Fixes for testing
tlrmchlsmth Nov 27, 2024
66116c7
Abstract executor class for typing
tlrmchlsmth Nov 27, 2024
eaeebc3
remove enforce_eager, format
tlrmchlsmth Nov 27, 2024
6d53d6e
remove stop_remote_worker_execution_loop
tlrmchlsmth Nov 27, 2024
a7025fb
Remove profiling
tlrmchlsmth Nov 27, 2024
6a3f2da
ExecutorMsg -> WorkerExecRequest
tlrmchlsmth Nov 27, 2024
d4e3813
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 27, 2024
52ef894
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 2, 2024
9f9883e
ensure_termination
tlrmchlsmth Dec 2, 2024
1990433
Move ensure_termination to executor to avoid futures
tlrmchlsmth Dec 2, 2024
f8a1b9b
minor updates
tlrmchlsmth Dec 2, 2024
963c97f
call destroy_distributed_environment atexit
tlrmchlsmth Dec 2, 2024
0678911
more graceful shutdown
tlrmchlsmth Dec 3, 2024
3d71b53
Simplify worker termination
tlrmchlsmth Dec 3, 2024
88c9c7b
atexit -> weakref.finalize
tlrmchlsmth Dec 3, 2024
ab7cb89
minor cleanup
tlrmchlsmth Dec 3, 2024
024bcad
core client cleanup rework
tlrmchlsmth Dec 3, 2024
d77bab5
poke CI
tlrmchlsmth Dec 3, 2024
24ffb8a
fix V1 test, temporarily delete some noisy log statements
tlrmchlsmth Dec 3, 2024
be4260f
nccl/issues/1234
tlrmchlsmth Dec 4, 2024
cb4b363
Cleanup, _add_prefix
tlrmchlsmth Dec 4, 2024
365ea06
fixup noise a bit
tlrmchlsmth Dec 4, 2024
c94e11b
tweaks
tlrmchlsmth Dec 5, 2024
2a36db7
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 5, 2024
536e5f2
back to atexit
tlrmchlsmth Dec 5, 2024
998eb1d
Clean up process termination
tlrmchlsmth Dec 5, 2024
ebb2544
robcomments
tlrmchlsmth Dec 5, 2024
c81b7f5
format
tlrmchlsmth Dec 5, 2024
0817336
client now kills workers directly to avoid zombies
tlrmchlsmth Dec 6, 2024
f10e5e8
remote rpc
tlrmchlsmth Dec 6, 2024
e49b071
use WorkerWrapperBase
tlrmchlsmth Dec 6, 2024
661278f
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 6, 2024
fce9696
de-duplicate env setup
tlrmchlsmth Dec 6, 2024
c61a3e0
Use collective_rpc for initialization
tlrmchlsmth Dec 7, 2024
8bb2430
add RPCParams for readability
tlrmchlsmth Dec 7, 2024
5271ec6
fixup
tlrmchlsmth Dec 7, 2024
50a12bc
Merge branch 'main' into tms/v1_tp: instance id
tlrmchlsmth Dec 7, 2024
edab869
review comments
tlrmchlsmth Dec 9, 2024
e0aea84
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 9, 2024
ce08cb2
profile
tlrmchlsmth Dec 9, 2024
65b79c4
move vllm envs import to work with run_with_both_engines
tlrmchlsmth Dec 9, 2024
143ed09
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 9, 2024
ab6bf27
review comments.
tlrmchlsmth Dec 9, 2024
819b229
collective rpc function signature sanity
tlrmchlsmth Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
Expand Down Expand Up @@ -143,6 +151,7 @@ def test_models_distributed(
)


@pytest.mark.skip_v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this skipped?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails on V1 but I don't know why. It's not related to this PR as it's not running TP and fails on current main (just enabled it on #10864)

def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
Expand All @@ -169,6 +178,7 @@ def test_model_with_failure(vllm_runner) -> None:
os.remove(filename)


@pytest.mark.skip_v1
def test_failure_with_async_out_proc(vllm_runner) -> None:

filename = None
Expand Down
95 changes: 87 additions & 8 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import pickle
import struct
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from typing import List, Optional, Tuple
from unittest.mock import patch

import msgspec
tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand All @@ -21,6 +24,13 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


class ShmRingBuffer:

Expand Down Expand Up @@ -74,7 +84,7 @@ def __init__(self,
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
created object to other processes by serializing it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
Expand Down Expand Up @@ -114,6 +124,10 @@ def __init__(self,
# and we should suppress the error
pass

def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)

def __reduce__(self):
return (
self.__class__,
Expand Down Expand Up @@ -147,13 +161,19 @@ class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)

buffer: Optional[ShmRingBuffer] = None
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None


class MessageQueue:

# For msgpack serialization, we use 4 bytes to store the size of each
# message, as we need the size of the encoded message while decoding.
# This is not needed for zmq or pickle.
SIZE_PREFIX_FORMAT = '!I' # unsigned int, 4 bytes, network byte order
SIZE_PREFIX_LEN = struct.calcsize(SIZE_PREFIX_FORMAT)

def __init__(
self,
n_reader, # number of all readers
Expand Down Expand Up @@ -228,7 +248,7 @@ def __init__(
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
buffer_handle=self.buffer.handle(),
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
Expand All @@ -247,8 +267,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
context = Context()

if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
Expand Down Expand Up @@ -329,7 +349,10 @@ def acquire_write(self):
# we need to wait until it is read by all readers

# Release the processor to other threads
os.sched_yield()
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(1e-5)
tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand Down Expand Up @@ -383,7 +406,10 @@ def acquire_read(self):
# we need to wait until it is written

# Release the processor to other threads
os.sched_yield()
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand All @@ -407,6 +433,7 @@ def acquire_read(self):
break

def enqueue(self, obj):
"""Enqueue obj using pickle serialization"""
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
Expand All @@ -422,6 +449,7 @@ def enqueue(self, obj):
self.remote_socket.send(serialized_obj)

def dequeue(self):
"""Dequeue obj using pickle serialization"""
if self._is_local_reader:
with self.acquire_read() as buf:
overflow = buf[0] == 1
Expand All @@ -440,6 +468,57 @@ def dequeue(self):
raise RuntimeError("Only readers can dequeue")
return obj

def enqueue_via_msgpack(self, obj: msgspec.Struct):
"""Enqueue obj using msgpack serialization"""
assert self._is_writer, "Only writers can enqueue"

encoder = msgspec.msgpack.Encoder()
serialized_obj = encoder.encode(obj)
size_to_write = self.SIZE_PREFIX_LEN + len(serialized_obj)

if self.n_local_reader > 0:
if size_to_write >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
obj_offset = 1 + self.SIZE_PREFIX_LEN

# Write size prefix
buf[1:obj_offset] = struct.pack(self.SIZE_PREFIX_FORMAT,
len(serialized_obj))

buf[obj_offset:obj_offset +
len(serialized_obj)] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)

def dequeue_via_msgpack(self, obj_type):
"""Enqueue obj using msgpack serialization"""
tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved
decoder = msgspec.msgpack.Decoder(obj_type)

if self._is_local_reader:
with self.acquire_read() as buf:
overflow = buf[0] == 1
if not overflow:
obj_offset = 1 + self.SIZE_PREFIX_LEN
size_bytes = buf[1:obj_offset]
msg_size = struct.unpack(self.SIZE_PREFIX_FORMAT,
size_bytes)[0]

obj = decoder.decode(buf[obj_offset:obj_offset + msg_size])
if overflow:
recv = self.local_socket.recv()
obj = decoder.decode(recv)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = decoder.decode(recv)
else:
raise RuntimeError("Only readers can dequeue")
return obj

def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.envs import VLLM_USE_V1
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
self.use_gather = not current_platform.is_tpu() and not VLLM_USE_V1

def forward(
self,
Expand Down
15 changes: 6 additions & 9 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange

logger = init_logger(__name__)


Expand Down Expand Up @@ -382,7 +379,7 @@ def update_from_output(
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
sampled_token_ids = model_runner_output.sampled_token_ids_cpu
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []
Expand Down Expand Up @@ -508,8 +505,8 @@ class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
mm_inputs: List[MultiModalKwargs]
mm_positions: List[PlaceholderRange]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
Expand Down
24 changes: 24 additions & 0 deletions vllm/v1/core/scheduler_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from enum import Enum, auto
from typing import Optional

import msgspec

from vllm.v1.core.scheduler import SchedulerOutput


#TODO: Move this file
class ExecutorMsgType(Enum):
TOIL = auto()
TERMINATE = auto()


class ExecutorMsg(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
"""A directive from the core process to its worker processes.

Wraps SchedulerOutput with a message type to distinguish between
regular work assignments and termination orders."""
message_type: ExecutorMsgType
payload: Optional[SchedulerOutput]
38 changes: 9 additions & 29 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import queue
import threading
import time
from contextlib import contextmanager
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import Synchronized
from typing import Any, Iterator, List, Tuple, Type, Union
from typing import List, Tuple, Type, Union

import zmq
import zmq.asyncio
Expand All @@ -23,6 +22,7 @@
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import make_zmq_socket
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -128,8 +128,11 @@ def step(self) -> List[EngineCoreOutput]:
scheduler_output, output)
return engine_core_outputs

def shutdown(self):
self.model_executor.shutdown()

def profile(self, is_start=True):
self.model_executor.worker.profile(is_start)
self.model_executor.profile(is_start)


class EngineCoreProc(EngineCore):
Expand Down Expand Up @@ -167,32 +170,9 @@ def __init__(
daemon=True).start()

# Send Readiness signal to EngineClient.
with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket:
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(EngineCoreProc.READY_STR)

@contextmanager
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
"""Context manager for use """

ctx = zmq.Context()
try:
socket = ctx.socket(type)

if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")

yield socket

except KeyboardInterrupt:
logger.debug("EngineCore had Keyboard Interrupt.")

finally:
ctx.destroy(linger=0)

@staticmethod
def wait_for_startup(
proc: BaseProcess,
Expand Down Expand Up @@ -337,7 +317,7 @@ def process_input_socket(self, input_path: str):
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()

with self.make_socket(input_path, zmq.constants.PULL) as socket:
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
Expand Down Expand Up @@ -365,7 +345,7 @@ def process_output_socket(self, output_path: str):
# Reuse send buffer.
buffer = bytearray()

with self.make_socket(output_path, zmq.constants.PUSH) as socket:
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
Expand Down
Loading