Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Multiprocessing Tensor Parallel Support for v1 #9856

Merged
merged 68 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5ad9c60
initial v1 tp support
tlrmchlsmth Nov 22, 2024
49869fa
V1 TP with zmq-based boostrapping
tlrmchlsmth Nov 22, 2024
71e08aa
improve check for USE_SCHED_YIELD
tlrmchlsmth Nov 22, 2024
4930246
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 22, 2024
3ea0cae
fixup
tlrmchlsmth Nov 22, 2024
d4b55ae
workers must be daemonic
tlrmchlsmth Nov 22, 2024
feeed73
We can now terminate properly
tlrmchlsmth Nov 22, 2024
e3c9c5c
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 25, 2024
254714d
fixes from merge
tlrmchlsmth Nov 25, 2024
10a627e
Fixup termination
tlrmchlsmth Nov 25, 2024
d95c01e
Appease mypy
tlrmchlsmth Nov 25, 2024
c08bae4
Allow shm_broadcast to enqueue by eithr pickle or msgpack
tlrmchlsmth Nov 25, 2024
2392755
Switch back to pickle for shm_broadcast serialization
tlrmchlsmth Nov 26, 2024
bf3705c
Finish msgpack -> pickle
tlrmchlsmth Nov 26, 2024
d4ea706
wrap sched_yield and time.sleep in a fn
tlrmchlsmth Nov 26, 2024
2174a5b
Review comments
tlrmchlsmth Nov 26, 2024
25270ab
Rename executors to uniproc and multiproc
tlrmchlsmth Nov 26, 2024
9322db5
more review comments
tlrmchlsmth Nov 26, 2024
b5bac31
format
tlrmchlsmth Nov 26, 2024
c4fcfce
hacky hacky hacky cleanup
tlrmchlsmth Nov 26, 2024
bedd593
Fix spawn vs fork issue using approach from #8823
tlrmchlsmth Nov 26, 2024
c03ef6d
skip non-distributed tests in test_basic_correctness to see what happens
tlrmchlsmth Nov 26, 2024
8d9d557
fix async_llm
tlrmchlsmth Nov 27, 2024
5f3a570
format
tlrmchlsmth Nov 27, 2024
b59babc
Fixes for testing
tlrmchlsmth Nov 27, 2024
66116c7
Abstract executor class for typing
tlrmchlsmth Nov 27, 2024
eaeebc3
remove enforce_eager, format
tlrmchlsmth Nov 27, 2024
6d53d6e
remove stop_remote_worker_execution_loop
tlrmchlsmth Nov 27, 2024
a7025fb
Remove profiling
tlrmchlsmth Nov 27, 2024
6a3f2da
ExecutorMsg -> WorkerExecRequest
tlrmchlsmth Nov 27, 2024
d4e3813
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 27, 2024
52ef894
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 2, 2024
9f9883e
ensure_termination
tlrmchlsmth Dec 2, 2024
1990433
Move ensure_termination to executor to avoid futures
tlrmchlsmth Dec 2, 2024
f8a1b9b
minor updates
tlrmchlsmth Dec 2, 2024
963c97f
call destroy_distributed_environment atexit
tlrmchlsmth Dec 2, 2024
0678911
more graceful shutdown
tlrmchlsmth Dec 3, 2024
3d71b53
Simplify worker termination
tlrmchlsmth Dec 3, 2024
88c9c7b
atexit -> weakref.finalize
tlrmchlsmth Dec 3, 2024
ab7cb89
minor cleanup
tlrmchlsmth Dec 3, 2024
024bcad
core client cleanup rework
tlrmchlsmth Dec 3, 2024
d77bab5
poke CI
tlrmchlsmth Dec 3, 2024
24ffb8a
fix V1 test, temporarily delete some noisy log statements
tlrmchlsmth Dec 3, 2024
be4260f
nccl/issues/1234
tlrmchlsmth Dec 4, 2024
cb4b363
Cleanup, _add_prefix
tlrmchlsmth Dec 4, 2024
365ea06
fixup noise a bit
tlrmchlsmth Dec 4, 2024
c94e11b
tweaks
tlrmchlsmth Dec 5, 2024
2a36db7
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 5, 2024
536e5f2
back to atexit
tlrmchlsmth Dec 5, 2024
998eb1d
Clean up process termination
tlrmchlsmth Dec 5, 2024
ebb2544
robcomments
tlrmchlsmth Dec 5, 2024
c81b7f5
format
tlrmchlsmth Dec 5, 2024
0817336
client now kills workers directly to avoid zombies
tlrmchlsmth Dec 6, 2024
f10e5e8
remote rpc
tlrmchlsmth Dec 6, 2024
e49b071
use WorkerWrapperBase
tlrmchlsmth Dec 6, 2024
661278f
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 6, 2024
fce9696
de-duplicate env setup
tlrmchlsmth Dec 6, 2024
c61a3e0
Use collective_rpc for initialization
tlrmchlsmth Dec 7, 2024
8bb2430
add RPCParams for readability
tlrmchlsmth Dec 7, 2024
5271ec6
fixup
tlrmchlsmth Dec 7, 2024
50a12bc
Merge branch 'main' into tms/v1_tp: instance id
tlrmchlsmth Dec 7, 2024
edab869
review comments
tlrmchlsmth Dec 9, 2024
e0aea84
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 9, 2024
ce08cb2
profile
tlrmchlsmth Dec 9, 2024
65b79c4
move vllm envs import to work with run_with_both_engines
tlrmchlsmth Dec 9, 2024
143ed09
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 9, 2024
ab6bf27
review comments.
tlrmchlsmth Dec 9, 2024
819b229
collective rpc function signature sanity
tlrmchlsmth Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
Expand All @@ -36,6 +44,7 @@ def test_vllm_gc_ed():
assert weak_llm() is None


