From 92e793d91a1a4e982662ecca0096e5edcafd21c6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 16 Jan 2025 20:19:52 +0800 Subject: [PATCH] [core] LLM.collective_rpc interface and RLHF example (#12084) Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 4 + examples/offline_inference/rlhf.py | 191 +++++++++++++++++++++++++++++ vllm/__init__.py | 39 ++++++ vllm/entrypoints/llm.py | 25 ++++ vllm/plugins/__init__.py | 31 ----- vllm/worker/worker_base.py | 15 ++- 6 files changed, 270 insertions(+), 35 deletions(-) create mode 100644 examples/offline_inference/rlhf.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 00fed96c1ac8c..7442de245bd80 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -126,11 +126,15 @@ steps: - tests/distributed - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile + - examples/offline_inference/rlhf.py commands: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py + # TODO: create a dedicated test section for multi-GPU example tests + # when we have multiple distributed example tests + - python3 ../examples/offline_inference/rlhf.py - label: Metrics, Tracing Test # 10min num_gpus: 2 diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py new file mode 100644 index 0000000000000..3bc303dad277f --- /dev/null +++ b/examples/offline_inference/rlhf.py @@ -0,0 +1,191 @@ +""" +a simple demonstration of RLHF with vLLM, inspired by +the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF . +It follows the design that, training processes and inference processes +are different, and they live on different GPUs. +Training processes send prompts to inference processes to generate data, +and also synchronize the weights of the model by broadcasting the weights +from the training process to the inference process. +Note that this is a simple demonstration of one training instance and one +inference instance. In practice, there could be multiple training instances +and multiple inference instances. For the full implementation, please refer +to the OpenRLHF framework. +""" +import os + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM + +from vllm import LLM, SamplingParams, configure_as_vllm_process +from vllm.utils import get_ip, get_open_port +from vllm.worker.worker import Worker + + +def stateless_init_process_group(master_address, master_port, rank, world_size, + device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + pg = StatelessProcessGroup.create(host=master_address, + port=master_port, + rank=rank, + world_size=world_size) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +class MyWorker(Worker): + """ + The `MyWorker` class inherits from `Worker` to provide custom functions. + For simplicity, we define the `MyWorker` class in this self-contained + script. Normally, we should define the `MyWorker` class in a separate + file and pass the qualified name of the class to the `worker_cls` + parameter. + """ + + def init_weight_update_group(self, master_address, master_port, + rank_offset, world_size): + from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset + self.model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + + def update_weight(self, name, dtype, shape): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, + src=0, + stream=torch.cuda.current_stream()) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + +class MyLLM(LLM): + + def __init__(self, *args, **kwargs): + # a hack to make the script work. + # stop ray from manipulating CUDA_VISIBLE_DEVICES + # at the top-level + del os.environ["CUDA_VISIBLE_DEVICES"] + super().__init__(*args, **kwargs) + + +""" +Start the training process, here we use huggingface transformers +as an example to hold a model on GPU 0. + +It is important for all the processes outside of vLLM to call +`configure_as_vllm_process` to set some common environment variables +the same as vLLM workers. +""" +configure_as_vllm_process() + +train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") +train_model.to("cuda:0") +""" +Start the inference process, here we use vLLM to hold a model on GPU 1 and +GPU 2. For the details on how to use ray, please refer to the ray +documentation https://docs.ray.io/en/latest/ . +""" +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" +ray.init() + +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, +) +""" +launch the vLLM inference engine. +here we use `enforce_eager` to reduce the start time. +""" +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, +)(MyLLM).remote( + model="facebook/opt-125m", + enforce_eager=True, + worker_cls=MyWorker, + tensor_parallel_size=2, + distributed_executor_backend="ray", +) + +# Generate texts from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + +# set up the communication between the training process +# and the inference engine. +master_address = get_ip() +master_port = get_open_port() + +handle = llm.collective_rpc.remote("init_weight_update_group", + args=(master_address, master_port, 1, 3)) +model_update_group = stateless_init_process_group(master_address, master_port, + 0, 3, torch.device("cuda:0")) +ray.get(handle) + +# simulate training, modify the weights of the model. +for name, p in train_model.named_parameters(): + p.data.zero_() + +# sync weight from the training process to the inference engine. +for name, p in train_model.named_parameters(): + handle = llm.collective_rpc.remote("update_weight", + args=(name, p.dtype, p.shape)) + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) + ray.get(handle) + +# check if the weights are updated. +assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) + +# use the updated model to generate texts, they will be nonsense +# because the weights are all zeros. +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) +for output in outputs_updated: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..a533dba561c00 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -17,6 +17,44 @@ from .version import __version__, __version_tuple__ + +def configure_as_vllm_process(): + """ + set some common config/environment variables that should be set + for all processes created by vllm and all processes + that interact with vllm workers. + """ + import os + + import torch + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + # see https://github.com/vllm-project/vllm/issues/10480 + os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' + # see https://github.com/vllm-project/vllm/issues/10619 + torch._inductor.config.compile_threads = 1 + + from vllm.platforms import current_platform + + if current_platform.is_xpu(): + # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 + torch._dynamo.config.disable = True + elif current_platform.is_hpu(): + # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) + # does not support torch.compile + # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for + # torch.compile support + is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' + if is_lazy: + torch._dynamo.config.disable = True + # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) + # requires enabling lazy collectives + # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 + os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' + + __all__ = [ "__version__", "__version_tuple__", @@ -42,4 +80,5 @@ "AsyncEngineArgs", "initialize_ray_cluster", "PoolingParams", + "configure_as_vllm_process", ] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index acb4db85632a8..b78d5c65a40f8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -4,6 +4,7 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, overload) +import cloudpickle from tqdm import tqdm from typing_extensions import deprecated @@ -186,6 +187,13 @@ def __init__( if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + if compilation_config is not None: if isinstance(compilation_config, (int, dict)): compilation_config_instance = CompilationConfig.from_cli( @@ -455,6 +463,23 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, RequestOutput) + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + """ + Run a method on all workers, with homogeneous arguments. + The main extension point for the LLM entrypoint. + Users can provide custom worker class through `worker_cls` + argument, and implement new methods in the worker class. + Then, users can call the new methods through this API. + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + return self.llm_engine.model_executor.collective_rpc( + method, timeout, args, kwargs) + def beam_search( self, prompts: List[Union[TokensPrompt, TextPrompt]], diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index e5fa4f0e4a2f6..ff54174f634af 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,9 +1,6 @@ import logging -import os from typing import Callable, Dict -import torch - import vllm.envs as envs logger = logging.getLogger(__name__) @@ -50,34 +47,6 @@ def load_general_plugins(): processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ - - # all processes created by vllm will load plugins, - # and here we can inject some common environment variables - # for all processes. - - # see https://github.com/vllm-project/vllm/issues/10480 - os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' - # see https://github.com/vllm-project/vllm/issues/10619 - torch._inductor.config.compile_threads = 1 - - from vllm.platforms import current_platform - - if current_platform.is_xpu(): - # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 - torch._dynamo.config.disable = True - if current_platform.is_hpu(): - # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) - # does not support torch.compile - # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for - # torch.compile support - is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' - if is_lazy: - torch._dynamo.config.disable = True - # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) - # requires enabling lazy collectives - # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 - os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' - global plugins_loaded if plugins_loaded: return diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d464b614b12f1..bced5b9f44228 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +import cloudpickle import torch from vllm.config import ObservabilityConfig, VllmConfig @@ -521,14 +522,20 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: kwargs = all_kwargs[self.rpc_rank] enable_trace_function_call_for_thread(self.vllm_config) - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ['NCCL_CUMEM_ENABLE'] = '0' + from vllm import configure_as_vllm_process + configure_as_vllm_process() from vllm.plugins import load_general_plugins load_general_plugins() - worker_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_cls) + if isinstance(self.vllm_config.parallel_config.worker_cls, str): + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls) + else: + assert isinstance(self.vllm_config.parallel_config.worker_cls, + bytes) + worker_class = cloudpickle.loads( + self.vllm_config.parallel_config.worker_cls) self.worker = worker_class(**kwargs) assert self.worker is not None