Skip to content

Commit

Permalink
Support managed jobs dashboard (#46)
Browse files Browse the repository at this point in the history
* [perf] use uv for venv creation and pip install (#4414)

* Revert "remove `uv` from runtime setup due to azure installation issue (#4401)"

This reverts commit 0b20d56.

* on azure, use --prerelease=allow to install azure-cli

* use uv venv --seed

* fix backwards compatibility

* really fix backwards compatibility

* use uv to set up controller dependencies

* fix python 3.8

* lint

* add missing file

* update comment

* split out azure-cli dep

* fix lint for dependencies

* use runpy.run_path rather than modifying sys.path

* fix cloud dependency installation commands

* lint

* Update sky/utils/controller_utils.py

Co-authored-by: Zhanghao Wu <[email protected]>

---------

Co-authored-by: Zhanghao Wu <[email protected]>

* [Minor] README updates. (#4436)

* [Minor] README touches.

* update

* update

* basic impl

* fixes

* start dashboard when restarting the controller

* Fix dashboard support

* Fix

* format

* format

* format

* format

* format

* address comments

* format

* fix

---------

Co-authored-by: Christopher Cooper <[email protected]>
Co-authored-by: Zongheng Yang <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2024
1 parent 7b24232 commit b46a994
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 73 deletions.
71 changes: 11 additions & 60 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@
import pathlib
import shlex
import shutil
import signal
import subprocess
import sys
import textwrap
import time
import traceback
import typing
from typing import Any, Dict, List, Optional, Tuple, Union
import webbrowser

import click
import colorama
Expand Down Expand Up @@ -76,6 +74,7 @@
from sky.utils import common_utils
from sky.utils import controller_utils
from sky.utils import dag_utils
from sky.utils import env_options
from sky.utils import log_utils
from sky.utils import registry
from sky.utils import resources_utils
Expand Down Expand Up @@ -1400,8 +1399,12 @@ def _get_managed_jobs(
f'Details: {common_utils.format_exception(e, use_bracket=True)}'
)
except Exception as e: # pylint: disable=broad-except
msg = ('Failed to query managed jobs: '
f'{common_utils.format_exception(e, use_bracket=True)}')
msg = ''
if env_options.Options.SHOW_DEBUG_INFO.get():
msg += traceback.format_exc()
msg += '\n'
msg += ('Failed to query managed jobs: '
f'{common_utils.format_exception(e, use_bracket=True)}')
else:
max_jobs_to_show = (_NUM_MANAGED_JOBS_TO_SHOW_IN_STATUS
if limit_num_jobs_to_show else None)
Expand Down Expand Up @@ -4071,62 +4074,10 @@ def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,


@jobs.command('dashboard', cls=_DocumentedCodeCommand)
@click.option(
'--port',
'-p',
default=None,
type=int,
required=False,
help=('Local port to use for the dashboard. If None, a free port is '
'automatically chosen.'))
@usage_lib.entrypoint
def jobs_dashboard(port: Optional[int]):
"""Opens a dashboard for managed jobs (needs controller to be UP)."""
# TODO(zongheng): ideally, the controller/dashboard server should expose the
# API perhaps via REST. Then here we would (1) not have to use SSH to try to
# see if the controller is UP first, which is slow; (2) not have to run SSH
# port forwarding first (we'd just launch a local dashboard which would make
# REST API calls to the controller dashboard server).
click.secho('Checking if jobs controller is up...', fg='yellow')
hint = ('Dashboard is not available if jobs controller is not up. Run a '
'managed job first.')
backend_utils.is_controller_accessible(
controller=controller_utils.Controllers.JOBS_CONTROLLER,
stopped_message=hint,
non_existent_message=hint,
exit_if_not_accessible=True)

# SSH forward a free local port to remote's dashboard port.
remote_port = constants.SPOT_DASHBOARD_REMOTE_PORT
if port is None:
free_port = common_utils.find_free_port(remote_port)
else:
free_port = port
ssh_command = (
f'ssh -qNL {free_port}:localhost:{remote_port} '
f'{controller_utils.Controllers.JOBS_CONTROLLER.value.cluster_name}')
click.echo('Forwarding port: ', nl=False)
click.secho(f'{ssh_command}', dim=True)

with subprocess.Popen(ssh_command, shell=True,
start_new_session=True) as ssh_process:
time.sleep(3) # Added delay for ssh_command to initialize.
webbrowser.open(f'http://localhost:{free_port}')
click.secho(
f'Dashboard is now available at: http://127.0.0.1:{free_port}',
fg='green')
try:
ssh_process.wait()
except KeyboardInterrupt:
# When user presses Ctrl-C in terminal, exits the previous ssh
# command so that <free local port> is freed up.
try:
os.killpg(os.getpgid(ssh_process.pid), signal.SIGTERM)
except ProcessLookupError:
# This happens if jobs controller is auto-stopped.
pass
finally:
click.echo('Exiting.')
def jobs_dashboard():
"""Opens a dashboard for managed jobs."""
managed_jobs.dashboard()


# TODO(zhwu): Backward compatibility for the old `sky spot launch` command.
Expand Down
2 changes: 2 additions & 0 deletions sky/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib

from sky.jobs.api.sdk import cancel
from sky.jobs.api.sdk import dashboard
from sky.jobs.api.sdk import launch
from sky.jobs.api.sdk import queue
from sky.jobs.api.sdk import tail_logs
Expand Down Expand Up @@ -30,6 +31,7 @@
'launch',
'queue',
'tail_logs',
'dashboard',
# utils
'ManagedJobCodeGen',
'format_job_table',
Expand Down
72 changes: 71 additions & 1 deletion sky/jobs/api/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""SDK functions for managed jobs."""
import os
import signal
import subprocess
import tempfile
import time
import typing
from typing import Any, Dict, List, Optional, Tuple, Union
import uuid
Expand All @@ -22,6 +25,7 @@
from sky.skylet import constants as skylet_constants
from sky.usage import usage_lib
from sky.utils import admin_policy_utils
from sky.utils import command_runner
from sky.utils import common_utils
from sky.utils import controller_utils
from sky.utils import dag_utils
Expand All @@ -35,6 +39,8 @@
import sky
from sky.backends import cloud_vm_ray_backend

logger = sky_logging.init_logger(__name__)


@timeline.event
@usage_lib.entrypoint
Expand Down Expand Up @@ -126,6 +132,7 @@ def launch(
'remote_user_config_path': remote_user_config_path,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
**controller_utils.shared_controller_vars_to_fill(
controller_utils.Controllers.JOBS_CONTROLLER,
remote_user_config_path=remote_user_config_path,
Expand Down Expand Up @@ -261,7 +268,20 @@ def _maybe_restart_controller(
rich_utils.force_update_status(
ux_utils.spinner_message(f'{spinner_message} - restarting '
'controller'))
handle = core.start(jobs_controller_type.value.cluster_name)
handle = core.start(cluster_name=jobs_controller_type.value.cluster_name)
# Make sure the dashboard is running when the controller is restarted.
# We should not directly use execution.launch() and have the dashboard cmd
# in the task setup because since we are using detached_setup, it will
# become a job on controller which messes up the job IDs (we assume the
# job ID in controller's job queue is consistent with managed job IDs).
logger.info('Starting dashboard...')
runner = handle.get_command_runners()[0]
runner.run(
f'export '
f'{skylet_constants.USER_ID_ENV_VAR}={common_utils.get_user_hash()!r}; '
f'{managed_job_constants.DASHBOARD_SETUP_CMD}',
stream_logs=True,
)
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))

Expand Down Expand Up @@ -431,3 +451,53 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
job_name=name,
follow=follow,
controller=controller)


def start_dashboard_forwarding(refresh: bool = False) -> Tuple[int, int]:
"""Opens a dashboard for managed jobs (needs controller to be UP)."""
# TODO(zongheng): ideally, the controller/dashboard server should expose the
# API perhaps via REST. Then here we would (1) not have to use SSH to try to
# see if the controller is UP first, which is slow; (2) not have to run SSH
# port forwarding first (we'd just launch a local dashboard which would make
# REST API calls to the controller dashboard server).
logger.info('Starting dashboard')
hint = ('Dashboard is not available if jobs controller is not up. Run '
'a managed job first or run: sky jobs queue --refresh')
handle = _maybe_restart_controller(
refresh=refresh,
stopped_message=hint,
spinner_message='Checking jobs controller')

# SSH forward a free local port to remote's dashboard port.
remote_port = skylet_constants.SPOT_DASHBOARD_REMOTE_PORT
free_port = common_utils.find_free_port(remote_port)
runner = handle.get_command_runners()[0]
ssh_command = ' '.join(
runner.ssh_base_command(ssh_mode=command_runner.SshMode.INTERACTIVE,
port_forward=[(free_port, remote_port)],
connect_timeout=1))
ssh_command = (
f'{ssh_command} '
f'> ~/sky_logs/api_server/dashboard-{common_utils.get_user_hash()}.log '
'2>&1')
logger.info(f'Forwarding port: {colorama.Style.DIM}{ssh_command}'
f'{colorama.Style.RESET_ALL}')

ssh_process = subprocess.Popen(ssh_command,
shell=True,
start_new_session=True)
time.sleep(3) # Added delay for ssh_command to initialize.
logger.info(f'{colorama.Fore.GREEN}Dashboard is now available at: '
f'http://127.0.0.1:{free_port}{colorama.Style.RESET_ALL}')

return free_port, ssh_process.pid


def stop_dashboard_forwarding(pid: int) -> None:
# Exit the ssh command when the context manager is closed.
try:
os.killpg(os.getpgid(pid), signal.SIGTERM)
except ProcessLookupError:
# This happens if jobs controller is auto-stopped.
pass
logger.info('Forwarding port closed. Exiting.')
64 changes: 64 additions & 0 deletions sky/jobs/api/dashboard_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Persistent dashboard sessions."""
import pathlib
from typing import Tuple

import filelock

from sky.utils import db_utils


def create_dashboard_table(cursor, conn):
cursor.execute("""\
CREATE TABLE IF NOT EXISTS dashboard_sessions (
user_hash TEXT PRIMARY KEY,
port INTEGER,
pid INTEGER)""")
conn.commit()


def _get_db_path() -> str:
path = pathlib.Path('~/.sky/dashboard/sessions.db')
path = path.expanduser().absolute()
path.parent.mkdir(parents=True, exist_ok=True)
return str(path)


DB_PATH = _get_db_path()
db_utils.SQLiteConn(DB_PATH, create_dashboard_table)
LOCK_FILE_PATH = '~/.sky/dashboard/sessions-{user_hash}.lock'


def get_dashboard_session(user_hash: str) -> Tuple[int, int]:
"""Get the port and pid of the dashboard session for the user."""
with db_utils.safe_cursor(DB_PATH) as cursor:
cursor.execute(
'SELECT port, pid FROM dashboard_sessions WHERE user_hash=?',
(user_hash,))
result = cursor.fetchone()
if result is None:
return 0, 0
return result


def add_dashboard_session(user_hash: str, port: int, pid: int) -> None:
"""Add a dashboard session for the user."""
with db_utils.safe_cursor(DB_PATH) as cursor:
cursor.execute(
'INSERT OR REPLACE INTO dashboard_sessions (user_hash, port, pid) '
'VALUES (?, ?, ?)', (user_hash, port, pid))


def remove_dashboard_session(user_hash: str) -> None:
"""Remove the dashboard session for the user."""
with db_utils.safe_cursor(DB_PATH) as cursor:
cursor.execute('DELETE FROM dashboard_sessions WHERE user_hash=?',
(user_hash,))
lock_path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
lock_path.unlink(missing_ok=True)


def get_dashboard_lock_for_user(user_hash: str) -> filelock.FileLock:
path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
path = path.expanduser().absolute()
path.parent.mkdir(parents=True, exist_ok=True)
return filelock.FileLock(path)
84 changes: 84 additions & 0 deletions sky/jobs/api/rest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
"""REST API for managed jobs."""
import os

import fastapi
import httpx

from sky import sky_logging
from sky.api.requests import executor
from sky.api.requests import payloads
from sky.api.requests import requests
from sky.jobs.api import core
from sky.jobs.api import dashboard_utils
from sky.skylet import constants
from sky.utils import common
from sky.utils import common_utils

logger = sky_logging.init_logger(__name__)

router = fastapi.APIRouter()

Expand Down Expand Up @@ -60,3 +69,78 @@ async def logs(
func=core.tail_logs,
schedule_type=requests.ScheduleType.NON_BLOCKING,
)


@router.get('/dashboard')
async def dashboard(request: fastapi.Request,
user_hash: str) -> fastapi.Response:
# Find the port for the dashboard of the user
os.environ[constants.USER_ID_ENV_VAR] = user_hash
common.reload()
logger.info(f'Starting dashboard for user hash: {user_hash}')

body = payloads.RequestBody()
body.env_vars[constants.USER_ID_ENV_VAR] = user_hash
body.entrypoint_command = 'jobs/dashboard'
body.override_skypilot_config = {}

with dashboard_utils.get_dashboard_lock_for_user(user_hash):
max_retries = 3
for attempt in range(max_retries):
port, pid = dashboard_utils.get_dashboard_session(user_hash)
if port == 0 or attempt > 0:
# Let the client know that we are waiting for starting the
# dashboard.
try:
port, pid = core.start_dashboard_forwarding()
except Exception as e: # pylint: disable=broad-except
# We catch all exceptions to gracefully handle unknown
# errors and raise an HTTPException to the client.
msg = (
'Dashboard failed to start: '
f'{common_utils.format_exception(e, use_bracket=True)}')
logger.error(msg)
raise fastapi.HTTPException(status_code=503, detail=msg)
dashboard_utils.add_dashboard_session(user_hash, port, pid)

# Assuming the dashboard is forwarded to localhost on the API server
dashboard_url = f'http://localhost:{port}'
try:
# Ping the dashboard to check if it's still running
async with httpx.AsyncClient() as client:
response = await client.request('GET',
dashboard_url,
timeout=1)
break # Connection successful, proceed with the request
except Exception as e: # pylint: disable=broad-except
# We catch all exceptions to gracefully handle unknown
# errors and retry or raise an HTTPException to the client.
msg = (
f'Dashboard connection attempt {attempt + 1} failed with '
f'{common_utils.format_exception(e, use_bracket=True)}')
logger.info(msg)
if attempt == max_retries - 1:
raise fastapi.HTTPException(status_code=503, detail=msg)

# Create a client session to forward the request
try:
async with httpx.AsyncClient() as client:
# Make the request and get the response
response = await client.request(
method='GET',
url=f'{dashboard_url}',
headers=request.headers.raw,
)

# Create a new response with the content already read
content = await response.aread()
return fastapi.Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get('content-type'))
except Exception as e:
msg = (f'Failed to forward request to dashboard: '
f'{common_utils.format_exception(e, use_bracket=True)}')
logger.error(msg)
raise fastapi.HTTPException(status_code=502, detail=msg)
Loading

0 comments on commit b46a994

Please sign in to comment.