Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Jul 31, 2024
1 parent bd27519 commit 11d4de5
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 71 deletions.
8 changes: 4 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from transformers import PreTrainedTokenizer

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
Expand Down Expand Up @@ -924,7 +924,7 @@ async def get_model_config(self) -> ModelConfig:
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()

async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
Expand All @@ -948,7 +948,7 @@ async def get_scheduler_config(self) -> SchedulerConfig:
)
else:
return self.engine.get_scheduler_config()

async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
Expand Down
11 changes: 5 additions & 6 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
_init_tokenizer_from_configs)
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, _init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
Expand Down Expand Up @@ -486,8 +486,7 @@ def _init_tokenizer(self) -> BaseTokenizerGroup:
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config)
)
enable_lora=bool(self.lora_config))

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
Expand Down Expand Up @@ -751,12 +750,12 @@ def get_model_config(self) -> ModelConfig:

def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.get_parallel_config
return self.parallel_config

def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config

def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
Expand Down
7 changes: 2 additions & 5 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
from transformers import AutoTokenizer

import vllm.envs as envs
from vllm.config import ModelConfig
Expand Down Expand Up @@ -116,10 +115,8 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]:

## Then build the client for the backend process
# TODO: figure out a way around passing the tokenizer
backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(
args.model),
port=port)
await backend.connect_to_server()
backend = RPCClient(port)
await backend.setup()

try:
yield backend
Expand Down
86 changes: 47 additions & 39 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,46 @@
import zmq
import zmq.asyncio

from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig,
LoRAConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (
RPC_REQUEST_TYPE, VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR,
RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest)
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import _init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer_group import (
_init_tokenizer_from_configs)


class RPCClient:

def __init__(self, port: int):
self.context = zmq.asyncio.Context()
self.path = f"tcp://localhost:{port}"

async def setup(self):
"""Setup the client before it starts sending server requests."""

# Wait until server is ready.
await self.wait_for_server()

# Get the configs.
self.model_config = await self.get_model_config()
self.decoding_config = await self.get_decoding_config()
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()

# Create the tokenizer group.
self.tokenizer_group = _init_tokenizer_from_configs(
# Note: this is a hack until we fully
self.tokenizer = _init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self.get_scheduler_config),
parallel_config=(await self.get_parallel_config()),
enable_lora=bool(await self.get_lora_config),
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)


def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()
Expand All @@ -53,15 +57,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)

# Ping RPCServer with GET_MODEL_CONFIG request.
# Ping RPCServer with a request.
await socket.send(pickle.dumps(request))

# Await the MODEL_CONFIG from the Server.
# Await the data from the Server.
data = pickle.loads(await socket.recv())

if not isinstance(data, expected_type):
socket.close()
raise ValueError(error_message)
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
socket.close()
raise ValueError(error_message)

socket.close()

Expand Down Expand Up @@ -90,7 +97,13 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
return response

async def get_tokenizer(self, lora_request: LoRARequest):
await self.tokenizer.get_lora_tokenizer_async(lora_request)
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_decoding_config(self):
return self.decoding_config

async def get_model_config(self):
return self.model_config

async def is_tracing_enabled(self):
# TODO: what is this?
Expand All @@ -103,50 +116,45 @@ async def wait_for_server(self):
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")

async def get_model_config(self) -> ModelConfig:
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server"
)
error_message="Could not get ModelConfig from RPC Server")

async def get_decoding_config(self):
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=ModelConfig,
error_message="Could not get DecodingConfig from RPC Server"
)
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")

async def get_parallel_config(self):
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server"
)

async def get_scheduler_config(self):
expected_type=ParallelConfig,
error_message="Could not get ModelConfig from RPC Server")

async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server"
)
error_message="Could not get SchedulerConfig from RPC Server")

async def get_lora_config(self):
async def _get_lora_config_rpc(self):
"""Get LoRAConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server"
)
error_message="Could not get LoRAConfig from RPC Server")

async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
Expand Down
9 changes: 3 additions & 6 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def cleanup(self):
self.socket.close()
self.context.destroy()



async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
Expand All @@ -54,9 +52,8 @@ async def get_decoding_config(self, identity):
await self.socket.send_multipart(
[identity,
pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)])

async def get_lora_config(self, identity):
"""Send the LoRAConfig"""
lora_config = await self.engine.get_lora_config()

await self.socket.send_multipart(
Expand All @@ -69,15 +66,15 @@ async def get_scheduler_config(self, identity):

await self.socket.send_multipart(
[identity,
pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)])
pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)])

async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()

await self.socket.send_multipart(
[identity,
pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)])
pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)])

async def do_log_stats(self, identity):
"""Log stats and confirm success."""
Expand Down
21 changes: 10 additions & 11 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Type

from vllm.config import (TokenizerPoolConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
Expand All @@ -19,15 +19,14 @@ def _init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
init_kwargs = dict(
tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)

init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)

return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)

Expand Down

0 comments on commit 11d4de5

Please sign in to comment.