Skip to content

Commit

Permalink
More efficient request process handling (#49)
Browse files Browse the repository at this point in the history
* More efficient process handling

* format

* avoid sleep

* fix

* fix

* remove logger

* Add comments

* format

* Use executor instead

* format

* Add TODO

* Fix timeline

* fix comment

* fix
  • Loading branch information
Michaelvll authored Dec 8, 2024
1 parent 5046dbb commit 390eec6
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 47 deletions.
100 changes: 65 additions & 35 deletions sky/api/requests/executor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Executor for the requests."""
import concurrent.futures
import enum
import functools
import multiprocessing
import os
import queue as queue_lib
import signal
import sys
import time
import traceback
Expand All @@ -21,6 +23,7 @@
from sky.usage import usage_lib
from sky.utils import common
from sky.utils import common_utils
from sky.utils import timeline
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -126,9 +129,14 @@ def _get_queue(schedule_type: requests.ScheduleType) -> RequestQueue:
return RequestQueue(schedule_type, backend=queue_backend)


def _wrapper(request_id: str, ignore_return_value: bool):
def _wrapper(request_id: str, ignore_return_value: bool) -> None:
"""Wrapper for a request task."""

def sigterm_handler(signum, frame):
raise KeyboardInterrupt

signal.signal(signal.SIGTERM, sigterm_handler)

def redirect_output(file):
"""Redirect stdout and stderr to the log file."""
fd = file.fileno() # Get the file descriptor from the file object
Expand Down Expand Up @@ -177,7 +185,10 @@ def restore_output(original_stdout, original_stderr):
with skypilot_config.override_skypilot_config(
request_body.override_skypilot_config):
return_value = func(**request_body.to_kwargs())
except Exception as e: # pylint: disable=broad-except
except KeyboardInterrupt:
logger.info(f'Request {request_id} aborted by user')
return
except (Exception, SystemExit) as e: # pylint: disable=broad-except
with ux_utils.enable_traceback():
stacktrace = traceback.format_exc()
setattr(e, 'stacktrace', stacktrace)
Expand All @@ -188,7 +199,7 @@ def restore_output(original_stdout, original_stderr):
request_task.set_error(e)
restore_output(original_stdout, original_stderr)
logger.info(f'Request {request_id} failed due to {e}')
return None
return
else:
with requests.update_request(request_id) as request_task:
assert request_task is not None, request_id
Expand All @@ -197,7 +208,10 @@ def restore_output(original_stdout, original_stderr):
request_task.set_return_value(return_value)
restore_output(original_stdout, original_stderr)
logger.info(f'Request {request_id} finished')
return return_value
finally:
# We need to call the save_timeline() since atexit will not be
# triggered as multiple requests can be sharing the same process.
timeline.save_timeline()


def schedule_request(
Expand Down Expand Up @@ -230,38 +244,54 @@ def request_worker(worker_id: int, schedule_type: requests.ScheduleType):
logger.info(f'Request worker {worker_id} -- started with pid '
f'{multiprocessing.current_process().pid}')
queue = _get_queue(schedule_type)
while True:
request = queue.get()
if request is None:
time.sleep(0.1)
continue
request_id, ignore_return_value = request
request = requests.get_request(request_id)
if request.status == requests.RequestStatus.ABORTED:
continue
logger.info(
f'Request worker {worker_id} -- running request: {request_id}')
# Start additional process to run the request, so that it can be aborted
# when requested by a user.
process = multiprocessing.Process(target=_wrapper,
args=(request_id,
ignore_return_value))
process.start()

if schedule_type == requests.ScheduleType.BLOCKING:
# Wait for the request to finish.
try:
process.join()
except Exception as e: # pylint: disable=broad-except
logger.error(
f'Request worker {worker_id} -- request {request_id} '
f'failed: {e}')
if schedule_type == requests.ScheduleType.BLOCKING:
max_worker_size = 1
else:
max_worker_size = None
# Use concurrent.futures.ProcessPoolExecutor instead of multiprocessing.Pool
# because the former is more efficient with the support of lazy creation of
# worker processes.
# We use executor instead of individual multiprocessing.Process to avoid
# the overhead of forking a new process for each request, which can be about
# 1s delay.
with concurrent.futures.ProcessPoolExecutor(
max_workers=max_worker_size) as executor:
while True:
request = queue.get()
if request is None:
time.sleep(0.1)
continue
request_id, ignore_return_value = request
request = requests.get_request(request_id)
if request.status == requests.RequestStatus.ABORTED:
continue
logger.info(
f'Request worker {worker_id} -- request {request_id} finished')
else:
# Non-blocking requests are handled by the non-blocking worker.
logger.info(
f'Request worker {worker_id} -- request {request_id} submitted')
f'Request worker {worker_id} -- submitted request: {request_id}'
)
# Start additional process to run the request, so that it can be
# aborted when requested by a user.
# TODO(zhwu): since the executor is reusing the request process,
# multiple requests can share the same process pid, which may cause
# issues with SkyPilot core functions if they rely on the exit of
# the process, such as subprocess_daemon.py.
future = executor.submit(_wrapper, request_id, ignore_return_value)

if schedule_type == requests.ScheduleType.BLOCKING:
# Wait for the request to finish.
try:
future.result(timeout=None)
except Exception as e: # pylint: disable=broad-except
logger.error(
f'Request worker {worker_id} -- request {request_id} '
f'failed: {e}')
logger.info(
f'Request worker {worker_id} -- request {request_id} '
'finished')
else:
# Non-blocking requests are handled by the non-blocking worker.
logger.info(
f'Request worker {worker_id} -- request {request_id} '
'submitted')


def start(num_queue_workers: int = 1) -> List[multiprocessing.Process]:
Expand Down
10 changes: 7 additions & 3 deletions sky/api/requests/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import pathlib
import signal
import sqlite3
from typing import Any, Callable, Dict, List, Optional, Tuple

Expand All @@ -19,7 +20,6 @@
from sky.api.requests.serializers import encoders
from sky.utils import common_utils
from sky.utils import db_utils
from sky.utils import subprocess_utils

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -249,8 +249,12 @@ def kill_requests(request_ids: List[str]):
continue
if request_record.pid is not None:
logger.debug(f'Killing request process {request_record.pid}')
subprocess_utils.kill_children_processes(
parent_pids=[request_record.pid], force=True)
# Use SIGTERM instead of SIGKILL:
# - The executor can handle SIGTERM gracefully
# - After SIGTERM, the executor can reuse the request process
# for other requests, avoiding the overhead of forking a new
# process for each request.
os.kill(request_record.pid, signal.SIGTERM)
request_record.status = RequestStatus.ABORTED


Expand Down
20 changes: 16 additions & 4 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,10 @@ async def get(request_id: str) -> requests_lib.RequestPayload:
status_code=404, detail=f'Request {request_id} not found')
if request_task.status > requests_lib.RequestStatus.RUNNING:
return request_task.encode()
await asyncio.sleep(1)
# Sleep 0 to yield, so other coroutines can run. This busy waiting
# loop is performance critical for short-running requests, so we do
# not want to yield too long.
await asyncio.sleep(0)


async def _yield_log_file_with_payloads_skipped(
Expand Down Expand Up @@ -629,7 +632,10 @@ async def log_streamer(request_id: Optional[str],
yield status_msg.update(
f'[dim]Waiting for {request_task.name} request: '
f'{request_id}[/dim]')
await asyncio.sleep(1)
# Sleep 0 to yield, so other coroutines can run. This busy waiting
# loop is performance critical for short-running requests, so we do
# not want to yield too long.
await asyncio.sleep(0)
request_task = requests_lib.get_request(request_id)
if show_request_waiting_spinner:
yield status_msg.stop()
Expand All @@ -653,14 +659,20 @@ async def log_streamer(request_id: Optional[str],
request_task = requests_lib.get_request(request_id)
if request_task.status > requests_lib.RequestStatus.RUNNING:
break
await asyncio.sleep(1)
# Sleep 0 to yield, so other coroutines can run. This busy
# waiting loop is performance critical for short-running
# requests, so we do not want to yield too long.
await asyncio.sleep(0)
continue
line_str = line.decode('utf-8')
if plain_logs:
is_payload, line_str = message_utils.decode_payload(
line_str, raise_for_mismatch=False)
if is_payload:
await asyncio.sleep(0) # Allow other tasks to run
# Sleep 0 to yield, so other coroutines can run. This busy
# waiting loop is performance critical for short-running
# requests, so we do not want to yield too long.
await asyncio.sleep(0)
continue
line_str = common_utils.remove_color(line_str)
yield line_str
Expand Down
10 changes: 7 additions & 3 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ def get_clusters(
bright = colorama.Style.BRIGHT
reset = colorama.Style.RESET_ALL

def _update_record_with_credentials(
def _update_record_with_credentials_and_resources_str(
record: Optional[Dict[str, Any]]) -> None:
"""Add the credentials to the record.
Expand All @@ -2397,9 +2397,12 @@ def _update_record_with_credentials(
handle = record['handle']
if handle is None:
return
record['resources_str'] = resources_utils.get_readable_resources_repr(
handle)
credentials = ssh_credential_from_yaml(handle.cluster_yaml,
handle.docker_user,
handle.ssh_user)

if not credentials:
return
ssh_private_key_path = credentials.get('ssh_private_key', None)
Expand Down Expand Up @@ -2433,9 +2436,10 @@ def _update_record_with_credentials(
clusters_str = ', '.join(not_exist_cluster_names)
logger.info(f'Cluster(s) not found: {bright}{clusters_str}{reset}.')
records = new_records

# Add auth_config to the records
for record in records:
_update_record_with_credentials(record)
_update_record_with_credentials_and_resources_str(record)

if refresh == common.StatusRefreshMode.NONE:
return records
Expand All @@ -2459,7 +2463,7 @@ def _refresh_cluster(cluster_name):
cluster_name,
force_refresh_statuses=force_refresh_statuses,
acquire_per_cluster_status_lock=True)
_update_record_with_credentials(record)
_update_record_with_credentials_and_resources_str(record)
except (exceptions.ClusterStatusFetchingError,
exceptions.CloudUserIdentityError,
exceptions.ClusterOwnerIdentityMismatchError) as e:
Expand Down
2 changes: 2 additions & 0 deletions sky/utils/cli_utils/status_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def _get_status_colored(cluster_record: _ClusterRecord) -> str:


def _get_resources(cluster_record: _ClusterRecord) -> str:
if 'resources_str' in cluster_record:
return cluster_record['resources_str']
handle = cluster_record['handle']
if isinstance(handle, backends.LocalDockerResourceHandle):
resources_str = 'docker'
Expand Down
7 changes: 5 additions & 2 deletions sky/utils/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def wrapper(*args, **kwargs):
return wrapper


def _save_timeline(file_path: str):
def save_timeline():
file_path = os.environ.get('SKYPILOT_TIMELINE_FILE_PATH')
if not file_path:
return
json_output = {
'traceEvents': _events,
'displayTimeUnit': 'ms',
Expand All @@ -130,4 +133,4 @@ def _save_timeline(file_path: str):


if os.environ.get('SKYPILOT_TIMELINE_FILE_PATH'):
atexit.register(_save_timeline, os.environ['SKYPILOT_TIMELINE_FILE_PATH'])
atexit.register(save_timeline)

0 comments on commit 390eec6

Please sign in to comment.