Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Multiprocessing Tensor Parallel Support for v1 #9856

Merged
merged 68 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5ad9c60
initial v1 tp support
tlrmchlsmth Nov 22, 2024
49869fa
V1 TP with zmq-based boostrapping
tlrmchlsmth Nov 22, 2024
71e08aa
improve check for USE_SCHED_YIELD
tlrmchlsmth Nov 22, 2024
4930246
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 22, 2024
3ea0cae
fixup
tlrmchlsmth Nov 22, 2024
d4b55ae
workers must be daemonic
tlrmchlsmth Nov 22, 2024
feeed73
We can now terminate properly
tlrmchlsmth Nov 22, 2024
e3c9c5c
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 25, 2024
254714d
fixes from merge
tlrmchlsmth Nov 25, 2024
10a627e
Fixup termination
tlrmchlsmth Nov 25, 2024
d95c01e
Appease mypy
tlrmchlsmth Nov 25, 2024
c08bae4
Allow shm_broadcast to enqueue by eithr pickle or msgpack
tlrmchlsmth Nov 25, 2024
2392755
Switch back to pickle for shm_broadcast serialization
tlrmchlsmth Nov 26, 2024
bf3705c
Finish msgpack -> pickle
tlrmchlsmth Nov 26, 2024
d4ea706
wrap sched_yield and time.sleep in a fn
tlrmchlsmth Nov 26, 2024
2174a5b
Review comments
tlrmchlsmth Nov 26, 2024
25270ab
Rename executors to uniproc and multiproc
tlrmchlsmth Nov 26, 2024
9322db5
more review comments
tlrmchlsmth Nov 26, 2024
b5bac31
format
tlrmchlsmth Nov 26, 2024
c4fcfce
hacky hacky hacky cleanup
tlrmchlsmth Nov 26, 2024
bedd593
Fix spawn vs fork issue using approach from #8823
tlrmchlsmth Nov 26, 2024
c03ef6d
skip non-distributed tests in test_basic_correctness to see what happens
tlrmchlsmth Nov 26, 2024
8d9d557
fix async_llm
tlrmchlsmth Nov 27, 2024
5f3a570
format
tlrmchlsmth Nov 27, 2024
b59babc
Fixes for testing
tlrmchlsmth Nov 27, 2024
66116c7
Abstract executor class for typing
tlrmchlsmth Nov 27, 2024
eaeebc3
remove enforce_eager, format
tlrmchlsmth Nov 27, 2024
6d53d6e
remove stop_remote_worker_execution_loop
tlrmchlsmth Nov 27, 2024
a7025fb
Remove profiling
tlrmchlsmth Nov 27, 2024
6a3f2da
ExecutorMsg -> WorkerExecRequest
tlrmchlsmth Nov 27, 2024
d4e3813
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Nov 27, 2024
52ef894
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 2, 2024
9f9883e
ensure_termination
tlrmchlsmth Dec 2, 2024
1990433
Move ensure_termination to executor to avoid futures
tlrmchlsmth Dec 2, 2024
f8a1b9b
minor updates
tlrmchlsmth Dec 2, 2024
963c97f
call destroy_distributed_environment atexit
tlrmchlsmth Dec 2, 2024
0678911
more graceful shutdown
tlrmchlsmth Dec 3, 2024
3d71b53
Simplify worker termination
tlrmchlsmth Dec 3, 2024
88c9c7b
atexit -> weakref.finalize
tlrmchlsmth Dec 3, 2024
ab7cb89
minor cleanup
tlrmchlsmth Dec 3, 2024
024bcad
core client cleanup rework
tlrmchlsmth Dec 3, 2024
d77bab5
poke CI
tlrmchlsmth Dec 3, 2024
24ffb8a
fix V1 test, temporarily delete some noisy log statements
tlrmchlsmth Dec 3, 2024
be4260f
nccl/issues/1234
tlrmchlsmth Dec 4, 2024
cb4b363
Cleanup, _add_prefix
tlrmchlsmth Dec 4, 2024
365ea06
fixup noise a bit
tlrmchlsmth Dec 4, 2024
c94e11b
tweaks
tlrmchlsmth Dec 5, 2024
2a36db7
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 5, 2024
536e5f2
back to atexit
tlrmchlsmth Dec 5, 2024
998eb1d
Clean up process termination
tlrmchlsmth Dec 5, 2024
ebb2544
robcomments
tlrmchlsmth Dec 5, 2024
c81b7f5
format
tlrmchlsmth Dec 5, 2024
0817336
client now kills workers directly to avoid zombies
tlrmchlsmth Dec 6, 2024
f10e5e8
remote rpc
tlrmchlsmth Dec 6, 2024
e49b071
use WorkerWrapperBase
tlrmchlsmth Dec 6, 2024
661278f
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 6, 2024
fce9696
de-duplicate env setup
tlrmchlsmth Dec 6, 2024
c61a3e0
Use collective_rpc for initialization
tlrmchlsmth Dec 7, 2024
8bb2430
add RPCParams for readability
tlrmchlsmth Dec 7, 2024
5271ec6
fixup
tlrmchlsmth Dec 7, 2024
50a12bc
Merge branch 'main' into tms/v1_tp: instance id
tlrmchlsmth Dec 7, 2024
edab869
review comments
tlrmchlsmth Dec 9, 2024
e0aea84
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 9, 2024
ce08cb2
profile
tlrmchlsmth Dec 9, 2024
65b79c4
move vllm envs import to work with run_with_both_engines
tlrmchlsmth Dec 9, 2024
143ed09
Merge branch 'main' into tms/v1_tp
tlrmchlsmth Dec 9, 2024
ab6bf27
review comments.
tlrmchlsmth Dec 9, 2024
819b229
collective rpc function signature sanity
tlrmchlsmth Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
Expand All @@ -36,6 +44,7 @@ def test_vllm_gc_ed():
assert weak_llm() is None


