Skip to content

Commit

Permalink
[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth authored Dec 10, 2024
1 parent bc192a2 commit 28b3a1c
Show file tree
Hide file tree
Showing 21 changed files with 733 additions and 146 deletions.
16 changes: 16 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
Expand All @@ -36,6 +44,7 @@ def test_vllm_gc_ed():
assert weak_llm() is None


@pytest.mark.skip_v1
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"])
Expand Down Expand Up @@ -118,6 +127,11 @@ def test_models_distributed(
if attention_backend:
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend

# Import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
if VLLM_USE_V1 and distributed_executor_backend != "mp":
pytest.skip(f"Skip {distributed_executor_backend} for V1")

dtype = "half"
max_tokens = 5

Expand All @@ -143,6 +157,7 @@ def test_models_distributed(
)


@pytest.mark.skip_v1
def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
Expand All @@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
os.remove(filename)


@pytest.mark.skip_v1
def test_failure_with_async_out_proc(vllm_runner) -> None:

filename = None
Expand Down
11 changes: 5 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
TypedDict, TypeVar, Union)
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -110,7 +109,7 @@ def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:


@pytest.fixture(params=[True, False])
def run_with_both_engines(request):
def run_with_both_engines(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
Expand All @@ -119,11 +118,11 @@ def run_with_both_engines(request):
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
monkeypatch.setenv('VLLM_USE_V1', '0')

yield


@pytest.fixture(autouse=True)
Expand Down
76 changes: 52 additions & 24 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from typing import List, Optional, Tuple
from unittest.mock import patch

import torch
Expand All @@ -21,6 +22,20 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)


class ShmRingBuffer:

Expand Down Expand Up @@ -114,11 +129,14 @@ def __init__(self,
# and we should suppress the error
pass

def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)

def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
self.handle(),
)

def __del__(self):
Expand Down Expand Up @@ -147,7 +165,7 @@ class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)

buffer: Optional[ShmRingBuffer] = None
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None

Expand Down Expand Up @@ -228,7 +246,7 @@ def __init__(
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
buffer_handle=self.buffer.handle(),
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
Expand All @@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
context = Context()

if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
Expand Down Expand Up @@ -314,7 +332,7 @@ def wait_until_ready(self):
assert recv == b"READY"

@contextmanager
def acquire_write(self):
def acquire_write(self, timeout: Optional[float] = None):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
Expand All @@ -329,16 +347,20 @@ def acquire_write(self):
# we need to wait until it is read by all readers

# Release the processor to other threads
os.sched_yield()
sched_yield()

# if we wait for a long time, we should warn the user
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
raise TimeoutError

continue
# found a block that is either
# (1) not written
Expand All @@ -365,7 +387,7 @@ def acquire_write(self):
break

@contextmanager
def acquire_read(self):
def acquire_read(self, timeout: Optional[float] = None):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
Expand All @@ -383,16 +405,20 @@ def acquire_read(self):
# we need to wait until it is written

# Release the processor to other threads
os.sched_yield()
sched_yield()

# if we wait for a long time, we should warn the user
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
raise TimeoutError

continue
# found a block that is not read by this reader
# let caller read from the buffer
Expand All @@ -406,24 +432,26 @@ def acquire_read(self):
1) % self.buffer.max_chunks
break

def enqueue(self, obj):
def enqueue(self, obj, timeout: Optional[float] = None):
""" Write to message queue with optional timeout (in seconds) """
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
with self.acquire_write(timeout) as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)

def dequeue(self):
def dequeue(self, timeout: Optional[float] = None):
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read() as buf:
with self.acquire_read(timeout) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
Expand Down
47 changes: 6 additions & 41 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,19 @@
from functools import partial
from typing import Any, List, Optional

import torch

from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.executor.multiproc_worker_utils import (
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, make_async,
get_distributed_init_method, get_open_port, make_async,
update_environment_variables)

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager

logger = init_logger(__name__)


Expand All @@ -37,30 +31,8 @@ def _init_executor(self) -> None:
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)

# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and world_size > 1:
maybe_set_triton_cache_manager()
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)

# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
Expand Down Expand Up @@ -122,13 +94,6 @@ def _check_executor_parameters(self):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
Expand Down
42 changes: 42 additions & 0 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,15 @@
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)

import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import cuda_is_initialized

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -270,3 +277,38 @@ def write_with_prefix(s: str):
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)


def set_multiprocessing_worker_envs(parallel_config):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""

if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)

# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and parallel_config.world_size > 1:
maybe_set_triton_cache_manager()
Loading

0 comments on commit 28b3a1c

Please sign in to comment.