Skip to content

Commit

Permalink
Merge pull request instructlab#2114 from makelinux/ensure_server
Browse files Browse the repository at this point in the history
refactor: move ensure_server() to vllm.py
  • Loading branch information
mergify[bot] authored Aug 28, 2024
2 parents f64fb3a + ead1d4e commit 2f9cacd
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 79 deletions.
71 changes: 1 addition & 70 deletions src/instructlab/model/backends/backends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from time import monotonic, sleep, time
from time import monotonic, sleep
from types import FrameType
from typing import Optional, Tuple
import abc
Expand All @@ -26,9 +26,7 @@
import uvicorn

# Local
from ...client import check_api_base
from ...configuration import _serve as serve_config
from ...configuration import get_api_base
from ...utils import split_hostport
from .common import CHAT_TEMPLATE_AUTO, LLAMA_CPP, VLLM

Expand Down Expand Up @@ -410,73 +408,6 @@ def wait_for_stable_vram_cuda(timeout: int) -> Tuple[bool, bool]:
logger.debug("Could not free cuda cache: %s", e)


def ensure_server(
backend: str,
api_base: str,
http_client=None,
host="localhost",
port=8000,
background=True,
foreground_allowed=False,
server_process_func=None,
max_startup_attempts=None,
) -> Tuple[
Optional[multiprocessing.Process], Optional[subprocess.Popen], Optional[str]
]:
"""Checks if server is running, if not starts one as a subprocess. Returns the server process
and the URL where it's available."""

logger.info(f"Trying to connect to model server at {api_base}")
if check_api_base(api_base, http_client):
return (None, None, api_base)
port = free_tcp_ipv4_port(host)
logger.debug(f"Using available port {port} for temporary model serving.")

host_port = f"{host}:{port}"
temp_api_base = get_api_base(host_port)
vllm_server_process = None

if backend == VLLM:
# TODO: resolve how the hostname is getting passed around the class and this function
vllm_server_process = server_process_func(port, background)
logger.info("Starting a temporary vLLM server at %s", temp_api_base)
count = 0
# Each call to check_api_base takes >2s + 2s sleep
# Default to 120 if not specified (~8 mins of wait time)
vllm_startup_max_attempts = max_startup_attempts or 120
start_time_secs = time()
while count < vllm_startup_max_attempts:
count += 1
# Check if the process is still alive
if vllm_server_process.poll():
if foreground_allowed and background:
raise ServerException(
"vLLM failed to start. Retry with --enable-serving-output to learn more about the failure."
)
raise ServerException("vLLM failed to start.")
logger.info(
"Waiting for the vLLM server to start at %s, this might take a moment... Attempt: %s/%s",
temp_api_base,
count,
vllm_startup_max_attempts,
)
if check_api_base(temp_api_base, http_client):
logger.info("vLLM engine successfully started at %s", temp_api_base)
break
if count == vllm_startup_max_attempts:
logger.info(
"Gave up waiting for vLLM server to start at %s after %s attempts",
temp_api_base,
vllm_startup_max_attempts,
)
duration = round(time() - start_time_secs, 1)
shutdown_process(vllm_server_process, 20)
# pylint: disable=raise-missing-from
raise ServerException(f"vLLM failed to start up in {duration} seconds")
sleep(2)
return (None, vllm_server_process, temp_api_base)


def free_tcp_ipv4_port(host: str) -> int:
"""Ask the OS for a random, ephemeral, and bindable TCP/IPv4 port
Expand Down
67 changes: 59 additions & 8 deletions src/instructlab/model/backends/vllm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from typing import Optional, Tuple
import json
import logging
import os
Expand All @@ -15,12 +16,13 @@
import httpx

# Local
from ...client import check_api_base
from ...configuration import get_api_base
from .backends import (
BackendServer,
Closeable,
ServerException,
ensure_server,
free_tcp_ipv4_port,
safe_close_all,
shutdown_process,
)
Expand Down Expand Up @@ -93,6 +95,61 @@ def create_server_process(self, port: int, background: bool) -> subprocess.Popen
self.register_resources(files)
return server_process

def _ensure_server(
self,
http_client=None,
background=True,
foreground_allowed=False,
) -> Tuple[Optional[subprocess.Popen], Optional[str]]:
"""Checks if server is running, if not starts one as a subprocess. Returns the server process
and the URL where it's available."""

logger.info(f"Trying to connect to model server at {self.api_base}")
if check_api_base(self.api_base, http_client):
return (None, self.api_base)
port = free_tcp_ipv4_port(self.host)
logger.debug(f"Using available port {port} for temporary model serving.")

host_port = f"{self.host}:{port}"
temp_api_base = get_api_base(host_port)
vllm_server_process = self.create_server_process(port, background)
logger.info("Starting a temporary vLLM server at %s", temp_api_base)
count = 0
# Each call to check_api_base takes >2s + 2s sleep
# Default to 120 if not specified (~8 mins of wait time)
vllm_startup_max_attempts = self.max_startup_attempts or 120
start_time_secs = time.time()
while count < vllm_startup_max_attempts:
count += 1
# Check if the process is still alive
if vllm_server_process.poll():
if foreground_allowed and background:
raise ServerException(
"vLLM failed to start. Retry with --enable-serving-output to learn more about the failure."
)
raise ServerException("vLLM failed to start.")
logger.info(
"Waiting for the vLLM server to start at %s, this might take a moment... Attempt: %s/%s",
temp_api_base,
count,
vllm_startup_max_attempts,
)
if check_api_base(temp_api_base, http_client):
logger.info("vLLM engine successfully started at %s", temp_api_base)
break
if count == vllm_startup_max_attempts:
logger.info(
"Gave up waiting for vLLM server to start at %s after %s attempts",
temp_api_base,
vllm_startup_max_attempts,
)
duration = round(time.time() - start_time_secs, 1)
shutdown_process(vllm_server_process, 20)
# pylint: disable=raise-missing-from
raise ServerException(f"vLLM failed to start up in {duration} seconds")
time.sleep(2)
return (vllm_server_process, temp_api_base)

def run_detached(
self,
http_client: httpx.Client | None = None,
Expand All @@ -102,16 +159,10 @@ def run_detached(
) -> str:
for i in range(max_startup_retries + 1):
try:
_, vllm_server_process, api_base = ensure_server(
backend=VLLM,
api_base=self.api_base,
vllm_server_process, api_base = self._ensure_server(
http_client=http_client,
host=self.host,
port=self.port,
background=background,
foreground_allowed=foreground_allowed,
server_process_func=self.create_server_process,
max_startup_attempts=self.max_startup_attempts,
)
self.process = vllm_server_process or self.process
self.api_base = api_base or self.api_base
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setup_gpus_config(section_path="serve", gpus=None, tps=None, vllm_args=lambd
return _CFG_FILE_NAME


@mock.patch("instructlab.model.backends.backends.check_api_base", return_value=False)
@mock.patch("instructlab.model.backends.vllm.check_api_base", return_value=False)
# ^ mimic server *not* running already
@mock.patch(
"instructlab.model.backends.backends.determine_backend",
Expand Down

0 comments on commit 2f9cacd

Please sign in to comment.