Skip to content

Commit

Permalink
[2/N] executor pass the complete config to worker/modelrunner (#9938)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
  • Loading branch information
youkaichao and njhill authored Nov 2, 2024
1 parent 1d4cfe2 commit e893795
Show file tree
Hide file tree
Showing 44 changed files with 250 additions and 580 deletions.
8 changes: 1 addition & 7 deletions tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init):
enable_lora=True)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
vllm_config=engine_config,
is_driver_worker=True,
)
model_runner.load_model()
Expand Down
12 changes: 8 additions & 4 deletions tests/lora/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from unittest.mock import patch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
ModelConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker


@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
worker = Worker(
vllm_config = VllmConfig(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
task="auto",
Expand All @@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files):
gpu_memory_utilization=1.,
swap_space=0,
cache_dtype="auto"),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
)
worker = Worker(
vllm_config=vllm_config,
local_rank=0,
rank=0,
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_device()
Expand Down
7 changes: 1 addition & 6 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T],
get_ip(), get_open_port())

worker = cls(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand Down
9 changes: 1 addition & 8 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args,
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = EncoderDecoderModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
vllm_config=engine_config,
is_driver_worker=True,
)
return model_runner
Expand Down
10 changes: 1 addition & 9 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
observability_config=engine_config.observability_config,
vllm_config=engine_config,
is_driver_worker=True,
)
return model_runner
Expand Down
7 changes: 1 addition & 6 deletions tests/worker/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@ def test_gpu_memory_profiling():
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand Down
7 changes: 1 addition & 6 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@ def test_swap() -> None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand Down
24 changes: 9 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
import json
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)

Expand Down Expand Up @@ -1941,9 +1941,9 @@ def __post_init__(self):
f"installed. Original error:\n{otel_import_error_traceback}")


@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
@dataclass
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""

Expand All @@ -1953,11 +1953,11 @@ class EngineConfig:
scheduler_config: SchedulerConfig
device_config: DeviceConfig
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None
observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None

def __post_init__(self):
"""Verify configs are valid & consistent with each other.
Expand All @@ -1975,9 +1975,3 @@ def __post_init__(self):
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)

def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
13 changes: 7 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
Expand Down Expand Up @@ -955,7 +956,7 @@ def create_load_config(self) -> LoadConfig:
ignore_patterns=self.ignore_patterns,
)

def create_engine_config(self) -> EngineConfig:
def create_engine_config(self) -> VllmConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
Expand Down Expand Up @@ -1167,7 +1168,7 @@ def create_engine_config(self) -> EngineConfig:
or "all" in detailed_trace_modules,
)

return EngineConfig(
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from weakref import ReferenceType

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
Expand Down Expand Up @@ -604,7 +604,7 @@ def __del__(self):

@classmethod
def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
Expand Down Expand Up @@ -663,7 +663,7 @@ def _get_executor_cls(
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from typing_extensions import TypeIs, TypeVar

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -219,7 +220,7 @@ def validate_outputs(

def __init__(
self,
vllm_config: EngineConfig,
vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
Expand Down Expand Up @@ -500,7 +501,7 @@ def _initialize_kv_caches(self) -> None:

@classmethod
def _get_executor_cls(cls,
engine_config: EngineConfig) -> Type[ExecutorBase]:
engine_config: VllmConfig) -> Type[ExecutorBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from zmq.asyncio import Socket

from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy
"""

def __init__(self, ipc_path: str, engine_config: EngineConfig,
def __init__(self, ipc_path: str, engine_config: VllmConfig,
engine_pid: int):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
Expand Down
9 changes: 1 addition & 8 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,11 @@ def _create_worker(
assert self.distributed_init_method is not None

kwargs = dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=rank == 0,
)
wrapper.init_worker(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple

from vllm.config import EngineConfig
from vllm.config import VllmConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -20,7 +20,7 @@ class ExecutorBase(ABC):

def __init__(
self,
vllm_config: EngineConfig,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down
11 changes: 1 addition & 10 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,12 @@ def _get_worker_kwargs(
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
)

def _get_worker_module_and_class(
Expand Down
6 changes: 1 addition & 5 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ def _init_worker(self):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = NeuronWorker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method)
Expand Down
8 changes: 1 addition & 7 deletions vllm/executor/openvino_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,10 @@ def _init_worker(self):
get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
Expand Down
7 changes: 1 addition & 6 deletions vllm/executor/tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ def _get_worker_kwargs(
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
Expand Down
Loading

0 comments on commit e893795

Please sign in to comment.