@pytest.mark.skip_v1
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"])
Expand Down Expand Up @@ -118,6 +127,11 @@ def test_models_distributed(
if attention_backend:
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend

# Import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this dynamic patching mean? envs.VLLM_USE_V1 should read the latest env var value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the import here to get the VLLM_USE_V1 check using when we are using the run_with_both_engines pytest fixture during testing.

vllm/tests/conftest.py

Lines 112 to 126 in bf0e382

@pytest.fixture(params=[True, False])
def run_with_both_engines(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield

Please LMK if you have a better idea!

if VLLM_USE_V1 and distributed_executor_backend != "mp":
pytest.skip(f"Skip {distributed_executor_backend} for V1")

dtype = "half"
max_tokens = 5

Expand All @@ -143,6 +157,7 @@ def test_models_distributed(
)


@pytest.mark.skip_v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this skipped?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails on V1 but I don't know why. It's not related to this PR as it's not running TP and fails on current main (just enabled it on #10864)

def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
Expand All @@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
os.remove(filename)


@pytest.mark.skip_v1
def test_failure_with_async_out_proc(vllm_runner) -> None:

filename = None
Expand Down
76 changes: 52 additions & 24 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from typing import List, Optional, Tuple
from unittest.mock import patch

import torch
Expand All @@ -21,6 +22,20 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)


class ShmRingBuffer:

Expand Down Expand Up @@ -114,11 +129,14 @@ def __init__(self,
# and we should suppress the error
pass

def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)

def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
self.handle(),
)

def __del__(self):
Expand Down Expand Up @@ -147,7 +165,7 @@ class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)

buffer: Optional[ShmRingBuffer] = None
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None

Expand Down Expand Up @@ -228,7 +246,7 @@ def __init__(
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
buffer_handle=self.buffer.handle(),
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
Expand All @@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
context = Context()

if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
Expand Down Expand Up @@ -314,7 +332,7 @@ def wait_until_ready(self):
assert recv == b"READY"

@contextmanager
def acquire_write(self):
def acquire_write(self, timeout: Optional[float] = None):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
Expand All @@ -329,16 +347,20 @@ def acquire_write(self):
# we need to wait until it is read by all readers

# Release the processor to other threads
os.sched_yield()
sched_yield()

# if we wait for a long time, we should warn the user
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
raise TimeoutError

continue
# found a block that is either
# (1) not written
Expand All @@ -365,7 +387,7 @@ def acquire_write(self):
break

@contextmanager
def acquire_read(self):
def acquire_read(self, timeout: Optional[float] = None):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
Expand All @@ -383,16 +405,20 @@ def acquire_read(self):
# we need to wait until it is written

# Release the processor to other threads
os.sched_yield()
sched_yield()

# if we wait for a long time, we should warn the user
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
raise TimeoutError

continue
# found a block that is not read by this reader
# let caller read from the buffer
Expand All @@ -406,24 +432,26 @@ def acquire_read(self):
1) % self.buffer.max_chunks
break

def enqueue(self, obj):
def enqueue(self, obj, timeout: Optional[float] = None):
""" Write to message queue with optional timeout (in seconds) """
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
with self.acquire_write(timeout) as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)

def dequeue(self):
def dequeue(self, timeout: Optional[float] = None):
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read() as buf:
with self.acquire_read(timeout) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
Expand Down
47 changes: 6 additions & 41 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,19 @@
from functools import partial
from typing import Any, List, Optional

import torch

from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.executor.multiproc_worker_utils import (
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, make_async,
get_distributed_init_method, get_open_port, make_async,
update_environment_variables)

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager

logger = init_logger(__name__)


Expand All @@ -37,30 +31,8 @@ def _init_executor(self) -> None:
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)

# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and world_size > 1:
maybe_set_triton_cache_manager()
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)

# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
Expand Down Expand Up @@ -122,13 +94,6 @@ def _check_executor_parameters(self):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
Expand Down
42 changes: 42 additions & 0 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,15 @@
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)

import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import cuda_is_initialized

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -270,3 +277,38 @@ def write_with_prefix(s: str):
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)


def set_multiprocessing_worker_envs(parallel_config):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""

if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)

# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and parallel_config.world_size > 1:
maybe_set_triton_cache_manager()
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
# note: we import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved
self.use_gather = not current_platform.is_tpu() and not VLLM_USE_V1

def forward(
self,
Expand Down
Loading
Loading