Skip to content

Commit

Permalink
[core] LLM.collective_rpc interface and RLHF example (vllm-project#12084
Browse files Browse the repository at this point in the history
)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 16, 2025
1 parent bf53e0c commit 92e793d
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 35 deletions.
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,15 @@ steps:
- tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
- examples/offline_inference/rlhf.py
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py

- label: Metrics, Tracing Test # 10min
num_gpus: 2
Expand Down
191 changes: 191 additions & 0 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
a simple demonstration of RLHF with vLLM, inspired by
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes
are different, and they live on different GPUs.
Training processes send prompts to inference processes to generate data,
and also synchronize the weights of the model by broadcasting the weights
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework.
"""
import os

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams, configure_as_vllm_process
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker


def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class MyWorker(Worker):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class MyLLM(LLM):

def __init__(self, *args, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
del os.environ["CUDA_VISIBLE_DEVICES"]
super().__init__(*args, **kwargs)


"""
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
It is important for all the processes outside of vLLM to call
`configure_as_vllm_process` to set some common environment variables
the same as vLLM workers.
"""
configure_as_vllm_process()

train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and
GPU 2. For the details on how to use ray, please refer to the ray
documentation https://docs.ray.io/en/latest/ .
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()

pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
"""
launch the vLLM inference engine.
here we use `enforce_eager` to reduce the start time.
"""
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
tensor_parallel_size=2,
distributed_executor_backend="ray",
)

# Generate texts from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

sampling_params = SamplingParams(temperature=0)

outputs = ray.get(llm.generate.remote(prompts, sampling_params))

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

# set up the communication between the training process
# and the inference engine.
master_address = get_ip()
master_port = get_open_port()

handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))
model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)

# simulate training, modify the weights of the model.
for name, p in train_model.named_parameters():
p.data.zero_()

# sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight",
args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)

# check if the weights are updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))

# use the updated model to generate texts, they will be nonsense
# because the weights are all zeros.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
39 changes: 39 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,44 @@

from .version import __version__, __version_tuple__


def configure_as_vllm_process():
"""
set some common config/environment variables that should be set
for all processes created by vllm and all processes
that interact with vllm workers.
"""
import os

import torch

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1

from vllm.platforms import current_platform

if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'


__all__ = [
"__version__",
"__version_tuple__",
Expand All @@ -42,4 +80,5 @@
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"configure_as_vllm_process",
]
25 changes: 25 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)

import cloudpickle
from tqdm import tqdm
from typing_extensions import deprecated

Expand Down Expand Up @@ -186,6 +187,13 @@ def __init__(
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
Expand Down Expand Up @@ -455,6 +463,23 @@ def generate(
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)

def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)

def beam_search(
self,
prompts: List[Union[TokensPrompt, TextPrompt]],
Expand Down
31 changes: 0 additions & 31 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import logging
import os
from typing import Callable, Dict

import torch

import vllm.envs as envs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,34 +47,6 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""

# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1

from vllm.platforms import current_platform

if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
if current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'

global plugins_loaded
if plugins_loaded:
return
Expand Down
15 changes: 11 additions & 4 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

import cloudpickle
import torch

from vllm.config import ObservabilityConfig, VllmConfig
Expand Down Expand Up @@ -521,14 +522,20 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
kwargs = all_kwargs[self.rpc_rank]
enable_trace_function_call_for_thread(self.vllm_config)

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
from vllm import configure_as_vllm_process
configure_as_vllm_process()

from vllm.plugins import load_general_plugins
load_general_plugins()

worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(**kwargs)
assert self.worker is not None

Expand Down

0 comments on commit 92e793d

Please sign in to comment.