@pytest.mark.skip_v1
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"])
Expand Down Expand Up @@ -118,6 +127,11 @@ def test_models_distributed(
if attention_backend:
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend

# Import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this dynamic patching mean? envs.VLLM_USE_V1 should read the latest env var value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the import here to get the VLLM_USE_V1 check using when we are using the run_with_both_engines pytest fixture during testing.

vllm/tests/conftest.py

Lines 112 to 126 in bf0e382

@pytest.fixture(params=[True, False])
def run_with_both_engines(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield

Please LMK if you have a better idea!

if VLLM_USE_V1 and distributed_executor_backend != "mp":
pytest.skip(f"Skip {distributed_executor_backend} for V1")

dtype = "half"
max_tokens = 5

Expand All @@ -143,6 +157,7 @@ def test_models_distributed(
)


@pytest.mark.skip_v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this skipped?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails on V1 but I don't know why. It's not related to this PR as it's not running TP and fails on current main (just enabled it on #10864)

def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
Expand All @@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
os.remove(filename)


@pytest.mark.skip_v1
def test_failure_with_async_out_proc(vllm_runner) -> None:

filename = None
Expand Down
5 changes: 0 additions & 5 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.splitting_ops)

from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
logger.debug("%s", lazy_format_graph_code("after split",
self.split_gm))
tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved

compilation_counter.num_piecewise_graphs_seen += len(
self.piecewise_graphs)
submod_names_to_compile = [
Expand Down
36 changes: 27 additions & 9 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from typing import List, Optional, Tuple
from unittest.mock import patch

import torch
Expand All @@ -21,6 +22,20 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)


class ShmRingBuffer:

Expand Down Expand Up @@ -114,11 +129,14 @@ def __init__(self,
# and we should suppress the error
pass

def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)

def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
self.handle(),
)

def __del__(self):
Expand Down Expand Up @@ -147,7 +165,7 @@ class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)

buffer: Optional[ShmRingBuffer] = None
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None

Expand Down Expand Up @@ -228,7 +246,7 @@ def __init__(
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
buffer_handle=self.buffer.handle(),
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
Expand All @@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
context = Context()

if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
Expand Down Expand Up @@ -329,7 +347,7 @@ def acquire_write(self):
# we need to wait until it is read by all readers

# Release the processor to other threads
os.sched_yield()
sched_yield()

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand Down Expand Up @@ -383,7 +401,7 @@ def acquire_read(self):
# we need to wait until it is written

# Release the processor to other threads
os.sched_yield()
sched_yield()

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
# note: we import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved
self.use_gather = not current_platform.is_tpu() and not VLLM_USE_V1

def forward(
self,
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
Expand Down Expand Up @@ -383,7 +385,7 @@ def update_from_output(
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
sampled_token_ids = model_runner_output.sampled_token_ids_cpu
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []
Expand Down
18 changes: 14 additions & 4 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.executor.abstract import Executor

logger = init_logger(__name__)

Expand All @@ -30,7 +30,7 @@ class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
Expand Down Expand Up @@ -114,14 +114,24 @@ def from_engine_args(
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""

self.engine_core.shutdown()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()

if handler := getattr(self, "output_handler", None):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
return GPUExecutor
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class

async def add_request(
self,
Expand Down
Loading
Loading