Skip to content

Commit

Permalink
Early task validation as a separate REST API (#42)
Browse files Browse the repository at this point in the history
* Call task validation at the start of sky launch/exec

* Create a new validate rest API

* add unit test

* remove double validation

* refactor validate()

* format

* refactor

* format

* nit

* address comments

* nit

* prevent catalog pulling on client side
  • Loading branch information
yika-luo authored Dec 9, 2024
1 parent 390eec6 commit 7b24232
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 67 deletions.
5 changes: 5 additions & 0 deletions sky/api/requests/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class CheckBody(RequestBody):
verbose: bool


class ValidateBody(RequestBody):
"""The request body for the validate endpoint."""
dag: str


class OptimizeBody(RequestBody):
"""The request body for the optimize endpoint."""
dag: str
Expand Down
21 changes: 21 additions & 0 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pathlib
import shutil
import sys
import tempfile
import time
from typing import AsyncGenerator, Deque, List, Optional
import uuid
Expand Down Expand Up @@ -37,6 +38,7 @@
from sky.skylet import constants
from sky.utils import common as common_lib
from sky.utils import common_utils
from sky.utils import dag_utils
from sky.utils import message_utils
from sky.utils import rich_utils
from sky.utils import status_lib
Expand Down Expand Up @@ -212,6 +214,25 @@ async def list_accelerator_counts(
)


@app.post('/validate')
async def validate(validate_body: payloads.ValidateBody):
# TODO(SKY-1035): validate if existing cluster satisfies the requested
# resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus
logger.info(f'Validating tasks: {validate_body.dag}')
with tempfile.NamedTemporaryFile(mode='w') as f:
f.write(validate_body.dag)
f.flush()
dag = dag_utils.load_chain_dag_from_yaml(f.name)
for task in dag.tasks:
# Will validate workdir and file_mounts in the backend, as those need
# to be validated after the files are uploaded to the SkyPilot server
# with `upload_mounts_to_api_server`.
task.validate_name()
task.validate_run()
for r in task.resources:
r.validate()


@app.post('/optimize')
async def optimize(optimize_body: payloads.OptimizeBody,
request: fastapi.Request):
Expand Down
24 changes: 23 additions & 1 deletion sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,27 @@ def optimize(
return api_common.get_request_id(response)


@usage_lib.entrypoint
@api_common.check_health
@annotations.public_api
def validate(dag: 'sky.Dag') -> str:
"""Validate the tasks.
The file paths (workdir and file_mounts) are validated on the client side
while the rest (e.g. resource) are validated on server side.
"""
for task in dag.tasks:
task.validate_workdir()
task.validate_file_mounts()
with tempfile.NamedTemporaryFile(mode='r') as f:
dag_utils.dump_chain_dag_to_yaml(dag, f.name)
dag_str = f.read()
body = payloads.ValidateBody(dag=dag_str)
response = requests.post(f'{api_common.get_server_url()}/validate',
json=json.loads(body.model_dump_json()))
return api_common.get_request_id(response)


@usage_lib.entrypoint
@api_common.check_health
@annotations.public_api
Expand Down Expand Up @@ -210,9 +231,9 @@ def launch(
# task, _ = backend_utils.check_can_clone_disk_and_override_task(
# clone_disk_from, cluster_name, task)
dag = dag_utils.convert_entrypoint_to_dag(task)
validate(dag)

confirm_shown = False

if need_confirmation:
cluster_status = None
request_id = status([cluster_name])
Expand Down Expand Up @@ -289,6 +310,7 @@ def exec( # pylint: disable=redefined-builtin
) -> str:
"""Execute a task."""
dag = dag_utils.convert_entrypoint_to_dag(task)
validate(dag)
dag = api_common.upload_mounts_to_api_server(dag, workdir_only=True)
with tempfile.NamedTemporaryFile(mode='r') as f:
dag_utils.dump_chain_dag_to_yaml(dag, f.name)
Expand Down
16 changes: 8 additions & 8 deletions sky/data/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,15 @@ def get_excluded_files(src_dir_path: str) -> List[str]:
skyignore_path = os.path.join(expand_src_dir_path,
constants.SKY_IGNORE_FILE)
if os.path.exists(skyignore_path):
logger.info(f' {colorama.Style.DIM}'
f'Excluded files to sync to cluster based on '
f'{constants.SKY_IGNORE_FILE}.'
f'{colorama.Style.RESET_ALL}')
logger.debug(f' {colorama.Style.DIM}'
f'Excluded files to sync to cluster based on '
f'{constants.SKY_IGNORE_FILE}.'
f'{colorama.Style.RESET_ALL}')
return get_excluded_files_from_skyignore(src_dir_path)
logger.info(f' {colorama.Style.DIM}'
f'Excluded files to sync to cluster based on '
f'{constants.GIT_IGNORE_FILE}.'
f'{colorama.Style.RESET_ALL}')
logger.debug(f' {colorama.Style.DIM}'
f'Excluded files to sync to cluster based on '
f'{constants.GIT_IGNORE_FILE}.'
f'{colorama.Style.RESET_ALL}')
return get_excluded_files_from_gitignore(src_dir_path)


Expand Down
1 change: 0 additions & 1 deletion sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _execute(
"""

dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
dag.validate()
for task in dag.tasks:
if task.storage_mounts is not None:
for storage in task.storage_mounts.values():
Expand Down
1 change: 1 addition & 0 deletions sky/jobs/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def launch(
from sky.api import sdk # pylint: disable=import-outside-toplevel

dag = dag_utils.convert_entrypoint_to_dag(task)
sdk.validate(dag)
if need_confirmation:
request_id = sdk.optimize(dag)
sdk.stream_and_get(request_id)
Expand Down
3 changes: 0 additions & 3 deletions sky/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ def optimize(dag: 'dag_lib.Dag',
exceptions.NoCloudAccessError: if no public clouds are enabled.
"""
with rich_utils.safe_status(ux_utils.spinner_message('Optimizing')):
# TODO: should validate the dag here for faster failure
# of invalid task configs.
_check_specified_clouds(dag)

# This function is effectful: mutates every node in 'dag' by setting
# node.best_resources if it is None.
Optimizer._add_dummy_source_sink_nodes(dag)
Expand Down
5 changes: 3 additions & 2 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def __init__(
self._set_accelerators(accelerators, accelerator_args)

def validate(self):
# TODO: move these out of init to prevent repeated calls.
self._try_validate_and_set_region_zone()
self._try_validate_instance_type()
self._try_validate_cpus_mem()
Expand Down Expand Up @@ -1466,9 +1465,11 @@ def add_if_not_none(key, value):

add_if_not_none('cloud', str(self.cloud))
add_if_not_none('instance_type', self.instance_type)
# TODO(SKY-1048): do not call self.cpus or self.accelerators
# to prevent catalog pulling on client side.
add_if_not_none('cpus', self._cpus)
add_if_not_none('memory', self.memory)
add_if_not_none('accelerators', self.accelerators)
add_if_not_none('accelerators', self._accelerators)
add_if_not_none('accelerator_args', self.accelerator_args)

if self._use_spot_specified:
Expand Down
2 changes: 2 additions & 0 deletions sky/serve/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def up(task: Union['sky.Task', 'sky.Dag'],
# This is to avoid circular import.
from sky.api import sdk # pylint: disable=import-outside-toplevel
dag = dag_utils.convert_entrypoint_to_dag(task)
sdk.validate(dag)
request_id = sdk.optimize(dag)
sdk.stream_and_get(request_id)
if need_confirmation:
Expand Down Expand Up @@ -60,6 +61,7 @@ def update(task: Union['sky.Task', 'sky.Dag'],
# This is to avoid circular import.
from sky.api import sdk # pylint: disable=import-outside-toplevel
dag = dag_utils.convert_entrypoint_to_dag(task)
sdk.validate(dag)
request_id = sdk.optimize(dag)
sdk.stream_and_get(request_id)
if need_confirmation:
Expand Down
119 changes: 67 additions & 52 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,22 @@ def __init__(

def validate(self):
"""Checks if the Task fields are valid."""
self.validate_name()
self.validate_run()
self.validate_workdir()
self.validate_file_mounts()
for r in self.resources:
r.validate()

def validate_name(self):
"""Validates if the task name is valid."""
if not _is_valid_name(self.name):
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Invalid task name {self.name}. Valid name: '
f'{_VALID_NAME_DESCR}')

# Check self.run
def validate_run(self):
"""Validates if the run command is valid."""
if callable(self.run):
run_sig = inspect.signature(self.run)
# Check that run is a function with 2 arguments.
Expand Down Expand Up @@ -336,19 +346,63 @@ def validate(self):
f'a command generator ({CommandGen}). '
f'Got {type(self.run)}')

# Workdir.
if self.workdir is not None:
full_workdir = os.path.expanduser(self.workdir)
if not os.path.isdir(full_workdir):
# Symlink to a dir is legal (isdir() follows symlinks).
def validate_file_mounts(self):
"""Validates if file_mounts paths are valid.
Note: if this function is called on a remote SkyPilot server,
it must be after the client side has sync-ed all files to the
remote server.
"""
if self.file_mounts is None:
return
for target, source in self.file_mounts.items():
if target.endswith('/') or source.endswith('/'):
with ux_utils.print_exception_no_traceback():
raise ValueError(
'File mount paths cannot end with a slash '
'(try "/mydir: /mydir" or "/myfile: /myfile"). '
f'Found: target={target} source={source}')
if data_utils.is_cloud_store_url(target):
with ux_utils.print_exception_no_traceback():
raise ValueError(
'File mount destination paths cannot be cloud storage')
if not data_utils.is_cloud_store_url(source):
self.file_mounts[target] = os.path.expanduser(source)
if not os.path.exists(os.path.expanduser(
source)) and not source.startswith('skypilot:'):
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'File mount source {source!r} does not exist '
'locally. To fix: check if it exists, and correct '
'the path.')
# TODO(zhwu): /home/username/sky_workdir as the target path need
# to be filtered out as well.
if (target == constants.SKY_REMOTE_WORKDIR and
self.workdir is not None):
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Workdir must exist and must be a directory (or '
f'a symlink to a directory). {self.workdir} not found.')
f'Cannot use {constants.SKY_REMOTE_WORKDIR!r} as a '
'destination path of a file mount, as it will be used '
'by the workdir. If uploading a file/folder to the '
'workdir is needed, please specify the full path to '
'the file/folder.')

# Resources.
for r in self.resources:
r.validate()
def validate_workdir(self):
"""Validates if workdir path is valid.
Note: if this function is called on a remote SkyPilot server,
it must be after the client side has sync-ed all files to the
remote server.
"""
if self.workdir is None:
return
full_workdir = os.path.expanduser(self.workdir)
if not os.path.isdir(full_workdir):
# Symlink to a dir is legal (isdir() follows symlinks).
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Workdir must be a valid directory (or '
f'a symlink to a directory). {self.workdir} not found.')

@staticmethod
def from_yaml_config(
Expand Down Expand Up @@ -740,46 +794,7 @@ def set_file_mounts(self, file_mounts: Optional[Dict[str, str]]) -> 'Task':
Returns:
self: the current task, with file mounts set.
Raises:
ValueError: if input paths are invalid.
"""
if file_mounts is None:
self.file_mounts = None
return self
for target, source in file_mounts.items():
if target.endswith('/') or source.endswith('/'):
with ux_utils.print_exception_no_traceback():
raise ValueError(
'File mount paths cannot end with a slash '
'(try "/mydir: /mydir" or "/myfile: /myfile"). '
f'Found: target={target} source={source}')
if data_utils.is_cloud_store_url(target):
with ux_utils.print_exception_no_traceback():
raise ValueError(
'File mount destination paths cannot be cloud storage')
# if not data_utils.is_cloud_store_url(source):
# file_mounts[target] = os.path.expanduser(source)
# if (not os.path.exists(
# os.path.abspath(os.path.expanduser(source))) and
# not source.startswith('skypilot:')):
# with ux_utils.print_exception_no_traceback():
# raise ValueError(
# f'File mount source {source!r} does not exist '
# 'locally. To fix: check if it exists, and correct '
# 'the path.')
# TODO(zhwu): /home/username/sky_workdir as the target path need
# to be filtered out as well.
if (target == constants.SKY_REMOTE_WORKDIR and
self.workdir is not None):
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Cannot use {constants.SKY_REMOTE_WORKDIR!r} as a '
'destination path of a file mount, as it will be used '
'by the workdir. If uploading a file/folder to the '
'workdir is needed, please specify the full path to '
'the file/folder.')

self.file_mounts = file_mounts
return self

Expand Down Expand Up @@ -815,8 +830,8 @@ def update_file_mounts(self, file_mounts: Dict[str, str]) -> 'Task':
self.file_mounts = {}
assert self.file_mounts is not None
self.file_mounts.update(file_mounts)
# For validation logic:
return self.set_file_mounts(self.file_mounts)
self.validate_file_mounts()
return self

def set_storage_mounts(
self,
Expand Down
Loading

0 comments on commit 7b24232

Please sign in to comment.