Skip to content

Commit

Permalink
Fix up torchrun; TP = 8 still a little wonky
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Apr 10, 2024
1 parent 0ca1c60 commit 93ff699
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def from_engine_args(
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif parallel_config.worker_use_torchrun:
elif engine_config.parallel_config.worker_use_torchrun:
from vllm.executor.torchrun_gpu_executor import TorchrunGPUExecutor
executor_class = TorchrunGPUExecutor
else:
Expand Down
9 changes: 6 additions & 3 deletions vllm/executor/torchrun_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Dict, List, Optional

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
Expand Down Expand Up @@ -32,12 +33,13 @@ def __init__(
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig]
) -> None:
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.is_driver_worker = self.local_rank == 0
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, lora_config,
vision_language_config)
vision_language_config, speculative_config)

def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
Expand All @@ -54,11 +56,12 @@ def _init_worker(self):
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
local_rank=self.local_rank,
rank=self.local_rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
vision_language_config=self.vision_language_config,
is_driver_worker=self.is_driver_worker,
)
self.driver_worker.init_device()
Expand Down
27 changes: 12 additions & 15 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ def init_worker_distributed_environment(
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
if not parallel_config.worker_use_torchrun:
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
else:
init_distributed_environment(parallel_config.world_size, -1,
"env://", local_rank)

if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
Expand All @@ -281,19 +285,12 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
if parallel_config.worker_use_torchrun:
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
init_method="env://",
)
else:
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
local_rank=local_rank,
rank=rank,
init_method=distributed_init_method,
)
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
local_rank=local_rank,
rank=rank,
init_method=distributed_init_method,
)

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand Down

0 comments on commit 93ff699

Please sign in to comment.