diff --git a/.buildkite/generate_pipeline.py b/.buildkite/generate_pipeline.py
new file mode 100644
index 00000000000..99f29ee258a
--- /dev/null
+++ b/.buildkite/generate_pipeline.py
@@ -0,0 +1,252 @@
+"""
+This script generates a Buildkite pipeline from test files.
+
+The script will generate two pipelines:
+
+tests/smoke_tests
+├── test_*.py -> release pipeline
+├── test_quick_tests_core.py -> run quick tests on PR before merging
+
+run `PYTHONPATH=$(pwd)/tests:$PYTHONPATH python .buildkite/generate_pipeline.py`
+to generate the pipeline for testing. The CI will run this script as a pre-step,
+and use the generated pipeline to run the tests.
+
+1. release pipeline, which runs all smoke tests by default, generates all
+ smoke tests for all clouds.
+2. pre-merge pipeline, which generates all smoke tests for all clouds,
+ author should specify which clouds to run by setting env in the step.
+
+We only have credentials for aws/azure/gcp/kubernetes(CLOUD_QUEUE_MAP and
+SERVE_CLOUD_QUEUE_MAP) now, smoke tests for those clouds are generated, other
+clouds are not supported yet, smoke tests for those clouds are not generated.
+"""
+
+import ast
+import os
+import random
+from typing import Any, Dict, List, Optional
+
+from conftest import cloud_to_pytest_keyword
+from conftest import default_clouds_to_run
+import yaml
+
+DEFAULT_CLOUDS_TO_RUN = default_clouds_to_run
+PYTEST_TO_CLOUD_KEYWORD = {v: k for k, v in cloud_to_pytest_keyword.items()}
+
+QUEUE_GENERIC_CLOUD = 'generic_cloud'
+QUEUE_GENERIC_CLOUD_SERVE = 'generic_cloud_serve'
+QUEUE_KUBERNETES = 'kubernetes'
+QUEUE_KUBERNETES_SERVE = 'kubernetes_serve'
+# Only aws, gcp, azure, and kubernetes are supported for now.
+# Other clouds do not have credentials.
+CLOUD_QUEUE_MAP = {
+ 'aws': QUEUE_GENERIC_CLOUD,
+ 'gcp': QUEUE_GENERIC_CLOUD,
+ 'azure': QUEUE_GENERIC_CLOUD,
+ 'kubernetes': QUEUE_KUBERNETES
+}
+# Serve tests runs long, and different test steps usually requires locks.
+# Its highly likely to fail if multiple serve tests are running concurrently.
+# So we use a different queue that runs only one concurrent test at a time.
+SERVE_CLOUD_QUEUE_MAP = {
+ 'aws': QUEUE_GENERIC_CLOUD_SERVE,
+ 'gcp': QUEUE_GENERIC_CLOUD_SERVE,
+ 'azure': QUEUE_GENERIC_CLOUD_SERVE,
+ 'kubernetes': QUEUE_KUBERNETES_SERVE
+}
+
+GENERATED_FILE_HEAD = ('# This is an auto-generated Buildkite pipeline by '
+ '.buildkite/generate_pipeline.py, Please do not '
+ 'edit directly.\n')
+
+
+def _get_full_decorator_path(decorator: ast.AST) -> str:
+ """Recursively get the full path of a decorator."""
+ if isinstance(decorator, ast.Attribute):
+ return f'{_get_full_decorator_path(decorator.value)}.{decorator.attr}'
+ elif isinstance(decorator, ast.Name):
+ return decorator.id
+ elif isinstance(decorator, ast.Call):
+ return _get_full_decorator_path(decorator.func)
+ raise ValueError(f'Unknown decorator type: {type(decorator)}')
+
+
+def _extract_marked_tests(file_path: str) -> Dict[str, List[str]]:
+ """Extract test functions and filter clouds using pytest.mark
+ from a Python test file.
+
+ We separate each test_function_{cloud} into different pipeline steps
+ to maximize the parallelism of the tests via the buildkite CI job queue.
+ This allows us to visualize the test results and rerun failures at the
+ granularity of each test_function_{cloud}.
+
+ If we make pytest --serve a job, it could contain dozens of test_functions
+ and run for hours. This makes it hard to visualize the test results and
+ rerun failures. Additionally, the parallelism would be controlled by pytest
+ instead of the buildkite job queue.
+ """
+ with open(file_path, 'r', encoding='utf-8') as file:
+ tree = ast.parse(file.read(), filename=file_path)
+
+ for node in ast.walk(tree):
+ for child in ast.iter_child_nodes(node):
+ setattr(child, 'parent', node)
+
+ function_cloud_map = {}
+ for node in ast.walk(tree):
+ if isinstance(node, ast.FunctionDef) and node.name.startswith('test_'):
+ class_name = None
+ if hasattr(node, 'parent') and isinstance(node.parent,
+ ast.ClassDef):
+ class_name = node.parent.name
+
+ clouds_to_include = []
+ clouds_to_exclude = []
+ is_serve_test = False
+ for decorator in node.decorator_list:
+ if isinstance(decorator, ast.Call):
+ # We only need to consider the decorator with no arguments
+ # to extract clouds.
+ continue
+ full_path = _get_full_decorator_path(decorator)
+ if full_path.startswith('pytest.mark.'):
+ assert isinstance(decorator, ast.Attribute)
+ suffix = decorator.attr
+ if suffix.startswith('no_'):
+ clouds_to_exclude.append(suffix[3:])
+ else:
+ if suffix == 'serve':
+ is_serve_test = True
+ continue
+ if suffix not in PYTEST_TO_CLOUD_KEYWORD:
+ # This mark does not specify a cloud, so we skip it.
+ continue
+ clouds_to_include.append(
+ PYTEST_TO_CLOUD_KEYWORD[suffix])
+ clouds_to_include = (clouds_to_include if clouds_to_include else
+ DEFAULT_CLOUDS_TO_RUN)
+ clouds_to_include = [
+ cloud for cloud in clouds_to_include
+ if cloud not in clouds_to_exclude
+ ]
+ cloud_queue_map = SERVE_CLOUD_QUEUE_MAP if is_serve_test else CLOUD_QUEUE_MAP
+ final_clouds_to_include = [
+ cloud for cloud in clouds_to_include if cloud in cloud_queue_map
+ ]
+ if clouds_to_include and not final_clouds_to_include:
+ print(f'Warning: {file_path}:{node.name} '
+ f'is marked to run on {clouds_to_include}, '
+ f'but we do not have credentials for those clouds. '
+ f'Skipped.')
+ continue
+ if clouds_to_include != final_clouds_to_include:
+ excluded_clouds = set(clouds_to_include) - set(
+ final_clouds_to_include)
+ print(
+ f'Warning: {file_path}:{node.name} '
+ f'is marked to run on {clouds_to_include}, '
+ f'but we only have credentials for {final_clouds_to_include}. '
+ f'clouds {excluded_clouds} are skipped.')
+ function_name = (f'{class_name}::{node.name}'
+ if class_name else node.name)
+ function_cloud_map[function_name] = (final_clouds_to_include, [
+ cloud_queue_map[cloud] for cloud in final_clouds_to_include
+ ])
+ return function_cloud_map
+
+
+def _generate_pipeline(test_file: str) -> Dict[str, Any]:
+ """Generate a Buildkite pipeline from test files."""
+ steps = []
+ function_cloud_map = _extract_marked_tests(test_file)
+ for test_function, clouds_and_queues in function_cloud_map.items():
+ for cloud, queue in zip(*clouds_and_queues):
+ step = {
+ 'label': f'{test_function} on {cloud}',
+ 'command': f'pytest {test_file}::{test_function} --{cloud}',
+ 'agents': {
+ # Separate agent pool for each cloud.
+ # Since they require different amount of resources and
+ # concurrency control.
+ 'queue': queue
+ },
+ 'if': f'build.env("{cloud}") == "1"'
+ }
+ steps.append(step)
+ return {'steps': steps}
+
+
+def _dump_pipeline_to_file(yaml_file_path: str,
+ pipelines: List[Dict[str, Any]],
+ extra_env: Optional[Dict[str, str]] = None):
+ default_env = {'LOG_TO_STDOUT': '1', 'PYTHONPATH': '${PYTHONPATH}:$(pwd)'}
+ if extra_env:
+ default_env.update(extra_env)
+ with open(yaml_file_path, 'w', encoding='utf-8') as file:
+ file.write(GENERATED_FILE_HEAD)
+ all_steps = []
+ for pipeline in pipelines:
+ all_steps.extend(pipeline['steps'])
+ # Shuffle the steps to avoid flakyness, consecutive runs of the same
+ # kind of test may fail for requiring locks on the same resources.
+ random.shuffle(all_steps)
+ final_pipeline = {'steps': all_steps, 'env': default_env}
+ yaml.dump(final_pipeline, file, default_flow_style=False)
+
+
+def _convert_release(test_files: List[str]):
+ yaml_file_path = '.buildkite/pipeline_smoke_tests_release.yaml'
+ output_file_pipelines = []
+ for test_file in test_files:
+ print(f'Converting {test_file} to {yaml_file_path}')
+ pipeline = _generate_pipeline(test_file)
+ output_file_pipelines.append(pipeline)
+ print(f'Converted {test_file} to {yaml_file_path}\n\n')
+ # Enable all clouds by default for release pipeline.
+ _dump_pipeline_to_file(yaml_file_path,
+ output_file_pipelines,
+ extra_env={cloud: '1' for cloud in CLOUD_QUEUE_MAP})
+
+
+def _convert_quick_tests_core(test_files: List[str]):
+ yaml_file_path = '.buildkite/pipeline_smoke_tests_quick_tests_core.yaml'
+ output_file_pipelines = []
+ for test_file in test_files:
+ print(f'Converting {test_file} to {yaml_file_path}')
+ # We want enable all clouds by default for each test function
+ # for pre-merge. And let the author controls which clouds
+ # to run by parameter.
+ pipeline = _generate_pipeline(test_file)
+ pipeline['steps'].append({
+ 'label': 'Backward compatibility test',
+ 'command': 'bash tests/backward_compatibility_tests.sh',
+ 'agents': {
+ 'queue': 'back_compat'
+ }
+ })
+ output_file_pipelines.append(pipeline)
+ print(f'Converted {test_file} to {yaml_file_path}\n\n')
+ _dump_pipeline_to_file(yaml_file_path,
+ output_file_pipelines,
+ extra_env={'SKYPILOT_SUPPRESS_SENSITIVE_LOG': '1'})
+
+
+def main():
+ test_files = os.listdir('tests/smoke_tests')
+ release_files = []
+ quick_tests_core_files = []
+ for test_file in test_files:
+ if not test_file.startswith('test_'):
+ continue
+ test_file_path = os.path.join('tests/smoke_tests', test_file)
+ if "test_quick_tests_core" in test_file:
+ quick_tests_core_files.append(test_file_path)
+ else:
+ release_files.append(test_file_path)
+
+ _convert_release(release_files)
+ _convert_quick_tests_core(quick_tests_core_files)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index db40b03b5fa..81f794dac24 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -24,7 +24,6 @@ repos:
args:
- "--sg=build/**" # Matches "${ISORT_YAPF_EXCLUDES[@]}"
- "--sg=sky/skylet/providers/ibm/**"
- files: "^(sky|tests|examples|llm|docs)/.*" # Only match these directories
# Second isort command
- id: isort
name: isort (IBM specific)
@@ -56,8 +55,8 @@ repos:
hooks:
- id: yapf
name: yapf
- exclude: (build/.*|sky/skylet/providers/ibm/.*) # Matches exclusions from the script
- args: ['--recursive', '--parallel'] # Only necessary flags
+ exclude: (sky/skylet/providers/ibm/.*) # Matches exclusions from the script
+ args: ['--recursive', '--parallel', '--in-place'] # Only necessary flags
additional_dependencies: [toml==0.10.2]
- repo: https://github.com/pylint-dev/pylint
diff --git a/README.md b/README.md
index f29b57be9ca..1ed99325df5 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@
-
+
@@ -43,7 +43,7 @@
Archived
- [Jul 2024] [**Finetune**](./llm/llama-3_1-finetuning/) and [**serve**](./llm/llama-3_1/) **Llama 3.1** on your infra
-- [Apr 2024] Serve and finetune [**Llama 3**](https://skypilot.readthedocs.io/en/latest/gallery/llms/llama-3.html) on any cloud or Kubernetes: [**example**](./llm/llama-3/)
+- [Apr 2024] Serve and finetune [**Llama 3**](https://docs.skypilot.co/en/latest/gallery/llms/llama-3.html) on any cloud or Kubernetes: [**example**](./llm/llama-3/)
- [Mar 2024] Serve and deploy [**Databricks DBRX**](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) on your infra: [**example**](./llm/dbrx/)
- [Feb 2024] Speed up your LLM deployments with [**SGLang**](https://github.com/sgl-project/sglang) for 5x throughput on SkyServe: [**example**](./llm/sglang/)
- [Dec 2023] Using [**LoRAX**](https://github.com/predibase/lorax) to serve 1000s of finetuned LLMs on a single instance in the cloud: [**example**](./llm/lorax/)
@@ -60,17 +60,17 @@
SkyPilot is a framework for running AI and batch workloads on any infra, offering unified execution, high cost savings, and high GPU availability.
SkyPilot **abstracts away infra burdens**:
-- Launch [dev clusters](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html), [jobs](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html), and [serving](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html) on any infra
+- Launch [dev clusters](https://docs.skypilot.co/en/latest/examples/interactive-development.html), [jobs](https://docs.skypilot.co/en/latest/examples/managed-jobs.html), and [serving](https://docs.skypilot.co/en/latest/serving/sky-serve.html) on any infra
- Easy job management: queue, run, and auto-recover many jobs
SkyPilot **supports multiple clusters, clouds, and hardware** ([the Sky](https://arxiv.org/abs/2205.07147)):
- Bring your reserved GPUs, Kubernetes clusters, or 12+ clouds
-- [Flexible provisioning](https://skypilot.readthedocs.io/en/latest/examples/auto-failover.html) of GPUs, TPUs, CPUs, with auto-retry
+- [Flexible provisioning](https://docs.skypilot.co/en/latest/examples/auto-failover.html) of GPUs, TPUs, CPUs, with auto-retry
SkyPilot **cuts your cloud costs & maximizes GPU availability**:
-* [Autostop](https://skypilot.readthedocs.io/en/latest/reference/auto-stop.html): automatic cleanup of idle resources
-* [Managed Spot](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html): 3-6x cost savings using spot instances, with preemption auto-recovery
-* [Optimizer](https://skypilot.readthedocs.io/en/latest/examples/auto-failover.html): 2x cost savings by auto-picking the cheapest & most available infra
+* [Autostop](https://docs.skypilot.co/en/latest/reference/auto-stop.html): automatic cleanup of idle resources
+* [Managed Spot](https://docs.skypilot.co/en/latest/examples/managed-jobs.html): 3-6x cost savings using spot instances, with preemption auto-recovery
+* [Optimizer](https://docs.skypilot.co/en/latest/examples/auto-failover.html): 2x cost savings by auto-picking the cheapest & most available infra
SkyPilot supports your existing GPU, TPU, and CPU workloads, with no code changes.
@@ -79,13 +79,13 @@ Install with pip:
# Choose your clouds:
pip install -U "skypilot[kubernetes,aws,gcp,azure,oci,lambda,runpod,fluidstack,paperspace,cudo,ibm,scp]"
```
-To get the latest features and fixes, use the nightly build or [install from source](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html):
+To get the latest features and fixes, use the nightly build or [install from source](https://docs.skypilot.co/en/latest/getting-started/installation.html):
```bash
# Choose your clouds:
pip install "skypilot-nightly[kubernetes,aws,gcp,azure,oci,lambda,runpod,fluidstack,paperspace,cudo,ibm,scp]"
```
-[Current supported infra](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html) (Kubernetes; AWS, GCP, Azure, OCI, Lambda Cloud, Fluidstack, RunPod, Cudo, Paperspace, Cloudflare, Samsung, IBM, VMware vSphere):
+[Current supported infra](https://docs.skypilot.co/en/latest/getting-started/installation.html) (Kubernetes; AWS, GCP, Azure, OCI, Lambda Cloud, Fluidstack, RunPod, Cudo, Paperspace, Cloudflare, Samsung, IBM, VMware vSphere):
@@ -95,16 +95,16 @@ pip install "skypilot-nightly[kubernetes,aws,gcp,azure,oci,lambda,runpod,fluidst
## Getting Started
-You can find our documentation [here](https://skypilot.readthedocs.io/en/latest/).
-- [Installation](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)
-- [Quickstart](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html)
-- [CLI reference](https://skypilot.readthedocs.io/en/latest/reference/cli.html)
+You can find our documentation [here](https://docs.skypilot.co/).
+- [Installation](https://docs.skypilot.co/en/latest/getting-started/installation.html)
+- [Quickstart](https://docs.skypilot.co/en/latest/getting-started/quickstart.html)
+- [CLI reference](https://docs.skypilot.co/en/latest/reference/cli.html)
## SkyPilot in 1 Minute
A SkyPilot task specifies: resource requirements, data to be synced, setup commands, and the task commands.
-Once written in this [**unified interface**](https://skypilot.readthedocs.io/en/latest/reference/yaml-spec.html) (YAML or Python API), the task can be launched on any available cloud. This avoids vendor lock-in, and allows easily moving jobs to a different provider.
+Once written in this [**unified interface**](https://docs.skypilot.co/en/latest/reference/yaml-spec.html) (YAML or Python API), the task can be launched on any available cloud. This avoids vendor lock-in, and allows easily moving jobs to a different provider.
Paste the following into a file `my_task.yaml`:
@@ -135,7 +135,7 @@ Prepare the workdir by cloning:
git clone https://github.com/pytorch/examples.git ~/torch_examples
```
-Launch with `sky launch` (note: [access to GPU instances](https://skypilot.readthedocs.io/en/latest/cloud-setup/quota.html) is needed for this example):
+Launch with `sky launch` (note: [access to GPU instances](https://docs.skypilot.co/en/latest/cloud-setup/quota.html) is needed for this example):
```bash
sky launch my_task.yaml
```
@@ -152,10 +152,10 @@ SkyPilot then performs the heavy-lifting for you, including:
-Refer to [Quickstart](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html) to get started with SkyPilot.
+Refer to [Quickstart](https://docs.skypilot.co/en/latest/getting-started/quickstart.html) to get started with SkyPilot.
## More Information
-To learn more, see [Concept: Sky Computing](https://docs.skypilot.co/en/latest/sky-computing.html), [SkyPilot docs](https://skypilot.readthedocs.io/en/latest/), and [SkyPilot blog](https://blog.skypilot.co/).
+To learn more, see [Concept: Sky Computing](https://docs.skypilot.co/en/latest/sky-computing.html), [SkyPilot docs](https://docs.skypilot.co/en/latest/), and [SkyPilot blog](https://blog.skypilot.co/).
Runnable examples:
diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt
index 2ab8ba30887..32e5b9dafb5 100644
--- a/docs/requirements-docs.txt
+++ b/docs/requirements-docs.txt
@@ -11,6 +11,7 @@ sphinx-autobuild==2021.3.14
sphinx-autodoc-typehints==1.25.2
sphinx-book-theme==1.1.0
sphinx-togglebutton==0.3.2
+sphinx-notfound-page==1.0.4
sphinxcontrib-applehelp==1.0.7
sphinxcontrib-devhelp==1.0.5
sphinxcontrib-googleanalytics==0.4
diff --git a/docs/source/_static/SkyPilot_wide_dark.svg b/docs/source/_static/SkyPilot_wide_dark.svg
index 6be00d9e591..cb2f742ab98 100644
--- a/docs/source/_static/SkyPilot_wide_dark.svg
+++ b/docs/source/_static/SkyPilot_wide_dark.svg
@@ -1,64 +1,54 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/source/_static/SkyPilot_wide_light.svg b/docs/source/_static/SkyPilot_wide_light.svg
index 0b2eaae8538..71945c0f927 100644
--- a/docs/source/_static/SkyPilot_wide_light.svg
+++ b/docs/source/_static/SkyPilot_wide_light.svg
@@ -1,64 +1,55 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css
index d5bbdd6cb51..aae9defea90 100644
--- a/docs/source/_static/custom.css
+++ b/docs/source/_static/custom.css
@@ -27,6 +27,7 @@ html[data-theme="light"] {
--pst-color-primary: #176de8;
--pst-color-secondary: var(--pst-color-primary);
--pst-color-text-base: #4c4c4d;
+ --logo-text-color: #0E2E65;
}
html[data-theme="dark"] {
@@ -34,6 +35,7 @@ html[data-theme="dark"] {
--pst-color-primary: #176de8;
--pst-color-secondary: var(--pst-color-primary);
--pst-color-text-base: #d8d8d8;
+ --logo-text-color: #D8D8D8;
.bd-sidebar::-webkit-scrollbar {
width: 6px;
diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js
index 06f492ef7b4..0c127a8fc46 100644
--- a/docs/source/_static/custom.js
+++ b/docs/source/_static/custom.js
@@ -1,16 +1,17 @@
-document.addEventListener('DOMContentLoaded', function () {
- var script = document.createElement('script');
- script.src = 'https://widget.kapa.ai/kapa-widget.bundle.js';
- script.setAttribute('data-website-id', '4223d017-a3d2-4b92-b191-ea4d425a23c3');
- script.setAttribute('data-project-name', 'SkyPilot');
- script.setAttribute('data-project-color', '#4C4C4D');
- script.setAttribute('data-project-logo', 'https://avatars.githubusercontent.com/u/109387420?s=100&v=4');
- script.setAttribute('data-modal-disclaimer', 'Results are automatically generated and may be inaccurate or contain inappropriate information. Do not include any sensitive information in your query.\n**To get further assistance, you can chat directly with the development team** by joining the [SkyPilot Slack](https://slack.skypilot.co/).');
- script.setAttribute('data-modal-title', 'SkyPilot Docs AI - Ask a Question.');
- script.setAttribute('data-button-position-bottom', '100px');
- script.async = true;
- document.head.appendChild(script);
-});
+// As of 2025-01-01, Kapa seems to be having issues loading on some ISPs, including comcast. Uncomment once resolved.
+// document.addEventListener('DOMContentLoaded', function () {
+// var script = document.createElement('script');
+// script.src = 'https://widget.kapa.ai/kapa-widget.bundle.js';
+// script.setAttribute('data-website-id', '4223d017-a3d2-4b92-b191-ea4d425a23c3');
+// script.setAttribute('data-project-name', 'SkyPilot');
+// script.setAttribute('data-project-color', '#4C4C4D');
+// script.setAttribute('data-project-logo', 'https://avatars.githubusercontent.com/u/109387420?s=100&v=4');
+// script.setAttribute('data-modal-disclaimer', 'Results are automatically generated and may be inaccurate or contain inappropriate information. Do not include any sensitive information in your query.\n**To get further assistance, you can chat directly with the development team** by joining the [SkyPilot Slack](https://slack.skypilot.co/).');
+// script.setAttribute('data-modal-title', 'SkyPilot Docs AI - Ask a Question.');
+// script.setAttribute('data-button-position-bottom', '100px');
+// script.async = true;
+// document.head.appendChild(script);
+// });
(function(h,o,t,j,a,r){
h.hj=h.hj||function(){(h.hj.q=h.hj.q||[]).push(arguments)};
@@ -25,11 +26,7 @@ document.addEventListener('DOMContentLoaded', function () {
document.addEventListener('DOMContentLoaded', () => {
// New items:
const newItems = [
- { selector: '.toctree-l1 > a', text: 'Managed Jobs' },
- { selector: '.toctree-l1 > a', text: 'Pixtral (Mistral AI)' },
{ selector: '.toctree-l1 > a', text: 'Many Parallel Jobs' },
- { selector: '.toctree-l1 > a', text: 'Reserved, Capacity Blocks, DWS' },
- { selector: '.toctree-l1 > a', text: 'Llama 3.2 (Meta)' },
{ selector: '.toctree-l1 > a', text: 'Admin Policy Enforcement' },
{ selector: '.toctree-l1 > a', text: 'Using Existing Machines' },
{ selector: '.toctree-l1 > a', text: 'Concept: Sky Computing' },
diff --git a/docs/source/_templates/navbar-skypilot-logo.html b/docs/source/_templates/navbar-skypilot-logo.html
index 0323953acde..1692f1f2a5d 100644
--- a/docs/source/_templates/navbar-skypilot-logo.html
+++ b/docs/source/_templates/navbar-skypilot-logo.html
@@ -9,5 +9,59 @@
{#- Logo HTML and image #}
- {{ theme_logo["svg"] }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/source/conf.py b/docs/source/conf.py
index a8ce3270e88..3c0b62c9947 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -41,6 +41,7 @@
'sphinxemoji.sphinxemoji',
'sphinx_design',
'myst_parser',
+ 'notfound.extension',
]
intersphinx_mapping = {
diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst
index 61c33b5c43e..2cd99b6c24b 100644
--- a/docs/source/examples/managed-jobs.rst
+++ b/docs/source/examples/managed-jobs.rst
@@ -152,6 +152,7 @@ The :code:`MOUNT` mode in :ref:`SkyPilot bucket mounting ` ensures
Note that the application code should save program checkpoints periodically and reload those states when the job is restarted.
This is typically achieved by reloading the latest checkpoint at the beginning of your program.
+
.. _spot-jobs-end-to-end:
An End-to-End Example
@@ -455,6 +456,46 @@ especially useful when there are many in-progress jobs to monitor, which the
terminal-based CLI may need more than one page to display.
+.. _intermediate-bucket:
+
+Intermediate storage for files
+------------------------------
+
+For managed jobs, SkyPilot requires an intermediate bucket to store files used in the task, such as local file mounts, temporary files, and the workdir.
+If you do not configure a bucket, SkyPilot will automatically create a temporary bucket named :code:`skypilot-filemounts-{username}-{run_id}` for each job launch. SkyPilot automatically deletes the bucket after the job completes.
+
+Alternatively, you can pre-provision a bucket and use it as an intermediate for storing file by setting :code:`jobs.bucket` in :code:`~/.sky/config.yaml`:
+
+.. code-block:: yaml
+
+ # ~/.sky/config.yaml
+ jobs:
+ bucket: s3://my-bucket # Supports s3://, gs://, https://.blob.core.windows.net/, r2://, cos:///
+
+
+If you choose to specify a bucket, ensure that the bucket already exists and that you have the necessary permissions.
+
+When using a pre-provisioned intermediate bucket with :code:`jobs.bucket`, SkyPilot creates job-specific directories under the bucket root to store files. They are organized in the following structure:
+
+.. code-block:: text
+
+ # cloud bucket, s3://my-bucket/ for example
+ my-bucket/
+ ├── job-15891b25/ # Job-specific directory
+ │ ├── local-file-mounts/ # Files from local file mounts
+ │ ├── tmp-files/ # Temporary files
+ │ └── workdir/ # Files from workdir
+ └── job-cae228be/ # Another job's directory
+ ├── local-file-mounts/
+ ├── tmp-files/
+ └── workdir/
+
+When using a custom bucket (:code:`jobs.bucket`), the job-specific directories (e.g., :code:`job-15891b25/`) created by SkyPilot are removed when the job completes.
+
+.. tip::
+ Multiple users can share the same intermediate bucket. Each user's jobs will have their own unique job-specific directories, ensuring that files are kept separate and organized.
+
+
Concept: Jobs Controller
------------------------
@@ -499,10 +540,9 @@ To achieve the above, you can specify custom configs in :code:`~/.sky/config.yam
# Specify the disk_size in GB of the jobs controller.
disk_size: 100
-The :code:`resources` field has the same spec as a normal SkyPilot job; see `here `__.
+The :code:`resources` field has the same spec as a normal SkyPilot job; see `here `__.
.. note::
These settings will not take effect if you have an existing controller (either
stopped or live). For them to take effect, tear down the existing controller
first, which requires all in-progress jobs to finish or be canceled.
-
diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst
index deb2307b67b..1d36b5ef6b8 100644
--- a/docs/source/getting-started/installation.rst
+++ b/docs/source/getting-started/installation.rst
@@ -59,6 +59,7 @@ Install SkyPilot using pip:
pip install "skypilot-nightly[runpod]"
pip install "skypilot-nightly[fluidstack]"
pip install "skypilot-nightly[paperspace]"
+ pip install "skypilot-nightly[do]"
pip install "skypilot-nightly[cudo]"
pip install "skypilot-nightly[ibm]"
pip install "skypilot-nightly[scp]"
diff --git a/docs/source/getting-started/tutorial.rst b/docs/source/getting-started/tutorial.rst
index 175f1391a6d..9b067be2876 100644
--- a/docs/source/getting-started/tutorial.rst
+++ b/docs/source/getting-started/tutorial.rst
@@ -2,19 +2,20 @@
Tutorial: AI Training
======================
-This example uses SkyPilot to train a Transformer-based language model from HuggingFace.
+This example uses SkyPilot to train a GPT-like model (inspired by Karpathy's `minGPT `_) with Distributed Data Parallel (DDP) in PyTorch.
-First, define a :ref:`task YAML ` with the resource requirements, the setup commands,
+We define a :ref:`task YAML ` with the resource requirements, the setup commands,
and the commands to run:
.. code-block:: yaml
- # dnn.yaml
+ # train.yaml
- name: huggingface
+ name: minGPT-ddp
resources:
- accelerators: V100:4
+ cpus: 4+
+ accelerators: L4:4 # Or A100:8, H100:8
# Optional: upload a working directory to remote ~/sky_workdir.
# Commands in "setup" and "run" will be executed under it.
@@ -30,26 +31,21 @@ and the commands to run:
# ~/.netrc: ~/.netrc
setup: |
- set -e # Exit if any command failed.
- git clone https://github.com/huggingface/transformers/ || true
- cd transformers
- pip install .
- cd examples/pytorch/text-classification
- pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
+ git clone --depth 1 https://github.com/pytorch/examples || true
+ cd examples
+ git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp
+ # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
+ uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113
run: |
- set -e # Exit if any command failed.
- cd transformers/examples/pytorch/text-classification
- python run_glue.py \
- --model_name_or_path bert-base-cased \
- --dataset_name imdb \
- --do_train \
- --max_seq_length 128 \
- --per_device_train_batch_size 32 \
- --learning_rate 2e-5 \
- --max_steps 50 \
- --output_dir /tmp/imdb/ --overwrite_output_dir \
- --fp16
+ cd examples/mingpt
+ export LOGLEVEL=INFO
+
+ echo "Starting minGPT-ddp training"
+
+ torchrun \
+ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
+ main.py
.. tip::
@@ -57,11 +53,15 @@ and the commands to run:
learn about how to use them to mount local dirs/files or object store buckets
(S3, GCS, R2) into your cluster, see :ref:`sync-code-artifacts`.
+.. tip::
+
+ The ``SKYPILOT_NUM_GPUS_PER_NODE`` environment variable is automatically set by SkyPilot to the number of GPUs per node. See :ref:`env-vars` for more.
+
Then, launch training:
.. code-block:: console
- $ sky launch -c lm-cluster dnn.yaml
+ $ sky launch -c mingpt train.yaml
This will provision the cheapest cluster with the required resources, execute the setup
commands, then execute the run commands.
diff --git a/docs/source/images/skypilot-wide-dark-1k.png b/docs/source/images/skypilot-wide-dark-1k.png
index 057b6a0ae97..b6ed7caec6f 100644
Binary files a/docs/source/images/skypilot-wide-dark-1k.png and b/docs/source/images/skypilot-wide-dark-1k.png differ
diff --git a/docs/source/images/skypilot-wide-light-1k.png b/docs/source/images/skypilot-wide-light-1k.png
index 7af87ad2864..178c6553dd3 100644
Binary files a/docs/source/images/skypilot-wide-light-1k.png and b/docs/source/images/skypilot-wide-light-1k.png differ
diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst
index 10a6b1bc90f..a0a335077f4 100644
--- a/docs/source/reference/config.rst
+++ b/docs/source/reference/config.rst
@@ -31,8 +31,12 @@ Available fields and semantics:
#
# These take effects only when a managed jobs controller does not already exist.
#
- # Ref: https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#customizing-job-controller-resources
+ # Ref: https://docs.skypilot.co/en/latest/examples/managed-jobs.html#customizing-job-controller-resources
jobs:
+ # Bucket to store managed jobs mount files and tmp files. Bucket must already exist.
+ # Optional. If not set, SkyPilot will create a new bucket for each managed job launch.
+ # Supports s3://, gs://, https://.blob.core.windows.net/, r2://, cos:///
+ bucket: s3://my-bucket/
controller:
resources: # same spec as 'resources' in a task YAML
cloud: gcp
@@ -487,13 +491,13 @@ Available fields and semantics:
# This must be either: 'loadbalancer', 'ingress' or 'podip'.
#
# loadbalancer: Creates services of type `LoadBalancer` to expose ports.
- # See https://skypilot.readthedocs.io/en/latest/reference/kubernetes/kubernetes-setup.html#loadbalancer-service.
+ # See https://docs.skypilot.co/en/latest/reference/kubernetes/kubernetes-setup.html#loadbalancer-service.
# This mode is supported out of the box on most cloud managed Kubernetes
# environments (e.g., GKE, EKS).
#
# ingress: Creates an ingress and a ClusterIP service for each port opened.
# Requires an Nginx ingress controller to be configured on the Kubernetes cluster.
- # Refer to https://skypilot.readthedocs.io/en/latest/reference/kubernetes/kubernetes-setup.html#nginx-ingress
+ # Refer to https://docs.skypilot.co/en/latest/reference/kubernetes/kubernetes-setup.html#nginx-ingress
# for details on deploying the NGINX ingress controller.
#
# podip: Directly returns the IP address of the pod. This mode does not
@@ -522,7 +526,7 @@ Available fields and semantics:
#
# : The name of a service account to use for all Kubernetes pods.
# This service account must exist in the user's namespace and have all
- # necessary permissions. Refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html
+ # necessary permissions. Refer to https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/kubernetes.html
# for details on the roles required by the service account.
#
# Using SERVICE_ACCOUNT or a custom service account only affects Kubernetes
@@ -590,7 +594,7 @@ Available fields and semantics:
# gke: uses cloud.google.com/gke-accelerator label to identify GPUs on nodes
# karpenter: uses karpenter.k8s.aws/instance-gpu-name label to identify GPUs on nodes
# generic: uses skypilot.co/accelerator labels to identify GPUs on nodes
- # Refer to https://skypilot.readthedocs.io/en/latest/reference/kubernetes/kubernetes-setup.html#setting-up-gpu-support
+ # Refer to https://docs.skypilot.co/en/latest/reference/kubernetes/kubernetes-setup.html#setting-up-gpu-support
# for more details on setting up labels for GPU support.
#
# Default: null (no autoscaler, autodetect label format for GPU nodes)
@@ -633,20 +637,30 @@ Available fields and semantics:
# Advanced OCI configurations (optional).
oci:
# A dict mapping region names to region-specific configurations, or
- # `default` for the default configuration.
+ # `default` for the default/global configuration.
default:
- # The OCID of the profile to use for launching instances (optional).
- oci_config_profile: DEFAULT
- # The OCID of the compartment to use for launching instances (optional).
+ # The profile name in ~/.oci/config to use for launching instances. If not
+ # set, the one named DEFAULT will be used (optional).
+ oci_config_profile: SKY_PROVISION_PROFILE
+ # The OCID of the compartment to use for launching instances. If not set,
+ # the root compartment will be used (optional).
compartment_ocid: ocid1.compartment.oc1..aaaaaaaahr7aicqtodxmcfor6pbqn3hvsngpftozyxzqw36gj4kh3w3kkj4q
- # The image tag to use for launching general instances (optional).
- image_tag_general: skypilot:cpu-ubuntu-2004
- # The image tag to use for launching GPU instances (optional).
- image_tag_gpu: skypilot:gpu-ubuntu-2004
-
+ # The default image tag to use for launching general instances (CPU) if the
+ # image_id parameter is not specified. If not set, the default is
+ # skypilot:cpu-ubuntu-2204 (optional).
+ image_tag_general: skypilot:cpu-oraclelinux8
+ # The default image tag to use for launching GPU instances if the image_id
+ # parameter is not specified. If not set, the default is
+ # skypilot:gpu-ubuntu-2204 (optional).
+ image_tag_gpu: skypilot:gpu-oraclelinux8
+
+ # Region-specific configurations
ap-seoul-1:
+ # The OCID of the VCN to use for instances (optional).
+ vcn_ocid: ocid1.vcn.oc1.ap-seoul-1.amaaaaaaak7gbriarkfs2ssus5mh347ktmi3xa72tadajep6asio3ubqgarq
# The OCID of the subnet to use for instances (optional).
vcn_subnet: ocid1.subnet.oc1.ap-seoul-1.aaaaaaaa5c6wndifsij6yfyfehmi3tazn6mvhhiewqmajzcrlryurnl7nuja
us-ashburn-1:
+ vcn_ocid: ocid1.vcn.oc1.ap-seoul-1.amaaaaaaak7gbriarkfs2ssus5mh347ktmi3xa72tadajep6asio3ubqgarq
vcn_subnet: ocid1.subnet.oc1.iad.aaaaaaaafbj7i3aqc4ofjaapa5edakde6g4ea2yaslcsay32cthp7qo55pxa
diff --git a/docs/source/reference/kubernetes/index.rst b/docs/source/reference/kubernetes/index.rst
index 89a57862c88..639b5b633ed 100644
--- a/docs/source/reference/kubernetes/index.rst
+++ b/docs/source/reference/kubernetes/index.rst
@@ -39,7 +39,7 @@ Why use SkyPilot on Kubernetes?
.. grid-item-card:: 🖼 Run popular models on Kubernetes
:text-align: center
- Train and serve `Llama-3 `_, `Mixtral `_, and more on your Kubernetes with ready-to-use recipes from the :ref:`AI gallery `.
+ Train and serve `Llama-3 `_, `Mixtral `_, and more on your Kubernetes with ready-to-use recipes from the :ref:`AI gallery `.
.. tab-item:: For Infrastructure Admins
diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst
index e4bbb2c8915..3323559bb36 100644
--- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst
+++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst
@@ -258,6 +258,67 @@ After launching the cluster with :code:`sky launch -c myclus task.yaml`, you can
To learn more about opening ports in SkyPilot tasks, see :ref:`Opening Ports `.
+Customizing SkyPilot pods
+-------------------------
+
+You can override the pod configuration used by SkyPilot by setting the :code:`pod_config` key in :code:`~/.sky/config.yaml`.
+The value of :code:`pod_config` should be a dictionary that follows the `Kubernetes Pod API `_. This will apply to all pods created by SkyPilot.
+
+For example, to set custom environment variables and use GPUDirect RDMA, you can add the following to your :code:`~/.sky/config.yaml` file:
+
+.. code-block:: yaml
+
+ # ~/.sky/config.yaml
+ kubernetes:
+ pod_config:
+ spec:
+ containers:
+ - env: # Custom environment variables to set in pod
+ - name: MY_ENV_VAR
+ value: MY_ENV_VALUE
+ resources: # Custom resources for GPUDirect RDMA
+ requests:
+ rdma/rdma_shared_device_a: 1
+ limits:
+ rdma/rdma_shared_device_a: 1
+
+
+Similarly, you can attach `Kubernetes volumes `_ (e.g., an `NFS volume `_) directly to your SkyPilot pods:
+
+.. code-block:: yaml
+
+ # ~/.sky/config.yaml
+ kubernetes:
+ pod_config:
+ spec:
+ containers:
+ - volumeMounts: # Custom volume mounts for the pod
+ - mountPath: /data
+ name: nfs-volume
+ volumes:
+ - name: nfs-volume
+ nfs: # Alternatively, use hostPath if your NFS is directly attached to the nodes
+ server: nfs.example.com
+ path: /nfs
+
+
+.. tip::
+
+ As an alternative to setting ``pod_config`` globally, you can also set it on a per-task basis directly in your task YAML with the ``config_overrides`` :ref:`field `.
+
+ .. code-block:: yaml
+
+ # task.yaml
+ run: |
+ python myscript.py
+
+ # Set pod_config for this task
+ experimental:
+ config_overrides:
+ pod_config:
+ ...
+
+
FAQs
----
@@ -293,38 +354,6 @@ FAQs
You can use your existing observability tools to filter resources with the label :code:`parent=skypilot` (:code:`kubectl get pods -l 'parent=skypilot'`). As an example, follow the instructions :ref:`here ` to deploy the Kubernetes Dashboard on your cluster.
-* **How can I specify custom configuration for the pods created by SkyPilot?**
-
- You can override the pod configuration used by SkyPilot by setting the :code:`pod_config` key in :code:`~/.sky/config.yaml`.
- The value of :code:`pod_config` should be a dictionary that follows the `Kubernetes Pod API `_.
-
- For example, to set custom environment variables and attach a volume on your pods, you can add the following to your :code:`~/.sky/config.yaml` file:
-
- .. code-block:: yaml
-
- kubernetes:
- pod_config:
- spec:
- containers:
- - env:
- - name: MY_ENV_VAR
- value: MY_ENV_VALUE
- volumeMounts: # Custom volume mounts for the pod
- - mountPath: /foo
- name: example-volume
- resources: # Custom resource requests and limits
- requests:
- rdma/rdma_shared_device_a: 1
- limits:
- rdma/rdma_shared_device_a: 1
- volumes:
- - name: example-volume
- hostPath:
- path: /tmp
- type: Directory
-
- For more details refer to :ref:`config-yaml`.
-
* **I am using a custom image. How can I speed up the pod startup time?**
You can pre-install SkyPilot dependencies in your custom image to speed up the pod startup time. Simply add these lines at the end of your Dockerfile:
diff --git a/docs/source/reference/storage.rst b/docs/source/reference/storage.rst
index 3c54680e79b..16f87c1ce2f 100644
--- a/docs/source/reference/storage.rst
+++ b/docs/source/reference/storage.rst
@@ -3,7 +3,7 @@
Cloud Object Storage
====================
-SkyPilot tasks can access data from buckets in cloud object storages such as AWS S3, Google Cloud Storage (GCS), Cloudflare R2 or IBM COS.
+SkyPilot tasks can access data from buckets in cloud object storages such as AWS S3, Google Cloud Storage (GCS), Cloudflare R2, OCI Object Storage or IBM COS.
Buckets are made available to each task at a local path on the remote VM, so
the task can access bucket objects as if they were local files.
@@ -28,7 +28,7 @@ Object storages are specified using the :code:`file_mounts` field in a SkyPilot
# Mount an existing S3 bucket
file_mounts:
/my_data:
- source: s3://my-bucket/ # or gs://, https://.blob.core.windows.net/, r2://, cos:///
+ source: s3://my-bucket/ # or gs://, https://.blob.core.windows.net/, r2://, cos:///, oci://
mode: MOUNT # Optional: either MOUNT or COPY. Defaults to MOUNT.
This will `mount `__ the contents of the bucket at ``s3://my-bucket/`` to the remote VM at ``/my_data``.
@@ -45,7 +45,7 @@ Object storages are specified using the :code:`file_mounts` field in a SkyPilot
file_mounts:
/my_data:
name: my-sky-bucket
- store: gcs # Optional: either of s3, gcs, azure, r2, ibm
+ store: gcs # Optional: either of s3, gcs, azure, r2, ibm, oci
SkyPilot will create an empty GCS bucket called ``my-sky-bucket`` and mount it at ``/my_data``.
This bucket can be used to write checkpoints, logs or other outputs directly to the cloud.
@@ -68,7 +68,7 @@ Object storages are specified using the :code:`file_mounts` field in a SkyPilot
/my_data:
name: my-sky-bucket
source: ~/dataset # Optional: path to local data to upload to the bucket
- store: s3 # Optional: either of s3, gcs, azure, r2, ibm
+ store: s3 # Optional: either of s3, gcs, azure, r2, ibm, oci
mode: MOUNT # Optional: either MOUNT or COPY. Defaults to MOUNT.
SkyPilot will create a S3 bucket called ``my-sky-bucket`` and upload the
@@ -290,12 +290,13 @@ Storage YAML reference
- https://.blob.core.windows.net/
- r2://
- cos:///
+ - oci://
If the source is local, data is uploaded to the cloud to an appropriate
- bucket (s3, gcs, azure, r2, or ibm). If source is bucket URI,
+ bucket (s3, gcs, azure, r2, oci, or ibm). If source is bucket URI,
the data is copied or mounted directly (see mode flag below).
- store: str; either of 's3', 'gcs', 'azure', 'r2', 'ibm'
+ store: str; either of 's3', 'gcs', 'azure', 'r2', 'ibm', 'oci'
If you wish to force sky.Storage to be backed by a specific cloud object
storage, you can specify it here. If not specified, SkyPilot chooses the
appropriate object storage based on the source path and task's cloud provider.
diff --git a/docs/source/reference/yaml-spec.rst b/docs/source/reference/yaml-spec.rst
index 455ee5909c9..d2f0506993a 100644
--- a/docs/source/reference/yaml-spec.rst
+++ b/docs/source/reference/yaml-spec.rst
@@ -23,7 +23,7 @@ Available fields:
# which `sky` is called.
#
# To exclude files from syncing, see
- # https://skypilot.readthedocs.io/en/latest/examples/syncing-code-artifacts.html#exclude-uploading-files
+ # https://docs.skypilot.co/en/latest/examples/syncing-code-artifacts.html#exclude-uploading-files
workdir: ~/my-task-code
# Number of nodes (optional; defaults to 1) to launch including the head node.
@@ -176,9 +176,9 @@ Available fields:
# tpu_vm: True # True to use TPU VM (the default); False to use TPU node.
# Custom image id (optional, advanced). The image id used to boot the
- # instances. Only supported for AWS and GCP (for non-docker image). If not
- # specified, SkyPilot will use the default debian-based image suitable for
- # machine learning tasks.
+ # instances. Only supported for AWS, GCP, OCI and IBM (for non-docker image).
+ # If not specified, SkyPilot will use the default debian-based image
+ # suitable for machine learning tasks.
#
# Docker support
# You can specify docker image to use by setting the image_id to
@@ -204,7 +204,7 @@ Available fields:
# image_id:
# us-east-1: ami-0729d913a335efca7
# us-west-2: ami-050814f384259894c
- image_id: ami-0868a20f5a3bf9702
+ #
# GCP
# To find GCP images: https://cloud.google.com/compute/docs/images
# image_id: projects/deeplearning-platform-release/global/images/common-cpu-v20230615-debian-11-py310
@@ -215,6 +215,24 @@ Available fields:
# To find Azure images: https://docs.microsoft.com/en-us/azure/virtual-machines/linux/cli-ps-findimage
# image_id: microsoft-dsvm:ubuntu-2004:2004:21.11.04
#
+ # OCI
+ # To find OCI images: https://docs.oracle.com/en-us/iaas/images
+ # You can choose the image with OS version from the following image tags
+ # provided by SkyPilot:
+ # image_id: skypilot:gpu-ubuntu-2204
+ # image_id: skypilot:gpu-ubuntu-2004
+ # image_id: skypilot:gpu-oraclelinux9
+ # image_id: skypilot:gpu-oraclelinux8
+ # image_id: skypilot:cpu-ubuntu-2204
+ # image_id: skypilot:cpu-ubuntu-2004
+ # image_id: skypilot:cpu-oraclelinux9
+ # image_id: skypilot:cpu-oraclelinux8
+ #
+ # It is also possible to specify your custom image's OCID with OS type,
+ # for example:
+ # image_id: ocid1.image.oc1.us-sanjose-1.aaaaaaaaywwfvy67wwe7f24juvjwhyjn3u7g7s3wzkhduxcbewzaeki2nt5q:oraclelinux
+ # image_id: ocid1.image.oc1.us-sanjose-1.aaaaaaaa5tnuiqevhoyfnaa5pqeiwjv6w5vf6w4q2hpj3atyvu3yd6rhlhyq:ubuntu
+ #
# IBM
# Create a private VPC image and paste its ID in the following format:
# image_id:
@@ -224,6 +242,7 @@ Available fields:
# https://www.ibm.com/cloud/blog/use-ibm-packer-plugin-to-create-custom-images-on-ibm-cloud-vpc-infrastructure
# To use a more limited but easier to manage tool:
# https://github.com/IBM/vpc-img-inst
+ image_id: ami-0868a20f5a3bf9702
# Labels to apply to the instances (optional).
#
@@ -307,7 +326,7 @@ Available fields:
/datasets-storage:
name: sky-dataset # Name of storage, optional when source is bucket URI
source: /local/path/datasets # Source path, can be local or bucket URI. Optional, do not specify to create an empty bucket.
- store: s3 # Could be either 's3', 'gcs', 'azure', 'r2', or 'ibm'; default: None. Optional.
+ store: s3 # Could be either 's3', 'gcs', 'azure', 'r2', 'oci', or 'ibm'; default: None. Optional.
persistent: True # Defaults to True; can be set to false to delete bucket after cluster is downed. Optional.
mode: MOUNT # Either MOUNT or COPY. Defaults to MOUNT. Optional.
@@ -357,7 +376,7 @@ In additional to the above fields, SkyPilot also supports the following experime
#
# The following fields can be overridden. Please refer to docs of Advanced
# Configuration for more details of those fields:
- # https://skypilot.readthedocs.io/en/latest/reference/config.html
+ # https://docs.skypilot.co/en/latest/reference/config.html
config_overrides:
docker:
run_options: ...
diff --git a/docs/source/reservations/existing-machines.rst b/docs/source/reservations/existing-machines.rst
index 10962ecd639..717043bfd25 100644
--- a/docs/source/reservations/existing-machines.rst
+++ b/docs/source/reservations/existing-machines.rst
@@ -42,7 +42,7 @@ Prerequisites
**Local machine (typically your laptop):**
* `kubectl `_
-* `SkyPilot `_
+* `SkyPilot `_
**Remote machines (your cluster, optionally with GPUs):**
diff --git a/docs/source/running-jobs/distributed-jobs.rst b/docs/source/running-jobs/distributed-jobs.rst
index f6c8cba9c9d..7c3421aa276 100644
--- a/docs/source/running-jobs/distributed-jobs.rst
+++ b/docs/source/running-jobs/distributed-jobs.rst
@@ -6,39 +6,40 @@ Distributed Multi-Node Jobs
SkyPilot supports multi-node cluster
provisioning and distributed execution on many nodes.
-For example, here is a simple PyTorch Distributed training example:
+For example, here is a simple example to train a GPT-like model (inspired by Karpathy's `minGPT `_) across 2 nodes with Distributed Data Parallel (DDP) in PyTorch.
.. code-block:: yaml
- :emphasize-lines: 6-6,21-21,23-26
+ :emphasize-lines: 6,19,23-24,26
- name: resnet-distributed-app
+ name: minGPT-ddp
- resources:
- accelerators: A100:8
+ resources:
+ accelerators: A100:8
- num_nodes: 2
+ num_nodes: 2
- setup: |
- pip3 install --upgrade pip
- git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet
- cd pytorch-distributed-resnet
- # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
- pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
- mkdir -p data && mkdir -p saved_models && cd data && \
- wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
- tar -xvzf cifar-10-python.tar.gz
+ setup: |
+ git clone --depth 1 https://github.com/pytorch/examples || true
+ cd examples
+ git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp
+ # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
+ uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113
- run: |
- cd pytorch-distributed-resnet
+ run: |
+ cd examples/mingpt
+ export LOGLEVEL=INFO
+
+ MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ echo "Starting distributed training, head node: $MASTER_ADDR"
- MASTER_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
- torchrun \
+ torchrun \
--nnodes=$SKYPILOT_NUM_NODES \
- --master_addr=$MASTER_ADDR \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
- --node_rank=$SKYPILOT_NODE_RANK \
- --master_port=12375 \
- resnet_ddp.py --num_epochs 20
+ --master_addr=$MASTER_ADDR \
+ --node_rank=${SKYPILOT_NODE_RANK} \
+ --master_port=8008 \
+ main.py
+
In the above,
@@ -55,6 +56,7 @@ In the above,
ulimit -n 65535
+You can find more `distributed training examples `_ (including `using rdvz backend for pytorch `_) in our `GitHub repository `_.
Environment variables
-----------------------------------------
diff --git a/docs/source/serving/sky-serve.rst b/docs/source/serving/sky-serve.rst
index c00fa427bd6..693102c0550 100644
--- a/docs/source/serving/sky-serve.rst
+++ b/docs/source/serving/sky-serve.rst
@@ -242,6 +242,9 @@ Under the hood, :code:`sky serve up`:
#. Meanwhile, the controller provisions replica VMs which later run the services;
#. Once any replica is ready, the requests sent to the Service Endpoint will be distributed to one of the endpoint replicas.
+.. note::
+ SkyServe uses least load load balancing to distribute the traffic to the replicas. It keeps track of the number of requests each replica has handled and routes the next request to the replica with the least load.
+
After the controller is provisioned, you'll see the following in :code:`sky serve status` output:
.. image:: ../images/sky-serve-status-output-provisioning.png
@@ -515,7 +518,7 @@ To achieve the above, you can specify custom configs in :code:`~/.sky/config.yam
# Specify the disk_size in GB of the SkyServe controller.
disk_size: 1024
-The :code:`resources` field has the same spec as a normal SkyPilot job; see `here `__.
+The :code:`resources` field has the same spec as a normal SkyPilot job; see `here `__.
.. note::
These settings will not take effect if you have an existing controller (either
diff --git a/examples/airflow/shared_state/README.md b/examples/airflow/shared_state/README.md
index 5f39471351a..917a45862a7 100644
--- a/examples/airflow/shared_state/README.md
+++ b/examples/airflow/shared_state/README.md
@@ -12,7 +12,7 @@ In this guide, we demonstrate how some simple SkyPilot operations, such as launc
* Airflow installed on a [Kubernetes cluster](https://airflow.apache.org/docs/helm-chart/stable/index.html) or [locally](https://airflow.apache.org/docs/apache-airflow/stable/start.html) (`SequentialExecutor`)
* A Kubernetes cluster to run tasks on. We'll use GKE in this example.
- * You can use our guide on [setting up a Kubernetes cluster](https://skypilot.readthedocs.io/en/latest/reference/kubernetes/kubernetes-setup.html).
+ * You can use our guide on [setting up a Kubernetes cluster](https://docs.skypilot.co/en/latest/reference/kubernetes/kubernetes-setup.html).
* A persistent volume storage class should be available that supports at least `ReadWriteOnce` access mode. GKE has this supported by default.
## Preparing the Kubernetes Cluster
@@ -39,7 +39,7 @@ In this guide, we demonstrate how some simple SkyPilot operations, such as launc
name: sky-airflow-sa
namespace: default
roleRef:
- # For minimal permissions, refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html
+ # For minimal permissions, refer to https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/kubernetes.html
kind: ClusterRole
name: cluster-admin
apiGroup: rbac.authorization.k8s.io
@@ -163,7 +163,7 @@ with DAG(dag_id='sky_k8s_example',
## Tips
1. **Persistent Volume**: If you have many concurrent tasks, you may want to use a storage class that supports [`ReadWriteMany`](https://kubernetes.io/docs/concepts/storage/persistent-volumes/#access-modes) access mode.
-2. **Cloud credentials**: If you wish to run tasks on different clouds, you can configure cloud credentials in Kubernetes secrets and mount them in the Sky pod defined in the DAG. See [SkyPilot docs on setting up cloud credentials](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#cloud-account-setup) for more on how to configure credentials in the pod.
+2. **Cloud credentials**: If you wish to run tasks on different clouds, you can configure cloud credentials in Kubernetes secrets and mount them in the Sky pod defined in the DAG. See [SkyPilot docs on setting up cloud credentials](https://docs.skypilot.co/en/latest/getting-started/installation.html#cloud-account-setup) for more on how to configure credentials in the pod.
3. **Logging**: All SkyPilot logs are written to container stdout, which is captured as task logs in Airflow and displayed in the UI. You can also write logs to a file and read them in subsequent tasks.
4. **XComs for shared state**: Airflow also provides [XComs](https://airflow.apache.org/docs/apache-airflow/stable/concepts/xcoms.html) for cross-task communication. [`sky_k8s_example_xcoms.py`](sky_k8s_example_xcoms.py) demonstrates how to use XComs to share state between tasks.
diff --git a/examples/airflow/training_workflow/README.md b/examples/airflow/training_workflow/README.md
index dad08d8d3b0..71cb10bef50 100644
--- a/examples/airflow/training_workflow/README.md
+++ b/examples/airflow/training_workflow/README.md
@@ -7,7 +7,7 @@ In this guide, we show how a training workflow involving data preprocessing, tra
-**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
+**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://docs.skypilot.co/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
## Why use SkyPilot with Airflow?
In AI workflows, **the transition from development to production is hard**.
@@ -24,7 +24,7 @@ production Airflow cluster. Behind the scenes, SkyPilot handles environment setu
Here's how you can use SkyPilot to take your dev workflows to production in Airflow:
1. **Define and test your workflow as SkyPilot tasks**.
- - Use `sky launch` and [Sky VSCode integration](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html#dev-vscode) to run, debug and iterate on your code.
+ - Use `sky launch` and [Sky VSCode integration](https://docs.skypilot.co/en/latest/examples/interactive-development.html#dev-vscode) to run, debug and iterate on your code.
2. **Orchestrate SkyPilot tasks in Airflow** by invoking `sky launch` on their YAMLs as a task in the Airflow DAG.
- Airflow does the scheduling, logging, and monitoring, while SkyPilot handles the infra setup and task execution.
@@ -34,7 +34,7 @@ Here's how you can use SkyPilot to take your dev workflows to production in Airf
* Airflow installed on a [Kubernetes cluster](https://airflow.apache.org/docs/helm-chart/stable/index.html) or [locally](https://airflow.apache.org/docs/apache-airflow/stable/start.html) (`SequentialExecutor`)
* A Kubernetes cluster to run tasks on. We'll use GKE in this example.
* A Google cloud account with GCS access to store the data for task.
- * Follow [SkyPilot instructions](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#google-cloud-platform-gcp) to set up Google Cloud credentials.
+ * Follow [SkyPilot instructions](https://docs.skypilot.co/en/latest/getting-started/installation.html#google-cloud-platform-gcp) to set up Google Cloud credentials.
## Preparing the Kubernetes Cluster
@@ -60,7 +60,7 @@ Here's how you can use SkyPilot to take your dev workflows to production in Airf
name: sky-airflow-sa
namespace: default
roleRef:
- # For minimal permissions, refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html
+ # For minimal permissions, refer to https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/kubernetes.html
kind: ClusterRole
name: cluster-admin
apiGroup: rbac.authorization.k8s.io
@@ -103,7 +103,7 @@ The train and eval step can be run in a similar way:
sky launch -c train --env DATA_BUCKET_URL=gs:// train.yaml
```
-Hint: You can use `ssh` and VSCode to [interactively develop](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html) and debug the tasks.
+Hint: You can use `ssh` and VSCode to [interactively develop](https://docs.skypilot.co/en/latest/examples/interactive-development.html) and debug the tasks.
Note: `eval` can be optionally run on the same cluster as `train` with `sky exec`. Refer to the `shared_state` airflow example on how to do this.
diff --git a/examples/cog/README.md b/examples/cog/README.md
index b2193e2e18f..97d886e2d2c 100644
--- a/examples/cog/README.md
+++ b/examples/cog/README.md
@@ -17,7 +17,7 @@ curl http://$IP:5000/predictions -X POST \
```
## Scale up the deployment using SkyServe
-We can use SkyServe (`sky serve`) to scale up the deployment to multiple instances, while enjoying load balancing, autoscaling, and other [SkyServe features](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html).
+We can use SkyServe (`sky serve`) to scale up the deployment to multiple instances, while enjoying load balancing, autoscaling, and other [SkyServe features](https://docs.skypilot.co/en/latest/serving/sky-serve.html).
```console
sky serve up -n cog ./sky.yaml
```
diff --git a/examples/distributed-pytorch/README.md b/examples/distributed-pytorch/README.md
new file mode 100644
index 00000000000..6c2f7092269
--- /dev/null
+++ b/examples/distributed-pytorch/README.md
@@ -0,0 +1,81 @@
+# Distributed Training with PyTorch
+
+This example demonstrates how to run distributed training with PyTorch using SkyPilot.
+
+**The example is based on [PyTorch's official minGPT example](https://github.com/pytorch/examples/tree/main/distributed/minGPT-ddp)**
+
+
+## Overview
+
+There are two ways to run distributed training with PyTorch:
+
+1. Using normal `torchrun`
+2. Using `rdvz` backend
+
+The main difference between the two for fixed-size distributed training is that `rdvz` backend automatically handles the rank for each node, while `torchrun` requires the rank to be set manually.
+
+SkyPilot offers convinient built-in environment variables to help you start distributed training easily.
+
+### Using normal `torchrun`
+
+
+The following command will spawn 2 nodes with 2 L4 GPU each:
+```
+sky launch -c train train.yaml
+```
+
+In [train.yaml](./train.yaml), we use `torchrun` to launch the training and set the arguments for distributed training using [environment variables](https://docs.skypilot.co/en/latest/running-jobs/environment-variables.html#skypilot-environment-variables) provided by SkyPilot.
+
+```yaml
+run: |
+ cd examples/mingpt
+ MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ torchrun \
+ --nnodes=$SKYPILOT_NUM_NODES \
+ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
+ --master_addr=$MASTER_ADDR \
+ --master_port=8008 \
+ --node_rank=${SKYPILOT_NODE_RANK} \
+ main.py
+```
+
+
+
+### Using `rdzv` backend
+
+`rdzv` is an alternative backend for distributed training:
+
+```
+sky launch -c train-rdzv train-rdzv.yaml
+```
+
+In [train-rdzv.yaml](./train-rdzv.yaml), we use `torchrun` to launch the training and set the arguments for distributed training using [environment variables](https://docs.skypilot.co/en/latest/running-jobs/environment-variables.html#skypilot-environment-variables) provided by SkyPilot.
+
+```yaml
+run: |
+ cd examples/mingpt
+ MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ echo "Starting distributed training, head node: $MASTER_ADDR"
+
+ torchrun \
+ --nnodes=$SKYPILOT_NUM_NODES \
+ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
+ --rdzv_backend=c10d \
+ --rdzv_endpoint=$MASTER_ADDR:29500 \
+ --rdzv_id $SKYPILOT_TASK_ID \
+ main.py
+```
+
+
+## Scale up
+
+If you would like to scale up the training, you can simply change the resources requirement, and SkyPilot's built-in environment variables will be set automatically.
+
+For example, the following command will spawn 4 nodes with 4 L4 GPUs each.
+
+```
+sky launch -c train train.yaml --num-nodes 4 --gpus L4:4 --cpus 8+
+```
+
+We increase the `--cpus` to 8+ as well to avoid the performance to be bottlenecked by the CPU.
+
diff --git a/examples/distributed-pytorch/train-rdzv.yaml b/examples/distributed-pytorch/train-rdzv.yaml
new file mode 100644
index 00000000000..3bcd63dde4c
--- /dev/null
+++ b/examples/distributed-pytorch/train-rdzv.yaml
@@ -0,0 +1,29 @@
+name: minGPT-ddp-rdzv
+
+resources:
+ cpus: 4+
+ accelerators: L4
+
+num_nodes: 2
+
+setup: |
+ git clone --depth 1 https://github.com/pytorch/examples || true
+ cd examples
+ git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp
+ # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
+ uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113
+
+run: |
+ cd examples/mingpt
+ export LOGLEVEL=INFO
+
+ MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ echo "Starting distributed training, head node: $MASTER_ADDR"
+
+ torchrun \
+ --nnodes=$SKYPILOT_NUM_NODES \
+ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
+ --rdzv_backend=c10d \
+ --rdzv_endpoint=$MASTER_ADDR:29500 \
+ --rdzv_id $SKYPILOT_TASK_ID \
+ main.py
diff --git a/examples/distributed-pytorch/train.yaml b/examples/distributed-pytorch/train.yaml
new file mode 100644
index 00000000000..b45941e1485
--- /dev/null
+++ b/examples/distributed-pytorch/train.yaml
@@ -0,0 +1,29 @@
+name: minGPT-ddp
+
+resources:
+ cpus: 4+
+ accelerators: L4
+
+num_nodes: 2
+
+setup: |
+ git clone --depth 1 https://github.com/pytorch/examples || true
+ cd examples
+ git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp
+ # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
+ uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113
+
+run: |
+ cd examples/mingpt
+ export LOGLEVEL=INFO
+
+ MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ echo "Starting distributed training, head node: $MASTER_ADDR"
+
+ torchrun \
+ --nnodes=$SKYPILOT_NUM_NODES \
+ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
+ --master_addr=$MASTER_ADDR \
+ --master_port=8008 \
+ --node_rank=${SKYPILOT_NODE_RANK} \
+ main.py
diff --git a/examples/k8s_cloud_deploy/README.md b/examples/k8s_cloud_deploy/README.md
index 5ba42cbe836..9b0d46249d4 100644
--- a/examples/k8s_cloud_deploy/README.md
+++ b/examples/k8s_cloud_deploy/README.md
@@ -56,11 +56,11 @@ NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS
## Run AI workloads on your Kubernetes cluster with SkyPilot
### Development clusters
-To launch a [GPU enabled development cluster](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html), run `sky launch -c mycluster --cloud kubernetes --gpus A10:1`.
+To launch a [GPU enabled development cluster](https://docs.skypilot.co/en/latest/examples/interactive-development.html), run `sky launch -c mycluster --cloud kubernetes --gpus A10:1`.
SkyPilot will setup SSH config for you.
-* [SSH access](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html#ssh): `ssh mycluster`
-* [VSCode remote development](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html#vscode): `code --remote ssh-remote+mycluster "/"`
+* [SSH access](https://docs.skypilot.co/en/latest/examples/interactive-development.html#ssh): `ssh mycluster`
+* [VSCode remote development](https://docs.skypilot.co/en/latest/examples/interactive-development.html#vscode): `code --remote ssh-remote+mycluster "/"`
### Jobs
@@ -87,7 +87,7 @@ sky-cmd-1-2ea4-head 1/1 Running 0 8m36s
sky-jobs-controller-2ea485ea-2ea4-head 1/1 Running 0 10m
```
-Refer to [SkyPilot docs](https://skypilot.readthedocs.io/) for more.
+Refer to [SkyPilot docs](https://docs.skypilot.co/) for more.
## Teardown
To teardown the Kubernetes cluster, run:
diff --git a/examples/local_docker/docker_in_docker.yaml b/examples/local_docker/docker_in_docker.yaml
deleted file mode 100644
index bdb6ed70ecf..00000000000
--- a/examples/local_docker/docker_in_docker.yaml
+++ /dev/null
@@ -1,19 +0,0 @@
-# Runs a docker container as a SkyPilot task.
-#
-# This demo can be run using the --docker flag, demonstrating the
-# docker-in-docker (dind) capabilities of SkyPilot docker mode.
-#
-# Usage:
-# sky launch --docker -c dind docker_in_docker.yaml
-# sky down dind
-
-name: dind
-
-resources:
- cloud: aws
-
-setup: |
- echo "No setup required!"
-
-run: |
- docker run --rm hello-world
diff --git a/examples/local_docker/ping.py b/examples/local_docker/ping.py
deleted file mode 100644
index c3a90c62243..00000000000
--- a/examples/local_docker/ping.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""An example app which pings localhost.
-
-This script is designed to demonstrate the use of different backends with
-SkyPilot. It is useful to support a LocalDockerBackend that users can use to
-debug their programs even before they run them on the Sky.
-"""
-
-import sky
-
-# Set backend here. It can be either LocalDockerBackend or CloudVmRayBackend.
-backend = sky.backends.LocalDockerBackend(
-) # or sky.backends.CloudVmRayBackend()
-
-with sky.Dag() as dag:
- resources = sky.Resources(accelerators={'K80': 1})
- setup_commands = 'apt-get update && apt-get install -y iputils-ping'
- task = sky.Task(run='ping 127.0.0.1 -c 100',
- docker_image='ubuntu',
- setup=setup_commands,
- name='ping').set_resources(resources)
-
-sky.launch(dag, backend=backend)
diff --git a/examples/local_docker/ping.yaml b/examples/local_docker/ping.yaml
deleted file mode 100644
index 0d0efd12419..00000000000
--- a/examples/local_docker/ping.yaml
+++ /dev/null
@@ -1,19 +0,0 @@
-# A minimal ping example.
-#
-# Runs a task that pings localhost 100 times.
-#
-# Usage:
-# sky launch --docker -c ping ping.yaml
-# sky down ping
-
-name: ping
-
-resources:
- cloud: aws
-
-setup: |
- sudo apt-get update --allow-insecure-repositories
- sudo apt-get install -y iputils-ping
-
-run: |
- ping 127.0.0.1 -c 100
diff --git a/examples/oci/dataset-mount.yaml b/examples/oci/dataset-mount.yaml
new file mode 100644
index 00000000000..91bec9cda65
--- /dev/null
+++ b/examples/oci/dataset-mount.yaml
@@ -0,0 +1,35 @@
+name: cpu-task1
+
+resources:
+ cloud: oci
+ region: us-sanjose-1
+ cpus: 2
+ disk_size: 256
+ disk_tier: medium
+ use_spot: False
+
+file_mounts:
+ # Mount an existing oci bucket
+ /datasets-storage:
+ source: oci://skybucket
+ mode: MOUNT # Either MOUNT or COPY. Optional.
+
+# Working directory (optional) containing the project codebase.
+# Its contents are synced to ~/sky_workdir/ on the cluster.
+workdir: .
+
+num_nodes: 1
+
+# Typical use: pip install -r requirements.txt
+# Invoked under the workdir (i.e., can use its files).
+setup: |
+ echo "*** Running setup for the task. ***"
+
+# Typical use: make use of resources, such as running training.
+# Invoked under the workdir (i.e., can use its files).
+run: |
+ echo "*** Running the task on OCI ***"
+ timestamp=$(date +%s)
+ ls -lthr /datasets-storage
+ echo "hi" >> /datasets-storage/foo.txt
+ ls -lthr /datasets-storage
diff --git a/examples/oci/dataset-upload-and-mount.yaml b/examples/oci/dataset-upload-and-mount.yaml
new file mode 100644
index 00000000000..13ddc4d2b35
--- /dev/null
+++ b/examples/oci/dataset-upload-and-mount.yaml
@@ -0,0 +1,47 @@
+name: cpu-task1
+
+resources:
+ cloud: oci
+ region: us-sanjose-1
+ cpus: 2
+ disk_size: 256
+ disk_tier: medium
+ use_spot: False
+
+file_mounts:
+ /datasets-storage:
+ name: skybucket # Name of storage, optional when source is bucket URI
+ source: ['./examples/oci'] # Source path, can be local or bucket URL. Optional, do not specify to create an empty bucket.
+ store: oci # E.g 'oci', 's3', 'gcs'...; default: None. Optional.
+ persistent: True # Defaults to True; can be set to false. Optional.
+ mode: MOUNT # Either MOUNT or COPY. Optional.
+
+ /datasets-storage2:
+ name: skybucket2 # Name of storage, optional when source is bucket URI
+ source: './examples/oci' # Source path, can be local or bucket URL. Optional, do not specify to create an empty bucket.
+ store: oci # E.g 'oci', 's3', 'gcs'...; default: None. Optional.
+ persistent: True # Defaults to True; can be set to false. Optional.
+ mode: MOUNT # Either MOUNT or COPY. Optional.
+
+# Working directory (optional) containing the project codebase.
+# Its contents are synced to ~/sky_workdir/ on the cluster.
+workdir: .
+
+num_nodes: 1
+
+# Typical use: pip install -r requirements.txt
+# Invoked under the workdir (i.e., can use its files).
+setup: |
+ echo "*** Running setup for the task. ***"
+
+# Typical use: make use of resources, such as running training.
+# Invoked under the workdir (i.e., can use its files).
+run: |
+ echo "*** Running the task on OCI ***"
+ ls -lthr /datasets-storage
+ echo "hi" >> /datasets-storage/foo.txt
+ ls -lthr /datasets-storage
+
+ ls -lthr /datasets-storage2
+ echo "hi" >> /datasets-storage2/foo2.txt
+ ls -lthr /datasets-storage2
diff --git a/examples/oci/gpu-oraclelinux9.yaml b/examples/oci/gpu-oraclelinux9.yaml
new file mode 100644
index 00000000000..cc7b05ea0fc
--- /dev/null
+++ b/examples/oci/gpu-oraclelinux9.yaml
@@ -0,0 +1,33 @@
+name: gpu-task
+
+resources:
+ # Optional; if left out, automatically pick the cheapest cloud.
+ cloud: oci
+
+ accelerators: A10:1
+
+ disk_size: 1024
+
+ disk_tier: high
+
+ image_id: skypilot:gpu-oraclelinux9
+
+
+# Working directory (optional) containing the project codebase.
+# Its contents are synced to ~/sky_workdir/ on the cluster.
+workdir: .
+
+num_nodes: 1
+
+# Typical use: pip install -r requirements.txt
+# Invoked under the workdir (i.e., can use its files).
+setup: |
+ echo "*** Running setup. ***"
+
+# Typical use: make use of resources, such as running training.
+# Invoked under the workdir (i.e., can use its files).
+run: |
+ echo "*** Running the task on OCI ***"
+ echo "hello, world"
+ nvidia-smi
+ echo "The task is completed."
diff --git a/examples/oci/gpu-ubuntu-2204.yaml b/examples/oci/gpu-ubuntu-2204.yaml
new file mode 100644
index 00000000000..e0012a31a1a
--- /dev/null
+++ b/examples/oci/gpu-ubuntu-2204.yaml
@@ -0,0 +1,33 @@
+name: gpu-task
+
+resources:
+ # Optional; if left out, automatically pick the cheapest cloud.
+ cloud: oci
+
+ accelerators: A10:1
+
+ disk_size: 1024
+
+ disk_tier: high
+
+ image_id: skypilot:gpu-ubuntu-2204
+
+
+# Working directory (optional) containing the project codebase.
+# Its contents are synced to ~/sky_workdir/ on the cluster.
+workdir: .
+
+num_nodes: 1
+
+# Typical use: pip install -r requirements.txt
+# Invoked under the workdir (i.e., can use its files).
+setup: |
+ echo "*** Running setup. ***"
+
+# Typical use: make use of resources, such as running training.
+# Invoked under the workdir (i.e., can use its files).
+run: |
+ echo "*** Running the task on OCI ***"
+ echo "hello, world"
+ nvidia-smi
+ echo "The task is completed."
diff --git a/examples/oci/oci-mounts.yaml b/examples/oci/oci-mounts.yaml
new file mode 100644
index 00000000000..6fd2aaf16eb
--- /dev/null
+++ b/examples/oci/oci-mounts.yaml
@@ -0,0 +1,26 @@
+resources:
+ cloud: oci
+
+file_mounts:
+ ~/tmpfile: ~/tmpfile
+ ~/a/b/c/tmpfile: ~/tmpfile
+ /tmp/workdir: ~/tmp-workdir
+
+ /mydir:
+ name: skybucket
+ source: ['~/tmp-workdir']
+ store: oci
+ mode: MOUNT
+
+setup: |
+ echo "*** Setup ***"
+
+run: |
+ echo "*** Run ***"
+
+ ls -lthr ~/tmpfile
+ ls -lthr ~/a/b/c
+ echo hi >> /tmp/workdir/new_file
+ ls -lthr /tmp/workdir
+
+ ls -lthr /mydir
diff --git a/examples/serve/minimal.yaml b/examples/serve/minimal.yaml
new file mode 100644
index 00000000000..c925d26f5d1
--- /dev/null
+++ b/examples/serve/minimal.yaml
@@ -0,0 +1,11 @@
+# An minimal example of a serve application.
+
+service:
+ readiness_probe: /
+ replicas: 1
+
+resources:
+ ports: 8080
+ cpus: 2+
+
+run: python3 -m http.server 8080
diff --git a/examples/spot/lightning_cifar10/train.py b/examples/spot/lightning_cifar10/train.py
index 0df6f18484b..14901e635ef 100644
--- a/examples/spot/lightning_cifar10/train.py
+++ b/examples/spot/lightning_cifar10/train.py
@@ -163,7 +163,7 @@ def main():
)
model_ckpts = glob.glob(argv.root_dir + "/*.ckpt")
- if argv.resume and len(model_ckpts) > 0:
+ if argv.resume and model_ckpts:
latest_ckpt = max(model_ckpts, key=os.path.getctime)
trainer.fit(model, cifar10_dm, ckpt_path=latest_ckpt)
else:
diff --git a/examples/stable_diffusion/README.md b/examples/stable_diffusion/README.md
index 2a4383f1347..56af44df91e 100644
--- a/examples/stable_diffusion/README.md
+++ b/examples/stable_diffusion/README.md
@@ -1,6 +1,6 @@
## Setup
-1. Install skypilot package by following these [instructions](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
+1. Install skypilot package by following these [instructions](https://docs.skypilot.co/en/latest/getting-started/installation.html).
2. Run `git clone https://github.com/skypilot-org/skypilot.git && cd examples/stable_diffusion`
diff --git a/examples/stable_diffusion/pushing_docker_image.md b/examples/stable_diffusion/pushing_docker_image.md
index 80b285fa832..0585d566543 100644
--- a/examples/stable_diffusion/pushing_docker_image.md
+++ b/examples/stable_diffusion/pushing_docker_image.md
@@ -1,6 +1,6 @@
## GCR
-1. Install skypilot package by following these [instructions](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
+1. Install skypilot package by following these [instructions](https://docs.skypilot.co/en/latest/getting-started/installation.html).
2. Run `git clone https://github.com/skypilot-org/skypilot.git `.
diff --git a/llm/codellama/README.md b/llm/codellama/README.md
index f145fd062ff..54019bd6d2a 100644
--- a/llm/codellama/README.md
+++ b/llm/codellama/README.md
@@ -38,7 +38,7 @@ The followings are the demos of Code Llama 70B hosted by SkyPilot Serve (aka Sky
## Running your own Code Llama with SkyPilot
-After [installing SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), run your own Code Llama on vLLM with SkyPilot in 1-click:
+After [installing SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html), run your own Code Llama on vLLM with SkyPilot in 1-click:
1. Start serving Code Llama 70B on a single instance with any available GPU in the list specified in [endpoint.yaml](https://github.com/skypilot-org/skypilot/tree/master/llm/codellama/endpoint.yaml) with a vLLM powered OpenAI-compatible endpoint:
```console
@@ -100,7 +100,7 @@ This returns the following completion:
## Scale up the service with SkyServe
-1. With [SkyServe](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html), a serving library built on top of SkyPilot, scaling up the Code Llama service is as simple as running:
+1. With [SkyServe](https://docs.skypilot.co/en/latest/serving/sky-serve.html), a serving library built on top of SkyPilot, scaling up the Code Llama service is as simple as running:
```bash
sky serve up -n code-llama ./endpoint.yaml
```
diff --git a/llm/dbrx/README.md b/llm/dbrx/README.md
index 3011af9d4e6..2845634b287 100644
--- a/llm/dbrx/README.md
+++ b/llm/dbrx/README.md
@@ -11,7 +11,7 @@ In this recipe, you will serve `databricks/dbrx-instruct` on your own infra --
## Prerequisites
- Go to the [HuggingFace model page](https://huggingface.co/databricks/dbrx-instruct) and request access to the model `databricks/dbrx-instruct`.
-- Check that you have installed SkyPilot ([docs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)).
+- Check that you have installed SkyPilot ([docs](https://docs.skypilot.co/en/latest/getting-started/installation.html)).
- Check that `sky check` shows clouds or Kubernetes are enabled.
## SkyPilot YAML
@@ -278,6 +278,6 @@ To shut down all resources:
sky serve down dbrx
```
-See more details in [SkyServe docs](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html).
+See more details in [SkyServe docs](https://docs.skypilot.co/en/latest/serving/sky-serve.html).
diff --git a/llm/gemma/README.md b/llm/gemma/README.md
index ef5027b2807..7296f7c7e31 100644
--- a/llm/gemma/README.md
+++ b/llm/gemma/README.md
@@ -24,7 +24,7 @@ Generate a read-only access token on huggingface [here](https://huggingface.co/s
```bash
pip install "skypilot-nightly[all]"
```
-For detailed installation instructions, please refer to the [installation guide](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
+For detailed installation instructions, please refer to the [installation guide](https://docs.skypilot.co/en/latest/getting-started/installation.html).
### Host on a Single Instance
diff --git a/llm/gpt-2/README.md b/llm/gpt-2/README.md
index 10fa2cf6998..b8e656e2353 100644
--- a/llm/gpt-2/README.md
+++ b/llm/gpt-2/README.md
@@ -13,7 +13,7 @@ pip install "skypilot-nightly[aws,gcp,azure,kubernetes,lambda,fluidstack]" # Cho
```bash
sky check
```
-Please check the instructions for enabling clouds at [SkyPilot doc](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
+Please check the instructions for enabling clouds at [SkyPilot doc](https://docs.skypilot.co/en/latest/getting-started/installation.html).
3. Download the YAML for starting the training:
```bash
diff --git a/llm/llama-3/README.md b/llm/llama-3/README.md
index 8ffcb3087a9..c4cf9066f63 100644
--- a/llm/llama-3/README.md
+++ b/llm/llama-3/README.md
@@ -29,7 +29,7 @@
## Prerequisites
- Go to the [HuggingFace model page](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) and request access to the model `meta-llama/Meta-Llama-3-70B-Instruct`.
-- Check that you have installed SkyPilot ([docs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)).
+- Check that you have installed SkyPilot ([docs](https://docs.skypilot.co/en/latest/getting-started/installation.html)).
- Check that `sky check` shows clouds or Kubernetes are enabled.
## SkyPilot YAML
@@ -326,7 +326,7 @@ To shut down all resources:
sky serve down llama3
```
-See more details in [SkyServe docs](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html).
+See more details in [SkyServe docs](https://docs.skypilot.co/en/latest/serving/sky-serve.html).
### **Optional**: Connect a GUI to your Llama-3 endpoint
@@ -349,4 +349,4 @@ sky launch -c llama3-gui ./gui.yaml --env ENDPOINT=$(sky serve status --endpoint
## Finetuning Llama-3
-You can finetune Llama-3 on your own data. We have an tutorial for finetunning Llama-2 for Vicuna on SkyPilot, which can be adapted for Llama-3. You can find the tutorial [here](https://skypilot.readthedocs.io/en/latest/gallery/tutorials/finetuning.html) and a detailed blog post [here](https://blog.skypilot.co/finetuning-llama2-operational-guide/).
+You can finetune Llama-3 on your own data. We have an tutorial for finetunning Llama-2 for Vicuna on SkyPilot, which can be adapted for Llama-3. You can find the tutorial [here](https://docs.skypilot.co/en/latest/gallery/tutorials/finetuning.html) and a detailed blog post [here](https://blog.skypilot.co/finetuning-llama2-operational-guide/).
diff --git a/llm/llama-3_1-finetuning/readme.md b/llm/llama-3_1-finetuning/readme.md
index 935dccde84e..ddc2b9e2463 100644
--- a/llm/llama-3_1-finetuning/readme.md
+++ b/llm/llama-3_1-finetuning/readme.md
@@ -7,10 +7,10 @@
On July 23, 2024, Meta released the [Llama 3.1 model family](https://ai.meta.com/blog/meta-llama-3-1/), including a 405B parameter model in both base model and instruction-tuned forms. Llama 3.1 405B became _the first open LLM that closely rivals top proprietary models_ like GPT-4o and Claude 3.5 Sonnet.
-This guide shows how to use [SkyPilot](https://github.com/skypilot-org/skypilot) and [torchtune](https://pytorch.org/torchtune/stable/index.html) to **finetune Llama 3.1 on your own data and infra**. Everything is packaged in a simple [SkyPilot YAML](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html), that can be launched with one command on your infra:
+This guide shows how to use [SkyPilot](https://github.com/skypilot-org/skypilot) and [torchtune](https://pytorch.org/torchtune/stable/index.html) to **finetune Llama 3.1 on your own data and infra**. Everything is packaged in a simple [SkyPilot YAML](https://docs.skypilot.co/en/latest/getting-started/quickstart.html), that can be launched with one command on your infra:
- Local GPU workstation
- Kubernetes cluster
-- Cloud accounts ([12 clouds supported](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html))
+- Cloud accounts ([12 clouds supported](https://docs.skypilot.co/en/latest/getting-started/installation.html))
@@ -233,7 +233,7 @@ export HF_TOKEN="xxxx"
```bash
pip install skypilot-nightly[aws,gcp,kubernetes]
# or other clouds (12 clouds + kubernetes supported) you have setup
-# See: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html
+# See: https://docs.skypilot.co/en/latest/getting-started/installation.html
```
5. Check your infra setup:
@@ -262,6 +262,6 @@ sky check
## What's next
* [AI on Kubernetes Without the Pain](https://blog.skypilot.co/ai-on-kubernetes/)
-* [SkyPilot AI Gallery](https://skypilot.readthedocs.io/en/latest/gallery/index.html)
-* [SkyPilot Docs](https://skypilot.readthedocs.io/en/latest/docs/index.html)
+* [SkyPilot AI Gallery](https://docs.skypilot.co/en/latest/gallery/index.html)
+* [SkyPilot Docs](https://docs.skypilot.co)
* [SkyPilot GitHub](https://github.com/skypilot-org/skypilot)
diff --git a/llm/llama-3_1/README.md b/llm/llama-3_1/README.md
index 6cfeb8dc5f9..2634811d8a1 100644
--- a/llm/llama-3_1/README.md
+++ b/llm/llama-3_1/README.md
@@ -13,7 +13,7 @@ This guide walks through how to serve Llama 3.1 models **completely on your infr
- Local GPU workstation
- Kubernetes cluster
-- Cloud accounts ([12 clouds supported](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html))
+- Cloud accounts ([12 clouds supported](https://docs.skypilot.co/en/latest/getting-started/installation.html))
SkyPilot will be used as the unified framework to launch serving on any (or multiple) infra that you bring.
@@ -64,7 +64,7 @@ sky check kubernetes
sky check
```
-See [docs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html) for details.
+See [docs](https://docs.skypilot.co/en/latest/getting-started/installation.html) for details.
### Step 1: Get a GPU dev node (pod or VM)
@@ -155,7 +155,7 @@ Now that we verified the model is working, let's package it for hands-free deplo
Whichever infra you use for GPUs, SkyPilot abstracts away the mundane infra tasks (e.g., setting up services on K8s, opening up ports for cloud VMs), making AI models super easy to deploy via one command.
-[Deploying via SkyPilot](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html) has several key benefits:
+[Deploying via SkyPilot](https://docs.skypilot.co/en/latest/serving/sky-serve.html) has several key benefits:
- Control node & replicas completely stay in your infra
- Automatic load-balancing across multiple replicas
- Automatic recovery of replicas
@@ -296,7 +296,7 @@ curl -L http://$ENDPOINT/v1/chat/completions \
🎉 **Congratulations!** You are now serving a Llama 3.1 8B model across two replicas. To recap, all model replicas **stay in your own private infrastructure** and SkyPilot ensures they are **healthy and available**.
-Details on autoscaling, rolling updates, and more in [SkyServe docs](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html).
+Details on autoscaling, rolling updates, and more in [SkyServe docs](https://docs.skypilot.co/en/latest/serving/sky-serve.html).
When you are done, shut down all resources:
diff --git a/llm/llama-3_2/README.md b/llm/llama-3_2/README.md
index 987dc0d90c5..f6c2a54ce6a 100644
--- a/llm/llama-3_2/README.md
+++ b/llm/llama-3_2/README.md
@@ -26,7 +26,7 @@
## Prerequisites
- Go to the [HuggingFace model page](https://huggingface.co/meta-llama/) and request access to the model [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) and [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision).
-- Check that you have installed SkyPilot ([docs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)).
+- Check that you have installed SkyPilot ([docs](https://docs.skypilot.co/en/latest/getting-started/installation.html)).
- Check that `sky check` shows clouds or Kubernetes are enabled.
## SkyPilot YAML
@@ -346,7 +346,7 @@ To shut down all resources:
sky serve down llama3
```
-See more details in [SkyServe docs](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html).
+See more details in [SkyServe docs](https://docs.skypilot.co/en/latest/serving/sky-serve.html).
## Developing and Finetuning Llama 3 series
diff --git a/llm/llama-chatbots/README.md b/llm/llama-chatbots/README.md
index 418d3d39d15..272cc24d288 100644
--- a/llm/llama-chatbots/README.md
+++ b/llm/llama-chatbots/README.md
@@ -17,12 +17,12 @@ It will automatically perform the following:
[**LLaMA**](https://github.com/facebookresearch/llama) is a set of Large Language Models (LLMs) recently released by Meta. Trained on more than 1 trillion tokens from public datasets, LLaMA achieves high quality and is space-efficient. You can [fill out a form to request access from Meta](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) to download the open model weights.
In the steps below we assume either (1) you have an unexpired download URL, or (2) the weights have been downloaded and stored on the local machine.
-[**SkyPilot**](https://github.com/skypilot-org/skypilot) is an open-source framework from UC Berkeley for seamlessly running machine learning on any cloud. With a simple CLI, users can easily launch many clusters and jobs, while substantially lowering their cloud bills. Currently, [Lambda Labs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#lambda-cloud) (low-cost GPU cloud), [AWS](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#aws), [GCP](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#gcp), and [Azure](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#azure) are supported. See [docs](https://skypilot.readthedocs.io/en/latest/) to learn more.
+[**SkyPilot**](https://github.com/skypilot-org/skypilot) is an open-source framework from UC Berkeley for seamlessly running machine learning on any cloud. With a simple CLI, users can easily launch many clusters and jobs, while substantially lowering their cloud bills. Currently, [Lambda Labs](https://docs.skypilot.co/en/latest/getting-started/installation.html#lambda-cloud) (low-cost GPU cloud), [AWS](https://docs.skypilot.co/en/latest/getting-started/installation.html#aws), [GCP](https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp), and [Azure](https://docs.skypilot.co/en/latest/getting-started/installation.html#azure) are supported. See [docs](https://docs.skypilot.co/en/latest/) to learn more.
## Steps
All YAML files used below live in [the SkyPilot repo](https://github.com/skypilot-org/skypilot/tree/master/llm/llama-chatbots), and the chatbot code is [here](https://github.com/skypilot-org/sky-llama).
-0. Install SkyPilot and [check that cloud credentials exist](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#cloud-account-setup):
+0. Install SkyPilot and [check that cloud credentials exist](https://docs.skypilot.co/en/latest/getting-started/installation.html#cloud-account-setup):
```bash
pip install "skypilot[aws,gcp,azure,lambda]" # pick your clouds
sky check
@@ -120,7 +120,7 @@ sky launch llama-30b.yaml -c llama-30b -s --env LLAMA_URL=$LLAMA_URL
sky launch llama-65b.yaml -c llama-65b -s --env LLAMA_URL=$LLAMA_URL
```
-To see details about these flags, see [CLI docs](https://skypilot.readthedocs.io/en/latest/reference/cli.html) or run `sky launch -h`.
+To see details about these flags, see [CLI docs](https://docs.skypilot.co/en/latest/reference/cli.html) or run `sky launch -h`.
## Cleaning up
When you are done, you can stop or tear down the cluster:
@@ -140,7 +140,7 @@ When you are done, you can stop or tear down the cluster:
```
**To see your clusters**, run `sky status`, which is a single pane of glass for all your clusters across regions/clouds.
-To learn more about various SkyPilot commands, see [Quickstart](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html).
+To learn more about various SkyPilot commands, see [Quickstart](https://docs.skypilot.co/en/latest/getting-started/quickstart.html).
## Why SkyPilot?
@@ -166,12 +166,12 @@ SkyPilot's `sky launch` command makes this entirely automatic. It performs *auto
- low-cost GPU cloud (Lambda; >3x cheaper than AWS/Azure/GCP)
- spot instances (>3x cheaper than on-demand)
- automatically choosing the cheapest cloud/region/zone
-- auto-stopping & auto-termination of instances ([docs](https://skypilot.readthedocs.io/en/latest/reference/auto-stop.html))
+- auto-stopping & auto-termination of instances ([docs](https://docs.skypilot.co/en/latest/reference/auto-stop.html))
## Recap
Congratulations! You have used SkyPilot to launch a LLaMA-based chatbot on the cloud with just one command. The system automatically handles setting up instances and it offers cloud portability, higher GPU availability, and cost reduction.
-LLaMA chatbots are just one example app. To leverage these benefits for your own ML projects on the cloud, we recommend the [Quickstart guide](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html).
+LLaMA chatbots are just one example app. To leverage these benefits for your own ML projects on the cloud, we recommend the [Quickstart guide](https://docs.skypilot.co/en/latest/getting-started/quickstart.html).
*Feedback or questions? Want to run other LLM models?* Feel free to drop a note to the SkyPilot team on [GitHub](https://github.com/skypilot-org/skypilot/) or [Slack](http://slack.skypilot.co/) and we're happy to chat!
diff --git a/llm/localgpt/README.md b/llm/localgpt/README.md
index 17b3332ee30..c52f1b08851 100644
--- a/llm/localgpt/README.md
+++ b/llm/localgpt/README.md
@@ -13,7 +13,7 @@ Install SkyPilot and check your setup of cloud credentials:
pip install git+https://github.com/skypilot-org/skypilot.git
sky check
```
-See [docs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html) for more.
+See [docs](https://docs.skypilot.co/en/latest/getting-started/installation.html) for more.
Once you are done, we will use [SkyPilot YAML for localGPT](https://github.com/skypilot-org/skypilot/tree/master/llm/localgpt/localgpt.yaml) to define our task and run it.
diff --git a/llm/lorax/README.md b/llm/lorax/README.md
index edd153d45f1..b1d5def6e78 100644
--- a/llm/lorax/README.md
+++ b/llm/lorax/README.md
@@ -40,7 +40,7 @@ sky launch -c lorax-cluster lorax.yaml
By default, this config will deploy `Mistral-7B-Instruct`, but this can be overridden by running `sky launch` with the argument `--env MODEL_ID=`.
-**NOTE:** This config will launch the instance on a public IP. It's highly recommended to secure the instance within a private subnet. See the [Advanced Configurations](https://skypilot.readthedocs.io/en/latest/reference/config.html#config-yaml) section of the SkyPilot docs for options to run within VPC and setup private IPs.
+**NOTE:** This config will launch the instance on a public IP. It's highly recommended to secure the instance within a private subnet. See the [Advanced Configurations](https://docs.skypilot.co/en/latest/reference/config.html#config-yaml) section of the SkyPilot docs for options to run within VPC and setup private IPs.
## Prompt LoRAX w/ base model
diff --git a/llm/mixtral/README.md b/llm/mixtral/README.md
index 0bddb77c665..8456dbb5fcf 100644
--- a/llm/mixtral/README.md
+++ b/llm/mixtral/README.md
@@ -15,7 +15,7 @@ SkyPilot can help you serve Mixtral by automatically finding available resources
sky launch -c mixtral ./serve.yaml
```
-Note that we specify the following resources, so that SkyPilot will automatically find any of the available GPUs specified by automatically [failover](https://skypilot.readthedocs.io/en/latest/examples/auto-failover.html) through all the candidates (in the order of the prices):
+Note that we specify the following resources, so that SkyPilot will automatically find any of the available GPUs specified by automatically [failover](https://docs.skypilot.co/en/latest/examples/auto-failover.html) through all the candidates (in the order of the prices):
```yaml
resources:
@@ -82,7 +82,7 @@ curl http://$IP:8000/v1/chat/completions \
## 2. Serve with multiple instances
-When scaling up is required, [SkyServe](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html) is the library built on top of SkyPilot, which can help you scale up the serving with multiple instances, while still providing a single endpoint. To serve Mixtral with multiple instances, run the following command:
+When scaling up is required, [SkyServe](https://docs.skypilot.co/en/latest/serving/sky-serve.html) is the library built on top of SkyPilot, which can help you scale up the serving with multiple instances, while still providing a single endpoint. To serve Mixtral with multiple instances, run the following command:
```bash
sky serve up -n mixtral ./serve.yaml
diff --git a/llm/ollama/README.md b/llm/ollama/README.md
index 16a8a9ea8e4..2d15b598381 100644
--- a/llm/ollama/README.md
+++ b/llm/ollama/README.md
@@ -17,7 +17,7 @@ To get started, install the latest version of SkyPilot:
pip install "skypilot-nightly[all]"
```
-For detailed installation instructions, please refer to the [installation guide](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
+For detailed installation instructions, please refer to the [installation guide](https://docs.skypilot.co/en/latest/getting-started/installation.html).
Once installed, run `sky check` to verify you have cloud access.
@@ -296,4 +296,4 @@ To shut down all resources:
sky serve down ollama
```
-See more details in [SkyServe docs](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html).
+See more details in [SkyServe docs](https://docs.skypilot.co/en/latest/serving/sky-serve.html).
diff --git a/llm/ollama/ollama.yaml b/llm/ollama/ollama.yaml
index 851dfe45dee..ed37c0ceb1b 100644
--- a/llm/ollama/ollama.yaml
+++ b/llm/ollama/ollama.yaml
@@ -47,13 +47,9 @@ service:
setup: |
# Install Ollama
- if [ "$(uname -m)" == "aarch64" ]; then
- # For apple silicon support
- sudo curl -L https://ollama.com/download/ollama-linux-arm64 -o /usr/bin/ollama
- else
- sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama
- fi
- sudo chmod +x /usr/bin/ollama
+ # official installation reference: https://ollama.com/download/linux
+ curl -fsSL https://ollama.com/install.sh | sh
+ sudo chmod +x /usr/local/bin/ollama
# Start `ollama serve` and capture PID to kill it after pull is done
ollama serve &
diff --git a/llm/pixtral/README.md b/llm/pixtral/README.md
index fccde1de7ad..987769c892a 100644
--- a/llm/pixtral/README.md
+++ b/llm/pixtral/README.md
@@ -57,7 +57,7 @@ This guide shows how to use run and deploy this multimodal model on your own clo
pip install 'skypilot[all]'
sky check
```
-Detailed instructions for installation and cloud setup [here](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
+Detailed instructions for installation and cloud setup [here](https://docs.skypilot.co/en/latest/getting-started/installation.html).
2. Launch the model on any cloud or Kubernetes:
```bash
@@ -150,7 +150,7 @@ These descriptions should give you a clear picture of the scenes depicted in the
## Scale Up Pixtral Endpoint as a Service
-1. Start a service with [SkyServe](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html):
+1. Start a service with [SkyServe](https://docs.skypilot.co/en/latest/serving/sky-serve.html):
```bash
sky serve up -n pixtral pixtral.yaml
```
diff --git a/llm/qwen/README.md b/llm/qwen/README.md
index 6846fc71f2f..d4c73edb842 100644
--- a/llm/qwen/README.md
+++ b/llm/qwen/README.md
@@ -27,7 +27,7 @@ As of Jun 2024, Qwen1.5-110B-Chat is ranked higher than GPT-4-0613 on the [LMSYS
## Running your own Qwen with SkyPilot
-After [installing SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), run your own Qwen model on vLLM with SkyPilot in 1-click:
+After [installing SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html), run your own Qwen model on vLLM with SkyPilot in 1-click:
1. Start serving Qwen 110B on a single instance with any available GPU in the list specified in [qwen15-110b.yaml](https://github.com/skypilot-org/skypilot/blob/master/llm/qwen/qwen15-110b.yaml) with a vLLM powered OpenAI-compatible endpoint (You can also switch to [qwen25-72b.yaml](https://github.com/skypilot-org/skypilot/blob/master/llm/qwen/qwen25-72b.yaml) or [qwen25-7b.yaml](https://github.com/skypilot-org/skypilot/blob/master/llm/qwen/qwen25-7b.yaml) for a smaller model):
@@ -98,7 +98,7 @@ curl http://$ENDPOINT/v1/chat/completions \
## Scale up the service with SkyServe
-1. With [SkyPilot Serving](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html), a serving library built on top of SkyPilot, scaling up the Qwen service is as simple as running:
+1. With [SkyPilot Serving](https://docs.skypilot.co/en/latest/serving/sky-serve.html), a serving library built on top of SkyPilot, scaling up the Qwen service is as simple as running:
```bash
sky serve up -n qwen ./qwen25-72b.yaml
```
diff --git a/llm/sglang/README.md b/llm/sglang/README.md
index 7d41b8fc168..f6bac3c71ad 100644
--- a/llm/sglang/README.md
+++ b/llm/sglang/README.md
@@ -21,7 +21,7 @@ sky check
```
## Serving vision-language model LLaVA with SGLang for more traffic using SkyServe
-1. Create a [`SkyServe Service YAML`](https://skypilot.readthedocs.io/en/latest/serving/service-yaml-spec.html) with a `service` section:
+1. Create a [`SkyServe Service YAML`](https://docs.skypilot.co/en/latest/serving/service-yaml-spec.html) with a `service` section:
```yaml
service:
@@ -33,7 +33,7 @@ service:
The entire Service YAML can be found here: [llava.yaml](https://github.com/skypilot-org/skypilot/tree/master/llm/sglang/llava.yaml).
-2. Start serving by using [SkyServe](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html) CLI:
+2. Start serving by using [SkyServe](https://docs.skypilot.co/en/latest/serving/sky-serve.html) CLI:
```bash
sky serve up -n sglang-llava llava.yaml
```
@@ -117,7 +117,7 @@ You should get a similar response as the following:
## Serving Llama-2 with SGLang for more traffic using SkyServe
1. The process is the same as serving LLaVA, but with the model path changed to Llama-2. Below are example commands for reference.
-2. Start serving by using [SkyServe](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html) CLI:
+2. Start serving by using [SkyServe](https://docs.skypilot.co/en/latest/serving/sky-serve.html) CLI:
```bash
sky serve up -n sglang-llama2 llama2.yaml --env HF_TOKEN=
```
diff --git a/llm/tabby/README.md b/llm/tabby/README.md
index 569b64538c1..9aa4ca4c803 100644
--- a/llm/tabby/README.md
+++ b/llm/tabby/README.md
@@ -17,13 +17,13 @@ This post shows how to use SkyPilot to host an ai coding assistant with just one
- OpenAPI interface, easy to integrate with existing infrastructure (e.g Cloud IDE).
- Supports consumer-grade GPUs.
-[**SkyPilot**](https://github.com/skypilot-org/skypilot) is an open-source framework from UC Berkeley for seamlessly running machine learning on any cloud. With a simple CLI, users can easily launch many clusters and jobs, while substantially lowering their cloud bills. Currently, [AWS](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#aws), [GCP](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#gcp), [Azure](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#azure), [Lambda Cloud](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#lambda-cloud), [IBM](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#ibm), [Oracle Cloud Infrastructure (OCI)](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#oracle-cloud-infrastructure-oci), [Cloudflare R2](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#cloudflare-r2) and [Samsung Cloud Platform (SCP)](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#samsung-cloud-platform-scp) are supported. See [docs](https://skypilot.readthedocs.io/en/latest/) to learn more.
+[**SkyPilot**](https://github.com/skypilot-org/skypilot) is an open-source framework from UC Berkeley for seamlessly running machine learning on any cloud. With a simple CLI, users can easily launch many clusters and jobs, while substantially lowering their cloud bills. Currently, [AWS](https://docs.skypilot.co/en/latest/getting-started/installation.html#aws), [GCP](https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp), [Azure](https://docs.skypilot.co/en/latest/getting-started/installation.html#azure), [Lambda Cloud](https://docs.skypilot.co/en/latest/getting-started/installation.html#lambda-cloud), [IBM](https://docs.skypilot.co/en/latest/getting-started/installation.html#ibm), [Oracle Cloud Infrastructure (OCI)](https://docs.skypilot.co/en/latest/getting-started/installation.html#oracle-cloud-infrastructure-oci), [Cloudflare R2](https://docs.skypilot.co/en/latest/getting-started/installation.html#cloudflare-r2) and [Samsung Cloud Platform (SCP)](https://docs.skypilot.co/en/latest/getting-started/installation.html#samsung-cloud-platform-scp) are supported. See [docs](https://docs.skypilot.co/en/latest/) to learn more.
## Steps
All YAML files used below live in [the SkyPilot repo](https://github.com/skypilot-org/skypilot/tree/master/llm/tabby).
-1. Install SkyPilot and [check that cloud credentials exist](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#cloud-account-setup):
+1. Install SkyPilot and [check that cloud credentials exist](https://docs.skypilot.co/en/latest/getting-started/installation.html#cloud-account-setup):
```bash
# pip install skypilot
@@ -94,4 +94,4 @@ When you are done, you can stop or tear down the cluster:
```
**To see your clusters**, run `sky status`, which is a single pane of glass for all your clusters across regions/clouds.
-To learn more about various SkyPilot commands, see [Quickstart](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html).
+To learn more about various SkyPilot commands, see [Quickstart](https://docs.skypilot.co/en/latest/getting-started/quickstart.html).
diff --git a/llm/vicuna-llama-2/README.md b/llm/vicuna-llama-2/README.md
index e392b231e64..31d78a243cb 100644
--- a/llm/vicuna-llama-2/README.md
+++ b/llm/vicuna-llama-2/README.md
@@ -120,7 +120,7 @@ sky launch --no-use-spot ...
### Reducing costs by 3x with spot instances
-[SkyPilot Managed Jobs](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html) is a library built on top of SkyPilot that helps users run jobs on spot instances without worrying about interruptions. That is the tool used by the LMSYS organization to train the first version of Vicuna (more details can be found in their [launch blog post](https://lmsys.org/blog/2023-03-30-vicuna/) and [example](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna)). With this, the training cost can be reduced from $1000 to **\$300**.
+[SkyPilot Managed Jobs](https://docs.skypilot.co/en/latest/examples/managed-jobs.html) is a library built on top of SkyPilot that helps users run jobs on spot instances without worrying about interruptions. That is the tool used by the LMSYS organization to train the first version of Vicuna (more details can be found in their [launch blog post](https://lmsys.org/blog/2023-03-30-vicuna/) and [example](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna)). With this, the training cost can be reduced from $1000 to **\$300**.
To use SkyPilot Managed Spot Jobs, you can simply replace `sky launch` with `sky jobs launch` in the above command:
diff --git a/llm/vicuna/README.md b/llm/vicuna/README.md
index 6d9f46127d4..b8c6ab100d8 100644
--- a/llm/vicuna/README.md
+++ b/llm/vicuna/README.md
@@ -4,7 +4,7 @@
-This README contains instructions to run and train Vicuna, an open-source LLM chatbot with quality comparable to ChatGPT. The Vicuna release was trained using SkyPilot on [cloud spot instances](https://skypilot.readthedocs.io/en/latest/examples/spot-jobs.html), with a cost of ~$300.
+This README contains instructions to run and train Vicuna, an open-source LLM chatbot with quality comparable to ChatGPT. The Vicuna release was trained using SkyPilot on [cloud spot instances](https://docs.skypilot.co/en/latest/examples/spot-jobs.html), with a cost of ~$300.
* [Blog post](https://lmsys.org/blog/2023-03-30-vicuna/)
* [Demo](https://chat.lmsys.org/)
diff --git a/llm/vllm/README.md b/llm/vllm/README.md
index 78617f3746d..c150ae46e2d 100644
--- a/llm/vllm/README.md
+++ b/llm/vllm/README.md
@@ -112,7 +112,7 @@ curl http://$IP:8000/v1/chat/completions \
## Serving Llama-2 with vLLM for more traffic using SkyServe
To scale up the model serving for more traffic, we introduced SkyServe to enable a user to easily deploy multiple replica of the model:
-1. Adding an `service` section in the above `serve-openai-api.yaml` file to make it an [`SkyServe Service YAML`](https://skypilot.readthedocs.io/en/latest/serving/service-yaml-spec.html):
+1. Adding an `service` section in the above `serve-openai-api.yaml` file to make it an [`SkyServe Service YAML`](https://docs.skypilot.co/en/latest/serving/service-yaml-spec.html):
```yaml
# The newly-added `service` section to the `serve-openai-api.yaml` file.
@@ -125,7 +125,7 @@ service:
The entire Service YAML can be found here: [service.yaml](https://github.com/skypilot-org/skypilot/tree/master/llm/vllm/service.yaml).
-2. Start serving by using [SkyServe](https://skypilot.readthedocs.io/en/latest/serving/sky-serve.html) CLI:
+2. Start serving by using [SkyServe](https://docs.skypilot.co/en/latest/serving/sky-serve.html) CLI:
```bash
sky serve up -n vllm-llama2 service.yaml
```
diff --git a/llm/yi/README.md b/llm/yi/README.md
index 1353320aa9f..b9d5c4a761d 100644
--- a/llm/yi/README.md
+++ b/llm/yi/README.md
@@ -19,7 +19,7 @@
## Running Yi model with SkyPilot
-After [installing SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), run your own Yi model on vLLM with SkyPilot in 1-click:
+After [installing SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html), run your own Yi model on vLLM with SkyPilot in 1-click:
1. Start serving Yi-1.5 34B on a single instance with any available GPU in the list specified in [yi15-34b.yaml](https://github.com/skypilot-org/skypilot/blob/master/llm/yi/yi15-34b.yaml) with a vLLM powered OpenAI-compatible endpoint (You can also switch to [yicoder-9b.yaml](https://github.com/skypilot-org/skypilot/blob/master/llm/yi/yicoder-9b.yaml) or [other model](https://github.com/skypilot-org/skypilot/tree/master/llm/yi) for a smaller model):
diff --git a/sky/adaptors/cloudflare.py b/sky/adaptors/cloudflare.py
index 864248614f3..e9c5613c97e 100644
--- a/sky/adaptors/cloudflare.py
+++ b/sky/adaptors/cloudflare.py
@@ -177,7 +177,7 @@ def check_credentials() -> Tuple[bool, Optional[str]]:
hints += f'\n{_INDENT_PREFIX} $ mkdir -p ~/.cloudflare'
hints += f'\n{_INDENT_PREFIX} $ echo > ~/.cloudflare/accountid' # pylint: disable=line-too-long
hints += f'\n{_INDENT_PREFIX}For more info: '
- hints += 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#cloudflare-r2' # pylint: disable=line-too-long
+ hints += 'https://docs.skypilot.co/en/latest/getting-started/installation.html#cloudflare-r2' # pylint: disable=line-too-long
return (False, hints) if hints else (True, hints)
diff --git a/sky/adaptors/do.py b/sky/adaptors/do.py
new file mode 100644
index 00000000000..d619efebc1c
--- /dev/null
+++ b/sky/adaptors/do.py
@@ -0,0 +1,20 @@
+"""Digital Ocean cloud adaptors"""
+
+# pylint: disable=import-outside-toplevel
+
+from sky.adaptors import common
+
+_IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for DO. '
+ 'Try pip install "skypilot[do]"')
+pydo = common.LazyImport('pydo', import_error_message=_IMPORT_ERROR_MESSAGE)
+azure = common.LazyImport('azure', import_error_message=_IMPORT_ERROR_MESSAGE)
+_LAZY_MODULES = (pydo, azure)
+
+
+# `pydo`` inherits Azure exceptions. See:
+# https://github.com/digitalocean/pydo/blob/7b01498d99eb0d3a772366b642e5fab3d6fc6aa2/examples/poc_droplets_volumes_sshkeys.py#L6
+@common.load_lazy_modules(modules=_LAZY_MODULES)
+def exceptions():
+ """Azure exceptions."""
+ from azure.core import exceptions as azure_exceptions
+ return azure_exceptions
diff --git a/sky/adaptors/oci.py b/sky/adaptors/oci.py
index 7a5fafa854a..31712de414f 100644
--- a/sky/adaptors/oci.py
+++ b/sky/adaptors/oci.py
@@ -1,8 +1,17 @@
"""Oracle OCI cloud adaptor"""
+import functools
+import logging
import os
from sky.adaptors import common
+from sky.clouds.utils import oci_utils
+
+# Suppress OCI circuit breaker logging before lazy import, because
+# oci modules prints additional message during imports, i.e., the
+# set_logger in the LazyImport called after imports will not take
+# effect.
+logging.getLogger('oci.circuit_breaker').setLevel(logging.WARNING)
CONFIG_PATH = '~/.oci/config'
ENV_VAR_OCI_CONFIG = 'OCI_CONFIG'
@@ -23,10 +32,16 @@ def get_config_file() -> str:
def get_oci_config(region=None, profile='DEFAULT'):
conf_file_path = get_config_file()
+ if not profile or profile == 'DEFAULT':
+ config_profile = oci_utils.oci_config.get_profile()
+ else:
+ config_profile = profile
+
oci_config = oci.config.from_file(file_location=conf_file_path,
- profile_name=profile)
+ profile_name=config_profile)
if region is not None:
oci_config['region'] = region
+
return oci_config
@@ -47,6 +62,29 @@ def get_identity_client(region=None, profile='DEFAULT'):
return oci.identity.IdentityClient(get_oci_config(region, profile))
+def get_object_storage_client(region=None, profile='DEFAULT'):
+ return oci.object_storage.ObjectStorageClient(
+ get_oci_config(region, profile))
+
+
def service_exception():
"""OCI service exception."""
return oci.exceptions.ServiceError
+
+
+def with_oci_env(f):
+
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ # pylint: disable=line-too-long
+ enter_env_cmds = [
+ 'conda info --envs | grep "sky-oci-cli-env" || conda create -n sky-oci-cli-env python=3.10 -y',
+ '. $(conda info --base 2> /dev/null)/etc/profile.d/conda.sh > /dev/null 2>&1 || true',
+ 'conda activate sky-oci-cli-env', 'pip install oci-cli',
+ 'export OCI_CLI_SUPPRESS_FILE_PERMISSIONS_WARNING=True'
+ ]
+ operation_cmd = [f(*args, **kwargs)]
+ leave_env_cmds = ['conda deactivate']
+ return ' && '.join(enter_env_cmds + operation_cmd + leave_env_cmds)
+
+ return wrapper
diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py
index 62a5aa778a6..4b55bfa7f49 100644
--- a/sky/backends/backend_utils.py
+++ b/sky/backends/backend_utils.py
@@ -171,6 +171,16 @@
('available_node_types', 'ray.head.default', 'node_config',
'azure_arm_parameters', 'cloudInitSetupCommands'),
]
+# These keys are expected to change when provisioning on an existing cluster,
+# but they don't actually represent a change that requires re-provisioning the
+# cluster. If the cluster yaml is the same except for these keys, we can safely
+# skip reprovisioning. See _deterministic_cluster_yaml_hash.
+_RAY_YAML_KEYS_TO_REMOVE_FOR_HASH = [
+ # On first launch, availability_zones will include all possible zones. Once
+ # the cluster exists, it will only include the zone that the cluster is
+ # actually in.
+ ('provider', 'availability_zone'),
+]
def is_ip(s: str) -> bool:
@@ -463,6 +473,42 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
return common_utils.dump_yaml_str(new_config)
+def get_expirable_clouds(
+ enabled_clouds: Sequence[clouds.Cloud]) -> List[clouds.Cloud]:
+ """Returns a list of clouds that use local credentials and whose credentials can expire.
+
+ This function checks each cloud in the provided sequence to determine if it uses local credentials
+ and if its credentials can expire. If both conditions are met, the cloud is added to the list of
+ expirable clouds.
+
+ Args:
+ enabled_clouds (Sequence[clouds.Cloud]): A sequence of cloud objects to check.
+
+ Returns:
+ list[clouds.Cloud]: A list of cloud objects that use local credentials and whose credentials can expire.
+ """
+ expirable_clouds = []
+ local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value
+ for cloud in enabled_clouds:
+ remote_identities = skypilot_config.get_nested(
+ (str(cloud).lower(), 'remote_identity'), None)
+ if remote_identities is None:
+ remote_identities = schemas.get_default_remote_identity(
+ str(cloud).lower())
+
+ local_credential_expiring = cloud.can_credential_expire()
+ if isinstance(remote_identities, str):
+ if remote_identities == local_credentials_value and local_credential_expiring:
+ expirable_clouds.append(cloud)
+ elif isinstance(remote_identities, list):
+ for profile in remote_identities:
+ if list(profile.values(
+ ))[0] == local_credentials_value and local_credential_expiring:
+ expirable_clouds.append(cloud)
+ break
+ return expirable_clouds
+
+
# TODO: too many things happening here - leaky abstraction. Refactor.
@timeline.event
def write_cluster_config(
@@ -740,6 +786,13 @@ def write_cluster_config(
tmp_yaml_path,
cluster_config_overrides=to_provision.cluster_config_overrides)
kubernetes_utils.combine_metadata_fields(tmp_yaml_path)
+ yaml_obj = common_utils.read_yaml(tmp_yaml_path)
+ pod_config = yaml_obj['available_node_types']['ray_head_default'][
+ 'node_config']
+ valid, message = kubernetes_utils.check_pod_config(pod_config)
+ if not valid:
+ raise exceptions.InvalidCloudConfigs(
+ f'Invalid pod_config. Details: {message}')
if dryrun:
# If dryrun, return the unfinished tmp yaml path.
@@ -814,6 +867,7 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
clouds.Cudo,
clouds.Paperspace,
clouds.Azure,
+ clouds.DO,
)):
config = auth.configure_ssh_info(config)
elif isinstance(cloud, clouds.GCP):
@@ -833,10 +887,6 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
common_utils.dump_yaml(cluster_config_file, config)
-def get_run_timestamp() -> str:
- return 'sky-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
-
-
def get_timestamp_from_run_timestamp(run_timestamp: str) -> float:
return datetime.strptime(
run_timestamp.partition('-')[2], '%Y-%m-%d-%H-%M-%S-%f').timestamp()
@@ -911,7 +961,7 @@ def _deterministic_cluster_yaml_hash(yaml_path: str) -> str:
yaml file and all the files in the file mounts, then hash the byte sequence.
The format of the byte sequence is:
- 32 bytes - sha256 hash of the yaml file
+ 32 bytes - sha256 hash of the yaml
for each file mount:
file mount remote destination (UTF-8), \0
if the file mount source is a file:
@@ -935,14 +985,29 @@ def _deterministic_cluster_yaml_hash(yaml_path: str) -> str:
we construct it incrementally by using hash.update() to add new bytes.
"""
+ # Load the yaml contents so that we can directly remove keys.
+ yaml_config = common_utils.read_yaml(yaml_path)
+ for key_list in _RAY_YAML_KEYS_TO_REMOVE_FOR_HASH:
+ dict_to_remove_from = yaml_config
+ found_key = True
+ for key in key_list[:-1]:
+ if (not isinstance(dict_to_remove_from, dict) or
+ key not in dict_to_remove_from):
+ found_key = False
+ break
+ dict_to_remove_from = dict_to_remove_from[key]
+ if found_key and key_list[-1] in dict_to_remove_from:
+ dict_to_remove_from.pop(key_list[-1])
+
def _hash_file(path: str) -> bytes:
return common_utils.hash_file(path, 'sha256').digest()
config_hash = hashlib.sha256()
- config_hash.update(_hash_file(yaml_path))
+ yaml_hash = hashlib.sha256(
+ common_utils.dump_yaml_str(yaml_config).encode('utf-8'))
+ config_hash.update(yaml_hash.digest())
- yaml_config = common_utils.read_yaml(yaml_path)
file_mounts = yaml_config.get('file_mounts', {})
# Remove the file mounts added by the newline.
if '' in file_mounts:
@@ -950,6 +1015,11 @@ def _hash_file(path: str) -> bytes:
file_mounts.pop('')
for dst, src in sorted(file_mounts.items()):
+ if src == yaml_path:
+ # Skip the yaml file itself. We have already hashed a modified
+ # version of it. The file may include fields we don't want to hash.
+ continue
+
expanded_src = os.path.expanduser(src)
config_hash.update(dst.encode('utf-8') + b'\0')
diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py
index 518389a5b9d..3cd8116688d 100644
--- a/sky/backends/cloud_vm_ray_backend.py
+++ b/sky/backends/cloud_vm_ray_backend.py
@@ -25,6 +25,7 @@
import sky
from sky import backends
+from sky import check as sky_check
from sky import cloud_stores
from sky import clouds
from sky import exceptions
@@ -182,6 +183,7 @@ def _get_cluster_config_template(cloud):
clouds.SCP: 'scp-ray.yml.j2',
clouds.OCI: 'oci-ray.yml.j2',
clouds.Paperspace: 'paperspace-ray.yml.j2',
+ clouds.DO: 'do-ray.yml.j2',
clouds.RunPod: 'runpod-ray.yml.j2',
clouds.Kubernetes: 'kubernetes-ray.yml.j2',
clouds.Vsphere: 'vsphere-ray.yml.j2',
@@ -1096,7 +1098,7 @@ def _gcp_handler(blocked_resources: Set['resources_lib.Resources'],
'having the required permissions and the user '
'account does not have enough permission to '
'update it. Please contact your administrator and '
- 'check out: https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html\n' # pylint: disable=line-too-long
+ 'check out: https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/gcp.html\n' # pylint: disable=line-too-long
f'Details: {message}')
_add_to_blocked_resources(
blocked_resources,
@@ -1393,8 +1395,7 @@ def _retry_zones(
f'in {to_provision.cloud}. '
f'{colorama.Style.RESET_ALL}'
f'To request quotas, check the instruction: '
- f'https://skypilot.readthedocs.io/en/latest/cloud-setup/quota.html.' # pylint: disable=line-too-long
- )
+ f'https://docs.skypilot.co/en/latest/cloud-setup/quota.html.')
for zones in self._yield_zones(to_provision, num_nodes, cluster_name,
prev_cluster_status,
@@ -2004,6 +2005,22 @@ def provision_with_retries(
skip_unnecessary_provisioning else None)
failover_history: List[Exception] = list()
+ # If the user is using local credentials which may expire, the
+ # controller may leak resources if the credentials expire while a job
+ # is running. Here we check the enabled clouds and expiring credentials
+ # and raise a warning to the user.
+ if task.is_controller_task():
+ enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh()
+ expirable_clouds = backend_utils.get_expirable_clouds(
+ enabled_clouds)
+
+ if len(expirable_clouds) > 0:
+ warnings = (f'\033[93mWarning: Credentials used for '
+ f'{expirable_clouds} may expire. Clusters may be '
+ f'leaked if the credentials expire while jobs '
+ f'are running. It is recommended to use credentials'
+ f' that never expire or a service account.\033[0m')
+ logger.warning(warnings)
# Retrying launchable resources.
while True:
@@ -2632,7 +2649,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
ResourceHandle = CloudVmRayResourceHandle # pylint: disable=invalid-name
def __init__(self):
- self.run_timestamp = backend_utils.get_run_timestamp()
+ self.run_timestamp = sky_logging.get_run_timestamp()
# NOTE: do not expanduser() here, as this '~/...' path is used for
# remote as well to be expanded on the remote side.
self.log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY,
@@ -2659,7 +2676,7 @@ def register_info(self, **kwargs) -> None:
self._optimize_target) or common.OptimizeTarget.COST
self._requested_features = kwargs.pop('requested_features',
self._requested_features)
- assert len(kwargs) == 0, f'Unexpected kwargs: {kwargs}'
+ assert not kwargs, f'Unexpected kwargs: {kwargs}'
def check_resources_fit_cluster(
self,
diff --git a/sky/backends/wheel_utils.py b/sky/backends/wheel_utils.py
index ed580569e0b..805117ee2a3 100644
--- a/sky/backends/wheel_utils.py
+++ b/sky/backends/wheel_utils.py
@@ -153,7 +153,10 @@ def _get_latest_modification_time(path: pathlib.Path) -> float:
if not path.exists():
return -1.
try:
- return max(os.path.getmtime(root) for root, _, _ in os.walk(path))
+ return max(
+ os.path.getmtime(os.path.join(root, f))
+ for root, dirs, files in os.walk(path)
+ for f in (*dirs, *files))
except ValueError:
return -1.
diff --git a/sky/benchmark/benchmark_utils.py b/sky/benchmark/benchmark_utils.py
index 24cb3cbbe13..ceb76e69cb1 100644
--- a/sky/benchmark/benchmark_utils.py
+++ b/sky/benchmark/benchmark_utils.py
@@ -537,7 +537,7 @@ def launch_benchmark_clusters(benchmark: str, clusters: List[str],
for yaml_fd, cluster in zip(yaml_fds, clusters)]
# Save stdout/stderr from cluster launches.
- run_timestamp = backend_utils.get_run_timestamp()
+ run_timestamp = sky_logging.get_run_timestamp()
log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp)
log_dir = os.path.expanduser(log_dir)
logger.info(
diff --git a/sky/check.py b/sky/check.py
index 8989110b253..848bf874001 100644
--- a/sky/check.py
+++ b/sky/check.py
@@ -132,7 +132,7 @@ def get_all_clouds():
'\nNote: The following clouds were disabled because they were not '
'included in allowed_clouds in ~/.sky/config.yaml: '
f'{", ".join([c for c in disallowed_cloud_names])}')
- if len(all_enabled_clouds) == 0:
+ if not all_enabled_clouds:
echo(
click.style(
'No cloud is enabled. SkyPilot will not be able to run any '
@@ -151,7 +151,7 @@ def get_all_clouds():
dim=True) + click.style(f'sky check{clouds_arg}', bold=True) +
'\n' + click.style(
'If any problems remain, refer to detailed docs at: '
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html', # pylint: disable=line-too-long
dim=True))
if disallowed_clouds_hint:
diff --git a/sky/cli.py b/sky/cli.py
index b0323601305..0c17247a4e0 100644
--- a/sky/cli.py
+++ b/sky/cli.py
@@ -180,7 +180,7 @@ def _get_glob_storages(storages: List[str]) -> List[str]:
glob_storages = []
for storage_object in storages:
glob_storage = global_user_state.get_glob_storage_name(storage_object)
- if len(glob_storage) == 0:
+ if not glob_storage:
click.echo(f'Storage {storage_object} not found.')
glob_storages.extend(glob_storage)
return list(set(glob_storages))
@@ -837,7 +837,7 @@ class _NaturalOrderGroup(click.Group):
Reference: https://github.com/pallets/click/issues/513
"""
- def list_commands(self, ctx):
+ def list_commands(self, ctx): # pylint: disable=unused-argument
return self.commands.keys()
@usage_lib.entrypoint('sky.api.cli', fallback=True)
@@ -1009,7 +1009,10 @@ def cli():
'backend_name',
flag_value=backends.LocalDockerBackend.NAME,
default=False,
- help='If used, runs locally inside a docker container.')
+ hidden=True,
+ help=('(Deprecated) Local docker support is deprecated. '
+ 'To run locally, create a local Kubernetes cluster with '
+ '``sky local up``.'))
@_add_click_options(_TASK_OPTIONS_WITH_NAME + _EXTRA_RESOURCES_OPTIONS +
_COMMON_OPTIONS)
@click.option(
@@ -1158,6 +1161,11 @@ def launch(
backend: backends.Backend
if backend_name == backends.LocalDockerBackend.NAME:
backend = backends.LocalDockerBackend()
+ click.secho(
+ 'WARNING: LocalDockerBackend is deprecated and will be '
+ 'removed in a future release. To run locally, create a local '
+ 'Kubernetes cluster with `sky local up`.',
+ fg='yellow')
elif backend_name == backends.CloudVmRayBackend.NAME:
backend = backends.CloudVmRayBackend()
else:
@@ -1493,7 +1501,7 @@ def _get_services(service_names: Optional[List[str]],
if len(service_records) != 1:
plural = 's' if len(service_records) > 1 else ''
service_num = (str(len(service_records))
- if len(service_records) > 0 else 'No')
+ if service_records else 'No')
raise click.UsageError(
f'{service_num} service{plural} found. Please specify '
'an existing service to show its endpoint. Usage: '
@@ -1724,8 +1732,7 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
if len(clusters) != 1:
with ux_utils.print_exception_no_traceback():
plural = 's' if len(clusters) > 1 else ''
- cluster_num = (str(len(clusters))
- if len(clusters) > 0 else 'No')
+ cluster_num = (str(len(clusters)) if clusters else 'No')
cause = 'a single' if len(clusters) > 1 else 'an existing'
raise ValueError(
_STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format(
@@ -1753,9 +1760,8 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
with ux_utils.print_exception_no_traceback():
plural = 's' if len(cluster_records) > 1 else ''
cluster_num = (str(len(cluster_records))
- if len(cluster_records) > 0 else
- f'{clusters[0]!r}')
- verb = 'found' if len(cluster_records) > 0 else 'not found'
+ if cluster_records else f'{clusters[0]!r}')
+ verb = 'found' if cluster_records else 'not found'
cause = 'a single' if len(clusters) > 1 else 'an existing'
raise ValueError(
_STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format(
@@ -2581,7 +2587,7 @@ def start(
'(see `sky status`), or the -a/--all flag.')
if all:
- if len(clusters) > 0:
+ if clusters:
click.echo('Both --all and cluster(s) specified for sky start. '
'Letting --all take effect.')
@@ -2939,7 +2945,7 @@ def _down_or_stop_clusters(
operation = f'{verb} auto{option_str} on'
names = list(names)
- if len(names) > 0:
+ if names:
controllers = [
name for name in names
if controller_utils.Controllers.from_name(name) is not None
@@ -2954,7 +2960,7 @@ def _down_or_stop_clusters(
# Make sure the controllers are explicitly specified without other
# normal clusters.
if controllers:
- if len(names) != 0:
+ if names:
names_str = ', '.join(map(repr, names))
raise click.UsageError(
f'{operation} controller(s) '
@@ -3008,7 +3014,7 @@ def _down_or_stop_clusters(
if apply_to_all or all_users:
all_clusters = _get_cluster_records_and_set_ssh_config(
clusters=None, all_users=all_users)
- if len(names) > 0:
+ if names:
click.echo(
f'Both --all and cluster(s) specified for `sky {command}`. '
'Letting --all take effect.')
@@ -3027,7 +3033,7 @@ def _down_or_stop_clusters(
click.echo('Cluster(s) not found (tip: see `sky status`).')
return
- if not no_confirm and len(clusters) > 0:
+ if not no_confirm and clusters:
cluster_str = 'clusters' if len(clusters) > 1 else 'cluster'
cluster_list = ', '.join(clusters)
click.confirm(
@@ -3439,7 +3445,7 @@ def _output():
for tpu in service_catalog.get_tpus():
if tpu in result:
tpu_table.add_row([tpu, _list_to_str(result.pop(tpu))])
- if len(tpu_table.get_string()) > 0:
+ if tpu_table.get_string():
yield '\n\n'
yield from tpu_table.get_string()
@@ -3551,7 +3557,7 @@ def _output():
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Cloud GPUs{colorama.Style.RESET_ALL}\n')
- if len(result) == 0:
+ if not result:
quantity_str = (f' with requested quantity {quantity}'
if quantity else '')
cloud_str = f' on {cloud_obj}.' if cloud_name else ' in cloud catalogs.'
@@ -3683,7 +3689,7 @@ def storage_delete(names: List[str], all: bool, yes: bool, async_call: bool): #
# Delete all storage objects.
sky storage delete -a
"""
- if sum([len(names) > 0, all]) != 1:
+ if sum([bool(names), all]) != 1:
raise click.UsageError('Either --all or a name must be specified.')
if all:
storages = sdk.get(sdk.storage_ls())
@@ -3765,15 +3771,12 @@ def jobs():
default=False,
required=False,
help='Skip confirmation prompt.')
-# TODO(cooperc): remove this flag once --fast can robustly detect cluster
-# yaml config changes
+# TODO(cooperc): remove this flag before releasing 0.8.0
@click.option('--fast',
default=False,
is_flag=True,
- help='[Experimental] Launch the job faster by skipping '
- 'controller initialization steps. If you update SkyPilot or '
- 'your local cloud credentials, they will not be reflected until '
- 'you run `sky jobs launch` at least once without this flag.')
+ help=('[Deprecated] Does nothing. Previous flag behavior is now '
+ 'enabled by default.'))
@timeline.event
@usage_lib.entrypoint
def jobs_launch(
@@ -3798,7 +3801,7 @@ def jobs_launch(
disk_tier: Optional[str],
ports: Tuple[str],
detach_run: bool,
- retry_until_up: bool,
+ retry_until_up: Optional[bool],
yes: bool,
fast: bool,
async_call: bool,
@@ -3857,6 +3860,16 @@ def jobs_launch(
else:
retry_until_up = True
+ # Deprecation. The default behavior is fast, and the flag will be removed.
+ # The flag was not present in 0.7.x (only nightly), so we will remove before
+ # 0.8.0 so that it never enters a stable release.
+ if fast:
+ click.secho(
+ 'Flag --fast is deprecated, as the behavior is now default. The '
+ 'flag will be removed soon. Please do not use it, so that you '
+ 'avoid "No such option" errors.',
+ fg='yellow')
+
if not isinstance(task_or_dag, sky.Dag):
assert isinstance(task_or_dag, sky.Task), task_or_dag
with sky.Dag() as dag:
@@ -4022,8 +4035,8 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
$ sky jobs cancel 1 2 3
"""
job_id_str = ','.join(map(str, job_ids))
- if sum([len(job_ids) > 0, name is not None, all]) != 1:
- argument_str = f'--job-ids {job_id_str}' if len(job_ids) > 0 else ''
+ if sum([bool(job_ids), name is not None, all]) != 1:
+ argument_str = f'--job-ids {job_id_str}' if job_ids else ''
argument_str += f' --name {name}' if name is not None else ''
argument_str += ' --all' if all else ''
raise click.UsageError(
@@ -4601,9 +4614,9 @@ def serve_down(
# Forcefully tear down a specific replica, even in failed status.
sky serve down my-service --replica-id 1 --purge
"""
- if sum([len(service_names) > 0, all]) != 1:
- argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len(
- service_names) > 0 else ''
+ if sum([bool(service_names), all]) != 1:
+ argument_str = (f'SERVICE_NAMES={",".join(service_names)}'
+ if service_names else '')
argument_str += ' --all' if all else ''
raise click.UsageError(
'Can only specify one of SERVICE_NAMES or --all. '
@@ -4977,7 +4990,7 @@ def benchmark_launch(
if idle_minutes_to_autostop is None:
idle_minutes_to_autostop = 5
commandline_args['idle-minutes-to-autostop'] = idle_minutes_to_autostop
- if len(env) > 0:
+ if env:
commandline_args['env'] = [f'{k}={v}' for k, v in env]
# Launch the benchmarking clusters in detach mode in parallel.
@@ -5253,7 +5266,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool],
raise click.BadParameter(
'Either specify benchmarks or use --all to delete all benchmarks.')
to_delete = []
- if len(benchmarks) > 0:
+ if benchmarks:
for benchmark in benchmarks:
record = benchmark_state.get_benchmark_from_name(benchmark)
if record is None:
@@ -5262,7 +5275,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool],
to_delete.append(record)
if all:
to_delete = benchmark_state.get_benchmarks()
- if len(benchmarks) > 0:
+ if benchmarks:
print('Both --all and benchmark(s) specified '
'for sky bench delete. Letting --all take effect.')
@@ -5368,7 +5381,7 @@ def _deploy_local_cluster(gpus: bool):
run_command = shlex.split(run_command)
# Setup logging paths
- run_timestamp = backend_utils.get_run_timestamp()
+ run_timestamp = sky_logging.get_run_timestamp()
log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp,
'local_up.log')
tail_cmd = 'tail -n100 -f ' + log_path
@@ -5486,7 +5499,7 @@ def _deploy_remote_cluster(ip_file: str, ssh_user: str, ssh_key_path: str,
deploy_command = shlex.split(deploy_command)
# Setup logging paths
- run_timestamp = backend_utils.get_run_timestamp()
+ run_timestamp = sky_logging.get_run_timestamp()
log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp,
'local_up.log')
tail_cmd = 'tail -n100 -f ' + log_path
diff --git a/sky/cloud_stores.py b/sky/cloud_stores.py
index 9893cc9eecf..48b46d6ad2c 100644
--- a/sky/cloud_stores.py
+++ b/sky/cloud_stores.py
@@ -7,6 +7,7 @@
* Better interface.
* Better implementation (e.g., fsspec, smart_open, using each cloud's SDK).
"""
+import os
import shlex
import subprocess
import time
@@ -18,6 +19,7 @@
from sky.adaptors import azure
from sky.adaptors import cloudflare
from sky.adaptors import ibm
+from sky.adaptors import oci
from sky.clouds import gcp
from sky.data import data_utils
from sky.data.data_utils import Rclone
@@ -114,8 +116,16 @@ class GcsCloudStorage(CloudStorage):
@property
def _gsutil_command(self):
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
- return (f'{alias_gen}; GOOGLE_APPLICATION_CREDENTIALS='
- f'{gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH} {gsutil_alias}')
+ return (
+ f'{alias_gen}; GOOGLE_APPLICATION_CREDENTIALS='
+ f'{gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH}; '
+ # Explicitly activate service account. Unlike the gcp packages
+ # and other GCP commands, gsutil does not automatically pick up
+ # the default credential keys when it is a service account.
+ 'gcloud auth activate-service-account '
+ '--key-file=$GOOGLE_APPLICATION_CREDENTIALS '
+ '2> /dev/null || true; '
+ f'{gsutil_alias}')
def is_directory(self, url: str) -> bool:
"""Returns whether 'url' is a directory.
@@ -136,7 +146,7 @@ def is_directory(self, url: str) -> bool:
# If is a bucket root, then we only need `gsutil` to succeed
# to make sure the bucket exists. It is already a directory.
_, key = data_utils.split_gcs_path(url)
- if len(key) == 0:
+ if not key:
return True
# Otherwise, gsutil ls -d url will return:
# --> url.rstrip('/') if url is not a directory
@@ -476,6 +486,64 @@ def make_sync_file_command(self, source: str, destination: str) -> str:
return self.make_sync_dir_command(source, destination)
+class OciCloudStorage(CloudStorage):
+ """OCI Cloud Storage."""
+
+ def is_directory(self, url: str) -> bool:
+ """Returns whether OCI 'url' is a directory.
+ In cloud object stores, a "directory" refers to a regular object whose
+ name is a prefix of other objects.
+ """
+ bucket_name, path = data_utils.split_oci_path(url)
+
+ client = oci.get_object_storage_client()
+ namespace = client.get_namespace(
+ compartment_id=oci.get_oci_config()['tenancy']).data
+
+ objects = client.list_objects(namespace_name=namespace,
+ bucket_name=bucket_name,
+ prefix=path).data.objects
+
+ if len(objects) == 0:
+ # A directory with few or no items
+ return True
+
+ if len(objects) > 1:
+ # A directory with more than 1 items
+ return True
+
+ object_name = objects[0].name
+ if path.endswith(object_name):
+ # An object path
+ return False
+
+ # A directory with only 1 item
+ return True
+
+ @oci.with_oci_env
+ def make_sync_dir_command(self, source: str, destination: str) -> str:
+ """Downloads using OCI CLI."""
+ bucket_name, path = data_utils.split_oci_path(source)
+
+ download_via_ocicli = (f'oci os object sync --no-follow-symlinks '
+ f'--bucket-name {bucket_name} '
+ f'--prefix "{path}" --dest-dir "{destination}"')
+
+ return download_via_ocicli
+
+ @oci.with_oci_env
+ def make_sync_file_command(self, source: str, destination: str) -> str:
+ """Downloads a file using OCI CLI."""
+ bucket_name, path = data_utils.split_oci_path(source)
+ filename = os.path.basename(path)
+ destination = os.path.join(destination, filename)
+
+ download_via_ocicli = (f'oci os object get --bucket-name {bucket_name} '
+ f'--name "{path}" --file "{destination}"')
+
+ return download_via_ocicli
+
+
def get_storage_from_path(url: str) -> CloudStorage:
"""Returns a CloudStorage by identifying the scheme:// in a URL."""
result = urllib.parse.urlsplit(url)
@@ -491,6 +559,7 @@ def get_storage_from_path(url: str) -> CloudStorage:
's3': S3CloudStorage(),
'r2': R2CloudStorage(),
'cos': IBMCosCloudStorage(),
+ 'oci': OciCloudStorage(),
# TODO: This is a hack, as Azure URL starts with https://, we should
# refactor the registry to be able to take regex, so that Azure blob can
# be identified with `https://(.*?)\.blob\.core\.windows\.net`
diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py
index 862abd34f65..ef42970c264 100644
--- a/sky/clouds/__init__.py
+++ b/sky/clouds/__init__.py
@@ -14,6 +14,7 @@
from sky.clouds.aws import AWS
from sky.clouds.azure import Azure
from sky.clouds.cudo import Cudo
+from sky.clouds.do import DO
from sky.clouds.fluidstack import Fluidstack
from sky.clouds.gcp import GCP
from sky.clouds.ibm import IBM
@@ -33,6 +34,7 @@
'Cudo',
'GCP',
'Lambda',
+ 'DO',
'Paperspace',
'SCP',
'RunPod',
diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py
index 6943227f009..6646c11708b 100644
--- a/sky/clouds/aws.py
+++ b/sky/clouds/aws.py
@@ -2,6 +2,8 @@
import enum
import fnmatch
import functools
+import hashlib
+import json
import os
import re
import subprocess
@@ -16,6 +18,7 @@
from sky import skypilot_config
from sky.adaptors import aws
from sky.clouds import service_catalog
+from sky.clouds.service_catalog import common as catalog_common
from sky.clouds.utils import aws_utils
from sky.skylet import constants
from sky.utils import common_utils
@@ -101,6 +104,24 @@ class AWSIdentityType(enum.Enum):
# region us-east-1 config-file ~/.aws/config
SHARED_CREDENTIALS_FILE = 'shared-credentials-file'
+ def can_credential_expire(self) -> bool:
+ """Check if the AWS identity type can expire.
+
+ SSO,IAM_ROLE and CONTAINER_ROLE are temporary credentials and refreshed
+ automatically. ENV and SHARED_CREDENTIALS_FILE are short-lived
+ credentials without refresh.
+ IAM ROLE:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+ SSO/Container-role refresh token:
+ https://docs.aws.amazon.com/solutions/latest/dea-api/auth-refreshtoken.html
+ """
+ # TODO(hong): Add a CLI based check for the expiration of the temporary
+ # credentials
+ expirable_types = {
+ AWSIdentityType.ENV, AWSIdentityType.SHARED_CREDENTIALS_FILE
+ }
+ return self in expirable_types
+
@registry.CLOUD_REGISTRY.register
class AWS(clouds.Cloud):
@@ -618,21 +639,17 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
'Failed to fetch the availability zones for the account '
f'{identity_str}. It is likely due to permission issues, please'
' check the minimal permission required for AWS: '
- 'https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/aws.html' # pylint: disable=
+ 'https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/aws.html' # pylint: disable=
f'\n{cls._INDENT_PREFIX}Details: '
f'{common_utils.format_exception(e, use_bracket=True)}')
return True, hints
@classmethod
def _current_identity_type(cls) -> Optional[AWSIdentityType]:
- proc = subprocess.run('aws configure list',
- shell=True,
- check=False,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- if proc.returncode != 0:
+ stdout = cls._aws_configure_list()
+ if stdout is None:
return None
- stdout = proc.stdout.decode()
+ output = stdout.decode()
# We determine the identity type by looking at the output of
# `aws configure list`. The output looks like:
@@ -647,10 +664,10 @@ def _current_identity_type(cls) -> Optional[AWSIdentityType]:
def _is_access_key_of_type(type_str: str) -> bool:
# The dot (.) does not match line separators.
- results = re.findall(fr'access_key.*{type_str}', stdout)
+ results = re.findall(fr'access_key.*{type_str}', output)
if len(results) > 1:
raise RuntimeError(
- f'Unexpected `aws configure list` output:\n{stdout}')
+ f'Unexpected `aws configure list` output:\n{output}')
return len(results) == 1
if _is_access_key_of_type(AWSIdentityType.SSO.value):
@@ -665,37 +682,20 @@ def _is_access_key_of_type(type_str: str) -> bool:
return AWSIdentityType.SHARED_CREDENTIALS_FILE
@classmethod
- @functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
- def get_user_identities(cls) -> Optional[List[List[str]]]:
- """Returns a [UserId, Account] list that uniquely identifies the user.
-
- These fields come from `aws sts get-caller-identity`. We permit the same
- actual user to:
-
- - switch between different root accounts (after which both elements
- of the list will be different) and have their clusters owned by
- each account be protected; or
-
- - within the same root account, switch between different IAM
- users, and treat [user_id=1234, account=A] and
- [user_id=4567, account=A] to be the *same*. Namely, switching
- between these IAM roles within the same root account will cause
- the first element of the returned list to differ, and will allow
- the same actual user to continue to interact with their clusters.
- Note: this is not 100% safe, since the IAM users can have very
- specific permissions, that disallow them to access the clusters
- but it is a reasonable compromise as that could be rare.
-
- Returns:
- A list of strings that uniquely identifies the user on this cloud.
- For identity check, we will fallback through the list of strings
- until we find a match, and print a warning if we fail for the
- first string.
+ @functools.lru_cache(maxsize=1)
+ def _aws_configure_list(cls) -> Optional[bytes]:
+ proc = subprocess.run('aws configure list',
+ shell=True,
+ check=False,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ if proc.returncode != 0:
+ return None
+ return proc.stdout
- Raises:
- exceptions.CloudUserIdentityError: if the user identity cannot be
- retrieved.
- """
+ @classmethod
+ @functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
+ def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]:
try:
sts = aws.client('sts')
# The caller identity contains 3 fields: UserId, Account, Arn.
@@ -774,6 +774,72 @@ def get_user_identities(cls) -> Optional[List[List[str]]]:
# automatic switching for AWS. Currently we only support one identity.
return [user_ids]
+ @classmethod
+ @functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
+ def get_user_identities(cls) -> Optional[List[List[str]]]:
+ """Returns a [UserId, Account] list that uniquely identifies the user.
+
+ These fields come from `aws sts get-caller-identity` and are cached
+ locally by `aws configure list` output. The identities are assumed to
+ be stable for the duration of the `sky` process. Modifying the
+ credentials while the `sky` process is running will not affect the
+ identity returned by this function.
+
+ We permit the same actual user to:
+
+ - switch between different root accounts (after which both elements
+ of the list will be different) and have their clusters owned by
+ each account be protected; or
+
+ - within the same root account, switch between different IAM
+ users, and treat [user_id=1234, account=A] and
+ [user_id=4567, account=A] to be the *same*. Namely, switching
+ between these IAM roles within the same root account will cause
+ the first element of the returned list to differ, and will allow
+ the same actual user to continue to interact with their clusters.
+ Note: this is not 100% safe, since the IAM users can have very
+ specific permissions, that disallow them to access the clusters
+ but it is a reasonable compromise as that could be rare.
+
+ Returns:
+ A list of strings that uniquely identifies the user on this cloud.
+ For identity check, we will fallback through the list of strings
+ until we find a match, and print a warning if we fail for the
+ first string.
+
+ Raises:
+ exceptions.CloudUserIdentityError: if the user identity cannot be
+ retrieved.
+ """
+ stdout = cls._aws_configure_list()
+ if stdout is None:
+ # `aws configure list` is not available, possible reasons:
+ # - awscli is not installed but credentials are valid, e.g. run from
+ # an EC2 instance with IAM role
+ # - aws credentials are not set, proceed anyway to get unified error
+ # message for users
+ return cls._sts_get_caller_identity()
+ config_hash = hashlib.md5(stdout).hexdigest()[:8]
+ # Getting aws identity cost ~1s, so we cache the result with the output of
+ # `aws configure list` as cache key. Different `aws configure list` output
+ # can have same aws identity, our assumption is the output would be stable
+ # in real world, so the number of cache files would be limited.
+ # TODO(aylei): consider using a more stable cache key and evalute eviction.
+ cache_path = catalog_common.get_catalog_path(
+ f'aws/.cache/user-identity-{config_hash}.txt')
+ if os.path.exists(cache_path):
+ try:
+ with open(cache_path, 'r', encoding='utf-8') as f:
+ return json.loads(f.read())
+ except json.JSONDecodeError:
+ # cache is invalid, ignore it and fetch identity again
+ pass
+
+ result = cls._sts_get_caller_identity()
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(json.dumps(result))
+ return result
+
@classmethod
def get_active_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_active_user_identity()
@@ -813,6 +879,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
if os.path.exists(os.path.expanduser(f'~/.aws/{filename}'))
}
+ @functools.lru_cache(maxsize=1)
+ def can_credential_expire(self) -> bool:
+ identity_type = self._current_identity_type()
+ return identity_type is not None and identity_type.can_credential_expire(
+ )
+
def instance_type_exists(self, instance_type):
return service_catalog.instance_type_exists(instance_type, clouds='aws')
diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py
index 697451371f5..2fa08c5aa46 100644
--- a/sky/clouds/cloud.py
+++ b/sky/clouds/cloud.py
@@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
"""
raise NotImplementedError
+ def can_credential_expire(self) -> bool:
+ """Returns whether the cloud credential can expire."""
+ return False
+
@classmethod
def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
"""Check the image size from the cloud.
diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py
index 801768f7db0..b080f552b7c 100644
--- a/sky/clouds/cudo.py
+++ b/sky/clouds/cudo.py
@@ -43,8 +43,7 @@ class Cudo(clouds.Cloud):
f'{_INDENT_PREFIX} $ cudoctl init\n'
f'{_INDENT_PREFIX}For more info: '
# pylint: disable=line-too-long
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html'
- )
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html')
_PROJECT_HINT = (
'Create a project and then set it as the default project,:\n'
@@ -52,8 +51,7 @@ class Cudo(clouds.Cloud):
f'{_INDENT_PREFIX} $ cudoctl init\n'
f'{_INDENT_PREFIX}For more info: '
# pylint: disable=line-too-long
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html'
- )
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html')
_CLOUD_UNSUPPORTED_FEATURES = {
clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.',
diff --git a/sky/clouds/do.py b/sky/clouds/do.py
new file mode 100644
index 00000000000..cbad7d488a8
--- /dev/null
+++ b/sky/clouds/do.py
@@ -0,0 +1,304 @@
+""" Digital Ocean Cloud. """
+
+import json
+import typing
+from typing import Dict, Iterator, List, Optional, Tuple, Union
+
+from sky import clouds
+from sky.adaptors import do
+from sky.clouds import service_catalog
+from sky.provision.do import utils as do_utils
+from sky.utils import registry
+from sky.utils import resources_utils
+
+if typing.TYPE_CHECKING:
+ from sky import resources as resources_lib
+
+_CREDENTIAL_FILE = 'config.yaml'
+
+
+@registry.CLOUD_REGISTRY.register(aliases=['digitalocean'])
+class DO(clouds.Cloud):
+ """Digital Ocean Cloud"""
+
+ _REPR = 'DO'
+ _CLOUD_UNSUPPORTED_FEATURES = {
+ clouds.CloudImplementationFeatures.CLONE_DISK_FROM_CLUSTER:
+ 'Migrating '
+ f'disk is not supported in {_REPR}.',
+ clouds.CloudImplementationFeatures.SPOT_INSTANCE:
+ 'Spot instances are '
+ f'not supported in {_REPR}.',
+ clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER:
+ 'Custom disk tiers'
+ f' is not supported in {_REPR}.',
+ }
+ # DO maximum node name length defined as <= 255
+ # https://docs.digitalocean.com/reference/api/api-reference/#operation/droplets_create
+ # 255 - 8 = 247 characters since
+ # our provisioner adds additional `-worker`.
+ _MAX_CLUSTER_NAME_LEN_LIMIT = 247
+ _regions: List[clouds.Region] = []
+
+ # Using the latest SkyPilot provisioner API to provision and check status.
+ PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
+ STATUS_VERSION = clouds.StatusVersion.SKYPILOT
+
+ @classmethod
+ def _unsupported_features_for_resources(
+ cls, resources: 'resources_lib.Resources'
+ ) -> Dict[clouds.CloudImplementationFeatures, str]:
+ """The features not supported based on the resources provided.
+
+ This method is used by check_features_are_supported() to check if the
+ cloud implementation supports all the requested features.
+
+ Returns:
+ A dict of {feature: reason} for the features not supported by the
+ cloud implementation.
+ """
+ del resources # unused
+ return cls._CLOUD_UNSUPPORTED_FEATURES
+
+ @classmethod
+ def _max_cluster_name_length(cls) -> Optional[int]:
+ return cls._MAX_CLUSTER_NAME_LEN_LIMIT
+
+ @classmethod
+ def regions_with_offering(
+ cls,
+ instance_type: str,
+ accelerators: Optional[Dict[str, int]],
+ use_spot: bool,
+ region: Optional[str],
+ zone: Optional[str],
+ ) -> List[clouds.Region]:
+ assert zone is None, 'DO does not support zones.'
+ del accelerators, zone # unused
+ if use_spot:
+ return []
+ regions = service_catalog.get_region_zones_for_instance_type(
+ instance_type, use_spot, 'DO')
+ if region is not None:
+ regions = [r for r in regions if r.name == region]
+ return regions
+
+ @classmethod
+ def get_vcpus_mem_from_instance_type(
+ cls,
+ instance_type: str,
+ ) -> Tuple[Optional[float], Optional[float]]:
+ return service_catalog.get_vcpus_mem_from_instance_type(instance_type,
+ clouds='DO')
+
+ @classmethod
+ def zones_provision_loop(
+ cls,
+ *,
+ region: str,
+ num_nodes: int,
+ instance_type: str,
+ accelerators: Optional[Dict[str, int]] = None,
+ use_spot: bool = False,
+ ) -> Iterator[None]:
+ del num_nodes # unused
+ regions = cls.regions_with_offering(instance_type,
+ accelerators,
+ use_spot,
+ region=region,
+ zone=None)
+ for r in regions:
+ assert r.zones is None, r
+ yield r.zones
+
+ def instance_type_to_hourly_cost(
+ self,
+ instance_type: str,
+ use_spot: bool,
+ region: Optional[str] = None,
+ zone: Optional[str] = None,
+ ) -> float:
+ return service_catalog.get_hourly_cost(
+ instance_type,
+ use_spot=use_spot,
+ region=region,
+ zone=zone,
+ clouds='DO',
+ )
+
+ def accelerators_to_hourly_cost(
+ self,
+ accelerators: Dict[str, int],
+ use_spot: bool,
+ region: Optional[str] = None,
+ zone: Optional[str] = None,
+ ) -> float:
+ """Returns the hourly cost of the accelerators, in dollars/hour."""
+ # the acc price is include in the instance price.
+ del accelerators, use_spot, region, zone # unused
+ return 0.0
+
+ def get_egress_cost(self, num_gigabytes: float) -> float:
+ return 0.0
+
+ def __repr__(self):
+ return self._REPR
+
+ @classmethod
+ def get_default_instance_type(
+ cls,
+ cpus: Optional[str] = None,
+ memory: Optional[str] = None,
+ disk_tier: Optional[resources_utils.DiskTier] = None,
+ ) -> Optional[str]:
+ """Returns the default instance type for DO."""
+ return service_catalog.get_default_instance_type(cpus=cpus,
+ memory=memory,
+ disk_tier=disk_tier,
+ clouds='DO')
+
+ @classmethod
+ def get_accelerators_from_instance_type(
+ cls, instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
+ return service_catalog.get_accelerators_from_instance_type(
+ instance_type, clouds='DO')
+
+ @classmethod
+ def get_zone_shell_cmd(cls) -> Optional[str]:
+ return None
+
+ def make_deploy_resources_variables(
+ self,
+ resources: 'resources_lib.Resources',
+ cluster_name: resources_utils.ClusterName,
+ region: 'clouds.Region',
+ zones: Optional[List['clouds.Zone']],
+ num_nodes: int,
+ dryrun: bool = False) -> Dict[str, Optional[str]]:
+ del zones, dryrun, cluster_name
+
+ r = resources
+ acc_dict = self.get_accelerators_from_instance_type(r.instance_type)
+ if acc_dict is not None:
+ custom_resources = json.dumps(acc_dict, separators=(',', ':'))
+ else:
+ custom_resources = None
+ image_id = None
+ if (resources.image_id is not None and
+ resources.extract_docker_image() is None):
+ if None in resources.image_id:
+ image_id = resources.image_id[None]
+ else:
+ assert region.name in resources.image_id
+ image_id = resources.image_id[region.name]
+ return {
+ 'instance_type': resources.instance_type,
+ 'custom_resources': custom_resources,
+ 'region': region.name,
+ **({
+ 'image_id': image_id
+ } if image_id else {})
+ }
+
+ def _get_feasible_launchable_resources(
+ self, resources: 'resources_lib.Resources'
+ ) -> resources_utils.FeasibleResources:
+ """Returns a list of feasible resources for the given resources."""
+ if resources.use_spot:
+ # TODO: Add hints to all return values in this method to help
+ # users understand why the resources are not launchable.
+ return resources_utils.FeasibleResources([], [], None)
+ if resources.instance_type is not None:
+ assert resources.is_launchable(), resources
+ resources = resources.copy(accelerators=None)
+ return resources_utils.FeasibleResources([resources], [], None)
+
+ def _make(instance_list):
+ resource_list = []
+ for instance_type in instance_list:
+ r = resources.copy(
+ cloud=DO(),
+ instance_type=instance_type,
+ accelerators=None,
+ cpus=None,
+ )
+ resource_list.append(r)
+ return resource_list
+
+ # Currently, handle a filter on accelerators only.
+ accelerators = resources.accelerators
+ if accelerators is None:
+ # Return a default instance type
+ default_instance_type = DO.get_default_instance_type(
+ cpus=resources.cpus,
+ memory=resources.memory,
+ disk_tier=resources.disk_tier)
+ return resources_utils.FeasibleResources(
+ _make([default_instance_type]), [], None)
+
+ assert len(accelerators) == 1, resources
+ acc, acc_count = list(accelerators.items())[0]
+ (instance_list, fuzzy_candidate_list) = (
+ service_catalog.get_instance_type_for_accelerator(
+ acc,
+ acc_count,
+ use_spot=resources.use_spot,
+ cpus=resources.cpus,
+ memory=resources.memory,
+ region=resources.region,
+ zone=resources.zone,
+ clouds='DO',
+ ))
+ if instance_list is None:
+ return resources_utils.FeasibleResources([], fuzzy_candidate_list,
+ None)
+ return resources_utils.FeasibleResources(_make(instance_list),
+ fuzzy_candidate_list, None)
+
+ @classmethod
+ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
+ """Verify that the user has valid credentials for DO."""
+ try:
+ # attempt to make a CURL request for listing instances
+ do_utils.client().droplets.list()
+ except do.exceptions().HttpResponseError as err:
+ return False, str(err)
+ except do_utils.DigitalOceanError as err:
+ return False, str(err)
+
+ return True, None
+
+ def get_credential_file_mounts(self) -> Dict[str, str]:
+ try:
+ do_utils.client()
+ return {
+ f'~/.config/doctl/{_CREDENTIAL_FILE}': do_utils.CREDENTIALS_PATH
+ }
+ except do_utils.DigitalOceanError:
+ return {}
+
+ @classmethod
+ def get_current_user_identity(cls) -> Optional[List[str]]:
+ # NOTE: used for very advanced SkyPilot functionality
+ # Can implement later if desired
+ return None
+
+ @classmethod
+ def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
+ del region
+ try:
+ response = do_utils.client().images.get(image_id=image_id)
+ return response['image']['size_gigabytes']
+ except do.exceptions().HttpResponseError as err:
+ raise do_utils.DigitalOceanError(
+ 'HTTP error while retrieving size of '
+ f'image_id {response}: {err.error.message}') from err
+ except KeyError as err:
+ raise do_utils.DigitalOceanError(
+ f'No image_id `{image_id}` found') from err
+
+ def instance_type_exists(self, instance_type: str) -> bool:
+ return service_catalog.instance_type_exists(instance_type, 'DO')
+
+ def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
+ return service_catalog.validate_region_zone(region, zone, clouds='DO')
diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py
index f988ec1d29e..c22bde48c52 100644
--- a/sky/clouds/gcp.py
+++ b/sky/clouds/gcp.py
@@ -133,6 +133,9 @@ class GCPIdentityType(enum.Enum):
SHARED_CREDENTIALS_FILE = ''
+ def can_credential_expire(self) -> bool:
+ return self == GCPIdentityType.SHARED_CREDENTIALS_FILE
+
@registry.CLOUD_REGISTRY.register
class GCP(clouds.Cloud):
@@ -168,7 +171,7 @@ class GCP(clouds.Cloud):
# ~/.config/gcloud/application_default_credentials.json.
f'{_INDENT_PREFIX} $ gcloud auth application-default login\n'
f'{_INDENT_PREFIX}For more info: '
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long
)
_APPLICATION_CREDENTIAL_HINT = (
'Run the following commands:\n'
@@ -176,7 +179,7 @@ class GCP(clouds.Cloud):
f'{_INDENT_PREFIX}Or set the environment variable GOOGLE_APPLICATION_CREDENTIALS '
'to the path of your service account key file.\n'
f'{_INDENT_PREFIX}For more info: '
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long
)
_SUPPORTED_DISK_TIERS = set(resources_utils.DiskTier)
@@ -831,13 +834,13 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
ret_permissions = request.execute().get('permissions', [])
diffs = set(gcp_minimal_permissions).difference(set(ret_permissions))
- if len(diffs) > 0:
+ if diffs:
identity_str = identity[0] if identity else None
return False, (
'The following permissions are not enabled for the current '
f'GCP identity ({identity_str}):\n '
f'{diffs}\n '
- 'For more details, visit: https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html') # pylint: disable=line-too-long
+ 'For more details, visit: https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/gcp.html') # pylint: disable=line-too-long
return True, None
def get_credential_file_mounts(self) -> Dict[str, str]:
@@ -864,6 +867,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
pass
return credentials
+ @functools.lru_cache(maxsize=1)
+ def can_credential_expire(self) -> bool:
+ identity_type = self._get_identity_type()
+ return identity_type is not None and identity_type.can_credential_expire(
+ )
+
@classmethod
def _get_identity_type(cls) -> Optional[GCPIdentityType]:
try:
diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py
index 9ffaa549312..aa2bb37bbe3 100644
--- a/sky/clouds/kubernetes.py
+++ b/sky/clouds/kubernetes.py
@@ -140,7 +140,7 @@ def _existing_allowed_contexts(cls) -> List[str]:
use the service account mounted in the pod.
"""
all_contexts = kubernetes_utils.get_all_kube_context_names()
- if len(all_contexts) == 0:
+ if not all_contexts:
return []
all_contexts = set(all_contexts)
@@ -396,7 +396,7 @@ def make_deploy_resources_variables(
tpu_requested = True
k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY
else:
- k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY
+ k8s_resource_key = kubernetes_utils.get_gpu_resource_key()
port_mode = network_utils.get_port_mode(None)
diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py
index 4ac423a455e..80247777a3a 100644
--- a/sky/clouds/oci.py
+++ b/sky/clouds/oci.py
@@ -233,6 +233,14 @@ def make_deploy_resources_variables(
listing_id = None
res_ver = None
+ os_type = None
+ if ':' in image_id:
+ # OS type provided in the --image-id. This is the case where
+ # custom image's ocid provided in the --image-id parameter.
+ # - ocid1.image...aaa:oraclelinux (os type is oraclelinux)
+ # - ocid1.image...aaa (OS not provided)
+ image_id, os_type = image_id.replace(' ', '').split(':')
+
cpus = resources.cpus
instance_type_arr = resources.instance_type.split(
oci_utils.oci_config.INSTANCE_TYPE_RES_SPERATOR)
@@ -298,15 +306,18 @@ def make_deploy_resources_variables(
cpus=None if cpus is None else float(cpus),
disk_tier=resources.disk_tier)
- image_str = self._get_image_str(image_id=resources.image_id,
- instance_type=resources.instance_type,
- region=region.name)
+ if os_type is None:
+ # OS type is not determined yet. So try to get it from vms.csv
+ image_str = self._get_image_str(
+ image_id=resources.image_id,
+ instance_type=resources.instance_type,
+ region=region.name)
- # pylint: disable=import-outside-toplevel
- from sky.clouds.service_catalog import oci_catalog
- os_type = oci_catalog.get_image_os_from_tag(tag=image_str,
- region=region.name)
- logger.debug(f'OS type for the image {image_str} is {os_type}')
+ # pylint: disable=import-outside-toplevel
+ from sky.clouds.service_catalog import oci_catalog
+ os_type = oci_catalog.get_image_os_from_tag(tag=image_str,
+ region=region.name)
+ logger.debug(f'OS type for the image {image_id} is {os_type}')
return {
'instance_type': instance_type,
@@ -391,7 +402,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
short_credential_help_str = (
'For more details, refer to: '
# pylint: disable=line-too-long
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#oracle-cloud-infrastructure-oci'
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html#oracle-cloud-infrastructure-oci'
)
credential_help_str = (
'To configure credentials, go to: '
diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py
index 85f1ed45bdb..cd8a32fafe8 100644
--- a/sky/clouds/paperspace.py
+++ b/sky/clouds/paperspace.py
@@ -259,7 +259,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
return False, (
'Failed to access Paperspace Cloud with credentials.\n '
'To configure credentials, follow the instructions at: '
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#paperspace\n '
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html#paperspace\n '
'Generate API key and create a json at `~/.paperspace/config.json` with \n '
' {"apiKey": "[YOUR API KEY]"}\n '
f'Reason: {str(e)}')
diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py
index 77de4fa9dfb..c3c7143462b 100644
--- a/sky/clouds/runpod.py
+++ b/sky/clouds/runpod.py
@@ -254,7 +254,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
' Credentials can be set up by running: \n'
f' $ pip install runpod \n'
f' $ runpod config\n'
- ' For more information, see https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#runpod' # pylint: disable=line-too-long
+ ' For more information, see https://docs.skypilot.co/en/latest/getting-started/installation.html#runpod' # pylint: disable=line-too-long
)
return True, None
diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py
index d62b886d000..196c0bc12b4 100644
--- a/sky/clouds/service_catalog/common.py
+++ b/sky/clouds/service_catalog/common.py
@@ -275,9 +275,10 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9)
candidate_loc = sorted(candidate_loc)
candidate_strs = ''
- if len(candidate_loc) > 0:
+ if candidate_loc:
candidate_strs = ', '.join(candidate_loc)
candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?'
+
return candidate_strs
def _get_all_supported_regions_str() -> str:
@@ -291,7 +292,7 @@ def _get_all_supported_regions_str() -> str:
filter_df = df
if region is not None:
filter_df = _filter_region_zone(filter_df, region, zone=None)
- if len(filter_df) == 0:
+ if filter_df.empty:
with ux_utils.print_exception_no_traceback():
error_msg = (f'Invalid region {region!r}')
candidate_strs = _get_candidate_str(
@@ -301,7 +302,7 @@ def _get_all_supported_regions_str() -> str:
faq_msg = (
'\nIf a region is not included in the following '
'list, please check the FAQ docs for how to fetch '
- 'its catalog info.\nhttps://skypilot.readthedocs.io'
+ 'its catalog info.\nhttps://docs.skypilot.co'
'/en/latest/reference/faq.html#advanced-how-to-'
'make-skypilot-use-all-global-regions')
error_msg += faq_msg + _get_all_supported_regions_str()
@@ -315,7 +316,7 @@ def _get_all_supported_regions_str() -> str:
if zone is not None:
maybe_region_df = filter_df
filter_df = filter_df[filter_df['AvailabilityZone'] == zone]
- if len(filter_df) == 0:
+ if filter_df.empty:
region_str = f' for region {region!r}' if region else ''
df = maybe_region_df if region else df
with ux_utils.print_exception_no_traceback():
@@ -383,7 +384,7 @@ def get_vcpus_mem_from_instance_type_impl(
instance_type: str,
) -> Tuple[Optional[float], Optional[float]]:
df = _get_instance_type(df, instance_type, None)
- if len(df) == 0:
+ if df.empty:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'No instance type {instance_type} found.')
assert len(set(df['vCPUs'])) == 1, ('Cannot determine the number of vCPUs '
@@ -489,7 +490,7 @@ def get_accelerators_from_instance_type_impl(
instance_type: str,
) -> Optional[Dict[str, Union[int, float]]]:
df = _get_instance_type(df, instance_type, None)
- if len(df) == 0:
+ if df.empty:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'No instance type {instance_type} found.')
row = df.iloc[0]
@@ -523,7 +524,7 @@ def get_instance_type_for_accelerator_impl(
result = df[(df['AcceleratorName'].str.fullmatch(acc_name, case=False)) &
(abs(df['AcceleratorCount'] - acc_count) <= 0.01)]
result = _filter_region_zone(result, region, zone)
- if len(result) == 0:
+ if result.empty:
fuzzy_result = df[
(df['AcceleratorName'].str.contains(acc_name, case=False)) &
(df['AcceleratorCount'] >= acc_count)]
@@ -532,7 +533,7 @@ def get_instance_type_for_accelerator_impl(
fuzzy_result = fuzzy_result[['AcceleratorName',
'AcceleratorCount']].drop_duplicates()
fuzzy_candidate_list = []
- if len(fuzzy_result) > 0:
+ if not fuzzy_result.empty:
for _, row in fuzzy_result.iterrows():
acc_cnt = float(row['AcceleratorCount'])
acc_count_display = (int(acc_cnt) if acc_cnt.is_integer() else
@@ -544,7 +545,7 @@ def get_instance_type_for_accelerator_impl(
result = _filter_with_cpus(result, cpus)
result = _filter_with_mem(result, memory)
result = _filter_region_zone(result, region, zone)
- if len(result) == 0:
+ if result.empty:
return ([], [])
# Current strategy: choose the cheapest instance
@@ -685,7 +686,7 @@ def get_image_id_from_tag_impl(df: 'pd.DataFrame', tag: str,
df = _filter_region_zone(df, region, zone=None)
assert len(df) <= 1, ('Multiple images found for tag '
f'{tag} in region {region}')
- if len(df) == 0:
+ if df.empty:
return None
image_id = df['ImageId'].iloc[0]
if pd.isna(image_id):
@@ -699,4 +700,4 @@ def is_image_tag_valid_impl(df: 'pd.DataFrame', tag: str,
df = df[df['Tag'] == tag]
df = _filter_region_zone(df, region, zone=None)
df = df.dropna(subset=['ImageId'])
- return len(df) > 0
+ return not df.empty
diff --git a/sky/clouds/service_catalog/constants.py b/sky/clouds/service_catalog/constants.py
index a5c96d49f11..36be51a48d0 100644
--- a/sky/clouds/service_catalog/constants.py
+++ b/sky/clouds/service_catalog/constants.py
@@ -1,7 +1,7 @@
"""Constants used for service catalog."""
HOSTED_CATALOG_DIR_URL = 'https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/common-acc/catalogs' # pylint: disable=line-too-long
-CATALOG_SCHEMA_VERSION = 'v5'
+CATALOG_SCHEMA_VERSION = 'v6'
CATALOG_DIR = '~/.sky/catalogs'
ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
'kubernetes', 'runpod', 'vsphere', 'cudo', 'fluidstack',
- 'paperspace')
+ 'paperspace', 'do')
diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py
index 4aef41f9c90..00768d5c6bb 100644
--- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py
+++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py
@@ -134,7 +134,7 @@ def get_pricing_df(region: Optional[str] = None) -> 'pd.DataFrame':
content_str = r.content.decode('ascii')
content = json.loads(content_str)
items = content.get('Items', [])
- if len(items) == 0:
+ if not items:
break
all_items += items
url = content.get('NextPageLink')
diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py
index e0ec7f66042..570bc773d2e 100644
--- a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py
+++ b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py
@@ -476,9 +476,6 @@ def _get_gpus_for_zone(zone: str) -> 'pd.DataFrame':
gpu_name = gpu_name.upper()
if 'H100-80GB' in gpu_name:
gpu_name = 'H100'
- if count != 8:
- # H100 only has 8 cards.
- continue
if 'H100-MEGA-80GB' in gpu_name:
gpu_name = 'H100-MEGA'
if count != 8:
diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py b/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py
index 216e8ed9b4f..c08a56955a0 100644
--- a/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py
+++ b/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py
@@ -534,7 +534,7 @@ def initialize_images_csv(csv_saving_path: str, vc_object,
gpu_name = tag_name.split('-')[1]
if gpu_name not in gpu_tags:
gpu_tags.append(gpu_name)
- if len(gpu_tags) > 0:
+ if gpu_tags:
gpu_tags_str = str(gpu_tags).replace('\'', '\"')
f.write(f'{item.id},{vcenter_name},{item_cpu},{item_memory}'
f',,,\'{gpu_tags_str}\'\n')
diff --git a/sky/clouds/service_catalog/do_catalog.py b/sky/clouds/service_catalog/do_catalog.py
new file mode 100644
index 00000000000..40a53aa6bc4
--- /dev/null
+++ b/sky/clouds/service_catalog/do_catalog.py
@@ -0,0 +1,111 @@
+"""Digital ocean service catalog.
+
+This module loads the service catalog file and can be used to
+query instance types and pricing information for digital ocean.
+"""
+
+import typing
+from typing import Dict, List, Optional, Tuple, Union
+
+from sky.clouds.service_catalog import common
+from sky.utils import ux_utils
+
+if typing.TYPE_CHECKING:
+ from sky.clouds import cloud
+
+_df = common.read_catalog('do/vms.csv')
+
+
+def instance_type_exists(instance_type: str) -> bool:
+ return common.instance_type_exists_impl(_df, instance_type)
+
+
+def validate_region_zone(
+ region: Optional[str],
+ zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
+ if zone is not None:
+ with ux_utils.print_exception_no_traceback():
+ raise ValueError('DO does not support zones.')
+ return common.validate_region_zone_impl('DO', _df, region, zone)
+
+
+def get_hourly_cost(
+ instance_type: str,
+ use_spot: bool = False,
+ region: Optional[str] = None,
+ zone: Optional[str] = None,
+) -> float:
+ """Returns the cost, or the cheapest cost among all zones for spot."""
+ if zone is not None:
+ with ux_utils.print_exception_no_traceback():
+ raise ValueError('DO does not support zones.')
+ return common.get_hourly_cost_impl(_df, instance_type, use_spot, region,
+ zone)
+
+
+def get_vcpus_mem_from_instance_type(
+ instance_type: str,) -> Tuple[Optional[float], Optional[float]]:
+ return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type)
+
+
+def get_default_instance_type(
+ cpus: Optional[str] = None,
+ memory: Optional[str] = None,
+ disk_tier: Optional[str] = None,
+) -> Optional[str]:
+ # NOTE: After expanding catalog to multiple entries, you may
+ # want to specify a default instance type or family.
+ del disk_tier # unused
+ return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory)
+
+
+def get_accelerators_from_instance_type(
+ instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
+ return common.get_accelerators_from_instance_type_impl(_df, instance_type)
+
+
+def get_instance_type_for_accelerator(
+ acc_name: str,
+ acc_count: int,
+ cpus: Optional[str] = None,
+ memory: Optional[str] = None,
+ use_spot: bool = False,
+ region: Optional[str] = None,
+ zone: Optional[str] = None,
+) -> Tuple[Optional[List[str]], List[str]]:
+ """Returns a list of instance types that have the given accelerator."""
+ if zone is not None:
+ with ux_utils.print_exception_no_traceback():
+ raise ValueError('DO does not support zones.')
+ return common.get_instance_type_for_accelerator_impl(
+ df=_df,
+ acc_name=acc_name,
+ acc_count=acc_count,
+ cpus=cpus,
+ memory=memory,
+ use_spot=use_spot,
+ region=region,
+ zone=zone,
+ )
+
+
+def get_region_zones_for_instance_type(instance_type: str,
+ use_spot: bool) -> List['cloud.Region']:
+ df = _df[_df['InstanceType'] == instance_type]
+ return common.get_region_zones(df, use_spot)
+
+
+def list_accelerators(
+ gpus_only: bool,
+ name_filter: Optional[str],
+ region_filter: Optional[str],
+ quantity_filter: Optional[int],
+ case_sensitive: bool = True,
+ all_regions: bool = False,
+ require_price: bool = True,
+) -> Dict[str, List[common.InstanceTypeInfo]]:
+ """Returns all instance types in DO offering GPUs."""
+ del require_price # unused
+ return common.list_accelerators_impl('DO', _df, gpus_only, name_filter,
+ region_filter, quantity_filter,
+ case_sensitive, all_regions)
diff --git a/sky/clouds/service_catalog/gcp_catalog.py b/sky/clouds/service_catalog/gcp_catalog.py
index c9e15f602dc..a83e00d8196 100644
--- a/sky/clouds/service_catalog/gcp_catalog.py
+++ b/sky/clouds/service_catalog/gcp_catalog.py
@@ -97,6 +97,9 @@
8: ['g2-standard-96'],
},
'H100': {
+ 1: ['a3-highgpu-1g'],
+ 2: ['a3-highgpu-2g'],
+ 4: ['a3-highgpu-4g'],
8: ['a3-highgpu-8g'],
},
'H100-MEGA': {
@@ -289,7 +292,9 @@ def get_instance_type_for_accelerator(
if acc_name in _ACC_INSTANCE_TYPE_DICTS:
df = _df[_df['InstanceType'].notna()]
- instance_types = _ACC_INSTANCE_TYPE_DICTS[acc_name][acc_count]
+ instance_types = _ACC_INSTANCE_TYPE_DICTS[acc_name].get(acc_count, None)
+ if instance_types is None:
+ return None, []
df = df[df['InstanceType'].isin(instance_types)]
# Check the cpus and memory specified by the user.
diff --git a/sky/clouds/utils/oci_utils.py b/sky/clouds/utils/oci_utils.py
index fbd2888c708..f68e090e7ab 100644
--- a/sky/clouds/utils/oci_utils.py
+++ b/sky/clouds/utils/oci_utils.py
@@ -6,6 +6,12 @@
configuration.
- Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add the constant
SERVICE_PORT_RULE_TAG
+ - Hysun He (hysun.he@oracle.com) @ Jan.01, 2025: Set the default image
+ from ubuntu 20.04 to ubuntu 22.04, including:
+ - GPU: skypilot:gpu-ubuntu-2004 -> skypilot:gpu-ubuntu-2204
+ - CPU: skypilot:cpu-ubuntu-2004 -> skypilot:cpu-ubuntu-2204
+ - Hysun He (hysun.he@oracle.com) @ Jan.01, 2025: Support reuse existing
+ VCN for SkyServe.
"""
import os
@@ -105,8 +111,15 @@ def get_compartment(cls, region):
('oci', region, 'compartment_ocid'), default_compartment_ocid)
return compartment
+ @classmethod
+ def get_vcn_ocid(cls, region):
+ # Will reuse the regional VCN if specified.
+ vcn = skypilot_config.get_nested(('oci', region, 'vcn_ocid'), None)
+ return vcn
+
@classmethod
def get_vcn_subnet(cls, region):
+ # Will reuse the subnet if specified.
vcn = skypilot_config.get_nested(('oci', region, 'vcn_subnet'), None)
return vcn
@@ -117,7 +130,7 @@ def get_default_gpu_image_tag(cls) -> str:
# the sky's user-config file (if not specified, use the hardcode one at
# last)
return skypilot_config.get_nested(('oci', 'default', 'image_tag_gpu'),
- 'skypilot:gpu-ubuntu-2004')
+ 'skypilot:gpu-ubuntu-2204')
@classmethod
def get_default_image_tag(cls) -> str:
@@ -125,7 +138,7 @@ def get_default_image_tag(cls) -> str:
# set the default image tag in the sky's user-config file. (if not
# specified, use the hardcode one at last)
return skypilot_config.get_nested(
- ('oci', 'default', 'image_tag_general'), 'skypilot:cpu-ubuntu-2004')
+ ('oci', 'default', 'image_tag_general'), 'skypilot:cpu-ubuntu-2204')
@classmethod
def get_sky_user_config_file(cls) -> str:
diff --git a/sky/clouds/utils/scp_utils.py b/sky/clouds/utils/scp_utils.py
index 3e91e22e6d9..4efc79313c5 100644
--- a/sky/clouds/utils/scp_utils.py
+++ b/sky/clouds/utils/scp_utils.py
@@ -65,7 +65,7 @@ def __setitem__(self, instance_id: str, value: Optional[Dict[str,
if value is None:
if instance_id in metadata:
metadata.pop(instance_id) # del entry
- if len(metadata) == 0:
+ if not metadata:
if os.path.exists(self.path):
os.remove(self.path)
return
@@ -84,7 +84,7 @@ def refresh(self, instance_ids: List[str]) -> None:
for instance_id in list(metadata.keys()):
if instance_id not in instance_ids:
del metadata[instance_id]
- if len(metadata) == 0:
+ if not metadata:
os.remove(self.path)
return
with open(self.path, 'w', encoding='utf-8') as f:
@@ -410,7 +410,7 @@ def list_security_groups(self, vpc_id=None, sg_name=None):
parameter.append('vpcId=' + vpc_id)
if sg_name is not None:
parameter.append('securityGroupName=' + sg_name)
- if len(parameter) > 0:
+ if parameter:
url = url + '?' + '&'.join(parameter)
return self._get(url)
diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py
index 6cb6c0d93a8..243791cf578 100644
--- a/sky/clouds/vsphere.py
+++ b/sky/clouds/vsphere.py
@@ -267,7 +267,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
'Run the following commands:'
f'\n{cls._INDENT_PREFIX} $ pip install skypilot[vSphere]'
f'\n{cls._INDENT_PREFIX}Credentials may also need to be set. '
- 'For more details. See https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#vmware-vsphere' # pylint: disable=line-too-long
+ 'For more details. See https://docs.skypilot.co/en/latest/getting-started/installation.html#vmware-vsphere' # pylint: disable=line-too-long
f'{common_utils.format_exception(e, use_bracket=True)}')
required_keys = ['name', 'username', 'password', 'clusters']
diff --git a/sky/core.py b/sky/core.py
index 875f5c959bf..31789e4b770 100644
--- a/sky/core.py
+++ b/sky/core.py
@@ -842,7 +842,7 @@ def download_logs(
backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend
- if job_ids is not None and len(job_ids) == 0:
+ if job_ids is not None and not job_ids:
return {}
usage_lib.record_cluster_name_for_current_operation(cluster_name)
@@ -891,7 +891,7 @@ def job_status(cluster_name: str,
f'of type {backend.__class__.__name__!r}.')
assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
- if job_ids is not None and len(job_ids) == 0:
+ if job_ids is not None and not job_ids:
return {}
sky_logging.print(f'{colorama.Fore.YELLOW}'
@@ -1043,7 +1043,7 @@ def local_up(gpus: bool = False) -> None:
run_command = shlex.split(run_command)
# Setup logging paths
- run_timestamp = backend_utils.get_run_timestamp()
+ run_timestamp = sky_logging.get_run_timestamp()
log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp,
'local_up.log')
tail_cmd = 'tail -n100 -f ' + log_path
@@ -1152,7 +1152,7 @@ def local_down() -> None:
run_command = shlex.split(down_script_path)
# Setup logging paths
- run_timestamp = backend_utils.get_run_timestamp()
+ run_timestamp = sky_logging.get_run_timestamp()
log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp,
'local_down.log')
tail_cmd = 'tail -n100 -f ' + log_path
diff --git a/sky/data/data_transfer.py b/sky/data/data_transfer.py
index 374871031cb..3ccc6f8fc0e 100644
--- a/sky/data/data_transfer.py
+++ b/sky/data/data_transfer.py
@@ -200,3 +200,40 @@ def _add_bucket_iam_member(bucket_name: str, role: str, member: str) -> None:
bucket.set_iam_policy(policy)
logger.debug(f'Added {member} with role {role} to {bucket_name}.')
+
+
+def s3_to_oci(s3_bucket_name: str, oci_bucket_name: str) -> None:
+ """Creates a one-time transfer from Amazon S3 to OCI Object Storage.
+ Args:
+ s3_bucket_name: str; Name of the Amazon S3 Bucket
+ oci_bucket_name: str; Name of the OCI Bucket
+ """
+ # TODO(HysunHe): Implement sync with other clouds (s3, gs)
+ raise NotImplementedError('Moving data directly from S3 to OCI bucket '
+ 'is currently not supported. Please specify '
+ 'a local source for the storage object.')
+
+
+def gcs_to_oci(gs_bucket_name: str, oci_bucket_name: str) -> None:
+ """Creates a one-time transfer from Google Cloud Storage to
+ OCI Object Storage.
+ Args:
+ gs_bucket_name: str; Name of the Google Cloud Storage Bucket
+ oci_bucket_name: str; Name of the OCI Bucket
+ """
+ # TODO(HysunHe): Implement sync with other clouds (s3, gs)
+ raise NotImplementedError('Moving data directly from GCS to OCI bucket '
+ 'is currently not supported. Please specify '
+ 'a local source for the storage object.')
+
+
+def r2_to_oci(r2_bucket_name: str, oci_bucket_name: str) -> None:
+ """Creates a one-time transfer from Cloudflare R2 to OCI Bucket.
+ Args:
+ r2_bucket_name: str; Name of the Cloudflare R2 Bucket
+ oci_bucket_name: str; Name of the OCI Bucket
+ """
+ raise NotImplementedError(
+ 'Moving data directly from Cloudflare R2 to OCI '
+ 'bucket is currently not supported. Please specify '
+ 'a local source for the storage object.')
diff --git a/sky/data/data_utils.py b/sky/data/data_utils.py
index 0c8fd64ddea..e8dcaa83017 100644
--- a/sky/data/data_utils.py
+++ b/sky/data/data_utils.py
@@ -20,6 +20,7 @@
from sky.adaptors import cloudflare
from sky.adaptors import gcp
from sky.adaptors import ibm
+from sky.skylet import log_lib
from sky.utils import common_utils
from sky.utils import ux_utils
@@ -430,6 +431,7 @@ def _group_files_by_dir(
def parallel_upload(source_path_list: List[str],
filesync_command_generator: Callable[[str, List[str]], str],
dirsync_command_generator: Callable[[str, str], str],
+ log_path: str,
bucket_name: str,
access_denied_message: str,
create_dirs: bool = False,
@@ -445,6 +447,7 @@ def parallel_upload(source_path_list: List[str],
for a list of files belonging to the same dir.
dirsync_command_generator: Callable that generates rsync command
for a directory.
+ log_path: Path to the log file.
access_denied_message: Message to intercept from the underlying
upload utility when permissions are insufficient. Used in
exception handling.
@@ -477,7 +480,7 @@ def parallel_upload(source_path_list: List[str],
p.starmap(
run_upload_cli,
zip(commands, [access_denied_message] * len(commands),
- [bucket_name] * len(commands)))
+ [bucket_name] * len(commands), [log_path] * len(commands)))
def get_gsutil_command() -> Tuple[str, str]:
@@ -518,37 +521,31 @@ def get_gsutil_command() -> Tuple[str, str]:
return gsutil_alias, alias_gen
-def run_upload_cli(command: str, access_denied_message: str, bucket_name: str):
- # TODO(zhwu): Use log_lib.run_with_log() and redirect the output
- # to a log file.
- with subprocess.Popen(command,
- stderr=subprocess.PIPE,
- stdout=subprocess.DEVNULL,
- shell=True) as process:
- stderr = []
- assert process.stderr is not None # for mypy
- while True:
- line = process.stderr.readline()
- if not line:
- break
- str_line = line.decode('utf-8')
- stderr.append(str_line)
- if access_denied_message in str_line:
- process.kill()
- with ux_utils.print_exception_no_traceback():
- raise PermissionError(
- 'Failed to upload files to '
- 'the remote bucket. The bucket does not have '
- 'write permissions. It is possible that '
- 'the bucket is public.')
- returncode = process.wait()
- if returncode != 0:
- stderr_str = '\n'.join(stderr)
- with ux_utils.print_exception_no_traceback():
- logger.error(stderr_str)
- raise exceptions.StorageUploadError(
- f'Upload to bucket failed for store {bucket_name}. '
- 'Please check the logs.')
+def run_upload_cli(command: str, access_denied_message: str, bucket_name: str,
+ log_path: str):
+ returncode, stdout, stderr = log_lib.run_with_log(
+ command,
+ log_path,
+ shell=True,
+ require_outputs=True,
+ # We need to use bash as some of the cloud commands uses bash syntax,
+ # such as [[ ... ]]
+ executable='/bin/bash')
+ if access_denied_message in stderr:
+ with ux_utils.print_exception_no_traceback():
+ raise PermissionError('Failed to upload files to '
+ 'the remote bucket. The bucket does not have '
+ 'write permissions. It is possible that '
+ 'the bucket is public.')
+ if returncode != 0:
+ with ux_utils.print_exception_no_traceback():
+ logger.error(stderr)
+ raise exceptions.StorageUploadError(
+ f'Upload to bucket failed for store {bucket_name}. '
+ f'Please check the logs: {log_path}')
+ if not stdout:
+ logger.debug('No file uploaded. This could be due to an error or '
+ 'because all files already exist on the cloud.')
def get_cos_regions() -> List[str]:
@@ -737,3 +734,14 @@ def _remove_bucket_profile_rclone(bucket_name: str,
lines_to_keep.append(line)
return lines_to_keep
+
+
+def split_oci_path(oci_path: str) -> Tuple[str, str]:
+ """Splits OCI Path into Bucket name and Relative Path to Bucket
+ Args:
+ oci_path: str; OCI Path, e.g. oci://imagenet/train/
+ """
+ path_parts = oci_path.replace('oci://', '').split('/')
+ bucket = path_parts.pop(0)
+ key = '/'.join(path_parts)
+ return bucket, key
diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py
index 090471dd06e..f00b4f3fc31 100644
--- a/sky/data/mounting_utils.py
+++ b/sky/data/mounting_utils.py
@@ -19,6 +19,7 @@
_BLOBFUSE_CACHE_ROOT_DIR = '~/.sky/blobfuse2_cache'
_BLOBFUSE_CACHE_DIR = ('~/.sky/blobfuse2_cache/'
'{storage_account_name}_{container_name}')
+RCLONE_VERSION = 'v1.68.2'
def get_s3_mount_install_cmd() -> str:
@@ -30,12 +31,19 @@ def get_s3_mount_install_cmd() -> str:
return install_cmd
-def get_s3_mount_cmd(bucket_name: str, mount_path: str) -> str:
+# pylint: disable=invalid-name
+def get_s3_mount_cmd(bucket_name: str,
+ mount_path: str,
+ _bucket_sub_path: Optional[str] = None) -> str:
"""Returns a command to mount an S3 bucket using goofys."""
+ if _bucket_sub_path is None:
+ _bucket_sub_path = ''
+ else:
+ _bucket_sub_path = f':{_bucket_sub_path}'
mount_cmd = ('goofys -o allow_other '
f'--stat-cache-ttl {_STAT_CACHE_TTL} '
f'--type-cache-ttl {_TYPE_CACHE_TTL} '
- f'{bucket_name} {mount_path}')
+ f'{bucket_name}{_bucket_sub_path} {mount_path}')
return mount_cmd
@@ -49,15 +57,20 @@ def get_gcs_mount_install_cmd() -> str:
return install_cmd
-def get_gcs_mount_cmd(bucket_name: str, mount_path: str) -> str:
+# pylint: disable=invalid-name
+def get_gcs_mount_cmd(bucket_name: str,
+ mount_path: str,
+ _bucket_sub_path: Optional[str] = None) -> str:
"""Returns a command to mount a GCS bucket using gcsfuse."""
-
+ bucket_sub_path_arg = f'--only-dir {_bucket_sub_path} '\
+ if _bucket_sub_path else ''
mount_cmd = ('gcsfuse -o allow_other '
'--implicit-dirs '
f'--stat-cache-capacity {_STAT_CACHE_CAPACITY} '
f'--stat-cache-ttl {_STAT_CACHE_TTL} '
f'--type-cache-ttl {_TYPE_CACHE_TTL} '
f'--rename-dir-limit {_RENAME_DIR_LIMIT} '
+ f'{bucket_sub_path_arg}'
f'{bucket_name} {mount_path}')
return mount_cmd
@@ -78,10 +91,12 @@ def get_az_mount_install_cmd() -> str:
return install_cmd
+# pylint: disable=invalid-name
def get_az_mount_cmd(container_name: str,
storage_account_name: str,
mount_path: str,
- storage_account_key: Optional[str] = None) -> str:
+ storage_account_key: Optional[str] = None,
+ _bucket_sub_path: Optional[str] = None) -> str:
"""Returns a command to mount an AZ Container using blobfuse2.
Args:
@@ -90,6 +105,7 @@ def get_az_mount_cmd(container_name: str,
belongs to.
mount_path: Path where the container will be mounting.
storage_account_key: Access key for the given storage account.
+ _bucket_sub_path: Sub path of the mounting container.
Returns:
str: Command used to mount AZ container with blobfuse2.
@@ -107,25 +123,38 @@ def get_az_mount_cmd(container_name: str,
cache_path = _BLOBFUSE_CACHE_DIR.format(
storage_account_name=storage_account_name,
container_name=container_name)
+ if _bucket_sub_path is None:
+ bucket_sub_path_arg = ''
+ else:
+ bucket_sub_path_arg = f'--subdirectory={_bucket_sub_path}/ '
mount_cmd = (f'AZURE_STORAGE_ACCOUNT={storage_account_name} '
f'{key_env_var} '
f'blobfuse2 {mount_path} --allow-other --no-symlinks '
'-o umask=022 -o default_permissions '
f'--tmp-path {cache_path} '
+ f'{bucket_sub_path_arg}'
f'--container-name {container_name}')
return mount_cmd
-def get_r2_mount_cmd(r2_credentials_path: str, r2_profile_name: str,
- endpoint_url: str, bucket_name: str,
- mount_path: str) -> str:
+# pylint: disable=invalid-name
+def get_r2_mount_cmd(r2_credentials_path: str,
+ r2_profile_name: str,
+ endpoint_url: str,
+ bucket_name: str,
+ mount_path: str,
+ _bucket_sub_path: Optional[str] = None) -> str:
"""Returns a command to install R2 mount utility goofys."""
+ if _bucket_sub_path is None:
+ _bucket_sub_path = ''
+ else:
+ _bucket_sub_path = f':{_bucket_sub_path}'
mount_cmd = (f'AWS_SHARED_CREDENTIALS_FILE={r2_credentials_path} '
f'AWS_PROFILE={r2_profile_name} goofys -o allow_other '
f'--stat-cache-ttl {_STAT_CACHE_TTL} '
f'--type-cache-ttl {_TYPE_CACHE_TTL} '
f'--endpoint {endpoint_url} '
- f'{bucket_name} {mount_path}')
+ f'{bucket_name}{_bucket_sub_path} {mount_path}')
return mount_cmd
@@ -137,9 +166,12 @@ def get_cos_mount_install_cmd() -> str:
return install_cmd
-def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str,
- bucket_rclone_profile: str, bucket_name: str,
- mount_path: str) -> str:
+def get_cos_mount_cmd(rclone_config_data: str,
+ rclone_config_path: str,
+ bucket_rclone_profile: str,
+ bucket_name: str,
+ mount_path: str,
+ _bucket_sub_path: Optional[str] = None) -> str:
"""Returns a command to mount an IBM COS bucket using rclone."""
# creates a fusermount soft link on older (<22) Ubuntu systems for
# rclone's mount utility.
@@ -151,14 +183,60 @@ def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str,
'mkdir -p ~/.config/rclone/ && '
f'echo "{rclone_config_data}" >> '
f'{rclone_config_path}')
+ if _bucket_sub_path is None:
+ sub_path_arg = f'{bucket_name}/{_bucket_sub_path}'
+ else:
+ sub_path_arg = f'/{bucket_name}'
# --daemon will keep the mounting process running in the background.
mount_cmd = (f'{configure_rclone_profile} && '
'rclone mount '
- f'{bucket_rclone_profile}:{bucket_name} {mount_path} '
+ f'{bucket_rclone_profile}:{sub_path_arg} {mount_path} '
'--daemon')
return mount_cmd
+def get_rclone_install_cmd() -> str:
+ """ RClone installation for both apt-get and rpm.
+ This would be common command.
+ """
+ # pylint: disable=line-too-long
+ install_cmd = (
+ f'(which dpkg > /dev/null 2>&1 && (which rclone > /dev/null || (cd ~ > /dev/null'
+ f' && curl -O https://downloads.rclone.org/{RCLONE_VERSION}/rclone-{RCLONE_VERSION}-linux-amd64.deb'
+ f' && sudo dpkg -i rclone-{RCLONE_VERSION}-linux-amd64.deb'
+ f' && rm -f rclone-{RCLONE_VERSION}-linux-amd64.deb)))'
+ f' || (which rclone > /dev/null || (cd ~ > /dev/null'
+ f' && curl -O https://downloads.rclone.org/{RCLONE_VERSION}/rclone-{RCLONE_VERSION}-linux-amd64.rpm'
+ f' && sudo yum --nogpgcheck install rclone-{RCLONE_VERSION}-linux-amd64.rpm -y'
+ f' && rm -f rclone-{RCLONE_VERSION}-linux-amd64.rpm))')
+ return install_cmd
+
+
+def get_oci_mount_cmd(mount_path: str, store_name: str, region: str,
+ namespace: str, compartment: str, config_file: str,
+ config_profile: str) -> str:
+ """ OCI specific RClone mount command for oci object storage. """
+ # pylint: disable=line-too-long
+ mount_cmd = (
+ f'sudo chown -R `whoami` {mount_path}'
+ f' && rclone config create oos_{store_name} oracleobjectstorage'
+ f' provider user_principal_auth namespace {namespace}'
+ f' compartment {compartment} region {region}'
+ f' oci-config-file {config_file}'
+ f' oci-config-profile {config_profile}'
+ f' && sed -i "s/oci-config-file/config_file/g;'
+ f' s/oci-config-profile/config_profile/g" ~/.config/rclone/rclone.conf'
+ f' && ([ ! -f /bin/fusermount3 ] && sudo ln -s /bin/fusermount /bin/fusermount3 || true)'
+ f' && (grep -q {mount_path} /proc/mounts || rclone mount oos_{store_name}:{store_name} {mount_path} --daemon --allow-non-empty)'
+ )
+ return mount_cmd
+
+
+def get_rclone_version_check_cmd() -> str:
+ """ RClone version check. This would be common command. """
+ return f'rclone --version | grep -q {RCLONE_VERSION}'
+
+
def _get_mount_binary(mount_cmd: str) -> str:
"""Returns mounting binary in string given as the mount command.
@@ -210,7 +288,7 @@ def get_mounting_script(
script = textwrap.dedent(f"""
#!/usr/bin/env bash
set -e
-
+
{command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD}
MOUNT_PATH={mount_path}
diff --git a/sky/data/storage.py b/sky/data/storage.py
index 0e35f0bd1c1..74d93cf1207 100644
--- a/sky/data/storage.py
+++ b/sky/data/storage.py
@@ -23,6 +23,7 @@
from sky.adaptors import cloudflare
from sky.adaptors import gcp
from sky.adaptors import ibm
+from sky.adaptors import oci
from sky.data import data_transfer
from sky.data import data_utils
from sky.data import mounting_utils
@@ -54,7 +55,9 @@
str(clouds.AWS()),
str(clouds.GCP()),
str(clouds.Azure()),
- str(clouds.IBM()), cloudflare.NAME
+ str(clouds.IBM()),
+ str(clouds.OCI()),
+ cloudflare.NAME,
]
# Maximum number of concurrent rsync upload processes
@@ -72,6 +75,8 @@
'Bucket {bucket_name!r} does not exist. '
'It may have been deleted externally.')
+_STORAGE_LOG_FILE_NAME = 'storage_sync.log'
+
def get_cached_enabled_storage_clouds_or_refresh(
raise_if_no_cloud_access: bool = False) -> List[str]:
@@ -113,6 +118,7 @@ class StoreType(enum.Enum):
AZURE = 'AZURE'
R2 = 'R2'
IBM = 'IBM'
+ OCI = 'OCI'
@classmethod
def from_cloud(cls, cloud: str) -> 'StoreType':
@@ -126,6 +132,8 @@ def from_cloud(cls, cloud: str) -> 'StoreType':
return StoreType.R2
elif cloud.lower() == str(clouds.Azure()).lower():
return StoreType.AZURE
+ elif cloud.lower() == str(clouds.OCI()).lower():
+ return StoreType.OCI
elif cloud.lower() == str(clouds.Lambda()).lower():
with ux_utils.print_exception_no_traceback():
raise ValueError('Lambda Cloud does not provide cloud storage.')
@@ -147,6 +155,8 @@ def from_store(cls, store: 'AbstractStore') -> 'StoreType':
return StoreType.R2
elif isinstance(store, IBMCosStore):
return StoreType.IBM
+ elif isinstance(store, OciStore):
+ return StoreType.OCI
else:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Unknown store type: {store}')
@@ -163,6 +173,8 @@ def store_prefix(self) -> str:
return 'r2://'
elif self == StoreType.IBM:
return 'cos://'
+ elif self == StoreType.OCI:
+ return 'oci://'
else:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Unknown store type: {self}')
@@ -188,6 +200,39 @@ def get_endpoint_url(cls, store: 'AbstractStore', path: str) -> str:
bucket_endpoint_url = f'{store_type.store_prefix()}{path}'
return bucket_endpoint_url
+ @classmethod
+ def get_fields_from_store_url(
+ cls, store_url: str
+ ) -> Tuple['StoreType', str, str, Optional[str], Optional[str]]:
+ """Returns the store type, store class, bucket name, and sub path from
+ a store URL, and the storage account name and region if applicable.
+
+ Args:
+ store_url: str; The store URL.
+ """
+ # The full path from the user config of IBM COS contains the region,
+ # and Azure Blob Storage contains the storage account name, we need to
+ # pass these information to the store constructor.
+ storage_account_name = None
+ region = None
+ for store_type in StoreType:
+ if store_url.startswith(store_type.store_prefix()):
+ if store_type == StoreType.AZURE:
+ storage_account_name, bucket_name, sub_path = \
+ data_utils.split_az_path(store_url)
+ elif store_type == StoreType.IBM:
+ bucket_name, sub_path, region = data_utils.split_cos_path(
+ store_url)
+ elif store_type == StoreType.R2:
+ bucket_name, sub_path = data_utils.split_r2_path(store_url)
+ elif store_type == StoreType.GCS:
+ bucket_name, sub_path = data_utils.split_gcs_path(store_url)
+ elif store_type == StoreType.S3:
+ bucket_name, sub_path = data_utils.split_s3_path(store_url)
+ return store_type, bucket_name, \
+ sub_path, storage_account_name, region
+ raise ValueError(f'Unknown store URL: {store_url}')
+
class StorageMode(enum.Enum):
MOUNT = 'MOUNT'
@@ -214,25 +259,29 @@ def __init__(self,
name: str,
source: Optional[SourceType],
region: Optional[str] = None,
- is_sky_managed: Optional[bool] = None):
+ is_sky_managed: Optional[bool] = None,
+ _bucket_sub_path: Optional[str] = None):
self.name = name
self.source = source
self.region = region
self.is_sky_managed = is_sky_managed
+ self._bucket_sub_path = _bucket_sub_path
def __repr__(self):
return (f'StoreMetadata('
f'\n\tname={self.name},'
f'\n\tsource={self.source},'
f'\n\tregion={self.region},'
- f'\n\tis_sky_managed={self.is_sky_managed})')
+ f'\n\tis_sky_managed={self.is_sky_managed},'
+ f'\n\t_bucket_sub_path={self._bucket_sub_path})')
def __init__(self,
name: str,
source: Optional[SourceType],
region: Optional[str] = None,
is_sky_managed: Optional[bool] = None,
- sync_on_reconstruction: Optional[bool] = True):
+ sync_on_reconstruction: Optional[bool] = True,
+ _bucket_sub_path: Optional[str] = None): # pylint: disable=invalid-name
"""Initialize AbstractStore
Args:
@@ -246,7 +295,11 @@ def __init__(self,
there. This is set to false when the Storage object is created not
for direct use, e.g. for 'sky storage delete', or the storage is
being re-used, e.g., for `sky start` on a stopped cluster.
-
+ _bucket_sub_path: str; The prefix of the bucket directory to be
+ created in the store, e.g. if _bucket_sub_path=my-dir, the files
+ will be uploaded to s3:///my-dir/.
+ This only works if source is a local directory.
+ # TODO(zpoint): Add support for non-local source.
Raises:
StorageBucketCreateError: If bucket creation fails
StorageBucketGetError: If fetching existing bucket fails
@@ -257,10 +310,29 @@ def __init__(self,
self.region = region
self.is_sky_managed = is_sky_managed
self.sync_on_reconstruction = sync_on_reconstruction
+
+ # To avoid mypy error
+ self._bucket_sub_path: Optional[str] = None
+ # Trigger the setter to strip any leading/trailing slashes.
+ self.bucket_sub_path = _bucket_sub_path
# Whether sky is responsible for the lifecycle of the Store.
self._validate()
self.initialize()
+ @property
+ def bucket_sub_path(self) -> Optional[str]:
+ """Get the bucket_sub_path."""
+ return self._bucket_sub_path
+
+ @bucket_sub_path.setter
+ # pylint: disable=invalid-name
+ def bucket_sub_path(self, bucket_sub_path: Optional[str]) -> None:
+ """Set the bucket_sub_path, stripping any leading/trailing slashes."""
+ if bucket_sub_path is not None:
+ self._bucket_sub_path = bucket_sub_path.strip('/')
+ else:
+ self._bucket_sub_path = None
+
@classmethod
def from_metadata(cls, metadata: StoreMetadata, **override_args):
"""Create a Store from a StoreMetadata object.
@@ -268,19 +340,26 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args):
Used when reconstructing Storage and Store objects from
global_user_state.
"""
- return cls(name=override_args.get('name', metadata.name),
- source=override_args.get('source', metadata.source),
- region=override_args.get('region', metadata.region),
- is_sky_managed=override_args.get('is_sky_managed',
- metadata.is_sky_managed),
- sync_on_reconstruction=override_args.get(
- 'sync_on_reconstruction', True))
+ return cls(
+ name=override_args.get('name', metadata.name),
+ source=override_args.get('source', metadata.source),
+ region=override_args.get('region', metadata.region),
+ is_sky_managed=override_args.get('is_sky_managed',
+ metadata.is_sky_managed),
+ sync_on_reconstruction=override_args.get('sync_on_reconstruction',
+ True),
+ # backward compatibility
+ _bucket_sub_path=override_args.get(
+ '_bucket_sub_path',
+ metadata._bucket_sub_path # pylint: disable=protected-access
+ ) if hasattr(metadata, '_bucket_sub_path') else None)
def get_metadata(self) -> StoreMetadata:
return self.StoreMetadata(name=self.name,
source=self.source,
region=self.region,
- is_sky_managed=self.is_sky_managed)
+ is_sky_managed=self.is_sky_managed,
+ _bucket_sub_path=self._bucket_sub_path)
def initialize(self):
"""Initializes the Store object on the cloud.
@@ -308,7 +387,11 @@ def upload(self) -> None:
raise NotImplementedError
def delete(self) -> None:
- """Removes the Storage object from the cloud."""
+ """Removes the Storage from the cloud."""
+ raise NotImplementedError
+
+ def _delete_sub_path(self) -> None:
+ """Removes objects from the sub path in the bucket."""
raise NotImplementedError
def get_handle(self) -> StorageHandle:
@@ -452,13 +535,19 @@ def remove_store(self, store: AbstractStore) -> None:
if storetype in self.sky_stores:
del self.sky_stores[storetype]
- def __init__(self,
- name: Optional[str] = None,
- source: Optional[SourceType] = None,
- stores: Optional[List[StoreType]] = None,
- persistent: Optional[bool] = True,
- mode: StorageMode = StorageMode.MOUNT,
- sync_on_reconstruction: bool = True) -> None:
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ source: Optional[SourceType] = None,
+ stores: Optional[List[StoreType]] = None,
+ persistent: Optional[bool] = True,
+ mode: StorageMode = StorageMode.MOUNT,
+ sync_on_reconstruction: bool = True,
+ # pylint: disable=invalid-name
+ _is_sky_managed: Optional[bool] = None,
+ # pylint: disable=invalid-name
+ _bucket_sub_path: Optional[str] = None
+ ) -> None:
"""Initializes a Storage object.
Three fields are required: the name of the storage, the source
@@ -496,6 +585,18 @@ def __init__(self,
there. This is set to false when the Storage object is created not
for direct use, e.g. for 'sky storage delete', or the storage is
being re-used, e.g., for `sky start` on a stopped cluster.
+ _is_sky_managed: Optional[bool]; Indicates if the storage is managed
+ by Sky. Without this argument, the controller's behavior differs
+ from the local machine. For example, if a bucket does not exist:
+ Local Machine (is_sky_managed=True) →
+ Controller (is_sky_managed=False).
+ With this argument, the controller aligns with the local machine,
+ ensuring it retains the is_sky_managed information from the YAML.
+ During teardown, if is_sky_managed is True, the controller should
+ delete the bucket. Otherwise, it might mistakenly delete only the
+ sub-path, assuming is_sky_managed is False.
+ _bucket_sub_path: Optional[str]; The subdirectory to use for the
+ storage object.
"""
self.name = name
self.source = source
@@ -507,6 +608,8 @@ def __init__(self,
for store in stores:
self.stores[store] = None
self.sync_on_reconstruction = sync_on_reconstruction
+ self._is_sky_managed = _is_sky_managed
+ self._bucket_sub_path = _bucket_sub_path
self._constructed = False
# TODO(romilb, zhwu): This is a workaround to support storage deletion
@@ -548,15 +651,9 @@ def construct(self):
# from existing ones
input_stores = self.stores
self.stores = {}
- sky_managed_stores = {
- t: s.get_metadata()
- for t, s in self.stores.items()
- if s.is_sky_managed
- }
self.handle = self.StorageMetadata(storage_name=self.name,
source=self.source,
- mode=self.mode,
- sky_stores=sky_managed_stores)
+ mode=self.mode)
for store in input_stores:
self.add_store(store)
@@ -575,6 +672,14 @@ def construct(self):
self.add_store(StoreType.R2)
elif self.source.startswith('cos://'):
self.add_store(StoreType.IBM)
+ elif self.source.startswith('oci://'):
+ self.add_store(StoreType.OCI)
+
+ def get_bucket_sub_path_prefix(self, blob_path: str) -> str:
+ """Adds the bucket sub path prefix to the blob path."""
+ if self._bucket_sub_path is not None:
+ return f'{blob_path}/{self._bucket_sub_path}'
+ return blob_path
@staticmethod
def _validate_source(
@@ -655,7 +760,7 @@ def _validate_local_source(local_source):
'using a bucket by writing : '
f'{source} in the file_mounts section of your YAML')
is_local_source = True
- elif split_path.scheme in ['s3', 'gs', 'https', 'r2', 'cos']:
+ elif split_path.scheme in ['s3', 'gs', 'https', 'r2', 'cos', 'oci']:
is_local_source = False
# Storage mounting does not support mounting specific files from
# cloud store - ensure path points to only a directory
@@ -679,7 +784,7 @@ def _validate_local_source(local_source):
with ux_utils.print_exception_no_traceback():
raise exceptions.StorageSourceError(
f'Supported paths: local, s3://, gs://, https://, '
- f'r2://, cos://. Got: {source}')
+ f'r2://, cos://, oci://. Got: {source}')
return source, is_local_source
def _validate_storage_spec(self, name: Optional[str]) -> None:
@@ -694,7 +799,7 @@ def validate_name(name):
"""
prefix = name.split('://')[0]
prefix = prefix.lower()
- if prefix in ['s3', 'gs', 'https', 'r2', 'cos']:
+ if prefix in ['s3', 'gs', 'https', 'r2', 'cos', 'oci']:
with ux_utils.print_exception_no_traceback():
raise exceptions.StorageNameError(
'Prefix detected: `name` cannot start with '
@@ -786,29 +891,40 @@ def _add_store_from_metadata(
store = S3Store.from_metadata(
s_metadata,
source=self.source,
- sync_on_reconstruction=self.sync_on_reconstruction)
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ _bucket_sub_path=self._bucket_sub_path)
elif s_type == StoreType.GCS:
store = GcsStore.from_metadata(
s_metadata,
source=self.source,
- sync_on_reconstruction=self.sync_on_reconstruction)
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ _bucket_sub_path=self._bucket_sub_path)
elif s_type == StoreType.AZURE:
assert isinstance(s_metadata,
AzureBlobStore.AzureBlobStoreMetadata)
store = AzureBlobStore.from_metadata(
s_metadata,
source=self.source,
- sync_on_reconstruction=self.sync_on_reconstruction)
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ _bucket_sub_path=self._bucket_sub_path)
elif s_type == StoreType.R2:
store = R2Store.from_metadata(
s_metadata,
source=self.source,
- sync_on_reconstruction=self.sync_on_reconstruction)
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ _bucket_sub_path=self._bucket_sub_path)
elif s_type == StoreType.IBM:
store = IBMCosStore.from_metadata(
s_metadata,
source=self.source,
- sync_on_reconstruction=self.sync_on_reconstruction)
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ _bucket_sub_path=self._bucket_sub_path)
+ elif s_type == StoreType.OCI:
+ store = OciStore.from_metadata(
+ s_metadata,
+ source=self.source,
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ _bucket_sub_path=self._bucket_sub_path)
else:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Unknown store type: {s_type}')
@@ -828,7 +944,6 @@ def _add_store_from_metadata(
'to be reconstructed while the corresponding '
'bucket was externally deleted.')
continue
-
self._add_store(store, is_reconstructed=True)
@classmethod
@@ -902,25 +1017,30 @@ def add_store(self,
store_cls = R2Store
elif store_type == StoreType.IBM:
store_cls = IBMCosStore
+ elif store_type == StoreType.OCI:
+ store_cls = OciStore
else:
with ux_utils.print_exception_no_traceback():
raise exceptions.StorageSpecError(
f'{store_type} not supported as a Store.')
-
- # Initialize store object and get/create bucket
try:
store = store_cls(
name=self.name,
source=self.source,
region=region,
- sync_on_reconstruction=self.sync_on_reconstruction)
+ sync_on_reconstruction=self.sync_on_reconstruction,
+ is_sky_managed=self._is_sky_managed,
+ _bucket_sub_path=self._bucket_sub_path)
except exceptions.StorageBucketCreateError:
# Creation failed, so this must be sky managed store. Add failure
# to state.
logger.error(f'Could not create {store_type} store '
f'with name {self.name}.')
- global_user_state.set_storage_status(self.name,
- StorageStatus.INIT_FAILED)
+ try:
+ global_user_state.set_storage_status(self.name,
+ StorageStatus.INIT_FAILED)
+ except ValueError as e:
+ logger.error(f'Error setting storage status: {e}')
raise
except exceptions.StorageBucketGetError:
# Bucket get failed, so this is not sky managed. Do not update state
@@ -1048,12 +1168,15 @@ def warn_for_git_dir(source: str):
def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage':
common_utils.validate_schema(config, schemas.get_storage_schema(),
'Invalid storage YAML: ')
-
name = config.pop('name', None)
source = config.pop('source', None)
store = config.pop('store', None)
mode_str = config.pop('mode', None)
force_delete = config.pop('_force_delete', None)
+ # pylint: disable=invalid-name
+ _is_sky_managed = config.pop('_is_sky_managed', None)
+ # pylint: disable=invalid-name
+ _bucket_sub_path = config.pop('_bucket_sub_path', None)
if force_delete is None:
force_delete = False
@@ -1078,13 +1201,15 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage':
source=source,
persistent=persistent,
mode=mode,
- stores=stores)
+ stores=stores,
+ _is_sky_managed=_is_sky_managed,
+ _bucket_sub_path=_bucket_sub_path)
# Add force deletion flag
storage_obj.force_delete = force_delete
return storage_obj
- def to_yaml_config(self) -> Dict[str, str]:
+ def to_yaml_config(self) -> Dict[str, Any]:
config = {}
def add_if_not_none(key: str, value: Optional[Any]):
@@ -1100,13 +1225,20 @@ def add_if_not_none(key: str, value: Optional[Any]):
add_if_not_none('source', self.source)
stores = None
- if len(self.stores) > 0:
+ is_sky_managed = self._is_sky_managed
+ if self.stores:
stores = ','.join([store.value for store in self.stores])
+ store = list(self.stores.values())[0]
+ if store is not None:
+ is_sky_managed = store.is_sky_managed
add_if_not_none('store', stores)
+ add_if_not_none('_is_sky_managed', is_sky_managed)
add_if_not_none('persistent', self.persistent)
add_if_not_none('mode', self.mode.value)
if self.force_delete:
config['_force_delete'] = True
+ if self._bucket_sub_path is not None:
+ config['_bucket_sub_path'] = self._bucket_sub_path
return config
@@ -1128,7 +1260,8 @@ def __init__(self,
source: str,
region: Optional[str] = _DEFAULT_REGION,
is_sky_managed: Optional[bool] = None,
- sync_on_reconstruction: bool = True):
+ sync_on_reconstruction: bool = True,
+ _bucket_sub_path: Optional[str] = None):
self.client: 'boto3.client.Client'
self.bucket: 'StorageHandle'
# TODO(romilb): This is purely a stopgap fix for
@@ -1141,7 +1274,7 @@ def __init__(self,
f'{self._DEFAULT_REGION} for bucket {name!r}.')
region = self._DEFAULT_REGION
super().__init__(name, source, region, is_sky_managed,
- sync_on_reconstruction)
+ sync_on_reconstruction, _bucket_sub_path)
def _validate(self):
if self.source is not None and isinstance(self.source, str):
@@ -1180,6 +1313,9 @@ def _validate(self):
assert data_utils.verify_ibm_cos_bucket(self.name), (
f'Source specified as {self.source}, a COS bucket. ',
'COS Bucket should exist.')
+ elif self.source.startswith('oci://'):
+ raise NotImplementedError(
+ 'Moving data from OCI to S3 is currently not supported.')
# Validate name
self.name = self.validate_name(self.name)
@@ -1190,7 +1326,7 @@ def _validate(self):
'Storage \'store: s3\' specified, but ' \
'AWS access is disabled. To fix, enable '\
'AWS by running `sky check`. More info: '\
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
)
@classmethod
@@ -1291,6 +1427,8 @@ def upload(self):
self._transfer_to_s3()
elif self.source.startswith('r2://'):
self._transfer_to_s3()
+ elif self.source.startswith('oci://'):
+ self._transfer_to_s3()
else:
self.batch_aws_rsync([self.source])
except exceptions.StorageUploadError:
@@ -1300,6 +1438,9 @@ def upload(self):
f'Upload failed for store {self.name}') from e
def delete(self) -> None:
+ if self._bucket_sub_path is not None and not self.is_sky_managed:
+ return self._delete_sub_path()
+
deleted_by_skypilot = self._delete_s3_bucket(self.name)
if deleted_by_skypilot:
msg_str = f'Deleted S3 bucket {self.name}.'
@@ -1309,6 +1450,19 @@ def delete(self) -> None:
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')
+ def _delete_sub_path(self) -> None:
+ assert self._bucket_sub_path is not None, 'bucket_sub_path is not set'
+ deleted_by_skypilot = self._delete_s3_bucket_sub_path(
+ self.name, self._bucket_sub_path)
+ if deleted_by_skypilot:
+ msg_str = f'Removed objects from S3 bucket ' \
+ f'{self.name}/{self._bucket_sub_path}.'
+ else:
+ msg_str = f'Failed to remove objects from S3 bucket ' \
+ f'{self.name}/{self._bucket_sub_path}.'
+ logger.info(f'{colorama.Fore.GREEN}{msg_str}'
+ f'{colorama.Style.RESET_ALL}')
+
def get_handle(self) -> StorageHandle:
return aws.resource('s3').Bucket(self.name)
@@ -1339,9 +1493,11 @@ def get_file_sync_command(base_dir_path, file_names):
for file_name in file_names
])
base_dir_path = shlex.quote(base_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" '
f'{includes} {base_dir_path} '
- f's3://{self.name}')
+ f's3://{self.name}{sub_path}')
return sync_command
def get_dir_sync_command(src_dir_path, dest_dir_name):
@@ -1353,9 +1509,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
for file_name in excluded_list
])
src_dir_path = shlex.quote(src_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} '
f'{src_dir_path} '
- f's3://{self.name}/{dest_dir_name}')
+ f's3://{self.name}{sub_path}/{dest_dir_name}')
return sync_command
# Generate message for upload
@@ -1364,17 +1522,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
else:
source_message = source_path_list[0]
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> s3://{self.name}/'
with rich_utils.safe_status(
- ux_utils.spinner_message(f'Syncing {source_message} -> '
- f's3://{self.name}/')):
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
data_utils.parallel_upload(
source_path_list,
get_file_sync_command,
get_dir_sync_command,
+ log_path,
self.name,
self._ACCESS_DENIED_MESSAGE,
create_dirs=create_dirs,
max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS)
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
def _transfer_to_s3(self) -> None:
assert isinstance(self.source, str), self.source
@@ -1466,7 +1631,8 @@ def mount_command(self, mount_path: str) -> str:
"""
install_cmd = mounting_utils.get_s3_mount_install_cmd()
mount_cmd = mounting_utils.get_s3_mount_cmd(self.bucket.name,
- mount_path)
+ mount_path,
+ self._bucket_sub_path)
return mounting_utils.get_mounting_command(mount_path, install_cmd,
mount_cmd)
@@ -1516,6 +1682,27 @@ def _create_s3_bucket(self,
) from e
return aws.resource('s3').Bucket(bucket_name)
+ def _execute_s3_remove_command(self, command: str, bucket_name: str,
+ hint_operating: str,
+ hint_failed: str) -> bool:
+ try:
+ with rich_utils.safe_status(
+ ux_utils.spinner_message(hint_operating)):
+ subprocess.check_output(command.split(' '),
+ stderr=subprocess.STDOUT)
+ except subprocess.CalledProcessError as e:
+ if 'NoSuchBucket' in e.output.decode('utf-8'):
+ logger.debug(
+ _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format(
+ bucket_name=bucket_name))
+ return False
+ else:
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketDeleteError(
+ f'{hint_failed}'
+ f'Detailed error: {e.output}')
+ return True
+
def _delete_s3_bucket(self, bucket_name: str) -> bool:
"""Deletes S3 bucket, including all objects in bucket
@@ -1533,29 +1720,28 @@ def _delete_s3_bucket(self, bucket_name: str) -> bool:
# The fastest way to delete is to run `aws s3 rb --force`,
# which removes the bucket by force.
remove_command = f'aws s3 rb s3://{bucket_name} --force'
- try:
- with rich_utils.safe_status(
- ux_utils.spinner_message(
- f'Deleting S3 bucket [green]{bucket_name}')):
- subprocess.check_output(remove_command.split(' '),
- stderr=subprocess.STDOUT)
- except subprocess.CalledProcessError as e:
- if 'NoSuchBucket' in e.output.decode('utf-8'):
- logger.debug(
- _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format(
- bucket_name=bucket_name))
- return False
- else:
- with ux_utils.print_exception_no_traceback():
- raise exceptions.StorageBucketDeleteError(
- f'Failed to delete S3 bucket {bucket_name}.'
- f'Detailed error: {e.output}')
+ success = self._execute_s3_remove_command(
+ remove_command, bucket_name,
+ f'Deleting S3 bucket [green]{bucket_name}[/]',
+ f'Failed to delete S3 bucket {bucket_name}.')
+ if not success:
+ return False
# Wait until bucket deletion propagates on AWS servers
while data_utils.verify_s3_bucket(bucket_name):
time.sleep(0.1)
return True
+ def _delete_s3_bucket_sub_path(self, bucket_name: str,
+ sub_path: str) -> bool:
+ """Deletes the sub path from the bucket."""
+ remove_command = f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive'
+ return self._execute_s3_remove_command(
+ remove_command, bucket_name, f'Removing objects from S3 bucket '
+ f'[green]{bucket_name}/{sub_path}[/]',
+ f'Failed to remove objects from S3 bucket {bucket_name}/{sub_path}.'
+ )
+
class GcsStore(AbstractStore):
"""GcsStore inherits from Storage Object and represents the backend
@@ -1569,11 +1755,12 @@ def __init__(self,
source: str,
region: Optional[str] = 'us-central1',
is_sky_managed: Optional[bool] = None,
- sync_on_reconstruction: Optional[bool] = True):
+ sync_on_reconstruction: Optional[bool] = True,
+ _bucket_sub_path: Optional[str] = None):
self.client: 'storage.Client'
self.bucket: StorageHandle
super().__init__(name, source, region, is_sky_managed,
- sync_on_reconstruction)
+ sync_on_reconstruction, _bucket_sub_path)
def _validate(self):
if self.source is not None and isinstance(self.source, str):
@@ -1612,6 +1799,9 @@ def _validate(self):
assert data_utils.verify_ibm_cos_bucket(self.name), (
f'Source specified as {self.source}, a COS bucket. ',
'COS Bucket should exist.')
+ elif self.source.startswith('oci://'):
+ raise NotImplementedError(
+ 'Moving data from OCI to GCS is currently not supported.')
# Validate name
self.name = self.validate_name(self.name)
# Check if the storage is enabled
@@ -1621,7 +1811,7 @@ def _validate(self):
'Storage \'store: gcs\' specified, but '
'GCP access is disabled. To fix, enable '
'GCP by running `sky check`. '
- 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.') # pylint: disable=line-too-long
+ 'More info: https://docs.skypilot.co/en/latest/getting-started/installation.html.') # pylint: disable=line-too-long
@classmethod
def validate_name(cls, name: str) -> str:
@@ -1720,6 +1910,8 @@ def upload(self):
self._transfer_to_gcs()
elif self.source.startswith('r2://'):
self._transfer_to_gcs()
+ elif self.source.startswith('oci://'):
+ self._transfer_to_gcs()
else:
# If a single directory is specified in source, upload
# contents to root of bucket by suffixing /*.
@@ -1731,6 +1923,9 @@ def upload(self):
f'Upload failed for store {self.name}') from e
def delete(self) -> None:
+ if self._bucket_sub_path is not None and not self.is_sky_managed:
+ return self._delete_sub_path()
+
deleted_by_skypilot = self._delete_gcs_bucket(self.name)
if deleted_by_skypilot:
msg_str = f'Deleted GCS bucket {self.name}.'
@@ -1740,6 +1935,19 @@ def delete(self) -> None:
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')
+ def _delete_sub_path(self) -> None:
+ assert self._bucket_sub_path is not None, 'bucket_sub_path is not set'
+ deleted_by_skypilot = self._delete_gcs_bucket(self.name,
+ self._bucket_sub_path)
+ if deleted_by_skypilot:
+ msg_str = f'Deleted objects in GCS bucket ' \
+ f'{self.name}/{self._bucket_sub_path}.'
+ else:
+ msg_str = f'GCS bucket {self.name} may have ' \
+ 'been deleted externally.'
+ logger.info(f'{colorama.Fore.GREEN}{msg_str}'
+ f'{colorama.Style.RESET_ALL}')
+
def get_handle(self) -> StorageHandle:
return self.client.get_bucket(self.name)
@@ -1774,13 +1982,19 @@ def batch_gsutil_cp(self,
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
sync_command = (f'{alias_gen}; echo "{copy_list}" | {gsutil_alias} '
f'cp -e -n -r -I gs://{self.name}')
-
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> gs://{self.name}/'
with rich_utils.safe_status(
- ux_utils.spinner_message(f'Syncing {source_message} -> '
- f'gs://{self.name}/')):
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
data_utils.run_upload_cli(sync_command,
self._ACCESS_DENIED_MESSAGE,
- bucket_name=self.name)
+ bucket_name=self.name,
+ log_path=log_path)
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
def batch_gsutil_rsync(self,
source_path_list: List[Path],
@@ -1807,9 +2021,11 @@ def get_file_sync_command(base_dir_path, file_names):
sync_format = '|'.join(file_names)
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
base_dir_path = shlex.quote(base_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = (f'{alias_gen}; {gsutil_alias} '
f'rsync -e -x \'^(?!{sync_format}$).*\' '
- f'{base_dir_path} gs://{self.name}')
+ f'{base_dir_path} gs://{self.name}{sub_path}')
return sync_command
def get_dir_sync_command(src_dir_path, dest_dir_name):
@@ -1819,9 +2035,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
excludes = '|'.join(excluded_list)
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
src_dir_path = shlex.quote(src_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = (f'{alias_gen}; {gsutil_alias} '
f'rsync -e -r -x \'({excludes})\' {src_dir_path} '
- f'gs://{self.name}/{dest_dir_name}')
+ f'gs://{self.name}{sub_path}/{dest_dir_name}')
return sync_command
# Generate message for upload
@@ -1830,17 +2048,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
else:
source_message = source_path_list[0]
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> gs://{self.name}/'
with rich_utils.safe_status(
- ux_utils.spinner_message(f'Syncing {source_message} -> '
- f'gs://{self.name}/')):
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
data_utils.parallel_upload(
source_path_list,
get_file_sync_command,
get_dir_sync_command,
+ log_path,
self.name,
self._ACCESS_DENIED_MESSAGE,
create_dirs=create_dirs,
max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS)
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
def _transfer_to_gcs(self) -> None:
if isinstance(self.source, str) and self.source.startswith('s3://'):
@@ -1919,7 +2144,8 @@ def mount_command(self, mount_path: str) -> str:
"""
install_cmd = mounting_utils.get_gcs_mount_install_cmd()
mount_cmd = mounting_utils.get_gcs_mount_cmd(self.bucket.name,
- mount_path)
+ mount_path,
+ self._bucket_sub_path)
version_check_cmd = (
f'gcsfuse --version | grep -q {mounting_utils.GCSFUSE_VERSION}')
return mounting_utils.get_mounting_command(mount_path, install_cmd,
@@ -1959,19 +2185,33 @@ def _create_gcs_bucket(self,
f'{new_bucket.storage_class}{colorama.Style.RESET_ALL}')
return new_bucket
- def _delete_gcs_bucket(self, bucket_name: str) -> bool:
- """Deletes GCS bucket, including all objects in bucket
+ def _delete_gcs_bucket(
+ self,
+ bucket_name: str,
+ # pylint: disable=invalid-name
+ _bucket_sub_path: Optional[str] = None
+ ) -> bool:
+ """Deletes objects in GCS bucket
Args:
bucket_name: str; Name of bucket
+ _bucket_sub_path: str; Sub path in the bucket, if provided only
+ objects in the sub path will be deleted, else the whole bucket will
+ be deleted
Returns:
bool; True if bucket was deleted, False if it was deleted externally.
"""
-
+ if _bucket_sub_path is not None:
+ command_suffix = f'/{_bucket_sub_path}'
+ hint_text = 'objects in '
+ else:
+ command_suffix = ''
+ hint_text = ''
with rich_utils.safe_status(
ux_utils.spinner_message(
- f'Deleting GCS bucket [green]{bucket_name}')):
+ f'Deleting {hint_text}GCS bucket '
+ f'[green]{bucket_name}{command_suffix}[/]')):
try:
self.client.get_bucket(bucket_name)
except gcp.forbidden_exception() as e:
@@ -1989,8 +2229,9 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool:
return False
try:
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
- remove_obj_command = (f'{alias_gen};{gsutil_alias} '
- f'rm -r gs://{bucket_name}')
+ remove_obj_command = (
+ f'{alias_gen};{gsutil_alias} '
+ f'rm -r gs://{bucket_name}{command_suffix}')
subprocess.check_output(remove_obj_command,
stderr=subprocess.STDOUT,
shell=True,
@@ -1999,7 +2240,8 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool:
except subprocess.CalledProcessError as e:
with ux_utils.print_exception_no_traceback():
raise exceptions.StorageBucketDeleteError(
- f'Failed to delete GCS bucket {bucket_name}.'
+ f'Failed to delete {hint_text}GCS bucket '
+ f'{bucket_name}{command_suffix}.'
f'Detailed error: {e.output}')
@@ -2051,7 +2293,8 @@ def __init__(self,
storage_account_name: str = '',
region: Optional[str] = 'eastus',
is_sky_managed: Optional[bool] = None,
- sync_on_reconstruction: bool = True):
+ sync_on_reconstruction: bool = True,
+ _bucket_sub_path: Optional[str] = None):
self.storage_client: 'storage.Client'
self.resource_client: 'storage.Client'
self.container_name: str
@@ -2063,7 +2306,7 @@ def __init__(self,
if region is None:
region = 'eastus'
super().__init__(name, source, region, is_sky_managed,
- sync_on_reconstruction)
+ sync_on_reconstruction, _bucket_sub_path)
@classmethod
def from_metadata(cls, metadata: AbstractStore.StoreMetadata,
@@ -2133,6 +2376,9 @@ def _validate(self):
assert data_utils.verify_ibm_cos_bucket(self.name), (
f'Source specified as {self.source}, a COS bucket. ',
'COS Bucket should exist.')
+ elif self.source.startswith('oci://'):
+ raise NotImplementedError(
+ 'Moving data from OCI to AZureBlob is not supported.')
# Validate name
self.name = self.validate_name(self.name)
@@ -2143,7 +2389,7 @@ def _validate(self):
'Storage "store: azure" specified, but '
'Azure access is disabled. To fix, enable '
'Azure by running `sky check`. More info: '
- 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
)
@classmethod
@@ -2210,6 +2456,17 @@ def initialize(self):
"""
self.storage_client = data_utils.create_az_client('storage')
self.resource_client = data_utils.create_az_client('resource')
+ self._update_storage_account_name_and_resource()
+
+ self.container_name, is_new_bucket = self._get_bucket()
+ if self.is_sky_managed is None:
+ # If is_sky_managed is not specified, then this is a new storage
+ # object (i.e., did not exist in global_user_state) and we should
+ # set the is_sky_managed property.
+ # If is_sky_managed is specified, then we take no action.
+ self.is_sky_managed = is_new_bucket
+
+ def _update_storage_account_name_and_resource(self):
self.storage_account_name, self.resource_group_name = (
self._get_storage_account_and_resource_group())
@@ -2220,13 +2477,13 @@ def initialize(self):
self.storage_account_name, self.resource_group_name,
self.storage_client, self.resource_client)
- self.container_name, is_new_bucket = self._get_bucket()
- if self.is_sky_managed is None:
- # If is_sky_managed is not specified, then this is a new storage
- # object (i.e., did not exist in global_user_state) and we should
- # set the is_sky_managed property.
- # If is_sky_managed is specified, then we take no action.
- self.is_sky_managed = is_new_bucket
+ def update_storage_attributes(self, **kwargs: Dict[str, Any]):
+ assert 'storage_account_name' in kwargs, (
+ 'only storage_account_name supported')
+ assert isinstance(kwargs['storage_account_name'],
+ str), ('storage_account_name must be a string')
+ self.storage_account_name = kwargs['storage_account_name']
+ self._update_storage_account_name_and_resource()
@staticmethod
def get_default_storage_account_name(region: Optional[str]) -> str:
@@ -2485,6 +2742,8 @@ def upload(self):
raise NotImplementedError(error_message.format('R2'))
elif self.source.startswith('cos://'):
raise NotImplementedError(error_message.format('IBM COS'))
+ elif self.source.startswith('oci://'):
+ raise NotImplementedError(error_message.format('OCI'))
else:
self.batch_az_blob_sync([self.source])
except exceptions.StorageUploadError:
@@ -2495,6 +2754,9 @@ def upload(self):
def delete(self) -> None:
"""Deletes the storage."""
+ if self._bucket_sub_path is not None and not self.is_sky_managed:
+ return self._delete_sub_path()
+
deleted_by_skypilot = self._delete_az_bucket(self.name)
if deleted_by_skypilot:
msg_str = (f'Deleted AZ Container {self.name!r} under storage '
@@ -2505,6 +2767,32 @@ def delete(self) -> None:
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')
+ def _delete_sub_path(self) -> None:
+ assert self._bucket_sub_path is not None, 'bucket_sub_path is not set'
+ try:
+ container_url = data_utils.AZURE_CONTAINER_URL.format(
+ storage_account_name=self.storage_account_name,
+ container_name=self.name)
+ container_client = data_utils.create_az_client(
+ client_type='container',
+ container_url=container_url,
+ storage_account_name=self.storage_account_name,
+ resource_group_name=self.resource_group_name)
+ # List and delete blobs in the specified directory
+ blobs = container_client.list_blobs(
+ name_starts_with=self._bucket_sub_path + '/')
+ for blob in blobs:
+ container_client.delete_blob(blob.name)
+ logger.info(
+ f'Deleted objects from sub path {self._bucket_sub_path} '
+ f'in container {self.name}.')
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(
+ f'Failed to delete objects from sub path '
+ f'{self._bucket_sub_path} in container {self.name}. '
+ f'Details: {common_utils.format_exception(e, use_bracket=True)}'
+ )
+
def get_handle(self) -> StorageHandle:
"""Returns the Storage Handle object."""
return self.storage_client.blob_containers.get(
@@ -2531,13 +2819,15 @@ def get_file_sync_command(base_dir_path, file_names) -> str:
includes_list = ';'.join(file_names)
includes = f'--include-pattern "{includes_list}"'
base_dir_path = shlex.quote(base_dir_path)
+ container_path = (f'{self.container_name}/{self._bucket_sub_path}'
+ if self._bucket_sub_path else self.container_name)
sync_command = (f'az storage blob sync '
f'--account-name {self.storage_account_name} '
f'--account-key {self.storage_account_key} '
f'{includes} '
'--delete-destination false '
f'--source {base_dir_path} '
- f'--container {self.container_name}')
+ f'--container {container_path}')
return sync_command
def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
@@ -2548,8 +2838,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
[file_name.rstrip('*') for file_name in excluded_list])
excludes = f'--exclude-path "{excludes_list}"'
src_dir_path = shlex.quote(src_dir_path)
- container_path = (f'{self.container_name}/{dest_dir_name}'
- if dest_dir_name else self.container_name)
+ container_path = (f'{self.container_name}/{self._bucket_sub_path}'
+ if self._bucket_sub_path else
+ f'{self.container_name}')
+ if dest_dir_name:
+ container_path = f'{container_path}/{dest_dir_name}'
sync_command = (f'az storage blob sync '
f'--account-name {self.storage_account_name} '
f'--account-key {self.storage_account_key} '
@@ -2568,17 +2861,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
container_endpoint = data_utils.AZURE_CONTAINER_URL.format(
storage_account_name=self.storage_account_name,
container_name=self.name)
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> {container_endpoint}/'
with rich_utils.safe_status(
- ux_utils.spinner_message(
- f'Syncing {source_message} -> {container_endpoint}/')):
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
data_utils.parallel_upload(
source_path_list,
get_file_sync_command,
get_dir_sync_command,
+ log_path,
self.name,
self._ACCESS_DENIED_MESSAGE,
create_dirs=create_dirs,
max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS)
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
def _get_bucket(self) -> Tuple[str, bool]:
"""Obtains the AZ Container.
@@ -2665,6 +2965,7 @@ def _get_bucket(self) -> Tuple[str, bool]:
f'{self.storage_account_name!r}.'
'Details: '
f'{common_utils.format_exception(e, use_bracket=True)}')
+
# If the container cannot be found in both private and public settings,
# the container is to be created by Sky. However, creation is skipped
# if Store object is being reconstructed for deletion or re-mount with
@@ -2695,7 +2996,8 @@ def mount_command(self, mount_path: str) -> str:
mount_cmd = mounting_utils.get_az_mount_cmd(self.container_name,
self.storage_account_name,
mount_path,
- self.storage_account_key)
+ self.storage_account_key,
+ self._bucket_sub_path)
return mounting_utils.get_mounting_command(mount_path, install_cmd,
mount_cmd)
@@ -2794,11 +3096,12 @@ def __init__(self,
source: str,
region: Optional[str] = 'auto',
is_sky_managed: Optional[bool] = None,
- sync_on_reconstruction: Optional[bool] = True):
+ sync_on_reconstruction: Optional[bool] = True,
+ _bucket_sub_path: Optional[str] = None):
self.client: 'boto3.client.Client'
self.bucket: 'StorageHandle'
super().__init__(name, source, region, is_sky_managed,
- sync_on_reconstruction)
+ sync_on_reconstruction, _bucket_sub_path)
def _validate(self):
if self.source is not None and isinstance(self.source, str):
@@ -2837,6 +3140,10 @@ def _validate(self):
assert data_utils.verify_ibm_cos_bucket(self.name), (
f'Source specified as {self.source}, a COS bucket. ',
'COS Bucket should exist.')
+ elif self.source.startswith('oci://'):
+ raise NotImplementedError(
+ 'Moving data from OCI to R2 is currently not supported.')
+
# Validate name
self.name = S3Store.validate_name(self.name)
# Check if the storage is enabled
@@ -2846,7 +3153,7 @@ def _validate(self):
'Storage \'store: r2\' specified, but ' \
'Cloudflare R2 access is disabled. To fix, '\
'enable Cloudflare R2 by running `sky check`. '\
- 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
+ 'More info: https://docs.skypilot.co/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
)
def initialize(self):
@@ -2888,6 +3195,8 @@ def upload(self):
self._transfer_to_r2()
elif self.source.startswith('r2://'):
pass
+ elif self.source.startswith('oci://'):
+ self._transfer_to_r2()
else:
self.batch_aws_rsync([self.source])
except exceptions.StorageUploadError:
@@ -2897,6 +3206,9 @@ def upload(self):
f'Upload failed for store {self.name}') from e
def delete(self) -> None:
+ if self._bucket_sub_path is not None and not self.is_sky_managed:
+ return self._delete_sub_path()
+
deleted_by_skypilot = self._delete_r2_bucket(self.name)
if deleted_by_skypilot:
msg_str = f'Deleted R2 bucket {self.name}.'
@@ -2906,6 +3218,19 @@ def delete(self) -> None:
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')
+ def _delete_sub_path(self) -> None:
+ assert self._bucket_sub_path is not None, 'bucket_sub_path is not set'
+ deleted_by_skypilot = self._delete_r2_bucket_sub_path(
+ self.name, self._bucket_sub_path)
+ if deleted_by_skypilot:
+ msg_str = f'Removed objects from R2 bucket ' \
+ f'{self.name}/{self._bucket_sub_path}.'
+ else:
+ msg_str = f'Failed to remove objects from R2 bucket ' \
+ f'{self.name}/{self._bucket_sub_path}.'
+ logger.info(f'{colorama.Fore.GREEN}{msg_str}'
+ f'{colorama.Style.RESET_ALL}')
+
def get_handle(self) -> StorageHandle:
return cloudflare.resource('s3').Bucket(self.name)
@@ -2937,11 +3262,13 @@ def get_file_sync_command(base_dir_path, file_names):
])
endpoint_url = cloudflare.create_endpoint()
base_dir_path = shlex.quote(base_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = ('AWS_SHARED_CREDENTIALS_FILE='
f'{cloudflare.R2_CREDENTIALS_PATH} '
'aws s3 sync --no-follow-symlinks --exclude="*" '
f'{includes} {base_dir_path} '
- f's3://{self.name} '
+ f's3://{self.name}{sub_path} '
f'--endpoint {endpoint_url} '
f'--profile={cloudflare.R2_PROFILE_NAME}')
return sync_command
@@ -2956,11 +3283,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
])
endpoint_url = cloudflare.create_endpoint()
src_dir_path = shlex.quote(src_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = ('AWS_SHARED_CREDENTIALS_FILE='
f'{cloudflare.R2_CREDENTIALS_PATH} '
f'aws s3 sync --no-follow-symlinks {excludes} '
f'{src_dir_path} '
- f's3://{self.name}/{dest_dir_name} '
+ f's3://{self.name}{sub_path}/{dest_dir_name} '
f'--endpoint {endpoint_url} '
f'--profile={cloudflare.R2_PROFILE_NAME}')
return sync_command
@@ -2971,17 +3300,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
else:
source_message = source_path_list[0]
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> r2://{self.name}/'
with rich_utils.safe_status(
- ux_utils.spinner_message(
- f'Syncing {source_message} -> r2://{self.name}/')):
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
data_utils.parallel_upload(
source_path_list,
get_file_sync_command,
get_dir_sync_command,
+ log_path,
self.name,
self._ACCESS_DENIED_MESSAGE,
create_dirs=create_dirs,
max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS)
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
def _transfer_to_r2(self) -> None:
assert isinstance(self.source, str), self.source
@@ -3084,11 +3420,9 @@ def mount_command(self, mount_path: str) -> str:
endpoint_url = cloudflare.create_endpoint()
r2_credential_path = cloudflare.R2_CREDENTIALS_PATH
r2_profile_name = cloudflare.R2_PROFILE_NAME
- mount_cmd = mounting_utils.get_r2_mount_cmd(r2_credential_path,
- r2_profile_name,
- endpoint_url,
- self.bucket.name,
- mount_path)
+ mount_cmd = mounting_utils.get_r2_mount_cmd(
+ r2_credential_path, r2_profile_name, endpoint_url, self.bucket.name,
+ mount_path, self._bucket_sub_path)
return mounting_utils.get_mounting_command(mount_path, install_cmd,
mount_cmd)
@@ -3121,6 +3455,43 @@ def _create_r2_bucket(self,
f'{self.name} but failed.') from e
return cloudflare.resource('s3').Bucket(bucket_name)
+ def _execute_r2_remove_command(self, command: str, bucket_name: str,
+ hint_operating: str,
+ hint_failed: str) -> bool:
+ try:
+ with rich_utils.safe_status(
+ ux_utils.spinner_message(hint_operating)):
+ subprocess.check_output(command.split(' '),
+ stderr=subprocess.STDOUT,
+ shell=True)
+ except subprocess.CalledProcessError as e:
+ if 'NoSuchBucket' in e.output.decode('utf-8'):
+ logger.debug(
+ _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format(
+ bucket_name=bucket_name))
+ return False
+ else:
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketDeleteError(
+ f'{hint_failed}'
+ f'Detailed error: {e.output}')
+ return True
+
+ def _delete_r2_bucket_sub_path(self, bucket_name: str,
+ sub_path: str) -> bool:
+ """Deletes the sub path from the bucket."""
+ endpoint_url = cloudflare.create_endpoint()
+ remove_command = (
+ f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} '
+ f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive '
+ f'--endpoint {endpoint_url} '
+ f'--profile={cloudflare.R2_PROFILE_NAME}')
+ return self._execute_r2_remove_command(
+ remove_command, bucket_name,
+ f'Removing objects from R2 bucket {bucket_name}/{sub_path}',
+ f'Failed to remove objects from R2 bucket {bucket_name}/{sub_path}.'
+ )
+
def _delete_r2_bucket(self, bucket_name: str) -> bool:
"""Deletes R2 bucket, including all objects in bucket
@@ -3143,24 +3514,12 @@ def _delete_r2_bucket(self, bucket_name: str) -> bool:
f'aws s3 rb s3://{bucket_name} --force '
f'--endpoint {endpoint_url} '
f'--profile={cloudflare.R2_PROFILE_NAME}')
- try:
- with rich_utils.safe_status(
- ux_utils.spinner_message(
- f'Deleting R2 bucket {bucket_name}')):
- subprocess.check_output(remove_command,
- stderr=subprocess.STDOUT,
- shell=True)
- except subprocess.CalledProcessError as e:
- if 'NoSuchBucket' in e.output.decode('utf-8'):
- logger.debug(
- _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format(
- bucket_name=bucket_name))
- return False
- else:
- with ux_utils.print_exception_no_traceback():
- raise exceptions.StorageBucketDeleteError(
- f'Failed to delete R2 bucket {bucket_name}.'
- f'Detailed error: {e.output}')
+
+ success = self._execute_r2_remove_command(
+ remove_command, bucket_name, f'Deleting R2 bucket {bucket_name}',
+ f'Failed to delete R2 bucket {bucket_name}.')
+ if not success:
+ return False
# Wait until bucket deletion propagates on AWS servers
while data_utils.verify_r2_bucket(bucket_name):
@@ -3179,11 +3538,12 @@ def __init__(self,
source: str,
region: Optional[str] = 'us-east',
is_sky_managed: Optional[bool] = None,
- sync_on_reconstruction: bool = True):
+ sync_on_reconstruction: bool = True,
+ _bucket_sub_path: Optional[str] = None):
self.client: 'storage.Client'
self.bucket: 'StorageHandle'
super().__init__(name, source, region, is_sky_managed,
- sync_on_reconstruction)
+ sync_on_reconstruction, _bucket_sub_path)
self.bucket_rclone_profile = \
Rclone.generate_rclone_bucket_profile_name(
self.name, Rclone.RcloneClouds.IBM)
@@ -3328,10 +3688,22 @@ def upload(self):
f'Upload failed for store {self.name}') from e
def delete(self) -> None:
+ if self._bucket_sub_path is not None and not self.is_sky_managed:
+ return self._delete_sub_path()
+
self._delete_cos_bucket()
logger.info(f'{colorama.Fore.GREEN}Deleted COS bucket {self.name}.'
f'{colorama.Style.RESET_ALL}')
+ def _delete_sub_path(self) -> None:
+ assert self._bucket_sub_path is not None, 'bucket_sub_path is not set'
+ bucket = self.s3_resource.Bucket(self.name)
+ try:
+ self._delete_cos_bucket_objects(bucket, self._bucket_sub_path + '/')
+ except ibm.ibm_botocore.exceptions.ClientError as e:
+ if e.__class__.__name__ == 'NoSuchBucket':
+ logger.debug('bucket already removed')
+
def get_handle(self) -> StorageHandle:
return self.s3_resource.Bucket(self.name)
@@ -3372,10 +3744,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
# .git directory is excluded from the sync
# wrapping src_dir_path with "" to support path with spaces
src_dir_path = shlex.quote(src_dir_path)
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
sync_command = (
'rclone copy --exclude ".git/*" '
f'{src_dir_path} '
- f'{self.bucket_rclone_profile}:{self.name}/{dest_dir_name}')
+ f'{self.bucket_rclone_profile}:{self.name}{sub_path}'
+ f'/{dest_dir_name}')
return sync_command
def get_file_sync_command(base_dir_path, file_names) -> str:
@@ -3401,9 +3776,12 @@ def get_file_sync_command(base_dir_path, file_names) -> str:
for file_name in file_names
])
base_dir_path = shlex.quote(base_dir_path)
- sync_command = ('rclone copy '
- f'{includes} {base_dir_path} '
- f'{self.bucket_rclone_profile}:{self.name}')
+ sub_path = (f'/{self._bucket_sub_path}'
+ if self._bucket_sub_path else '')
+ sync_command = (
+ 'rclone copy '
+ f'{includes} {base_dir_path} '
+ f'{self.bucket_rclone_profile}:{self.name}{sub_path}')
return sync_command
# Generate message for upload
@@ -3412,17 +3790,24 @@ def get_file_sync_command(base_dir_path, file_names) -> str:
else:
source_message = source_path_list[0]
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> cos://{self.region}/{self.name}/'
with rich_utils.safe_status(
- ux_utils.spinner_message(f'Syncing {source_message} -> '
- f'cos://{self.region}/{self.name}/')):
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
data_utils.parallel_upload(
source_path_list,
get_file_sync_command,
get_dir_sync_command,
+ log_path,
self.name,
self._ACCESS_DENIED_MESSAGE,
create_dirs=create_dirs,
max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS)
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
def _get_bucket(self) -> Tuple[StorageHandle, bool]:
"""returns IBM COS bucket object if exists, otherwise creates it.
@@ -3481,6 +3866,7 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]:
Rclone.RcloneClouds.IBM,
self.region, # type: ignore
)
+
if not bucket_region and self.sync_on_reconstruction:
# bucket doesn't exist
return self._create_cos_bucket(self.name, self.region), True
@@ -3527,7 +3913,8 @@ def mount_command(self, mount_path: str) -> str:
Rclone.RCLONE_CONFIG_PATH,
self.bucket_rclone_profile,
self.bucket.name,
- mount_path)
+ mount_path,
+ self._bucket_sub_path)
return mounting_utils.get_mounting_command(mount_path, install_cmd,
mount_cmd)
@@ -3565,18 +3952,442 @@ def _create_cos_bucket(self,
return self.bucket
- def _delete_cos_bucket(self):
- bucket = self.s3_resource.Bucket(self.name)
- try:
- bucket_versioning = self.s3_resource.BucketVersioning(self.name)
- if bucket_versioning.status == 'Enabled':
+ def _delete_cos_bucket_objects(self,
+ bucket: Any,
+ prefix: Optional[str] = None):
+ bucket_versioning = self.s3_resource.BucketVersioning(bucket.name)
+ if bucket_versioning.status == 'Enabled':
+ if prefix is not None:
+ res = list(
+ bucket.object_versions.filter(Prefix=prefix).delete())
+ else:
res = list(bucket.object_versions.delete())
+ else:
+ if prefix is not None:
+ res = list(bucket.objects.filter(Prefix=prefix).delete())
else:
res = list(bucket.objects.delete())
- logger.debug(f'Deleted bucket\'s content:\n{res}')
+ logger.debug(f'Deleted bucket\'s content:\n{res}, prefix: {prefix}')
+
+ def _delete_cos_bucket(self):
+ bucket = self.s3_resource.Bucket(self.name)
+ try:
+ self._delete_cos_bucket_objects(bucket)
bucket.delete()
bucket.wait_until_not_exists()
except ibm.ibm_botocore.exceptions.ClientError as e:
if e.__class__.__name__ == 'NoSuchBucket':
logger.debug('bucket already removed')
Rclone.delete_rclone_bucket_profile(self.name, Rclone.RcloneClouds.IBM)
+
+
+class OciStore(AbstractStore):
+ """OciStore inherits from Storage Object and represents the backend
+ for OCI buckets.
+ """
+
+ _ACCESS_DENIED_MESSAGE = 'AccessDeniedException'
+
+ def __init__(self,
+ name: str,
+ source: str,
+ region: Optional[str] = None,
+ is_sky_managed: Optional[bool] = None,
+ sync_on_reconstruction: Optional[bool] = True,
+ _bucket_sub_path: Optional[str] = None):
+ self.client: Any
+ self.bucket: StorageHandle
+ self.oci_config_file: str
+ self.config_profile: str
+ self.compartment: str
+ self.namespace: str
+
+ # Bucket region should be consistence with the OCI config file
+ region = oci.get_oci_config()['region']
+
+ super().__init__(name, source, region, is_sky_managed,
+ sync_on_reconstruction, _bucket_sub_path)
+ # TODO(zpoint): add _bucket_sub_path to the sync/mount/delete commands
+
+ def _validate(self):
+ if self.source is not None and isinstance(self.source, str):
+ if self.source.startswith('oci://'):
+ assert self.name == data_utils.split_oci_path(self.source)[0], (
+ 'OCI Bucket is specified as path, the name should be '
+ 'the same as OCI bucket.')
+ elif not re.search(r'^\w+://', self.source):
+ # Treat it as local path.
+ pass
+ else:
+ raise NotImplementedError(
+ f'Moving data from {self.source} to OCI is not supported.')
+
+ # Validate name
+ self.name = self.validate_name(self.name)
+ # Check if the storage is enabled
+ if not _is_storage_cloud_enabled(str(clouds.OCI())):
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.ResourcesUnavailableError(
+ 'Storage \'store: oci\' specified, but ' \
+ 'OCI access is disabled. To fix, enable '\
+ 'OCI by running `sky check`. '\
+ 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
+ )
+
+ @classmethod
+ def validate_name(cls, name) -> str:
+ """Validates the name of the OCI store.
+
+ Source for rules: https://docs.oracle.com/en-us/iaas/Content/Object/Tasks/managingbuckets.htm#Managing_Buckets # pylint: disable=line-too-long
+ """
+
+ def _raise_no_traceback_name_error(err_str):
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageNameError(err_str)
+
+ if name is not None and isinstance(name, str):
+ # Check for overall length
+ if not 1 <= len(name) <= 256:
+ _raise_no_traceback_name_error(
+ f'Invalid store name: name {name} must contain 1-256 '
+ 'characters.')
+
+ # Check for valid characters and start/end with a number or letter
+ pattern = r'^[A-Za-z0-9-._]+$'
+ if not re.match(pattern, name):
+ _raise_no_traceback_name_error(
+ f'Invalid store name: name {name} can only contain '
+ 'upper or lower case letters, numeric characters, hyphens '
+ '(-), underscores (_), and dots (.). Spaces are not '
+ 'allowed. Names must start and end with a number or '
+ 'letter.')
+ else:
+ _raise_no_traceback_name_error('Store name must be specified.')
+ return name
+
+ def initialize(self):
+ """Initializes the OCI store object on the cloud.
+
+ Initialization involves fetching bucket if exists, or creating it if
+ it does not.
+
+ Raises:
+ StorageBucketCreateError: If bucket creation fails
+ StorageBucketGetError: If fetching existing bucket fails
+ StorageInitError: If general initialization fails.
+ """
+ # pylint: disable=import-outside-toplevel
+ from sky.clouds.utils import oci_utils
+ from sky.provision.oci.query_utils import query_helper
+
+ self.oci_config_file = oci.get_config_file()
+ self.config_profile = oci_utils.oci_config.get_profile()
+
+ ## pylint: disable=line-too-long
+ # What's compartment? See thttps://docs.oracle.com/en/cloud/foundation/cloud_architecture/governance/compartments.html
+ self.compartment = query_helper.find_compartment(self.region)
+ self.client = oci.get_object_storage_client(region=self.region,
+ profile=self.config_profile)
+ self.namespace = self.client.get_namespace(
+ compartment_id=oci.get_oci_config()['tenancy']).data
+
+ self.bucket, is_new_bucket = self._get_bucket()
+ if self.is_sky_managed is None:
+ # If is_sky_managed is not specified, then this is a new storage
+ # object (i.e., did not exist in global_user_state) and we should
+ # set the is_sky_managed property.
+ # If is_sky_managed is specified, then we take no action.
+ self.is_sky_managed = is_new_bucket
+
+ def upload(self):
+ """Uploads source to store bucket.
+
+ Upload must be called by the Storage handler - it is not called on
+ Store initialization.
+
+ Raises:
+ StorageUploadError: if upload fails.
+ """
+ try:
+ if isinstance(self.source, list):
+ self.batch_oci_rsync(self.source, create_dirs=True)
+ elif self.source is not None:
+ if self.source.startswith('oci://'):
+ pass
+ else:
+ self.batch_oci_rsync([self.source])
+ except exceptions.StorageUploadError:
+ raise
+ except Exception as e:
+ raise exceptions.StorageUploadError(
+ f'Upload failed for store {self.name}') from e
+
+ def delete(self) -> None:
+ deleted_by_skypilot = self._delete_oci_bucket(self.name)
+ if deleted_by_skypilot:
+ msg_str = f'Deleted OCI bucket {self.name}.'
+ else:
+ msg_str = (f'OCI bucket {self.name} may have been deleted '
+ f'externally. Removing from local state.')
+ logger.info(f'{colorama.Fore.GREEN}{msg_str}'
+ f'{colorama.Style.RESET_ALL}')
+
+ def get_handle(self) -> StorageHandle:
+ return self.client.get_bucket(namespace_name=self.namespace,
+ bucket_name=self.name).data
+
+ def batch_oci_rsync(self,
+ source_path_list: List[Path],
+ create_dirs: bool = False) -> None:
+ """Invokes oci sync to batch upload a list of local paths to Bucket
+
+ Use OCI bulk operation to batch process the file upload
+
+ Args:
+ source_path_list: List of paths to local files or directories
+ create_dirs: If the local_path is a directory and this is set to
+ False, the contents of the directory are directly uploaded to
+ root of the bucket. If the local_path is a directory and this is
+ set to True, the directory is created in the bucket root and
+ contents are uploaded to it.
+ """
+
+ @oci.with_oci_env
+ def get_file_sync_command(base_dir_path, file_names):
+ includes = ' '.join(
+ [f'--include "{file_name}"' for file_name in file_names])
+ sync_command = (
+ 'oci os object bulk-upload --no-follow-symlinks --overwrite '
+ f'--bucket-name {self.name} --namespace-name {self.namespace} '
+ f'--src-dir "{base_dir_path}" {includes}')
+
+ return sync_command
+
+ @oci.with_oci_env
+ def get_dir_sync_command(src_dir_path, dest_dir_name):
+ if dest_dir_name and not str(dest_dir_name).endswith('/'):
+ dest_dir_name = f'{dest_dir_name}/'
+
+ excluded_list = storage_utils.get_excluded_files(src_dir_path)
+ excluded_list.append('.git/*')
+ excludes = ' '.join([
+ f'--exclude {shlex.quote(file_name)}'
+ for file_name in excluded_list
+ ])
+
+ # we exclude .git directory from the sync
+ sync_command = (
+ 'oci os object bulk-upload --no-follow-symlinks --overwrite '
+ f'--bucket-name {self.name} --namespace-name {self.namespace} '
+ f'--object-prefix "{dest_dir_name}" --src-dir "{src_dir_path}" '
+ f'{excludes} ')
+
+ return sync_command
+
+ # Generate message for upload
+ if len(source_path_list) > 1:
+ source_message = f'{len(source_path_list)} paths'
+ else:
+ source_message = source_path_list[0]
+
+ log_path = sky_logging.generate_tmp_logging_file_path(
+ _STORAGE_LOG_FILE_NAME)
+ sync_path = f'{source_message} -> oci://{self.name}/'
+ with rich_utils.safe_status(
+ ux_utils.spinner_message(f'Syncing {sync_path}',
+ log_path=log_path)):
+ data_utils.parallel_upload(
+ source_path_list=source_path_list,
+ filesync_command_generator=get_file_sync_command,
+ dirsync_command_generator=get_dir_sync_command,
+ log_path=log_path,
+ bucket_name=self.name,
+ access_denied_message=self._ACCESS_DENIED_MESSAGE,
+ create_dirs=create_dirs,
+ max_concurrent_uploads=1)
+
+ logger.info(
+ ux_utils.finishing_message(f'Storage synced: {sync_path}',
+ log_path))
+
+ def _get_bucket(self) -> Tuple[StorageHandle, bool]:
+ """Obtains the OCI bucket.
+ If the bucket exists, this method will connect to the bucket.
+
+ If the bucket does not exist, there are three cases:
+ 1) Raise an error if the bucket source starts with oci://
+ 2) Return None if bucket has been externally deleted and
+ sync_on_reconstruction is False
+ 3) Create and return a new bucket otherwise
+
+ Return tuple (Bucket, Boolean): The first item is the bucket
+ json payload from the OCI API call, the second item indicates
+ if this is a new created bucket(True) or an existing bucket(False).
+
+ Raises:
+ StorageBucketCreateError: If creating the bucket fails
+ StorageBucketGetError: If fetching a bucket fails
+ """
+ try:
+ get_bucket_response = self.client.get_bucket(
+ namespace_name=self.namespace, bucket_name=self.name)
+ bucket = get_bucket_response.data
+ return bucket, False
+ except oci.service_exception() as e:
+ if e.status == 404: # Not Found
+ if isinstance(self.source,
+ str) and self.source.startswith('oci://'):
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketGetError(
+ 'Attempted to connect to a non-existent bucket: '
+ f'{self.source}') from e
+ else:
+ # If bucket cannot be found (i.e., does not exist), it is
+ # to be created by Sky. However, creation is skipped if
+ # Store object is being reconstructed for deletion.
+ if self.sync_on_reconstruction:
+ bucket = self._create_oci_bucket(self.name)
+ return bucket, True
+ else:
+ return None, False
+ elif e.status == 401: # Unauthorized
+ # AccessDenied error for buckets that are private and not
+ # owned by user.
+ command = (
+ f'oci os object list --namespace-name {self.namespace} '
+ f'--bucket-name {self.name}')
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketGetError(
+ _BUCKET_FAIL_TO_CONNECT_MESSAGE.format(name=self.name) +
+ f' To debug, consider running `{command}`.') from e
+ else:
+ # Unknown / unexpected error happened. This might happen when
+ # Object storage service itself functions not normal (e.g.
+ # maintainance event causes internal server error or request
+ # timeout, etc).
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketGetError(
+ f'Failed to connect to OCI bucket {self.name}') from e
+
+ def mount_command(self, mount_path: str) -> str:
+ """Returns the command to mount the bucket to the mount_path.
+
+ Uses Rclone to mount the bucket.
+
+ Args:
+ mount_path: str; Path to mount the bucket to.
+ """
+ install_cmd = mounting_utils.get_rclone_install_cmd()
+ mount_cmd = mounting_utils.get_oci_mount_cmd(
+ mount_path=mount_path,
+ store_name=self.name,
+ region=str(self.region),
+ namespace=self.namespace,
+ compartment=self.bucket.compartment_id,
+ config_file=self.oci_config_file,
+ config_profile=self.config_profile)
+ version_check_cmd = mounting_utils.get_rclone_version_check_cmd()
+
+ return mounting_utils.get_mounting_command(mount_path, install_cmd,
+ mount_cmd, version_check_cmd)
+
+ def _download_file(self, remote_path: str, local_path: str) -> None:
+ """Downloads file from remote to local on OCI bucket
+
+ Args:
+ remote_path: str; Remote path on OCI bucket
+ local_path: str; Local path on user's device
+ """
+ if remote_path.startswith(f'/{self.name}'):
+ # If the remote path is /bucket_name, we need to
+ # remove the leading /
+ remote_path = remote_path.lstrip('/')
+
+ filename = os.path.basename(remote_path)
+ if not local_path.endswith(filename):
+ local_path = os.path.join(local_path, filename)
+
+ @oci.with_oci_env
+ def get_file_download_command(remote_path, local_path):
+ download_command = (f'oci os object get --bucket-name {self.name} '
+ f'--namespace-name {self.namespace} '
+ f'--name {remote_path} --file {local_path}')
+
+ return download_command
+
+ download_command = get_file_download_command(remote_path, local_path)
+
+ try:
+ with rich_utils.safe_status(
+ f'[bold cyan]Downloading: {remote_path} -> {local_path}[/]'
+ ):
+ subprocess.check_output(download_command,
+ stderr=subprocess.STDOUT,
+ shell=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f'Download failed: {remote_path} -> {local_path}.\n'
+ f'Detail errors: {e.output}')
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketDeleteError(
+ f'Failed download file {self.name}:{remote_path}.') from e
+
+ def _create_oci_bucket(self, bucket_name: str) -> StorageHandle:
+ """Creates OCI bucket with specific name in specific region
+
+ Args:
+ bucket_name: str; Name of bucket
+ region: str; Region name, e.g. us-central1, us-west1
+ """
+ logger.debug(f'_create_oci_bucket: {bucket_name}')
+ try:
+ create_bucket_response = self.client.create_bucket(
+ namespace_name=self.namespace,
+ create_bucket_details=oci.oci.object_storage.models.
+ CreateBucketDetails(
+ name=bucket_name,
+ compartment_id=self.compartment,
+ ))
+ bucket = create_bucket_response.data
+ return bucket
+ except oci.service_exception() as e:
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketCreateError(
+ f'Failed to create OCI bucket: {self.name}') from e
+
+ def _delete_oci_bucket(self, bucket_name: str) -> bool:
+ """Deletes OCI bucket, including all objects in bucket
+
+ Args:
+ bucket_name: str; Name of bucket
+
+ Returns:
+ bool; True if bucket was deleted, False if it was deleted externally.
+ """
+ logger.debug(f'_delete_oci_bucket: {bucket_name}')
+
+ @oci.with_oci_env
+ def get_bucket_delete_command(bucket_name):
+ remove_command = (f'oci os bucket delete --bucket-name '
+ f'{bucket_name} --empty --force')
+
+ return remove_command
+
+ remove_command = get_bucket_delete_command(bucket_name)
+
+ try:
+ with rich_utils.safe_status(
+ f'[bold cyan]Deleting OCI bucket {bucket_name}[/]'):
+ subprocess.check_output(remove_command.split(' '),
+ stderr=subprocess.STDOUT)
+ except subprocess.CalledProcessError as e:
+ if 'BucketNotFound' in e.output.decode('utf-8'):
+ logger.debug(
+ _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format(
+ bucket_name=bucket_name))
+ return False
+ else:
+ logger.error(e.output)
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.StorageBucketDeleteError(
+ f'Failed to delete OCI bucket {bucket_name}.')
+ return True
diff --git a/sky/jobs/api/core.py b/sky/jobs/api/core.py
index b966ae6fafd..ac57a65cba0 100644
--- a/sky/jobs/api/core.py
+++ b/sky/jobs/api/core.py
@@ -49,6 +49,7 @@ def launch(
name: Optional[str] = None,
stream_logs: bool = True,
retry_until_up: bool = False,
+ # TODO(cooperc): remove fast arg before 0.8.0
fast: bool = False,
) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
@@ -61,9 +62,8 @@ def launch(
managed job.
name: Name of the managed job.
detach_run: Whether to detach the run.
- fast: Whether to use sky.launch(fast=True) for the jobs controller. If
- True, the SkyPilot wheel and the cloud credentials may not be updated
- on the jobs controller.
+ fast: [Deprecated] Does nothing, and will be removed soon. We will
+ always use fast mode as it's fully safe now.
Returns:
- Job ID for the managed job
@@ -372,8 +372,8 @@ def cancel(name: Optional[str] = None,
stopped_message='All managed jobs should have finished.')
job_id_str = ','.join(map(str, job_ids))
- if sum([len(job_ids) > 0, name is not None, all]) != 1:
- argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else ''
+ if sum([bool(job_ids), name is not None, all]) != 1:
+ argument_str = f'job_ids={job_id_str}' if job_ids else ''
argument_str += f' name={name}' if name is not None else ''
argument_str += ' all' if all else ''
with ux_utils.print_exception_no_traceback():
diff --git a/sky/jobs/state.py b/sky/jobs/state.py
index 9a5ab4b3cad..31dcfcfd5eb 100644
--- a/sky/jobs/state.py
+++ b/sky/jobs/state.py
@@ -591,7 +591,7 @@ def get_latest_task_id_status(
If the job_id does not exist, (None, None) will be returned.
"""
id_statuses = _get_all_task_ids_statuses(job_id)
- if len(id_statuses) == 0:
+ if not id_statuses:
return None, None
task_id, status = id_statuses[-1]
for task_id, status in id_statuses:
@@ -617,7 +617,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
WHERE spot_job_id=(?)
ORDER BY task_id ASC""", (job_id,)).fetchall()
reason = [r[0] for r in reason if r[0] is not None]
- if len(reason) == 0:
+ if not reason:
return None
return reason[0]
diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py
index fceb00d1147..e8cbf7f3d37 100644
--- a/sky/jobs/utils.py
+++ b/sky/jobs/utils.py
@@ -229,11 +229,11 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
if job_ids is None:
job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None)
job_ids = list(set(job_ids))
- if len(job_ids) == 0:
+ if not job_ids:
return 'No job to cancel.'
job_id_str = ', '.join(map(str, job_ids))
logger.info(f'Cancelling jobs {job_id_str}.')
- cancelled_job_ids = []
+ cancelled_job_ids: List[int] = []
for job_id in job_ids:
# Check the status of the managed job status. If it is in
# terminal state, we can safely skip it.
@@ -263,7 +263,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
shutil.copy(str(signal_file), str(legacy_signal_file))
cancelled_job_ids.append(job_id)
- if len(cancelled_job_ids) == 0:
+ if not cancelled_job_ids:
return 'No job to cancel.'
identity_str = f'Job with ID {cancelled_job_ids[0]} is'
if len(cancelled_job_ids) > 1:
@@ -276,7 +276,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
def cancel_job_by_name(job_name: str) -> str:
"""Cancel a job by name."""
job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name)
- if len(job_ids) == 0:
+ if not job_ids:
return f'No running job found with name {job_name!r}.'
if len(job_ids) > 1:
return (f'{colorama.Fore.RED}Multiple running jobs found '
@@ -511,7 +511,7 @@ def stream_logs(job_id: Optional[int],
for job in managed_jobs
if job['job_name'] == job_name
}
- if len(managed_job_ids) == 0:
+ if not managed_job_ids:
return f'No managed job found with name {job_name!r}.'
if len(managed_job_ids) > 1:
job_ids_str = ', '.join(
@@ -537,7 +537,7 @@ def stream_logs(job_id: Optional[int],
if job_id is None:
assert job_name is not None
job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name)
- if len(job_ids) == 0:
+ if not job_ids:
return f'No running managed job found with name {job_name!r}.'
if len(job_ids) > 1:
raise ValueError(
diff --git a/sky/optimizer.py b/sky/optimizer.py
index b61492860c9..95bf6e70e98 100644
--- a/sky/optimizer.py
+++ b/sky/optimizer.py
@@ -181,7 +181,7 @@ def _remove_dummy_source_sink_nodes(dag: 'dag_lib.Dag'):
"""Removes special Source and Sink nodes."""
source = [t for t in dag.tasks if t.name == _DUMMY_SOURCE_NAME]
sink = [t for t in dag.tasks if t.name == _DUMMY_SINK_NAME]
- if len(source) == len(sink) == 0:
+ if not source and not sink:
return
assert len(source) == len(sink) == 1, dag.tasks
dag.remove(source[0])
@@ -1294,7 +1294,7 @@ def _fill_in_launchable_resources(
resources, num_nodes=task.num_nodes)
if feasible_resources.hint is not None:
hints[cloud] = feasible_resources.hint
- if len(feasible_resources.resources_list) > 0:
+ if feasible_resources.resources_list:
# Assume feasible_resources is sorted by prices. Guaranteed by
# the implementation of get_feasible_launchable_resources and
# the underlying service_catalog filtering
@@ -1306,7 +1306,7 @@ def _fill_in_launchable_resources(
else:
all_fuzzy_candidates.update(
feasible_resources.fuzzy_candidate_list)
- if len(launchable[resources]) == 0:
+ if not launchable[resources]:
clouds_str = str(clouds_list) if len(clouds_list) > 1 else str(
clouds_list[0])
num_node_str = ''
diff --git a/sky/provision/aws/config.py b/sky/provision/aws/config.py
index 6a8c77eafed..ffa87c3a011 100644
--- a/sky/provision/aws/config.py
+++ b/sky/provision/aws/config.py
@@ -279,7 +279,7 @@ def _has_igw_route(route_tables):
logger.debug(f'subnet {subnet_id} route tables: {route_tables}')
if _has_igw_route(route_tables):
return True
- if len(route_tables) > 0:
+ if route_tables:
return False
# Handle the case that a "main" route table is implicitly associated with
@@ -454,7 +454,7 @@ def _vpc_id_from_security_group_ids(ec2, sg_ids: List[str]) -> Any:
no_sg_msg = ('Failed to detect a security group with id equal to any of '
'the configured SecurityGroupIds.')
- assert len(vpc_ids) > 0, no_sg_msg
+ assert vpc_ids, no_sg_msg
return vpc_ids[0]
diff --git a/sky/provision/do/__init__.py b/sky/provision/do/__init__.py
new file mode 100644
index 00000000000..75502d3cb05
--- /dev/null
+++ b/sky/provision/do/__init__.py
@@ -0,0 +1,11 @@
+"""DO provisioner for SkyPilot."""
+
+from sky.provision.do.config import bootstrap_instances
+from sky.provision.do.instance import cleanup_ports
+from sky.provision.do.instance import get_cluster_info
+from sky.provision.do.instance import open_ports
+from sky.provision.do.instance import query_instances
+from sky.provision.do.instance import run_instances
+from sky.provision.do.instance import stop_instances
+from sky.provision.do.instance import terminate_instances
+from sky.provision.do.instance import wait_instances
diff --git a/sky/provision/do/config.py b/sky/provision/do/config.py
new file mode 100644
index 00000000000..0b10f7f7698
--- /dev/null
+++ b/sky/provision/do/config.py
@@ -0,0 +1,14 @@
+"""Paperspace configuration bootstrapping."""
+
+from sky import sky_logging
+from sky.provision import common
+
+logger = sky_logging.init_logger(__name__)
+
+
+def bootstrap_instances(
+ region: str, cluster_name: str,
+ config: common.ProvisionConfig) -> common.ProvisionConfig:
+ """Bootstraps instances for the given cluster."""
+ del region, cluster_name
+ return config
diff --git a/sky/provision/do/constants.py b/sky/provision/do/constants.py
new file mode 100644
index 00000000000..0010646f873
--- /dev/null
+++ b/sky/provision/do/constants.py
@@ -0,0 +1,10 @@
+"""DO cloud constants
+"""
+
+POLL_INTERVAL = 5
+WAIT_DELETE_VOLUMES = 5
+
+GPU_IMAGES = {
+ 'gpu-h100x1-80gb': 'gpu-h100x1-base',
+ 'gpu-h100x8-640gb': 'gpu-h100x8-base',
+}
diff --git a/sky/provision/do/instance.py b/sky/provision/do/instance.py
new file mode 100644
index 00000000000..763b7644890
--- /dev/null
+++ b/sky/provision/do/instance.py
@@ -0,0 +1,287 @@
+"""DigitalOcean instance provisioning."""
+
+import time
+from typing import Any, Dict, List, Optional
+import uuid
+
+from sky import sky_logging
+from sky.provision import common
+from sky.provision.do import constants
+from sky.provision.do import utils
+from sky.utils import status_lib
+
+# The maximum number of times to poll for the status of an operation
+MAX_POLLS = 60 // constants.POLL_INTERVAL
+# Stopping instances can take several minutes, so we increase the timeout
+MAX_POLLS_FOR_UP_OR_STOP = MAX_POLLS * 8
+
+logger = sky_logging.init_logger(__name__)
+
+
+def _get_head_instance(
+ instances: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+ for instance_name, instance_meta in instances.items():
+ if instance_name.endswith('-head'):
+ return instance_meta
+ return None
+
+
+def run_instances(region: str, cluster_name_on_cloud: str,
+ config: common.ProvisionConfig) -> common.ProvisionRecord:
+ """Runs instances for the given cluster."""
+
+ pending_status = ['new']
+ newly_started_instances = utils.filter_instances(cluster_name_on_cloud,
+ pending_status + ['off'])
+ while True:
+ instances = utils.filter_instances(cluster_name_on_cloud,
+ pending_status)
+ if not instances:
+ break
+ instance_statuses = [
+ instance['status'] for instance in instances.values()
+ ]
+ logger.info(f'Waiting for {len(instances)} instances to be ready: '
+ f'{instance_statuses}')
+ time.sleep(constants.POLL_INTERVAL)
+
+ exist_instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=pending_status +
+ ['active', 'off'])
+ if len(exist_instances) > config.count:
+ raise RuntimeError(
+ f'Cluster {cluster_name_on_cloud} already has '
+ f'{len(exist_instances)} nodes, but {config.count} are required.')
+
+ stopped_instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=['off'])
+ for instance in stopped_instances.values():
+ utils.start_instance(instance)
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
+ instances = utils.filter_instances(cluster_name_on_cloud, ['off'])
+ if len(instances) == 0:
+ break
+ num_stopped_instances = len(stopped_instances)
+ num_restarted_instances = num_stopped_instances - len(instances)
+ logger.info(
+ f'Waiting for {num_restarted_instances}/{num_stopped_instances} '
+ 'stopped instances to be restarted.')
+ time.sleep(constants.POLL_INTERVAL)
+ else:
+ msg = ('run_instances: Failed to restart all'
+ 'instances possibly due to to capacity issue.')
+ logger.warning(msg)
+ raise RuntimeError(msg)
+
+ exist_instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=['active'])
+ head_instance = _get_head_instance(exist_instances)
+ to_start_count = config.count - len(exist_instances)
+ if to_start_count < 0:
+ raise RuntimeError(
+ f'Cluster {cluster_name_on_cloud} already has '
+ f'{len(exist_instances)} nodes, but {config.count} are required.')
+ if to_start_count == 0:
+ if head_instance is None:
+ head_instance = list(exist_instances.values())[0]
+ utils.rename_instance(
+ head_instance,
+ f'{cluster_name_on_cloud}-{uuid.uuid4().hex[:4]}-head')
+ assert head_instance is not None, ('`head_instance` should not be None')
+ logger.info(f'Cluster {cluster_name_on_cloud} already has '
+ f'{len(exist_instances)} nodes, no need to start more.')
+ return common.ProvisionRecord(
+ provider_name='do',
+ cluster_name=cluster_name_on_cloud,
+ region=region,
+ zone=None,
+ head_instance_id=head_instance['name'],
+ resumed_instance_ids=list(newly_started_instances.keys()),
+ created_instance_ids=[],
+ )
+
+ created_instances: List[Dict[str, Any]] = []
+ for _ in range(to_start_count):
+ instance_type = 'head' if head_instance is None else 'worker'
+ instance = utils.create_instance(
+ region=region,
+ cluster_name_on_cloud=cluster_name_on_cloud,
+ instance_type=instance_type,
+ config=config)
+ logger.info(f'Launched instance {instance["name"]}.')
+ created_instances.append(instance)
+ if head_instance is None:
+ head_instance = instance
+
+ # Wait for instances to be ready.
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
+ instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=['active'])
+ logger.info('Waiting for instances to be ready: '
+ f'({len(instances)}/{config.count}).')
+ if len(instances) == config.count:
+ break
+
+ time.sleep(constants.POLL_INTERVAL)
+ else:
+ # Failed to launch config.count of instances after max retries
+ msg = 'run_instances: Failed to create the instances'
+ logger.warning(msg)
+ raise RuntimeError(msg)
+ assert head_instance is not None, 'head_instance should not be None'
+ return common.ProvisionRecord(
+ provider_name='do',
+ cluster_name=cluster_name_on_cloud,
+ region=region,
+ zone=None,
+ head_instance_id=head_instance['name'],
+ resumed_instance_ids=list(stopped_instances.keys()),
+ created_instance_ids=[
+ instance['name'] for instance in created_instances
+ ],
+ )
+
+
+def wait_instances(region: str, cluster_name_on_cloud: str,
+ state: Optional[status_lib.ClusterStatus]) -> None:
+ del region, cluster_name_on_cloud, state # unused
+ # We already wait on ready state in `run_instances` no need
+
+
+def stop_instances(
+ cluster_name_on_cloud: str,
+ provider_config: Optional[Dict[str, Any]] = None,
+ worker_only: bool = False,
+) -> None:
+ del provider_config # unused
+ all_instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=None)
+ num_instances = len(all_instances)
+
+ # Request a stop on all instances
+ for instance_name, instance_meta in all_instances.items():
+ if worker_only and instance_name.endswith('-head'):
+ num_instances -= 1
+ continue
+ utils.stop_instance(instance_meta)
+
+ # Wait for instances to stop
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
+ all_instances = utils.filter_instances(cluster_name_on_cloud, ['off'])
+ if len(all_instances) >= num_instances:
+ break
+ time.sleep(constants.POLL_INTERVAL)
+ else:
+ raise RuntimeError(f'Maximum number of polls: '
+ f'{MAX_POLLS_FOR_UP_OR_STOP} reached. '
+ f'Instance {all_instances} is still not in '
+ 'STOPPED status.')
+
+
+def terminate_instances(
+ cluster_name_on_cloud: str,
+ provider_config: Optional[Dict[str, Any]] = None,
+ worker_only: bool = False,
+) -> None:
+ """See sky/provision/__init__.py"""
+ del provider_config # unused
+ instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=None)
+ for instance_name, instance_meta in instances.items():
+ logger.debug(f'Terminating instance {instance_name}')
+ if worker_only and instance_name.endswith('-head'):
+ continue
+ utils.down_instance(instance_meta)
+
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
+ instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=None)
+ if len(instances) == 0 or len(instances) <= 1 and worker_only:
+ break
+ time.sleep(constants.POLL_INTERVAL)
+ else:
+ msg = ('Failed to delete all instances')
+ logger.warning(msg)
+ raise RuntimeError(msg)
+
+
+def get_cluster_info(
+ region: str,
+ cluster_name_on_cloud: str,
+ provider_config: Optional[Dict[str, Any]] = None,
+) -> common.ClusterInfo:
+ del region # unused
+ running_instances = utils.filter_instances(cluster_name_on_cloud,
+ ['active'])
+ instances: Dict[str, List[common.InstanceInfo]] = {}
+ head_instance: Optional[str] = None
+ for instance_name, instance_meta in running_instances.items():
+ if instance_name.endswith('-head'):
+ head_instance = instance_name
+ for net in instance_meta['networks']['v4']:
+ if net['type'] == 'public':
+ instance_ip = net['ip_address']
+ break
+ instances[instance_name] = [
+ common.InstanceInfo(
+ instance_id=instance_meta['name'],
+ internal_ip=instance_ip,
+ external_ip=instance_ip,
+ ssh_port=22,
+ tags={},
+ )
+ ]
+
+ assert head_instance is not None, 'no head instance found'
+ return common.ClusterInfo(
+ instances=instances,
+ head_instance_id=head_instance,
+ provider_name='do',
+ provider_config=provider_config,
+ )
+
+
+def query_instances(
+ cluster_name_on_cloud: str,
+ provider_config: Optional[Dict[str, Any]] = None,
+ non_terminated_only: bool = True,
+) -> Dict[str, Optional[status_lib.ClusterStatus]]:
+ """See sky/provision/__init__.py"""
+ # terminated instances are not retrieved by the
+ # API making `non_terminated_only` argument moot.
+ del non_terminated_only
+ assert provider_config is not None, (cluster_name_on_cloud, provider_config)
+ instances = utils.filter_instances(cluster_name_on_cloud,
+ status_filters=None)
+
+ status_map = {
+ 'new': status_lib.ClusterStatus.INIT,
+ 'archive': status_lib.ClusterStatus.INIT,
+ 'active': status_lib.ClusterStatus.UP,
+ 'off': status_lib.ClusterStatus.STOPPED,
+ }
+ statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {}
+ for instance_meta in instances.values():
+ status = status_map[instance_meta['status']]
+ statuses[instance_meta['name']] = status
+ return statuses
+
+
+def open_ports(
+ cluster_name_on_cloud: str,
+ ports: List[str],
+ provider_config: Optional[Dict[str, Any]] = None,
+) -> None:
+ """See sky/provision/__init__.py"""
+ logger.debug(
+ f'Skip opening ports {ports} for DigitalOcean instances, as all '
+ 'ports are open by default.')
+ del cluster_name_on_cloud, provider_config, ports
+
+
+def cleanup_ports(
+ cluster_name_on_cloud: str,
+ ports: List[str],
+ provider_config: Optional[Dict[str, Any]] = None,
+) -> None:
+ del cluster_name_on_cloud, provider_config, ports
diff --git a/sky/provision/do/utils.py b/sky/provision/do/utils.py
new file mode 100644
index 00000000000..f706007134e
--- /dev/null
+++ b/sky/provision/do/utils.py
@@ -0,0 +1,301 @@
+"""DigitalOcean API client wrapper for SkyPilot.
+
+Example usage of `pydo` client library was mostly taken from here:
+https://github.com/digitalocean/pydo/blob/main/examples/poc_droplets_volumes_sshkeys.py
+"""
+
+import copy
+import os
+from typing import Any, Dict, List, Optional
+import urllib
+import uuid
+
+from sky import sky_logging
+from sky.adaptors import do
+from sky.provision import common
+from sky.provision import constants as provision_constants
+from sky.provision.do import constants
+from sky.utils import common_utils
+
+logger = sky_logging.init_logger(__name__)
+
+POSSIBLE_CREDENTIALS_PATHS = [
+ os.path.expanduser(
+ '~/Library/Application Support/doctl/config.yaml'), # OS X
+ os.path.expanduser(
+ os.path.join(os.getenv('XDG_CONFIG_HOME', '~/.config/'),
+ 'doctl/config.yaml')), # Linux
+]
+INITIAL_BACKOFF_SECONDS = 10
+MAX_BACKOFF_FACTOR = 10
+MAX_ATTEMPTS = 6
+SSH_KEY_NAME_ON_DO = f'sky-key-{common_utils.get_user_hash()}'
+
+CREDENTIALS_PATH = '~/.config/doctl/config.yaml'
+_client = None
+_ssh_key_id = None
+
+
+class DigitalOceanError(Exception):
+ pass
+
+
+def _init_client():
+ global _client, CREDENTIALS_PATH
+ assert _client is None
+ CREDENTIALS_PATH = None
+ credentials_found = 0
+ for path in POSSIBLE_CREDENTIALS_PATHS:
+ if os.path.exists(path):
+ CREDENTIALS_PATH = path
+ credentials_found += 1
+ logger.debug(f'Digital Ocean credential path found at {path}')
+ if not credentials_found > 1:
+ logger.debug('more than 1 credential file found')
+ if CREDENTIALS_PATH is None:
+ raise DigitalOceanError(
+ 'no credentials file found from '
+ f'the following paths {POSSIBLE_CREDENTIALS_PATHS}')
+
+ # attempt default context
+ credentials = common_utils.read_yaml(CREDENTIALS_PATH)
+ default_token = credentials.get('access-token', None)
+ if default_token is not None:
+ try:
+ test_client = do.pydo.Client(token=default_token)
+ test_client.droplets.list()
+ logger.debug('trying `default` context')
+ _client = test_client
+ return _client
+ except do.exceptions().HttpResponseError:
+ pass
+
+ auth_contexts = credentials.get('auth-contexts', None)
+ if auth_contexts is not None:
+ for context, api_token in auth_contexts.items():
+ try:
+ test_client = do.pydo.Client(token=api_token)
+ test_client.droplets.list()
+ logger.debug(f'using {context} context')
+ _client = test_client
+ break
+ except do.exceptions().HttpResponseError:
+ continue
+ else:
+ raise DigitalOceanError(
+ 'no valid api tokens found try '
+ 'setting a new API token with `doctl auth init`')
+ return _client
+
+
+def client():
+ global _client
+ if _client is None:
+ _client = _init_client()
+ return _client
+
+
+def ssh_key_id(public_key: str):
+ global _ssh_key_id
+ if _ssh_key_id is None:
+ page = 1
+ paginated = True
+ while paginated:
+ try:
+ resp = client().ssh_keys.list(per_page=50, page=page)
+ for ssh_key in resp['ssh_keys']:
+ if ssh_key['public_key'] == public_key:
+ _ssh_key_id = ssh_key
+ return _ssh_key_id
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: '
+ f'{err.error.message}') from err
+
+ pages = resp['links']
+ if 'pages' in pages and 'next' in pages['pages']:
+ pages = pages['pages']
+ parsed_url = urllib.parse.urlparse(pages['next'])
+ page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
+ else:
+ paginated = False
+
+ request = {
+ 'public_key': public_key,
+ 'name': SSH_KEY_NAME_ON_DO,
+ }
+ _ssh_key_id = client().ssh_keys.create(body=request)['ssh_key']
+ return _ssh_key_id
+
+
+def _create_volume(request: Dict[str, Any]) -> Dict[str, Any]:
+ try:
+ resp = client().volumes.create(body=request)
+ volume = resp['volume']
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+ else:
+ return volume
+
+
+def _create_droplet(request: Dict[str, Any]) -> Dict[str, Any]:
+ try:
+ resp = client().droplets.create(body=request)
+ droplet_id = resp['droplet']['id']
+
+ get_resp = client().droplets.get(droplet_id)
+ droplet = get_resp['droplet']
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+ return droplet
+
+
+def create_instance(region: str, cluster_name_on_cloud: str, instance_type: str,
+ config: common.ProvisionConfig) -> Dict[str, Any]:
+ """Creates a instance and mounts the requested block storage
+
+ Args:
+ region (str): instance region
+ instance_name (str): name of instance
+ config (common.ProvisionConfig): provisioner configuration
+
+ Returns:
+ Dict[str, Any]: instance metadata
+ """
+ # sort tags by key to support deterministic unit test stubbing
+ tags = dict(sorted(copy.deepcopy(config.tags).items()))
+ tags = {
+ 'Name': cluster_name_on_cloud,
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
+ **tags
+ }
+ tags = [f'{key}:{value}' for key, value in tags.items()]
+ default_image = constants.GPU_IMAGES.get(
+ config.node_config['InstanceType'],
+ 'gpu-h100x1-base',
+ )
+ image_id = config.node_config['ImageId']
+ image_id = image_id if image_id is not None else default_image
+ instance_name = (f'{cluster_name_on_cloud}-'
+ f'{uuid.uuid4().hex[:4]}-{instance_type}')
+ instance_request = {
+ 'name': instance_name,
+ 'region': region,
+ 'size': config.node_config['InstanceType'],
+ 'image': image_id,
+ 'ssh_keys': [
+ ssh_key_id(
+ config.authentication_config['ssh_public_key'])['fingerprint']
+ ],
+ 'tags': tags,
+ }
+ instance = _create_droplet(instance_request)
+
+ volume_request = {
+ 'size_gigabytes': config.node_config['DiskSize'],
+ 'name': instance_name,
+ 'region': region,
+ 'filesystem_type': 'ext4',
+ 'tags': tags
+ }
+ volume = _create_volume(volume_request)
+
+ attach_request = {'type': 'attach', 'droplet_id': instance['id']}
+ try:
+ client().volume_actions.post_by_id(volume['id'], attach_request)
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+ logger.debug(f'{instance_name} created')
+ return instance
+
+
+def start_instance(instance: Dict[str, Any]):
+ try:
+ client().droplet_actions.post(droplet_id=instance['id'],
+ body={'type': 'power_on'})
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+
+
+def stop_instance(instance: Dict[str, Any]):
+ try:
+ client().droplet_actions.post(
+ droplet_id=instance['id'],
+ body={'type': 'shutdown'},
+ )
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+
+
+def down_instance(instance: Dict[str, Any]):
+ # We use dangerous destroy to atomically delete
+ # block storage and instance for autodown
+ try:
+ client().droplets.destroy_with_associated_resources_dangerous(
+ droplet_id=instance['id'], x_dangerous=True)
+ except do.exceptions().HttpResponseError as err:
+ if 'a destroy is already in progress' in err.error.message:
+ return
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+
+
+def rename_instance(instance: Dict[str, Any], new_name: str):
+ try:
+ client().droplet_actions.rename(droplet=instance['id'],
+ body={
+ 'type': 'rename',
+ 'name': new_name
+ })
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+
+
+def filter_instances(
+ cluster_name_on_cloud: str,
+ status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
+ """Returns Dict mapping instance name
+ to instance metadata filtered by status
+ """
+
+ filtered_instances: Dict[str, Any] = {}
+ page = 1
+ paginated = True
+ while paginated:
+ try:
+ resp = client().droplets.list(
+ tag_name=f'{provision_constants.TAG_SKYPILOT_CLUSTER_NAME}:'
+ f'{cluster_name_on_cloud}',
+ per_page=50,
+ page=page)
+ for instance in resp['droplets']:
+ if status_filters is None or instance[
+ 'status'] in status_filters:
+ filtered_instances[instance['name']] = instance
+ except do.exceptions().HttpResponseError as err:
+ raise DigitalOceanError(
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
+ ) from err
+
+ pages = resp['links']
+ if 'pages' in pages and 'next' in pages['pages']:
+ pages = pages['pages']
+ parsed_url = urllib.parse.urlparse(pages['next'])
+ page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
+ else:
+ paginated = False
+ return filtered_instances
diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py
index c55508ab41a..848c7a06983 100644
--- a/sky/provision/docker_utils.py
+++ b/sky/provision/docker_utils.py
@@ -338,14 +338,20 @@ def _check_docker_installed(self):
no_exist = 'NoExist'
# SkyPilot: Add the current user to the docker group first (if needed),
# before checking if docker is installed to avoid permission issues.
- cleaned_output = self._run(
- 'id -nG $USER | grep -qw docker || '
- 'sudo usermod -aG docker $USER > /dev/null 2>&1;'
- f'command -v {self.docker_cmd} || echo {no_exist!r}')
- if no_exist in cleaned_output or 'docker' not in cleaned_output:
- logger.error(
- f'{self.docker_cmd.capitalize()} not installed. Please use an '
- f'image with {self.docker_cmd.capitalize()} installed.')
+ docker_cmd = ('id -nG $USER | grep -qw docker || '
+ 'sudo usermod -aG docker $USER > /dev/null 2>&1;'
+ f'command -v {self.docker_cmd} || echo {no_exist!r}')
+ cleaned_output = self._run(docker_cmd)
+ timeout = 60 * 10 # 10 minute timeout
+ start = time.time()
+ while no_exist in cleaned_output or 'docker' not in cleaned_output:
+ if time.time() - start > timeout:
+ logger.error(
+ f'{self.docker_cmd.capitalize()} not installed. Please use '
+ f'an image with {self.docker_cmd.capitalize()} installed.')
+ return
+ time.sleep(5)
+ cleaned_output = self._run(docker_cmd)
def _check_container_status(self):
if self.initialized:
diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py
index a8292669a7c..a99267eb0b9 100644
--- a/sky/provision/gcp/config.py
+++ b/sky/provision/gcp/config.py
@@ -397,7 +397,7 @@ def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str,
operation = compute.networks().getEffectiveFirewalls(project=project_id,
network=vpc_name)
response = operation.execute()
- if len(response) == 0:
+ if not response:
return False
effective_rules = response['firewalls']
@@ -515,7 +515,7 @@ def _create_rules(project_id: str, compute, rules, vpc_name):
rule_list = _list_firewall_rules(project_id,
compute,
filter=f'(name={rule_name})')
- if len(rule_list) > 0:
+ if rule_list:
_delete_firewall_rule(project_id, compute, rule_name)
body = rule.copy()
@@ -624,7 +624,7 @@ def get_usable_vpc_and_subnet(
vpc_list = _list_vpcnets(project_id,
compute,
filter=f'name={constants.SKYPILOT_VPC_NAME}')
- if len(vpc_list) == 0:
+ if not vpc_list:
body = constants.VPC_TEMPLATE.copy()
body['name'] = body['name'].format(VPC_NAME=constants.SKYPILOT_VPC_NAME)
body['selfLink'] = body['selfLink'].format(
diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py
index 4f442709b0c..7b3fd4046b5 100644
--- a/sky/provision/gcp/constants.py
+++ b/sky/provision/gcp/constants.py
@@ -142,7 +142,7 @@
]
# A list of permissions required to run SkyPilot on GCP.
-# Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html # pylint: disable=line-too-long
+# Keep this in sync with https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/gcp.html # pylint: disable=line-too-long
VM_MINIMAL_PERMISSIONS = [
'compute.disks.create',
'compute.disks.list',
diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py
index 370430720f0..0fe920be9d6 100644
--- a/sky/provision/kubernetes/config.py
+++ b/sky/provision/kubernetes/config.py
@@ -232,7 +232,7 @@ def _get_resource(container_resources: Dict[str, Any], resource_name: str,
# Look for keys containing the resource_name. For example,
# the key 'nvidia.com/gpu' contains the key 'gpu'.
matching_keys = [key for key in resources if resource_name in key.lower()]
- if len(matching_keys) == 0:
+ if not matching_keys:
return float('inf')
if len(matching_keys) > 1:
# Should have only one match -- mostly relevant for gpu.
@@ -265,7 +265,7 @@ def _configure_autoscaler_service_account(
field_selector = f'metadata.name={name}'
accounts = (kubernetes.core_api(context).list_namespaced_service_account(
namespace, field_selector=field_selector).items)
- if len(accounts) > 0:
+ if accounts:
assert len(accounts) == 1
# Nothing to check for equality and patch here,
# since the service_account.metadata.name is the only important
@@ -308,7 +308,7 @@ def _configure_autoscaler_role(namespace: str, context: Optional[str],
field_selector = f'metadata.name={name}'
roles = (kubernetes.auth_api(context).list_namespaced_role(
namespace, field_selector=field_selector).items)
- if len(roles) > 0:
+ if roles:
assert len(roles) == 1
existing_role = roles[0]
# Convert to k8s object to compare
@@ -374,7 +374,7 @@ def _configure_autoscaler_role_binding(
field_selector = f'metadata.name={name}'
role_bindings = (kubernetes.auth_api(context).list_namespaced_role_binding(
rb_namespace, field_selector=field_selector).items)
- if len(role_bindings) > 0:
+ if role_bindings:
assert len(role_bindings) == 1
existing_binding = role_bindings[0]
new_rb = kubernetes_utils.dict_to_k8s_object(binding, 'V1RoleBinding')
@@ -415,7 +415,7 @@ def _configure_autoscaler_cluster_role(namespace, context,
field_selector = f'metadata.name={name}'
cluster_roles = (kubernetes.auth_api(context).list_cluster_role(
field_selector=field_selector).items)
- if len(cluster_roles) > 0:
+ if cluster_roles:
assert len(cluster_roles) == 1
existing_cr = cluster_roles[0]
new_cr = kubernetes_utils.dict_to_k8s_object(role, 'V1ClusterRole')
@@ -460,7 +460,7 @@ def _configure_autoscaler_cluster_role_binding(
field_selector = f'metadata.name={name}'
cr_bindings = (kubernetes.auth_api(context).list_cluster_role_binding(
field_selector=field_selector).items)
- if len(cr_bindings) > 0:
+ if cr_bindings:
assert len(cr_bindings) == 1
existing_binding = cr_bindings[0]
new_binding = kubernetes_utils.dict_to_k8s_object(
@@ -639,7 +639,7 @@ def _configure_services(namespace: str, context: Optional[str],
field_selector = f'metadata.name={name}'
services = (kubernetes.core_api(context).list_namespaced_service(
namespace, field_selector=field_selector).items)
- if len(services) > 0:
+ if services:
assert len(services) == 1
existing_service = services[0]
# Convert to k8s object to compare
diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py
index 731e5afb275..6094ca7d350 100644
--- a/sky/provision/kubernetes/instance.py
+++ b/sky/provision/kubernetes/instance.py
@@ -180,6 +180,7 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
# case we will need to update this logic.
# TODO(Doyoung): Update the error message raised
# with the multi-host TPU support.
+ gpu_resource_key = kubernetes_utils.get_gpu_resource_key() # pylint: disable=line-too-long
if 'Insufficient google.com/tpu' in event_message:
extra_msg = (
f'Verify if '
@@ -192,14 +193,15 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
pod,
extra_msg,
details=event_message))
- elif (('Insufficient nvidia.com/gpu'
+ elif ((f'Insufficient {gpu_resource_key}'
in event_message) or
('didn\'t match Pod\'s node affinity/selector'
in event_message)):
extra_msg = (
- f'Verify if '
- f'{pod.spec.node_selector[label_key]}'
- ' is available in the cluster.')
+ f'Verify if any node matching label '
+ f'{pod.spec.node_selector[label_key]} and '
+ f'sufficient resource {gpu_resource_key} '
+ f'is available in the cluster.')
raise config_lib.KubernetesError(
_lack_resource_msg('GPU',
pod,
@@ -722,13 +724,13 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
'Continuing without using nvidia RuntimeClass.\n'
'If you are on a K3s cluster, manually '
'override runtimeClassName in ~/.sky/config.yaml. '
- 'For more details, refer to https://skypilot.readthedocs.io/en/latest/reference/config.html') # pylint: disable=line-too-long
+ 'For more details, refer to https://docs.skypilot.co/en/latest/reference/config.html') # pylint: disable=line-too-long
needs_gpus = False
limits = pod_spec['spec']['containers'][0].get('resources',
{}).get('limits')
if limits is not None:
- needs_gpus = limits.get(kubernetes_utils.GPU_RESOURCE_KEY, 0) > 0
+ needs_gpus = limits.get(kubernetes_utils.get_gpu_resource_key(), 0) > 0
# TPU pods provisioned on GKE use the default containerd runtime.
# Reference: https://cloud.google.com/kubernetes-engine/docs/how-to/migrate-containerd#overview # pylint: disable=line-too-long
@@ -879,27 +881,62 @@ def _terminate_node(namespace: str, context: Optional[str],
pod_name: str) -> None:
"""Terminate a pod."""
logger.debug('terminate_instances: calling delete_namespaced_pod')
- try:
- kubernetes.core_api(context).delete_namespaced_service(
- pod_name, namespace, _request_timeout=config_lib.DELETION_TIMEOUT)
- kubernetes.core_api(context).delete_namespaced_service(
- f'{pod_name}-ssh',
- namespace,
- _request_timeout=config_lib.DELETION_TIMEOUT)
- except kubernetes.api_exception():
- pass
+
+ def _delete_k8s_resource_with_retry(delete_func: Callable,
+ resource_type: str,
+ resource_name: str) -> None:
+ """Helper to delete Kubernetes resources with 404 handling and retries.
+
+ Args:
+ delete_func: Function to call to delete the resource
+ resource_type: Type of resource being deleted (e.g. 'service'),
+ used in logging
+ resource_name: Name of the resource being deleted, used in logging
+ """
+ max_retries = 3
+ retry_delay = 5 # seconds
+
+ for attempt in range(max_retries):
+ try:
+ delete_func()
+ return
+ except kubernetes.api_exception() as e:
+ if e.status == 404:
+ logger.warning(
+ f'terminate_instances: Tried to delete {resource_type} '
+ f'{resource_name}, but the {resource_type} was not '
+ 'found (404).')
+ return
+ elif attempt < max_retries - 1:
+ logger.warning(f'terminate_instances: Failed to delete '
+ f'{resource_type} {resource_name} (attempt '
+ f'{attempt + 1}/{max_retries}). Error: {e}. '
+ f'Retrying in {retry_delay} seconds...')
+ time.sleep(retry_delay)
+ else:
+ raise
+
+ # Delete services for the pod
+ for service_name in [pod_name, f'{pod_name}-ssh']:
+ _delete_k8s_resource_with_retry(
+ delete_func=lambda name=service_name: kubernetes.core_api(
+ context).delete_namespaced_service(name=name,
+ namespace=namespace,
+ _request_timeout=config_lib.
+ DELETION_TIMEOUT),
+ resource_type='service',
+ resource_name=service_name)
+
# Note - delete pod after all other resources are deleted.
# This is to ensure there are no leftover resources if this down is run
# from within the pod, e.g., for autodown.
- try:
- kubernetes.core_api(context).delete_namespaced_pod(
- pod_name, namespace, _request_timeout=config_lib.DELETION_TIMEOUT)
- except kubernetes.api_exception() as e:
- if e.status == 404:
- logger.warning('terminate_instances: Tried to delete pod '
- f'{pod_name}, but the pod was not found (404).')
- else:
- raise
+ _delete_k8s_resource_with_retry(
+ delete_func=lambda: kubernetes.core_api(context).delete_namespaced_pod(
+ name=pod_name,
+ namespace=namespace,
+ _request_timeout=config_lib.DELETION_TIMEOUT),
+ resource_type='pod',
+ resource_name=pod_name)
def terminate_instances(
diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py
index b16482e5072..29fcf181edd 100644
--- a/sky/provision/kubernetes/network_utils.py
+++ b/sky/provision/kubernetes/network_utils.py
@@ -230,7 +230,7 @@ def get_ingress_external_ip_and_ports(
namespace, _request_timeout=kubernetes.API_TIMEOUT).items
if item.metadata.name == 'ingress-nginx-controller'
]
- if len(ingress_services) == 0:
+ if not ingress_services:
return (None, None)
ingress_service = ingress_services[0]
diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py
index 64054d0e362..b51b1153646 100644
--- a/sky/provision/kubernetes/utils.py
+++ b/sky/provision/kubernetes/utils.py
@@ -342,14 +342,15 @@ def get_accelerator_from_label_value(cls, value: str) -> str:
"""
canonical_gpu_names = [
'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4',
- 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L4'
+ 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4'
]
for canonical_name in canonical_gpu_names:
# A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB
if canonical_name == 'A100-80GB' and re.search(
r'A100.*-80GB', value):
return canonical_name
- elif canonical_name in value:
+ # Use word boundary matching to prevent substring matches
+ elif re.search(rf'\b{re.escape(canonical_name)}\b', value):
return canonical_name
# If we didn't find a canonical name:
@@ -440,7 +441,7 @@ def detect_accelerator_resource(
nodes = get_kubernetes_nodes(context)
for node in nodes:
cluster_resources.update(node.status.allocatable.keys())
- has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or
+ has_accelerator = (get_gpu_resource_key() in cluster_resources or
TPU_RESOURCE_KEY in cluster_resources)
return has_accelerator, cluster_resources
@@ -585,7 +586,7 @@ def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType',
node for node in nodes if gpu_label_key in node.metadata.labels and
node.metadata.labels[gpu_label_key] == gpu_label_val
]
- assert len(gpu_nodes) > 0, 'GPU nodes not found'
+ assert gpu_nodes, 'GPU nodes not found'
if is_tpu_on_gke(acc_type):
# If requested accelerator is a TPU type, check if the cluster
# has sufficient TPU resource to meet the requirement.
@@ -894,6 +895,52 @@ def check_credentials(context: Optional[str],
return True, None
+def check_pod_config(pod_config: dict) \
+ -> Tuple[bool, Optional[str]]:
+ """Check if the pod_config is a valid pod config
+
+ Using deserialize api to check the pod_config is valid or not.
+
+ Returns:
+ bool: True if pod_config is valid.
+ str: Error message about why the pod_config is invalid, None otherwise.
+ """
+ errors = []
+ # This api_client won't be used to send any requests, so there is no need to
+ # load kubeconfig
+ api_client = kubernetes.kubernetes.client.ApiClient()
+
+ # Used for kubernetes api_client deserialize function, the function will use
+ # data attr, the detail ref:
+ # https://github.com/kubernetes-client/python/blob/master/kubernetes/client/api_client.py#L244
+ class InnerResponse():
+
+ def __init__(self, data: dict):
+ self.data = json.dumps(data)
+
+ try:
+ # Validate metadata if present
+ if 'metadata' in pod_config:
+ try:
+ value = InnerResponse(pod_config['metadata'])
+ api_client.deserialize(
+ value, kubernetes.kubernetes.client.V1ObjectMeta)
+ except ValueError as e:
+ errors.append(f'Invalid metadata: {str(e)}')
+ # Validate spec if present
+ if 'spec' in pod_config:
+ try:
+ value = InnerResponse(pod_config['spec'])
+ api_client.deserialize(value,
+ kubernetes.kubernetes.client.V1PodSpec)
+ except ValueError as e:
+ errors.append(f'Invalid spec: {str(e)}')
+ return len(errors) == 0, '.'.join(errors)
+ except Exception as e: # pylint: disable=broad-except
+ errors.append(f'Validation error: {str(e)}')
+ return False, '.'.join(errors)
+
+
def is_kubeconfig_exec_auth(
context: Optional[str] = None) -> Tuple[bool, Optional[str]]:
"""Checks if the kubeconfig file uses exec-based authentication
@@ -974,7 +1021,7 @@ def is_kubeconfig_exec_auth(
'~/.sky/config.yaml:\n'
' kubernetes:\n'
' remote_identity: SERVICE_ACCOUNT\n'
- ' More: https://skypilot.readthedocs.io/en/latest/'
+ ' More: https://docs.skypilot.co/en/latest/'
'reference/config.html')
return True, exec_msg
return False, None
@@ -1236,7 +1283,8 @@ def construct_ssh_jump_command(
'-o StrictHostKeyChecking=no '
'-o UserKnownHostsFile=/dev/null '
f'-o IdentitiesOnly=yes '
- f'-W %h:%p {ssh_jump_user}@{ssh_jump_ip}')
+ r'-W \[%h\]:%p '
+ f'{ssh_jump_user}@{ssh_jump_ip}')
if ssh_jump_port is not None:
ssh_jump_proxy_command += f' -p {ssh_jump_port} '
if proxy_cmd_path is not None:
@@ -1530,7 +1578,7 @@ def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str],
def find(l, predicate):
"""Utility function to find element in given list"""
results = [x for x in l if predicate(x)]
- return results[0] if len(results) > 0 else None
+ return results[0] if results else None
# Get the SSH jump pod name from the head pod
try:
@@ -2204,10 +2252,11 @@ def get_node_accelerator_count(attribute_dict: dict) -> int:
Number of accelerators allocated or available from the node. If no
resource is found, it returns 0.
"""
- assert not (GPU_RESOURCE_KEY in attribute_dict and
+ gpu_resource_name = get_gpu_resource_key()
+ assert not (gpu_resource_name in attribute_dict and
TPU_RESOURCE_KEY in attribute_dict)
- if GPU_RESOURCE_KEY in attribute_dict:
- return int(attribute_dict[GPU_RESOURCE_KEY])
+ if gpu_resource_name in attribute_dict:
+ return int(attribute_dict[gpu_resource_name])
elif TPU_RESOURCE_KEY in attribute_dict:
return int(attribute_dict[TPU_RESOURCE_KEY])
return 0
@@ -2391,3 +2440,18 @@ def process_skypilot_pods(
num_pods = len(cluster.pods)
cluster.resources_str = f'{num_pods}x {cluster.resources}'
return list(clusters.values()), jobs_controllers, serve_controllers
+
+
+def get_gpu_resource_key():
+ """Get the GPU resource name to use in kubernetes.
+ The function first checks for an environment variable.
+ If defined, it uses its value; otherwise, it returns the default value.
+ Args:
+ name (str): Default GPU resource name, default is "nvidia.com/gpu".
+ Returns:
+ str: The selected GPU resource name.
+ """
+ # Retrieve GPU resource name from environment variable, if set.
+ # Else use default.
+ # E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc.
+ return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default=GPU_RESOURCE_KEY)
diff --git a/sky/provision/lambda_cloud/lambda_utils.py b/sky/provision/lambda_cloud/lambda_utils.py
index 4d8e6246b6d..cfd8e02ad23 100644
--- a/sky/provision/lambda_cloud/lambda_utils.py
+++ b/sky/provision/lambda_cloud/lambda_utils.py
@@ -50,7 +50,7 @@ def set(self, instance_id: str, value: Optional[Dict[str, Any]]) -> None:
if value is None:
if instance_id in metadata:
metadata.pop(instance_id) # del entry
- if len(metadata) == 0:
+ if not metadata:
if os.path.exists(self.path):
os.remove(self.path)
return
@@ -69,7 +69,7 @@ def refresh(self, instance_ids: List[str]) -> None:
for instance_id in list(metadata.keys()):
if instance_id not in instance_ids:
del metadata[instance_id]
- if len(metadata) == 0:
+ if not metadata:
os.remove(self.path)
return
with open(self.path, 'w', encoding='utf-8') as f:
@@ -150,7 +150,7 @@ def create_instances(
['regions_with_capacity_available'])
available_regions = [reg['name'] for reg in available_regions]
if region not in available_regions:
- if len(available_regions) > 0:
+ if available_regions:
aval_reg = ' '.join(available_regions)
else:
aval_reg = 'None'
diff --git a/sky/provision/oci/query_utils.py b/sky/provision/oci/query_utils.py
index 47a0438cb21..3f545aca4ba 100644
--- a/sky/provision/oci/query_utils.py
+++ b/sky/provision/oci/query_utils.py
@@ -7,6 +7,8 @@
find_compartment: allow search subtree when find a compartment.
- Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add methods to
Add/remove security rules: create_nsg_rules & remove_nsg
+ - Hysun He (hysun.he@oracle.com) @ Jan.01, 2025: Support reuse existing
+ VCN for SkyServe.
"""
from datetime import datetime
import functools
@@ -17,7 +19,6 @@
import typing
from typing import List, Optional, Tuple
-from sky import exceptions
from sky import sky_logging
from sky.adaptors import common as adaptors_common
from sky.adaptors import oci as oci_adaptor
@@ -248,7 +249,7 @@ def find_compartment(cls, region) -> str:
limit=1)
compartments = list_compartments_response.data
- if len(compartments) > 0:
+ if compartments:
skypilot_compartment = compartments[0].id
return skypilot_compartment
@@ -274,7 +275,7 @@ def find_create_vcn_subnet(cls, region) -> Optional[str]:
display_name=oci_utils.oci_config.VCN_NAME,
lifecycle_state='AVAILABLE')
vcns = list_vcns_response.data
- if len(vcns) > 0:
+ if vcns:
# Found the VCN.
skypilot_vcn = vcns[0].id
list_subnets_response = net_client.list_subnets(
@@ -359,7 +360,7 @@ def create_vcn_subnet(cls, net_client,
if str(s.cidr_block).startswith('all-') and str(s.cidr_block).
endswith('-services-in-oracle-services-network')
]
- if len(services) > 0:
+ if services:
# Create service gateway for regional services.
create_sg_response = net_client.create_service_gateway(
create_service_gateway_details=oci_adaptor.oci.core.models.
@@ -496,23 +497,25 @@ def find_nsg(cls, region: str, nsg_name: str,
compartment = cls.find_compartment(region)
- list_vcns_resp = net_client.list_vcns(
- compartment_id=compartment,
- display_name=oci_utils.oci_config.VCN_NAME,
- lifecycle_state='AVAILABLE',
- )
+ vcn_id = oci_utils.oci_config.get_vcn_ocid(region)
+ if vcn_id is None:
+ list_vcns_resp = net_client.list_vcns(
+ compartment_id=compartment,
+ display_name=oci_utils.oci_config.VCN_NAME,
+ lifecycle_state='AVAILABLE',
+ )
- if not list_vcns_resp:
- raise exceptions.ResourcesUnavailableError(
- 'The VCN is not available')
+ # Get the primary vnic. The vnic might be an empty list for the
+ # corner case when the cluster was exited during provision.
+ if not list_vcns_resp.data:
+ return None
- # Get the primary vnic.
- assert len(list_vcns_resp.data) > 0
- vcn = list_vcns_resp.data[0]
+ vcn = list_vcns_resp.data[0]
+ vcn_id = vcn.id
list_nsg_resp = net_client.list_network_security_groups(
compartment_id=compartment,
- vcn_id=vcn.id,
+ vcn_id=vcn_id,
limit=1,
display_name=nsg_name,
)
@@ -529,7 +532,7 @@ def find_nsg(cls, region: str, nsg_name: str,
create_network_security_group_details=oci_adaptor.oci.core.models.
CreateNetworkSecurityGroupDetails(
compartment_id=compartment,
- vcn_id=vcn.id,
+ vcn_id=vcn_id,
display_name=nsg_name,
))
get_nsg_resp = net_client.get_network_security_group(
diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py
index fb9ea380c0d..2f81aaf4de0 100644
--- a/sky/provision/provisioner.py
+++ b/sky/provision/provisioner.py
@@ -416,7 +416,6 @@ def _post_provision_setup(
f'{json.dumps(dataclasses.asdict(provision_record), indent=2)}\n'
'Cluster info:\n'
f'{json.dumps(dataclasses.asdict(cluster_info), indent=2)}')
-
head_instance = cluster_info.get_head_instance()
if head_instance is None:
e = RuntimeError(f'Provision failed for cluster {cluster_name!r}. '
diff --git a/sky/provision/vsphere/common/vim_utils.py b/sky/provision/vsphere/common/vim_utils.py
index 33c02db8feb..bde1bc25cf0 100644
--- a/sky/provision/vsphere/common/vim_utils.py
+++ b/sky/provision/vsphere/common/vim_utils.py
@@ -56,7 +56,7 @@ def get_hosts_by_cluster_names(content, vcenter_name, cluster_name_dicts=None):
'name': cluster.name
} for cluster in cluster_view.view]
cluster_view.Destroy()
- if len(cluster_name_dicts) == 0:
+ if not cluster_name_dicts:
logger.warning(f'vCenter \'{vcenter_name}\' has no clusters')
# Retrieve all cluster names from the cluster_name_dicts
diff --git a/sky/provision/vsphere/instance.py b/sky/provision/vsphere/instance.py
index 9a15c1e9602..c66d9760efe 100644
--- a/sky/provision/vsphere/instance.py
+++ b/sky/provision/vsphere/instance.py
@@ -159,7 +159,7 @@ def _create_instances(
if not gpu_instance:
# Find an image for CPU
images_df = images_df[images_df['GpuTags'] == '\'[]\'']
- if len(images_df) == 0:
+ if not images_df:
logger.error(
f'Can not find an image for instance type: {instance_type}.')
raise Exception(
@@ -182,7 +182,7 @@ def _create_instances(
image_instance_mapping_df = image_instance_mapping_df[
image_instance_mapping_df['InstanceType'] == instance_type]
- if len(image_instance_mapping_df) == 0:
+ if not image_instance_mapping_df:
raise Exception(f"""There is no image can match instance type named
{instance_type}
If you are using CPU-only instance, assign an image with tag
@@ -215,10 +215,9 @@ def _create_instances(
hosts_df = hosts_df[(hosts_df['AvailableCPUs'] /
hosts_df['cpuMhz']) >= cpus_needed]
hosts_df = hosts_df[hosts_df['AvailableMemory(MB)'] >= memory_needed]
- assert len(hosts_df) > 0, (
- f'There is no host available to create the instance '
- f'{vms_item["InstanceType"]}, at least {cpus_needed} '
- f'cpus and {memory_needed}MB memory are required.')
+ assert hosts_df, (f'There is no host available to create the instance '
+ f'{vms_item["InstanceType"]}, at least {cpus_needed} '
+ f'cpus and {memory_needed}MB memory are required.')
# Sort the hosts df by AvailableCPUs to get the compatible host with the
# least resource
@@ -356,7 +355,7 @@ def _choose_vsphere_cluster_name(config: common.ProvisionConfig, region: str,
skypilot framework-optimized availability_zones"""
vsphere_cluster_name = None
vsphere_cluster_name_str = config.provider_config['availability_zone']
- if len(vc_object.clusters) > 0:
+ if vc_object.clusters:
for optimized_cluster_name in vsphere_cluster_name_str.split(','):
if optimized_cluster_name in [
item['name'] for item in vc_object.clusters
diff --git a/sky/provision/vsphere/vsphere_utils.py b/sky/provision/vsphere/vsphere_utils.py
index faec5d54930..51f284b0fc6 100644
--- a/sky/provision/vsphere/vsphere_utils.py
+++ b/sky/provision/vsphere/vsphere_utils.py
@@ -257,7 +257,7 @@ def get_skypilot_profile_id(self):
# hard code here. should support configure later.
profile_name = 'skypilot_policy'
storage_profile_id = None
- if len(profile_ids) > 0:
+ if profile_ids:
profiles = pm.PbmRetrieveContent(profileIds=profile_ids)
for profile in profiles:
if profile_name in profile.name:
diff --git a/sky/resources.py b/sky/resources.py
index 68944e37177..625bb7b85f9 100644
--- a/sky/resources.py
+++ b/sky/resources.py
@@ -662,7 +662,7 @@ def _try_validate_and_set_region_zone(self) -> None:
continue
valid_clouds.append(cloud)
- if len(valid_clouds) == 0:
+ if not valid_clouds:
if len(enabled_clouds) == 1:
cloud_str = f'for cloud {enabled_clouds[0]}'
else:
@@ -775,7 +775,7 @@ def _try_validate_instance_type(self) -> None:
for cloud in enabled_clouds:
if cloud.instance_type_exists(self._instance_type):
valid_clouds.append(cloud)
- if len(valid_clouds) == 0:
+ if not valid_clouds:
if len(enabled_clouds) == 1:
cloud_str = f'for cloud {enabled_clouds[0]}'
else:
@@ -1006,7 +1006,7 @@ def _try_validate_labels(self) -> None:
f'Label rejected due to {cloud}: {err_msg}'
])
break
- if len(invalid_table.rows) > 0:
+ if invalid_table.rows:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'The following labels are invalid:'
@@ -1281,7 +1281,7 @@ def copy(self, **override) -> 'Resources':
_cluster_config_overrides=override.pop(
'_cluster_config_overrides', self._cluster_config_overrides),
)
- assert len(override) == 0
+ assert not override
return resources
def valid_on_region_zones(self, region: str, zones: List[str]) -> bool:
diff --git a/sky/serve/api/core.py b/sky/serve/api/core.py
index 45bc350a1f3..4a50dca35a3 100644
--- a/sky/serve/api/core.py
+++ b/sky/serve/api/core.py
@@ -362,7 +362,7 @@ def update(
raise RuntimeError(e.error_msg) from e
service_statuses = serve_utils.load_service_status(serve_status_payload)
- if len(service_statuses) == 0:
+ if not service_statuses:
with ux_utils.print_exception_no_traceback():
raise RuntimeError(f'Cannot find service {service_name!r}.'
f'To spin up a service, use {ux_utils.BOLD}'
@@ -386,6 +386,17 @@ def update(
with ux_utils.print_exception_no_traceback():
raise RuntimeError(prompt)
+ original_lb_policy = service_record['load_balancing_policy']
+ assert task.service is not None, 'Service section not found.'
+ if original_lb_policy != task.service.load_balancing_policy:
+ logger.warning(
+ f'{colorama.Fore.YELLOW}Current load balancing policy '
+ f'{original_lb_policy!r} is different from the new policy '
+ f'{task.service.load_balancing_policy!r}. Updating the load '
+ 'balancing policy is not supported yet and it will be ignored. '
+ 'The service will continue to use the current load balancing '
+ f'policy.{colorama.Style.RESET_ALL}')
+
with rich_utils.safe_status(
ux_utils.spinner_message('Initializing service')):
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
@@ -482,9 +493,9 @@ def down(
stopped_message='All services should have terminated.')
service_names_str = ','.join(service_names)
- if sum([len(service_names) > 0, all]) != 1:
- argument_str = f'service_names={service_names_str}' if len(
- service_names) > 0 else ''
+ if sum([bool(service_names), all]) != 1:
+ argument_str = (f'service_names={service_names_str}'
+ if service_names else '')
argument_str += ' all' if all else ''
raise ValueError('Can only specify one of service_names or all. '
f'Provided {argument_str!r}.')
@@ -583,9 +594,10 @@ def status(
'status': (sky.ServiceStatus) service status,
'controller_port': (Optional[int]) controller port,
'load_balancer_port': (Optional[int]) load balancer port,
- 'policy': (Optional[str]) load balancer policy description,
+ 'policy': (Optional[str]) autoscaling policy description,
'requested_resources_str': (str) str representation of
requested resources,
+ 'load_balancing_policy': (str) load balancing policy name,
'replica_info': (List[Dict[str, Any]]) replica information,
}
diff --git a/sky/serve/autoscalers.py b/sky/serve/autoscalers.py
index a4278f192fb..7a6311ad535 100644
--- a/sky/serve/autoscalers.py
+++ b/sky/serve/autoscalers.py
@@ -320,8 +320,8 @@ def select_outdated_replicas_to_scale_down(
"""Select outdated replicas to scale down."""
if self.update_mode == serve_utils.UpdateMode.ROLLING:
- latest_ready_replicas = []
- old_nonterminal_replicas = []
+ latest_ready_replicas: List['replica_managers.ReplicaInfo'] = []
+ old_nonterminal_replicas: List['replica_managers.ReplicaInfo'] = []
for info in replica_infos:
if info.version == self.latest_version:
if info.is_ready:
diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py
index 30697532a22..6b4621569d6 100644
--- a/sky/serve/load_balancer.py
+++ b/sky/serve/load_balancer.py
@@ -45,6 +45,8 @@ def __init__(self,
# Use the registry to create the load balancing policy
self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make(
load_balancing_policy_name)
+ logger.info('Starting load balancer with policy '
+ f'{load_balancing_policy_name}.')
self._request_aggregator: serve_utils.RequestsAggregator = (
serve_utils.RequestTimestamp())
# TODO(tian): httpx.Client has a resource limit of 100 max connections
@@ -128,6 +130,7 @@ async def _proxy_request_to(
encountered if anything goes wrong.
"""
logger.info(f'Proxy request to {url}')
+ self._load_balancing_policy.pre_execute_hook(url, request)
try:
# We defer the get of the client here on purpose, for case when the
# replica is ready in `_proxy_with_retries` but refreshed before
@@ -147,11 +150,16 @@ async def _proxy_request_to(
content=await request.body(),
timeout=constants.LB_STREAM_TIMEOUT)
proxy_response = await client.send(proxy_request, stream=True)
+
+ async def background_func():
+ await proxy_response.aclose()
+ self._load_balancing_policy.post_execute_hook(url, request)
+
return fastapi.responses.StreamingResponse(
content=proxy_response.aiter_raw(),
status_code=proxy_response.status_code,
headers=proxy_response.headers,
- background=background.BackgroundTask(proxy_response.aclose))
+ background=background.BackgroundTask(background_func))
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f'Error when proxy request to {url}: '
f'{common_utils.format_exception(e)}')
@@ -263,7 +271,7 @@ def run_load_balancer(controller_addr: str,
parser.add_argument(
'--load-balancing-policy',
choices=available_policies,
- default='round_robin',
+ default=lb_policies.DEFAULT_LB_POLICY,
help=f'The load balancing policy to use. Available policies: '
f'{", ".join(available_policies)}.')
args = parser.parse_args()
diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py
index aec6eb01487..4ad69f78943 100644
--- a/sky/serve/load_balancing_policies.py
+++ b/sky/serve/load_balancing_policies.py
@@ -1,7 +1,9 @@
"""LoadBalancingPolicy: Policy to select endpoint."""
+import collections
import random
+import threading
import typing
-from typing import List, Optional
+from typing import Dict, List, Optional
from sky import sky_logging
@@ -13,6 +15,10 @@
# Define a registry for load balancing policies
LB_POLICIES = {}
DEFAULT_LB_POLICY = None
+# Prior to #4439, the default policy was round_robin. We store the legacy
+# default policy here to maintain backwards compatibility. Remove this after
+# 2 minor release, i.e., 0.9.0.
+LEGACY_DEFAULT_POLICY = 'round_robin'
def _request_repr(request: 'fastapi.Request') -> str:
@@ -38,11 +44,17 @@ def __init_subclass__(cls, name: str, default: bool = False):
DEFAULT_LB_POLICY = name
@classmethod
- def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
- """Create a load balancing policy from a name."""
+ def make_policy_name(cls, policy_name: Optional[str]) -> str:
+ """Return the policy name."""
+ assert DEFAULT_LB_POLICY is not None, 'No default policy set.'
if policy_name is None:
- policy_name = DEFAULT_LB_POLICY
+ return DEFAULT_LB_POLICY
+ return policy_name
+ @classmethod
+ def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
+ """Create a load balancing policy from a name."""
+ policy_name = cls.make_policy_name(policy_name)
if policy_name not in LB_POLICIES:
raise ValueError(f'Unknown load balancing policy: {policy_name}')
return LB_POLICIES[policy_name]()
@@ -65,8 +77,16 @@ def select_replica(self, request: 'fastapi.Request') -> Optional[str]:
def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
raise NotImplementedError
+ def pre_execute_hook(self, replica_url: str,
+ request: 'fastapi.Request') -> None:
+ pass
+
+ def post_execute_hook(self, replica_url: str,
+ request: 'fastapi.Request') -> None:
+ pass
+
-class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True):
+class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin'):
"""Round-robin load balancing policy."""
def __init__(self) -> None:
@@ -90,3 +110,43 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
ready_replica_url = self.ready_replicas[self.index]
self.index = (self.index + 1) % len(self.ready_replicas)
return ready_replica_url
+
+
+class LeastLoadPolicy(LoadBalancingPolicy, name='least_load', default=True):
+ """Least load load balancing policy."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.load_map: Dict[str, int] = collections.defaultdict(int)
+ self.lock = threading.Lock()
+
+ def set_ready_replicas(self, ready_replicas: List[str]) -> None:
+ if set(self.ready_replicas) == set(ready_replicas):
+ return
+ with self.lock:
+ self.ready_replicas = ready_replicas
+ for r in self.ready_replicas:
+ if r not in ready_replicas:
+ del self.load_map[r]
+ for replica in ready_replicas:
+ self.load_map[replica] = self.load_map.get(replica, 0)
+
+ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
+ del request # Unused.
+ if not self.ready_replicas:
+ return None
+ with self.lock:
+ return min(self.ready_replicas,
+ key=lambda replica: self.load_map.get(replica, 0))
+
+ def pre_execute_hook(self, replica_url: str,
+ request: 'fastapi.Request') -> None:
+ del request # Unused.
+ with self.lock:
+ self.load_map[replica_url] += 1
+
+ def post_execute_hook(self, replica_url: str,
+ request: 'fastapi.Request') -> None:
+ del request # Unused.
+ with self.lock:
+ self.load_map[replica_url] -= 1
diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py
index 9f8e59eee8b..454cea00c80 100644
--- a/sky/serve/replica_managers.py
+++ b/sky/serve/replica_managers.py
@@ -171,7 +171,7 @@ def _get_resources_ports(task_yaml: str) -> str:
"""Get the resources ports used by the task."""
task = sky.Task.from_yaml(task_yaml)
# Already checked all ports are the same in sky.serve.core.up
- assert len(task.resources) >= 1, task
+ assert task.resources, task
task_resources: 'resources.Resources' = list(task.resources)[0]
# Already checked the resources have and only have one port
# before upload the task yaml.
diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py
index 333e0138fb4..f3e8fbf1e53 100644
--- a/sky/serve/serve_state.py
+++ b/sky/serve/serve_state.py
@@ -11,6 +11,7 @@
import colorama
from sky.serve import constants
+from sky.serve import load_balancing_policies as lb_policies
from sky.utils import db_utils
if typing.TYPE_CHECKING:
@@ -76,6 +77,8 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
'active_versions',
f'TEXT DEFAULT {json.dumps([])!r}')
+db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
+ 'load_balancing_policy', 'TEXT DEFAULT NULL')
_UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'
@@ -223,7 +226,7 @@ def from_replica_statuses(
for status in ReplicaStatus.failed_statuses()) > 0:
return cls.FAILED
# When min_replicas = 0, there is no (provisioning) replica.
- if len(replica_statuses) == 0:
+ if not replica_statuses:
return cls.NO_REPLICA
return cls.REPLICA_INIT
@@ -241,7 +244,8 @@ def from_replica_statuses(
def add_service(name: str, controller_job_id: int, policy: str,
- requested_resources_str: str, status: ServiceStatus) -> bool:
+ requested_resources_str: str, load_balancing_policy: str,
+ status: ServiceStatus) -> bool:
"""Add a service in the database.
Returns:
@@ -254,10 +258,10 @@ def add_service(name: str, controller_job_id: int, policy: str,
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
- requested_resources_str)
- VALUES (?, ?, ?, ?, ?)""",
+ requested_resources_str, load_balancing_policy)
+ VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
- requested_resources_str))
+ requested_resources_str, load_balancing_policy))
except sqlite3.IntegrityError as e:
if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
@@ -324,7 +328,12 @@ def set_service_load_balancer_port(service_name: str,
def _get_service_from_row(row) -> Dict[str, Any]:
(current_version, name, controller_job_id, controller_port,
load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
- _, active_versions) = row[:13]
+ _, active_versions, load_balancing_policy) = row[:14]
+ if load_balancing_policy is None:
+ # This entry in database was added in #4439, and it will always be set
+ # to a str value. If it is None, it means it is an legacy entry and is
+ # using the legacy default policy.
+ load_balancing_policy = lb_policies.LEGACY_DEFAULT_POLICY
return {
'name': name,
'controller_job_id': controller_job_id,
@@ -341,6 +350,7 @@ def _get_service_from_row(row) -> Dict[str, Any]:
# integers in json format. This is mainly for display purpose.
'active_versions': json.loads(active_versions),
'requested_resources_str': requested_resources_str,
+ 'load_balancing_policy': load_balancing_policy,
}
diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py
index 1ef953a28c7..d2af4ab307b 100644
--- a/sky/serve/serve_utils.py
+++ b/sky/serve/serve_utils.py
@@ -108,7 +108,7 @@ class UpdateMode(enum.Enum):
class ThreadSafeDict(Generic[KeyType, ValueType]):
"""A thread-safe dict."""
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._dict: Dict[KeyType, ValueType] = dict(*args, **kwargs)
self._lock = threading.Lock()
@@ -381,7 +381,7 @@ def _get_service_status(
def get_service_status_encoded(service_names: Optional[List[str]]) -> str:
- service_statuses = []
+ service_statuses: List[Dict[str, str]] = []
if service_names is None:
# Get all service names
service_names = serve_state.get_glob_service_names(None)
@@ -400,7 +400,7 @@ def get_service_status_encoded(service_names: Optional[List[str]]) -> str:
def load_service_status(payload: str) -> List[Dict[str, Any]]:
service_statuses_encoded = message_utils.decode_payload(
payload, payload_type='service_status')
- service_statuses = []
+ service_statuses: List[Dict[str, Any]] = []
for service_status in service_statuses_encoded:
service_statuses.append({
k: pickle.loads(base64.b64decode(v))
@@ -432,7 +432,7 @@ def _terminate_failed_services(
A message indicating potential resource leak (if any). If no
resource leak is detected, return None.
"""
- remaining_replica_clusters = []
+ remaining_replica_clusters: List[str] = []
# The controller should have already attempted to terminate those
# replicas, so we don't need to try again here.
for replica_info in serve_state.get_replica_infos(service_name):
@@ -459,8 +459,8 @@ def _terminate_failed_services(
def terminate_services(service_names: Optional[List[str]], purge: bool) -> str:
service_names = serve_state.get_glob_service_names(service_names)
- terminated_service_names = []
- messages = []
+ terminated_service_names: List[str] = []
+ messages: List[str] = []
for service_name in service_names:
service_status = _get_service_status(service_name,
with_replica_info=False)
@@ -506,7 +506,7 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str:
f.write(UserSignal.TERMINATE.value)
f.flush()
terminated_service_names.append(f'{service_name!r}')
- if len(terminated_service_names) == 0:
+ if not terminated_service_names:
messages.append('No service to terminate.')
else:
identity_str = f'Service {terminated_service_names[0]} is'
@@ -839,10 +839,12 @@ def format_service_table(service_records: List[Dict[str, Any]],
'NAME', 'VERSION', 'UPTIME', 'STATUS', 'REPLICAS', 'ENDPOINT'
]
if show_all:
- service_columns.extend(['POLICY', 'REQUESTED_RESOURCES'])
+ service_columns.extend([
+ 'AUTOSCALING_POLICY', 'LOAD_BALANCING_POLICY', 'REQUESTED_RESOURCES'
+ ])
service_table = log_utils.create_table(service_columns)
- replica_infos = []
+ replica_infos: List[Dict[str, Any]] = []
for record in service_records:
for replica in record['replica_info']:
replica['service_name'] = record['name']
@@ -860,6 +862,7 @@ def format_service_table(service_records: List[Dict[str, Any]],
endpoint = get_endpoint(record)
policy = record['policy']
requested_resources_str = record['requested_resources_str']
+ load_balancing_policy = record['load_balancing_policy']
service_values = [
service_name,
@@ -870,7 +873,8 @@ def format_service_table(service_records: List[Dict[str, Any]],
endpoint,
]
if show_all:
- service_values.extend([policy, requested_resources_str])
+ service_values.extend(
+ [policy, load_balancing_policy, requested_resources_str])
service_table.add_row(service_values)
replica_table = _format_replica_table(replica_infos, show_all)
@@ -912,7 +916,8 @@ def _format_replica_table(replica_records: List[Dict[str, Any]],
region = '-'
zone = '-'
- replica_handle: 'backends.CloudVmRayResourceHandle' = record['handle']
+ replica_handle: Optional['backends.CloudVmRayResourceHandle'] = record[
+ 'handle']
if replica_handle is not None:
resources_str = resources_utils.get_readable_resources_repr(
replica_handle, simplify=not show_all)
diff --git a/sky/serve/service.py b/sky/serve/service.py
index ba74d45ec91..ec549d489ca 100644
--- a/sky/serve/service.py
+++ b/sky/serve/service.py
@@ -150,6 +150,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
controller_job_id=job_id,
policy=service_spec.autoscaling_policy_str(),
requested_resources_str=backend_utils.get_task_resources_str(task),
+ load_balancing_policy=service_spec.load_balancing_policy,
status=serve_state.ServiceStatus.CONTROLLER_INIT)
# Directly throw an error here. See sky/serve/api.py::up
# for more details.
diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py
index 000eed139f1..41de54cf806 100644
--- a/sky/serve/service_spec.py
+++ b/sky/serve/service_spec.py
@@ -2,12 +2,13 @@
import json
import os
import textwrap
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
import yaml
from sky import serve
from sky.serve import constants
+from sky.serve import load_balancing_policies as lb_policies
from sky.utils import common_utils
from sky.utils import schemas
from sky.utils import ux_utils
@@ -185,9 +186,12 @@ def from_yaml(yaml_path: str) -> 'SkyServiceSpec':
return SkyServiceSpec.from_yaml_config(config['service'])
def to_yaml_config(self) -> Dict[str, Any]:
- config = dict()
+ config: Dict[str, Any] = {}
- def add_if_not_none(section, key, value, no_empty: bool = False):
+ def add_if_not_none(section: str,
+ key: Optional[str],
+ value: Any,
+ no_empty: bool = False):
if no_empty and not value:
return
if value is not None:
@@ -230,8 +234,8 @@ def probe_str(self):
' with custom headers')
return f'{method}{headers}'
- def spot_policy_str(self):
- policy_strs = []
+ def spot_policy_str(self) -> str:
+ policy_strs: List[str] = []
if (self.dynamic_ondemand_fallback is not None and
self.dynamic_ondemand_fallback):
policy_strs.append('Dynamic on-demand fallback')
@@ -327,5 +331,6 @@ def use_ondemand_fallback(self) -> bool:
return self._use_ondemand_fallback
@property
- def load_balancing_policy(self) -> Optional[str]:
- return self._load_balancing_policy
+ def load_balancing_policy(self) -> str:
+ return lb_policies.LoadBalancingPolicy.make_policy_name(
+ self._load_balancing_policy)
diff --git a/sky/setup_files/dependencies.py b/sky/setup_files/dependencies.py
index 0c9e2fcfc29..37ec14b0e01 100644
--- a/sky/setup_files/dependencies.py
+++ b/sky/setup_files/dependencies.py
@@ -135,6 +135,7 @@
'fluidstack': [], # No dependencies needed for fluidstack
'cudo': ['cudo-compute>=0.1.10'],
'paperspace': [], # No dependencies needed for paperspace
+ 'do': ['pydo>=0.3.0', 'azure-core>=1.24.0', 'azure-common'],
'vsphere': [
'pyvmomi==8.0.1.0.2',
# vsphere-automation-sdk is also required, but it does not have
diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py
index 42ee901f473..dc7196514f1 100644
--- a/sky/setup_files/setup.py
+++ b/sky/setup_files/setup.py
@@ -186,6 +186,6 @@ def parse_readme(readme: str) -> str:
'Homepage': 'https://github.com/skypilot-org/skypilot',
'Issues': 'https://github.com/skypilot-org/skypilot/issues',
'Discussion': 'https://github.com/skypilot-org/skypilot/discussions',
- 'Documentation': 'https://skypilot.readthedocs.io/en/latest/',
+ 'Documentation': 'https://docs.skypilot.co/',
},
)
diff --git a/sky/sky_logging.py b/sky/sky_logging.py
index 8370b8d0042..2555ce1e1d4 100644
--- a/sky/sky_logging.py
+++ b/sky/sky_logging.py
@@ -1,18 +1,22 @@
"""Logging utilities."""
import builtins
import contextlib
+from datetime import datetime
import logging
+import os
import sys
import threading
import colorama
+from sky.skylet import constants
from sky.utils import env_options
from sky.utils import rich_utils
# UX: Should we show logging prefixes and some extra information in optimizer?
_FORMAT = '%(levelname).1s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
_DATE_FORMAT = '%m-%d %H:%M:%S'
+_SENSITIVE_LOGGER = ['sky.provisioner', 'sky.optimizer']
def _show_logging_prefix():
@@ -71,6 +75,23 @@ def _setup_logger():
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
+ if env_options.Options.SUPPRESS_SENSITIVE_LOG.get():
+ # If the sensitive log is enabled, we reinitialize a new handler
+ # and force set the level to INFO to suppress the debug logs
+ # for certain loggers.
+ for logger_name in _SENSITIVE_LOGGER:
+ logger = logging.getLogger(logger_name)
+ handler_to_logger = rich_utils.RichSafeStreamHandler(sys.stdout)
+ handler_to_logger.flush = sys.stdout.flush # type: ignore
+ logger.addHandler(handler_to_logger)
+ logger.setLevel(logging.INFO)
+ if _show_logging_prefix():
+ handler_to_logger.setFormatter(FORMATTER)
+ else:
+ handler_to_logger.setFormatter(NO_PREFIX_FORMATTER)
+ # Do not propagate to the parent logger to avoid parent
+ # logger printing the logs.
+ logger.propagate = False
def reload_logger():
@@ -91,7 +112,7 @@ def reload_logger():
_setup_logger()
-def init_logger(name: str):
+def init_logger(name: str) -> logging.Logger:
return logging.getLogger(name)
@@ -139,3 +160,16 @@ def is_silent():
# threads.
_logging_config.is_silent = False
return _logging_config.is_silent
+
+
+def get_run_timestamp() -> str:
+ return 'sky-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
+
+
+def generate_tmp_logging_file_path(file_name: str) -> str:
+ """Generate an absolute path of a tmp file for logging."""
+ run_timestamp = get_run_timestamp()
+ log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp)
+ log_path = os.path.expanduser(os.path.join(log_dir, file_name))
+
+ return log_path
diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py
index bae521ddb7d..c1f65b0c367 100644
--- a/sky/skylet/constants.py
+++ b/sky/skylet/constants.py
@@ -271,9 +271,7 @@
# Used for translate local file mounts to cloud storage. Please refer to
# sky/execution.py::_maybe_translate_local_file_mounts_and_sync_up for
# more details.
-WORKDIR_BUCKET_NAME = 'skypilot-workdir-{username}-{id}'
-FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-folder-{username}-{id}'
-FILE_MOUNTS_FILE_ONLY_BUCKET_NAME = 'skypilot-filemounts-files-{username}-{id}'
+FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-{username}-{id}'
FILE_MOUNTS_LOCAL_TMP_DIR = 'skypilot-filemounts-files-{id}'
FILE_MOUNTS_REMOTE_TMP_DIR = '/tmp/sky-{}-filemounts-files'
# For API server, the use a temporary directory in the same path as the upload
@@ -282,6 +280,12 @@
# persistent volume, so any contents in ~/.sky/ cannot be hard linked elsewhere.
FILE_MOUNTS_LOCAL_TMP_BASE_PATH = '~/.sky/tmp/'
+# Used when an managed jobs are created and
+# files are synced up to the cloud.
+FILE_MOUNTS_WORKDIR_SUBPATH = 'job-{run_id}/workdir'
+FILE_MOUNTS_SUBPATH = 'job-{run_id}/local-file-mounts/{i}'
+FILE_MOUNTS_TMP_SUBPATH = 'job-{run_id}/tmp-files'
+
# The default idle timeout for SkyPilot controllers. This include spot
# controller and sky serve controller.
# TODO(tian): Refactor to controller_utils. Current blocker: circular import.
diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py
index ecc5e81c1eb..f67494e4beb 100644
--- a/sky/skylet/job_lib.py
+++ b/sky/skylet/job_lib.py
@@ -588,7 +588,7 @@ def update_job_status(job_ids: List[int],
This function should only be run on the remote instance with ray>=2.4.0.
"""
echo = logger.info if not silent else logger.debug
- if len(job_ids) == 0:
+ if not job_ids:
return []
statuses = []
diff --git a/sky/skylet/providers/ibm/node_provider.py b/sky/skylet/providers/ibm/node_provider.py
index 5e2a2d64493..44622369e92 100644
--- a/sky/skylet/providers/ibm/node_provider.py
+++ b/sky/skylet/providers/ibm/node_provider.py
@@ -377,7 +377,7 @@ def non_terminated_nodes(self, tag_filters) -> List[str]:
node["id"], nic_id
).get_result()
floating_ips = res["floating_ips"]
- if len(floating_ips) == 0:
+ if not floating_ips:
# not adding a node that's yet/failed to
# to get a floating ip provisioned
continue
@@ -485,7 +485,7 @@ def _get_instance_data(self, name):
"""Returns instance (node) information matching the specified name"""
instances_data = self.ibm_vpc_client.list_instances(name=name).get_result()
- if len(instances_data["instances"]) > 0:
+ if instances_data["instances"]:
return instances_data["instances"][0]
return None
diff --git a/sky/skylet/providers/scp/config.py b/sky/skylet/providers/scp/config.py
index c20b1837f26..d19744e7322 100644
--- a/sky/skylet/providers/scp/config.py
+++ b/sky/skylet/providers/scp/config.py
@@ -107,7 +107,7 @@ def get_vcp_subnets(self):
for item in subnet_contents
if item['subnetState'] == 'ACTIVE' and item["vpcId"] == vpc
]
- if len(subnet_list) > 0:
+ if subnet_list:
vpc_subnets[vpc] = subnet_list
return vpc_subnets
diff --git a/sky/skylet/providers/scp/node_provider.py b/sky/skylet/providers/scp/node_provider.py
index 004eaac3830..f99b477ab06 100644
--- a/sky/skylet/providers/scp/node_provider.py
+++ b/sky/skylet/providers/scp/node_provider.py
@@ -259,7 +259,7 @@ def _config_security_group(self, zone_id, vpc, cluster_name):
for sg in sg_contents
if sg["securityGroupId"] == sg_id
]
- if len(sg) != 0 and sg[0] == "ACTIVE":
+ if sg and sg[0] == "ACTIVE":
break
time.sleep(5)
@@ -282,16 +282,16 @@ def _del_security_group(self, sg_id):
for sg in sg_contents
if sg["securityGroupId"] == sg_id
]
- if len(sg) == 0:
+ if not sg:
break
def _refresh_security_group(self, vms):
- if len(vms) > 0:
+ if vms:
return
# remove security group if vm does not exist
keys = self.metadata.keys()
security_group_id = self.metadata[
- keys[0]]['creation']['securityGroupId'] if len(keys) > 0 else None
+ keys[0]]['creation']['securityGroupId'] if keys else None
if security_group_id:
try:
self._del_security_group(security_group_id)
@@ -308,7 +308,7 @@ def _del_vm(self, vm_id):
for vm in vm_contents
if vm["virtualServerId"] == vm_id
]
- if len(vms) == 0:
+ if not vms:
break
def _del_firwall_rules(self, firewall_id, rule_ids):
@@ -391,7 +391,7 @@ def _create_instance_sequence(self, vpc, instance_config):
return None, None, None, None
def _undo_funcs(self, undo_func_list):
- while len(undo_func_list) > 0:
+ while undo_func_list:
func = undo_func_list.pop()
func()
@@ -468,7 +468,7 @@ def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str],
zone_config = ZoneConfig(self.scp_client, node_config)
vpc_subnets = zone_config.get_vcp_subnets()
- if (len(vpc_subnets) == 0):
+ if not vpc_subnets:
raise SCPError("This region/zone does not have available VPCs.")
instance_config = zone_config.bootstrap_instance_config(node_config)
diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py
index 10c3fe16d3c..082ec993b8e 100644
--- a/sky/skypilot_config.py
+++ b/sky/skypilot_config.py
@@ -164,7 +164,7 @@ def _reload_config() -> None:
_dict,
schemas.get_config_schema(),
f'Invalid config YAML ({config_path}). See: '
- 'https://skypilot.readthedocs.io/en/latest/reference/config.html. ' # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/reference/config.html. ' # pylint: disable=line-too-long
'Error: ',
skip_none=False)
diff --git a/sky/task.py b/sky/task.py
index 978d374fe5a..0ab267a6679 100644
--- a/sky/task.py
+++ b/sky/task.py
@@ -975,10 +975,6 @@ def _get_preferred_store(
store_type = storage_lib.StoreType.from_cloud(storage_cloud_str)
return store_type, storage_region
- # def initialize_all_storage_stores(self):
- # for _, storage in self.storage_mounts.items():
- # storage.initialize_all_stores()
-
def sync_storage_mounts(self) -> None:
"""(INTERNAL) Eagerly syncs storage mounts to cloud storage.
@@ -989,12 +985,15 @@ def sync_storage_mounts(self) -> None:
for storage in self.storage_mounts.values():
storage.construct()
assert storage.name is not None, storage
- if len(storage.stores) == 0:
+ if not storage.stores:
store_type, store_region = self._get_preferred_store()
self.storage_plans[storage] = store_type
storage.add_store(store_type, store_region)
storage.sync_all_stores()
else:
+ # We don't need to sync the storage here as if the stores are
+ # not empty, it measn the storage has been synced during
+ # construct() above
# We will download the first store that is added to remote.
self.storage_plans[storage] = list(storage.stores.keys())[0]
@@ -1013,6 +1012,7 @@ def sync_storage_mounts(self) -> None:
else:
assert storage.name is not None, storage
blob_path = 's3://' + storage.name
+ blob_path = storage.get_bucket_sub_path_prefix(blob_path)
self.update_file_mounts({
mnt_path: blob_path,
})
@@ -1023,6 +1023,7 @@ def sync_storage_mounts(self) -> None:
else:
assert storage.name is not None, storage
blob_path = 'gs://' + storage.name
+ blob_path = storage.get_bucket_sub_path_prefix(blob_path)
self.update_file_mounts({
mnt_path: blob_path,
})
@@ -1041,6 +1042,7 @@ def sync_storage_mounts(self) -> None:
blob_path = data_utils.AZURE_CONTAINER_URL.format(
storage_account_name=storage_account_name,
container_name=storage.name)
+ blob_path = storage.get_bucket_sub_path_prefix(blob_path)
self.update_file_mounts({
mnt_path: blob_path,
})
@@ -1051,6 +1053,7 @@ def sync_storage_mounts(self) -> None:
blob_path = storage.source
else:
blob_path = 'r2://' + storage.name
+ blob_path = storage.get_bucket_sub_path_prefix(blob_path)
self.update_file_mounts({
mnt_path: blob_path,
})
@@ -1066,7 +1069,18 @@ def sync_storage_mounts(self) -> None:
cos_region = data_utils.Rclone.get_region_from_rclone(
storage.name, data_utils.Rclone.RcloneClouds.IBM)
blob_path = f'cos://{cos_region}/{storage.name}'
+ blob_path = storage.get_bucket_sub_path_prefix(blob_path)
self.update_file_mounts({mnt_path: blob_path})
+ elif store_type is storage_lib.StoreType.OCI:
+ if storage.source is not None and not isinstance(
+ storage.source,
+ list) and storage.source.startswith('oci://'):
+ blob_path = storage.source
+ else:
+ blob_path = 'oci://' + storage.name
+ self.update_file_mounts({
+ mnt_path: blob_path,
+ })
else:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Storage Type {store_type} '
diff --git a/sky/templates/do-ray.yml.j2 b/sky/templates/do-ray.yml.j2
new file mode 100644
index 00000000000..ea9db59398e
--- /dev/null
+++ b/sky/templates/do-ray.yml.j2
@@ -0,0 +1,98 @@
+cluster_name: {{cluster_name_on_cloud}}
+
+# The maximum number of workers nodes to launch in addition to the head node.
+max_workers: {{num_nodes - 1}}
+upscaling_speed: {{num_nodes - 1}}
+idle_timeout_minutes: 60
+
+{%- if docker_image is not none %}
+docker:
+ image: {{docker_image}}
+ container_name: {{docker_container_name}}
+ run_options:
+ - --ulimit nofile=1048576:1048576
+ {%- for run_option in docker_run_options %}
+ - {{run_option}}
+ {%- endfor %}
+ {%- if docker_login_config is not none %}
+ docker_login_config:
+ username: |-
+ {{docker_login_config.username}}
+ password: |-
+ {{docker_login_config.password}}
+ server: |-
+ {{docker_login_config.server}}
+ {%- endif %}
+{%- endif %}
+
+provider:
+ type: external
+ module: sky.provision.do
+ region: "{{region}}"
+
+auth:
+ ssh_user: root
+ ssh_private_key: {{ssh_private_key}}
+ ssh_public_key: |-
+ skypilot:ssh_public_key_content
+
+available_node_types:
+ ray_head_default:
+ resources: {}
+ node_config:
+ InstanceType: {{instance_type}}
+ DiskSize: {{disk_size}}
+ {%- if image_id is not none %}
+ ImageId: {{image_id}}
+ {%- endif %}
+
+head_node_type: ray_head_default
+
+# Format: `REMOTE_PATH : LOCAL_PATH`
+file_mounts: {
+ "{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}",
+ "{{sky_remote_path}}/{{sky_wheel_hash}}": "{{sky_local_path}}",
+{%- for remote_path, local_path in credentials.items() %}
+ "{{remote_path}}": "{{local_path}}",
+{%- endfor %}
+}
+
+rsync_exclude: []
+
+initialization_commands: []
+
+# List of shell commands to run to set up nodes.
+# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH
+# connection, which is expensive. Try your best to co-locate commands into fewer
+# items!
+#
+# Increment the following for catching performance bugs easier:
+# current num items (num SSH connections): 1
+setup_commands:
+ # Disable `unattended-upgrades` to prevent apt-get from hanging. It should be called at the beginning before the process started to avoid being blocked. (This is a temporary fix.)
+ # Create ~/.ssh/config file in case the file does not exist in the image.
+ # Line 'rm ..': there is another installation of pip.
+ # Line 'sudo bash ..': set the ulimit as suggested by ray docs for performance. https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html#system-configuration
+ # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase.
+ # Line 'mkdir -p ..': disable host key check
+ # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys`
+ - {%- for initial_setup_command in initial_setup_commands %}
+ {{ initial_setup_command }}
+ {%- endfor %}
+ sudo systemctl stop unattended-upgrades || true;
+ sudo systemctl disable unattended-upgrades || true;
+ sudo sed -i 's/Unattended-Upgrade "1"/Unattended-Upgrade "0"/g' /etc/apt/apt.conf.d/20auto-upgrades || true;
+ sudo kill -9 `sudo lsof /var/lib/dpkg/lock-frontend | awk '{print $2}' | tail -n 1` || true;
+ sudo pkill -9 apt-get;
+ sudo pkill -9 dpkg;
+ sudo dpkg --configure -a;
+ mkdir -p ~/.ssh; touch ~/.ssh/config;
+ {{ conda_installation_commands }}
+ {{ ray_skypilot_installation_commands }}
+ sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf';
+ sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload;
+ mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config;
+ [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf');
+
+# Command to start ray clusters are now placed in `sky.provision.instance_setup`.
+# We do not need to list it here anymore.
diff --git a/sky/utils/accelerator_registry.py b/sky/utils/accelerator_registry.py
index 161b1f55c01..c96bde9f0bc 100644
--- a/sky/utils/accelerator_registry.py
+++ b/sky/utils/accelerator_registry.py
@@ -70,7 +70,7 @@ def canonicalize_accelerator_name(accelerator: str,
return name
if cloud_str is None or cloud_str in clouds:
names.append(name)
- if len(names) == 0:
+ if not names:
return accelerator
if len(names) == 1:
return names[0]
diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py
index c14227f8727..f7cb2fb1f97 100644
--- a/sky/utils/common_utils.py
+++ b/sky/utils/common_utils.py
@@ -616,7 +616,7 @@ def get_cleaned_username(username: str = '') -> str:
return username
-def fill_template(template_name: str, variables: Dict,
+def fill_template(template_name: str, variables: Dict[str, Any],
output_path: str) -> None:
"""Create a file from a Jinja template and return the filename."""
assert template_name.endswith('.j2'), template_name
diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py
index ef1923e4588..8a687b03242 100644
--- a/sky/utils/controller_utils.py
+++ b/sky/utils/controller_utils.py
@@ -650,10 +650,16 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
still sync up any storage mounts with local source paths (which do not
undergo translation).
"""
+
# ================================================================
# Translate the workdir and local file mounts to cloud file mounts.
# ================================================================
+ def _sub_path_join(sub_path: Optional[str], path: str) -> str:
+ if sub_path is None:
+ return path
+ return os.path.join(sub_path, path).strip('/')
+
run_id = common_utils.get_usage_run_id()[:8]
original_file_mounts = task.file_mounts if task.file_mounts else {}
original_storage_mounts = task.storage_mounts if task.storage_mounts else {}
@@ -680,11 +686,26 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
ux_utils.spinner_message(
f'Translating {msg} to SkyPilot Storage...'))
+ # Get the bucket name for the workdir and file mounts,
+ # we store all these files in same bucket from config.
+ bucket_wth_prefix = skypilot_config.get_nested(('jobs', 'bucket'), None)
+ store_kwargs: Dict[str, Any] = {}
+ if bucket_wth_prefix is None:
+ store_type = sub_path = None
+ storage_account_name = region = None
+ bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format(
+ username=common_utils.get_cleaned_username(), id=run_id)
+ else:
+ (store_type, bucket_name, sub_path, storage_account_name, region) = (
+ storage_lib.StoreType.get_fields_from_store_url(bucket_wth_prefix))
+ if storage_account_name is not None:
+ store_kwargs['storage_account_name'] = storage_account_name
+ if region is not None:
+ store_kwargs['region'] = region
+
# Step 1: Translate the workdir to SkyPilot storage.
new_storage_mounts = {}
if task.workdir is not None:
- bucket_name = constants.WORKDIR_BUCKET_NAME.format(
- username=common_utils.get_cleaned_username(), id=run_id)
workdir = task.workdir
task.workdir = None
if (constants.SKY_REMOTE_WORKDIR in original_file_mounts or
@@ -692,14 +713,21 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
raise ValueError(
f'Cannot mount {constants.SKY_REMOTE_WORKDIR} as both the '
'workdir and file_mounts contains it as the target.')
- new_storage_mounts[
- constants.
- SKY_REMOTE_WORKDIR] = storage_lib.Storage.from_yaml_config({
- 'name': bucket_name,
- 'source': workdir,
- 'persistent': False,
- 'mode': 'COPY',
- })
+ bucket_sub_path = _sub_path_join(
+ sub_path,
+ constants.FILE_MOUNTS_WORKDIR_SUBPATH.format(run_id=run_id))
+ stores = None
+ if store_type is not None:
+ stores = [store_type]
+
+ storage_obj = storage_lib.Storage(name=bucket_name,
+ source=workdir,
+ persistent=False,
+ mode=storage_lib.StorageMode.COPY,
+ stores=stores,
+ _is_sky_managed=not bucket_wth_prefix,
+ _bucket_sub_path=bucket_sub_path)
+ new_storage_mounts[constants.SKY_REMOTE_WORKDIR] = storage_obj
# Check of the existence of the workdir in file_mounts is done in
# the task construction.
logger.info(f' {colorama.Style.DIM}Workdir: {workdir!r} '
@@ -717,98 +745,100 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
if os.path.isfile(os.path.abspath(os.path.expanduser(src))):
copy_mounts_with_file_in_src[dst] = src
continue
- bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format(
- username=common_utils.get_cleaned_username(),
- id=f'{run_id}-{i}',
- )
- new_storage_mounts[dst] = storage_lib.Storage.from_yaml_config({
- 'name': bucket_name,
- 'source': src,
- 'persistent': False,
- 'mode': 'COPY',
- })
+ bucket_sub_path = _sub_path_join(
+ sub_path, constants.FILE_MOUNTS_SUBPATH.format(i=i, run_id=run_id))
+ stores = None
+ if store_type is not None:
+ stores = [store_type]
+ storage_obj = storage_lib.Storage(name=bucket_name,
+ source=src,
+ persistent=False,
+ mode=storage_lib.StorageMode.COPY,
+ stores=stores,
+ _is_sky_managed=not bucket_wth_prefix,
+ _bucket_sub_path=bucket_sub_path)
+ new_storage_mounts[dst] = storage_obj
logger.info(f' {colorama.Style.DIM}Folder : {src!r} '
f'-> storage: {bucket_name!r}.{colorama.Style.RESET_ALL}')
# Step 3: Translate local file mounts with file in src to SkyPilot storage.
# Hard link the files in src to a temporary directory, and upload folder.
- base_tmp_dir = os.path.expanduser(constants.FILE_MOUNTS_LOCAL_TMP_BASE_PATH)
- os.makedirs(base_tmp_dir, exist_ok=True)
- with tempfile.TemporaryDirectory(dir=base_tmp_dir) as temp_path:
- local_fm_path = os.path.join(
- temp_path, constants.FILE_MOUNTS_LOCAL_TMP_DIR.format(id=run_id))
- os.makedirs(local_fm_path, exist_ok=True)
- file_bucket_name = constants.FILE_MOUNTS_FILE_ONLY_BUCKET_NAME.format(
- username=common_utils.get_cleaned_username(), id=run_id)
- file_mount_remote_tmp_dir = constants.FILE_MOUNTS_REMOTE_TMP_DIR.format(
- path)
- if copy_mounts_with_file_in_src:
- src_to_file_id = {}
- for i, src in enumerate(set(copy_mounts_with_file_in_src.values())):
- src_to_file_id[src] = i
- os.link(os.path.abspath(os.path.expanduser(src)),
- os.path.join(local_fm_path, f'file-{i}'))
-
- new_storage_mounts[file_mount_remote_tmp_dir] = (
- storage_lib.Storage.from_yaml_config({
- 'name': file_bucket_name,
- 'source': local_fm_path,
- 'persistent': False,
- 'mode': 'MOUNT',
- }))
- if file_mount_remote_tmp_dir in original_storage_mounts:
- with ux_utils.print_exception_no_traceback():
- raise ValueError(
- 'Failed to translate file mounts, due to the default '
- f'destination {file_mount_remote_tmp_dir} '
- 'being taken.')
- sources = list(src_to_file_id.keys())
- sources_str = '\n '.join(sources)
- logger.info(f' {colorama.Style.DIM}Files (listed below) '
- f' -> storage: {file_bucket_name}:'
- f'\n {sources_str}{colorama.Style.RESET_ALL}')
+ file_mounts_tmp_subpath = _sub_path_join(
+ sub_path, constants.FILE_MOUNTS_TMP_SUBPATH.format(run_id=run_id))
+ local_fm_path = os.path.join(
+ tempfile.gettempdir(),
+ constants.FILE_MOUNTS_LOCAL_TMP_DIR.format(id=run_id))
+ os.makedirs(local_fm_path, exist_ok=True)
+ file_mount_remote_tmp_dir = constants.FILE_MOUNTS_REMOTE_TMP_DIR.format(
+ path)
+ if copy_mounts_with_file_in_src:
+ src_to_file_id = {}
+ for i, src in enumerate(set(copy_mounts_with_file_in_src.values())):
+ src_to_file_id[src] = i
+ os.link(os.path.abspath(os.path.expanduser(src)),
+ os.path.join(local_fm_path, f'file-{i}'))
+ stores = None
+ if store_type is not None:
+ stores = [store_type]
+ storage_obj = storage_lib.Storage(
+ name=bucket_name,
+ source=local_fm_path,
+ persistent=False,
+ mode=storage_lib.StorageMode.MOUNT,
+ stores=stores,
+ _is_sky_managed=not bucket_wth_prefix,
+ _bucket_sub_path=file_mounts_tmp_subpath)
+
+ new_storage_mounts[file_mount_remote_tmp_dir] = storage_obj
+ if file_mount_remote_tmp_dir in original_storage_mounts:
+ with ux_utils.print_exception_no_traceback():
+ raise ValueError(
+ 'Failed to translate file mounts, due to the default '
+ f'destination {file_mount_remote_tmp_dir} '
+ 'being taken.')
+ sources = list(src_to_file_id.keys())
+ sources_str = '\n '.join(sources)
+ logger.info(f' {colorama.Style.DIM}Files (listed below) '
+ f' -> storage: {bucket_name}:'
+ f'\n {sources_str}{colorama.Style.RESET_ALL}')
+
+ rich_utils.force_update_status(
+ ux_utils.spinner_message('Uploading translated local files/folders'))
+ task.update_storage_mounts(new_storage_mounts)
+
+ # Step 4: Upload storage from sources
+ # Upload the local source to a bucket. The task will not be executed
+ # locally, so we need to upload the files/folders to the bucket manually
+ # here before sending the task to the remote jobs controller.
+ if task.storage_mounts:
+ # There may be existing (non-translated) storage mounts, so log this
+ # whenever task.storage_mounts is non-empty.
rich_utils.force_update_status(
- ux_utils.spinner_message(
- 'Uploading translated local files/folders'))
- task.update_storage_mounts(new_storage_mounts)
-
- # Step 4: Upload storage from sources
- # Upload the local source to a bucket. The task will not be executed
- # locally, so we need to upload the files/folders to the bucket manually
- # here before sending the task to the remote jobs controller.
- if task.storage_mounts:
- # There may be existing (non-translated) storage mounts, so log this
- # whenever task.storage_mounts is non-empty.
- rich_utils.force_update_status(
- ux_utils.spinner_message(
- 'Uploading local sources to storage[/] '
- '[dim]View storages: sky storage ls'))
- try:
- task.sync_storage_mounts()
- except (ValueError, exceptions.NoCloudAccessError) as e:
- if 'No enabled cloud for storage' in str(e) or isinstance(
- e, exceptions.NoCloudAccessError):
- data_src = None
- if has_local_source_paths_file_mounts:
- data_src = 'file_mounts'
- if has_local_source_paths_workdir:
- if data_src:
- data_src += ' and workdir'
- else:
- data_src = 'workdir'
- store_enabled_clouds = ', '.join(
- storage_lib.STORE_ENABLED_CLOUDS)
- with ux_utils.print_exception_no_traceback():
- raise exceptions.NotSupportedError(
- f'Unable to use {data_src} - no cloud with object '
- 'store support is enabled. Please enable at least one '
- 'cloud with object store support '
- f'({store_enabled_clouds}) by running `sky check`, or '
- 'remove {data_src} from your task.'
- '\nHint: If you do not have any cloud access, you may '
- 'still download data and code over the network using '
- 'curl or other tools in the `setup` section of the '
- 'task.') from None
+ ux_utils.spinner_message('Uploading local sources to storage[/] '
+ '[dim]View storages: sky storage ls'))
+ try:
+ task.sync_storage_mounts()
+ except (ValueError, exceptions.NoCloudAccessError) as e:
+ if 'No enabled cloud for storage' in str(e) or isinstance(
+ e, exceptions.NoCloudAccessError):
+ data_src = None
+ if has_local_source_paths_file_mounts:
+ data_src = 'file_mounts'
+ if has_local_source_paths_workdir:
+ if data_src:
+ data_src += ' and workdir'
+ else:
+ data_src = 'workdir'
+ store_enabled_clouds = ', '.join(storage_lib.STORE_ENABLED_CLOUDS)
+ with ux_utils.print_exception_no_traceback():
+ raise exceptions.NotSupportedError(
+ f'Unable to use {data_src} - no cloud with object store '
+ 'is enabled. Please enable at least one cloud with '
+ f'object store support ({store_enabled_clouds}) by running '
+ f'`sky check`, or remove {data_src} from your task.'
+ '\nHint: If you do not have any cloud access, you may still'
+ ' download data and code over the network using curl or '
+ 'other tools in the `setup` section of the task.') from None
# Step 5: Add the file download into the file mounts, such as
# /original-dst: s3://spot-fm-file-only-bucket-name/file-0
@@ -817,11 +847,12 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
# file_mount_remote_tmp_dir will only exist when there are files in
# the src for copy mounts.
storage_obj = task.storage_mounts[file_mount_remote_tmp_dir]
- store_type = list(storage_obj.stores.keys())[0]
- store_object = storage_obj.stores[store_type]
- assert store_object is not None, storage_obj
+ curr_store_type = list(storage_obj.stores.keys())[0]
+ store_object = storage_obj.stores[curr_store_type]
+ assert store_object is not None
bucket_url = storage_lib.StoreType.get_endpoint_url(
- store_object, file_bucket_name)
+ store_object, bucket_name)
+ bucket_url += f'/{file_mounts_tmp_subpath}'
for dst, src in copy_mounts_with_file_in_src.items():
file_id = src_to_file_id[src]
new_file_mounts[dst] = bucket_url + f'/file-{file_id}'
@@ -838,10 +869,10 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
store_types = list(storage_obj.stores.keys())
assert len(store_types) == 1, (
'We only support one store type for now.', storage_obj.stores)
- store_type = store_types[0]
- store_object = storage_obj.stores[store_type]
- assert store_object is not None, storage_obj
- assert storage_obj.name is not None, storage_obj
+ curr_store_type = store_types[0]
+ store_object = storage_obj.stores[curr_store_type]
+ assert store_object is not None and storage_obj.name is not None, (
+ store_object, storage_obj.name)
storage_obj.source = storage_lib.StoreType.get_endpoint_url(
store_object, storage_obj.name)
storage_obj.force_delete = True
@@ -858,12 +889,14 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task',
store_types = list(storage_obj.stores.keys())
assert len(store_types) == 1, (
'We only support one store type for now.', storage_obj.stores)
- store_type = store_types[0]
- store_object = storage_obj.stores[store_type]
- assert store_object is not None, storage_obj
- assert storage_obj.name is not None, storage_obj
+ curr_store_type = store_types[0]
+ store_object = storage_obj.stores[curr_store_type]
+ assert store_object is not None and storage_obj.name is not None, (
+ store_object, storage_obj.name)
source = storage_lib.StoreType.get_endpoint_url(
store_object, storage_obj.name)
+ assert store_object is not None and storage_obj.name is not None, (
+ store_object, storage_obj.name)
new_storage = storage_lib.Storage.from_yaml_config({
'source': source,
'persistent': storage_obj.persistent,
diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py
index 661053b7175..7721151422c 100644
--- a/sky/utils/dag_utils.py
+++ b/sky/utils/dag_utils.py
@@ -89,7 +89,7 @@ def load_chain_dag_from_yaml(
elif len(configs) == 1:
dag_name = configs[0].get('name')
- if len(configs) == 0:
+ if not configs:
# YAML has only `name: xxx`. Still instantiate a task.
configs = [{'name': dag_name}]
diff --git a/sky/utils/env_options.py b/sky/utils/env_options.py
index ebec8eeb90d..cfc20a76253 100644
--- a/sky/utils/env_options.py
+++ b/sky/utils/env_options.py
@@ -11,6 +11,7 @@ class Options(enum.Enum):
SHOW_DEBUG_INFO = ('SKYPILOT_DEBUG', False)
DISABLE_LOGGING = ('SKYPILOT_DISABLE_USAGE_COLLECTION', False)
MINIMIZE_LOGGING = ('SKYPILOT_MINIMIZE_LOGGING', True)
+ SUPPRESS_SENSITIVE_LOG = ('SKYPILOT_SUPPRESS_SENSITIVE_LOG', False)
# Internal: this is used to skip the cloud user identity check, which is
# used to protect cluster operations in a multi-identity scenario.
# Currently, this is only used in the job and serve controller, as there
diff --git a/sky/utils/kubernetes/deploy_remote_cluster.sh b/sky/utils/kubernetes/deploy_remote_cluster.sh
index 94736474289..8d7ba3e5729 100755
--- a/sky/utils/kubernetes/deploy_remote_cluster.sh
+++ b/sky/utils/kubernetes/deploy_remote_cluster.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-# Refer to https://skypilot.readthedocs.io/en/latest/reservations/existing-machines.html for details on how to use this script.
+# Refer to https://docs.skypilot.co/en/latest/reservations/existing-machines.html for details on how to use this script.
set -e
# Colors for nicer UX
diff --git a/sky/utils/kubernetes/gpu_labeler.py b/sky/utils/kubernetes/gpu_labeler.py
index 14fbbdedca5..9f5a11cea42 100644
--- a/sky/utils/kubernetes/gpu_labeler.py
+++ b/sky/utils/kubernetes/gpu_labeler.py
@@ -101,7 +101,7 @@ def label():
# Get the list of nodes with GPUs
gpu_nodes = []
for node in nodes:
- if kubernetes_utils.GPU_RESOURCE_KEY in node.status.capacity:
+ if kubernetes_utils.get_gpu_resource_key() in node.status.capacity:
gpu_nodes.append(node)
print(f'Found {len(gpu_nodes)} GPU nodes in the cluster')
@@ -115,7 +115,7 @@ def label():
print('Continuing without using nvidia RuntimeClass. '
'This may fail on K3s clusters. '
'For more details, refer to K3s deployment notes at: '
- 'https://skypilot.readthedocs.io/en/latest/reference/kubernetes/kubernetes-setup.html') # pylint: disable=line-too-long
+ 'https://docs.skypilot.co/en/latest/reference/kubernetes/kubernetes-setup.html') # pylint: disable=line-too-long
nvidia_exists = False
if nvidia_exists:
@@ -139,10 +139,10 @@ def label():
# Create the job for this node`
batch_v1.create_namespaced_job(namespace, job_manifest)
print(f'Created GPU labeler job for node {node_name}')
- if len(gpu_nodes) == 0:
+ if not gpu_nodes:
print('No GPU nodes found in the cluster. If you have GPU nodes, '
'please ensure that they have the label '
- f'`{kubernetes_utils.GPU_RESOURCE_KEY}: `')
+ f'`{kubernetes_utils.get_gpu_resource_key()}: `')
else:
print('GPU labeling started - this may take 10 min or more to complete.'
'\nTo check the status of GPU labeling jobs, run '
diff --git a/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py b/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py
index 380c82f8c88..a764fb6e5e4 100644
--- a/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py
+++ b/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py
@@ -126,7 +126,7 @@ def manage_lifecycle():
f'error: {e}\n')
raise
- if len(ret.items) == 0:
+ if not ret.items:
sys.stdout.write(
f'[Lifecycle] Did not find pods with label '
f'"{label_selector}" in namespace {current_namespace}\n')
diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py
index b3d1748612a..07269f25dc4 100644
--- a/sky/utils/schemas.py
+++ b/sky/utils/schemas.py
@@ -299,6 +299,12 @@ def get_storage_schema():
mode.value for mode in storage.StorageMode
]
},
+ '_is_sky_managed': {
+ 'type': 'boolean',
+ },
+ '_bucket_sub_path': {
+ 'type': 'string',
+ },
'_force_delete': {
'type': 'boolean',
}
@@ -724,6 +730,11 @@ def get_config_schema():
'resources': resources_schema,
}
},
+ 'bucket': {
+ 'type': 'string',
+ 'pattern': '^(https|s3|gs|r2|cos)://.+',
+ 'required': [],
+ }
}
}
cloud_configs = {
@@ -878,6 +889,9 @@ def get_config_schema():
'image_tag_gpu': {
'type': 'string',
},
+ 'vcn_ocid': {
+ 'type': 'string',
+ },
'vcn_subnet': {
'type': 'string',
},
diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh
index 40609c315e1..af120060a84 100644
--- a/tests/backward_compatibility_tests.sh
+++ b/tests/backward_compatibility_tests.sh
@@ -35,18 +35,20 @@ gcloud --version || conda install -c conda-forge google-cloud-sdk -y
rm -r ~/.sky/wheels || true
cd ../sky-master
git pull origin master
-$UV pip uninstall skypilot
-$UV pip install --prerelease=allow azure-cli
-$UV pip install -e ".[all]"
+pip uninstall -y skypilot
+pip install uv
+uv pip install --prerelease=allow "azure-cli>=2.65.0"
+uv pip install -e ".[all]"
cd -
conda env list | grep sky-back-compat-current || conda create -n sky-back-compat-current -y python=3.9
conda activate sky-back-compat-current
gcloud --version || conda install -c conda-forge google-cloud-sdk -y
rm -r ~/.sky/wheels || true
-$UV pip uninstall skypilot
-$UV pip install --prerelease=allow azure-cli
-$UV pip install -e ".[all]"
+pip uninstall -y skypilot
+pip install uv
+uv pip install --prerelease=allow "azure-cli>=2.65.0"
+uv pip install -e ".[all]"
# exec + launch
diff --git a/tests/conftest.py b/tests/conftest.py
index afbd9f32b09..9b43d55214a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -23,7 +23,7 @@
# To only run tests for managed jobs (without generic tests), use
# --managed-jobs.
all_clouds_in_smoke_tests = [
- 'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci',
+ 'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci', 'do',
'kubernetes', 'vsphere', 'cudo', 'fluidstack', 'paperspace', 'runpod'
]
default_clouds_to_run = ['aws', 'azure']
@@ -46,6 +46,7 @@
'fluidstack': 'fluidstack',
'cudo': 'cudo',
'paperspace': 'paperspace',
+ 'do': 'do',
'runpod': 'runpod'
}
diff --git a/tests/skyserve/load_balancer/service.yaml b/tests/skyserve/load_balancer/service.yaml
index 742b8efd2f4..232136d4a61 100644
--- a/tests/skyserve/load_balancer/service.yaml
+++ b/tests/skyserve/load_balancer/service.yaml
@@ -5,6 +5,7 @@ service:
initial_delay_seconds: 180
replica_policy:
min_replicas: 3
+ load_balancing_policy: round_robin
resources:
ports: 8080
diff --git a/tests/skyserve/update/bump_version_after.yaml b/tests/skyserve/update/bump_version_after.yaml
index 6e845f54b9e..0f2c6925bc6 100644
--- a/tests/skyserve/update/bump_version_after.yaml
+++ b/tests/skyserve/update/bump_version_after.yaml
@@ -16,7 +16,7 @@ service:
replicas: 3
resources:
- ports: 8081
+ ports: 8080
cpus: 2+
setup: |
diff --git a/tests/skyserve/update/bump_version_before.yaml b/tests/skyserve/update/bump_version_before.yaml
index c9fd957e41a..de922b66434 100644
--- a/tests/skyserve/update/bump_version_before.yaml
+++ b/tests/skyserve/update/bump_version_before.yaml
@@ -16,7 +16,7 @@ service:
replicas: 2
resources:
- ports: 8081
+ ports: 8080
cpus: 2+
setup: |
diff --git a/tests/skyserve/update/new.yaml b/tests/skyserve/update/new.yaml
index 2c9cebd0cb5..5e5d853e09d 100644
--- a/tests/skyserve/update/new.yaml
+++ b/tests/skyserve/update/new.yaml
@@ -3,6 +3,7 @@ service:
path: /health
initial_delay_seconds: 100
replicas: 2
+ load_balancing_policy: round_robin
resources:
ports: 8081
diff --git a/tests/skyserve/update/old.yaml b/tests/skyserve/update/old.yaml
index 4b99cb92e8c..4cb19b8327b 100644
--- a/tests/skyserve/update/old.yaml
+++ b/tests/skyserve/update/old.yaml
@@ -3,6 +3,7 @@ service:
path: /health
initial_delay_seconds: 20
replicas: 2
+ load_balancing_policy: round_robin
resources:
ports: 8080
diff --git a/tests/smoke_tests/__init__.py b/tests/smoke_tests/__init__.py
new file mode 100644
index 00000000000..63d4cd2b811
--- /dev/null
+++ b/tests/smoke_tests/__init__.py
@@ -0,0 +1,2 @@
+"""For smoke tests import."""
+__all__ = ['smoke_tests_utils']
diff --git a/tests/smoke_tests/smoke_tests_utils.py b/tests/smoke_tests/smoke_tests_utils.py
new file mode 100644
index 00000000000..6cd8c4ecaeb
--- /dev/null
+++ b/tests/smoke_tests/smoke_tests_utils.py
@@ -0,0 +1,440 @@
+import enum
+import inspect
+import os
+import subprocess
+import sys
+import tempfile
+from typing import Dict, List, NamedTuple, Optional, Tuple
+import uuid
+
+import colorama
+import pytest
+
+import sky
+from sky import serve
+from sky.clouds import AWS
+from sky.clouds import GCP
+from sky.utils import common_utils
+from sky.utils import subprocess_utils
+
+# To avoid the second smoke test reusing the cluster launched in the first
+# smoke test. Also required for test_managed_jobs_recovery to make sure the
+# manual termination with aws ec2 does not accidentally terminate other clusters
+# for for the different managed jobs launch with the same job name but a
+# different job id.
+test_id = str(uuid.uuid4())[-2:]
+
+LAMBDA_TYPE = '--cloud lambda --gpus A10'
+FLUIDSTACK_TYPE = '--cloud fluidstack --gpus RTXA4000'
+
+SCP_TYPE = '--cloud scp'
+SCP_GPU_V100 = '--gpus V100-32GB'
+
+STORAGE_SETUP_COMMANDS = [
+ 'touch ~/tmpfile', 'mkdir -p ~/tmp-workdir',
+ 'touch ~/tmp-workdir/tmp\ file', 'touch ~/tmp-workdir/tmp\ file2',
+ 'touch ~/tmp-workdir/foo',
+ '[ ! -e ~/tmp-workdir/circle-link ] && ln -s ~/tmp-workdir/ ~/tmp-workdir/circle-link || true',
+ 'touch ~/.ssh/id_rsa.pub'
+]
+
+# Get the job queue, and print it once on its own, then print it again to
+# use with grep by the caller.
+GET_JOB_QUEUE = 's=$(sky jobs queue); echo "$s"; echo "$s"'
+# Wait for a job to be not in RUNNING state. Used to check for RECOVERING.
+JOB_WAIT_NOT_RUNNING = (
+ 's=$(sky jobs queue);'
+ 'until ! echo "$s" | grep "{job_name}" | grep "RUNNING"; do '
+ 'sleep 10; s=$(sky jobs queue);'
+ 'echo "Waiting for job to stop RUNNING"; echo "$s"; done')
+
+# Cluster functions
+_ALL_JOB_STATUSES = "|".join([status.value for status in sky.JobStatus])
+_ALL_CLUSTER_STATUSES = "|".join([status.value for status in sky.ClusterStatus])
+_ALL_MANAGED_JOB_STATUSES = "|".join(
+ [status.value for status in sky.ManagedJobStatus])
+
+
+def _statuses_to_str(statuses: List[enum.Enum]):
+ """Convert a list of enums to a string with all the values separated by |."""
+ assert len(statuses) > 0, 'statuses must not be empty'
+ if len(statuses) > 1:
+ return '(' + '|'.join([status.value for status in statuses]) + ')'
+ else:
+ return statuses[0].value
+
+
+_WAIT_UNTIL_CLUSTER_STATUS_CONTAINS = (
+ # A while loop to wait until the cluster status
+ # becomes certain status, with timeout.
+ 'start_time=$SECONDS; '
+ 'while true; do '
+ 'if (( $SECONDS - $start_time > {timeout} )); then '
+ ' echo "Timeout after {timeout} seconds waiting for cluster status \'{cluster_status}\'"; exit 1; '
+ 'fi; '
+ 'current_status=$(sky status {cluster_name} --refresh | '
+ 'awk "/^{cluster_name}/ '
+ '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_CLUSTER_STATUSES +
+ ')$/) print \$i}}"); '
+ 'if [[ "$current_status" =~ {cluster_status} ]]; '
+ 'then echo "Target cluster status {cluster_status} reached."; break; fi; '
+ 'echo "Waiting for cluster status to become {cluster_status}, current status: $current_status"; '
+ 'sleep 10; '
+ 'done')
+
+
+def get_cmd_wait_until_cluster_status_contains(
+ cluster_name: str, cluster_status: List[sky.ClusterStatus],
+ timeout: int):
+ return _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.format(
+ cluster_name=cluster_name,
+ cluster_status=_statuses_to_str(cluster_status),
+ timeout=timeout)
+
+
+def get_cmd_wait_until_cluster_status_contains_wildcard(
+ cluster_name_wildcard: str, cluster_status: List[sky.ClusterStatus],
+ timeout: int):
+ wait_cmd = _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.replace(
+ 'sky status {cluster_name}',
+ 'sky status "{cluster_name}"').replace('awk "/^{cluster_name}/',
+ 'awk "/^{cluster_name_awk}/')
+ return wait_cmd.format(cluster_name=cluster_name_wildcard,
+ cluster_name_awk=cluster_name_wildcard.replace(
+ '*', '.*'),
+ cluster_status=_statuses_to_str(cluster_status),
+ timeout=timeout)
+
+
+_WAIT_UNTIL_CLUSTER_IS_NOT_FOUND = (
+ # A while loop to wait until the cluster is not found or timeout
+ 'start_time=$SECONDS; '
+ 'while true; do '
+ 'if (( $SECONDS - $start_time > {timeout} )); then '
+ ' echo "Timeout after {timeout} seconds waiting for cluster to be removed"; exit 1; '
+ 'fi; '
+ 'if sky status -r {cluster_name}; sky status {cluster_name} | grep "{cluster_name} not found"; then '
+ ' echo "Cluster {cluster_name} successfully removed."; break; '
+ 'fi; '
+ 'echo "Waiting for cluster {cluster_name} to be removed..."; '
+ 'sleep 10; '
+ 'done')
+
+
+def get_cmd_wait_until_cluster_is_not_found(cluster_name: str, timeout: int):
+ return _WAIT_UNTIL_CLUSTER_IS_NOT_FOUND.format(cluster_name=cluster_name,
+ timeout=timeout)
+
+
+_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID = (
+ # A while loop to wait until the job status
+ # contains certain status, with timeout.
+ 'start_time=$SECONDS; '
+ 'while true; do '
+ 'if (( $SECONDS - $start_time > {timeout} )); then '
+ ' echo "Timeout after {timeout} seconds waiting for job status \'{job_status}\'"; exit 1; '
+ 'fi; '
+ 'current_status=$(sky queue {cluster_name} | '
+ 'awk "\\$1 == \\"{job_id}\\" '
+ '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_JOB_STATUSES +
+ ')$/) print \$i}}"); '
+ 'found=0; ' # Initialize found variable outside the loop
+ 'while read -r line; do ' # Read line by line
+ ' if [[ "$line" =~ {job_status} ]]; then ' # Check each line
+ ' echo "Target job status {job_status} reached."; '
+ ' found=1; '
+ ' break; ' # Break inner loop
+ ' fi; '
+ 'done <<< "$current_status"; '
+ 'if [ "$found" -eq 1 ]; then break; fi; ' # Break outer loop if match found
+ 'echo "Waiting for job status to contain {job_status}, current status: $current_status"; '
+ 'sleep 10; '
+ 'done')
+
+_WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace(
+ 'awk "\\$1 == \\"{job_id}\\"', 'awk "')
+
+_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace(
+ 'awk "\\$1 == \\"{job_id}\\"', 'awk "\\$2 == \\"{job_name}\\"')
+
+
+def get_cmd_wait_until_job_status_contains_matching_job_id(
+ cluster_name: str, job_id: str, job_status: List[sky.JobStatus],
+ timeout: int):
+ return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.format(
+ cluster_name=cluster_name,
+ job_id=job_id,
+ job_status=_statuses_to_str(job_status),
+ timeout=timeout)
+
+
+def get_cmd_wait_until_job_status_contains_without_matching_job(
+ cluster_name: str, job_status: List[sky.JobStatus], timeout: int):
+ return _WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB.format(
+ cluster_name=cluster_name,
+ job_status=_statuses_to_str(job_status),
+ timeout=timeout)
+
+
+def get_cmd_wait_until_job_status_contains_matching_job_name(
+ cluster_name: str, job_name: str, job_status: List[sky.JobStatus],
+ timeout: int):
+ return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format(
+ cluster_name=cluster_name,
+ job_name=job_name,
+ job_status=_statuses_to_str(job_status),
+ timeout=timeout)
+
+
+# Managed job functions
+
+_WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.replace(
+ 'sky queue {cluster_name}', 'sky jobs queue').replace(
+ 'awk "\\$2 == \\"{job_name}\\"',
+ 'awk "\\$2 == \\"{job_name}\\" || \\$3 == \\"{job_name}\\"').replace(
+ _ALL_JOB_STATUSES, _ALL_MANAGED_JOB_STATUSES)
+
+
+def get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name: str, job_status: List[sky.JobStatus], timeout: int):
+ return _WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format(
+ job_name=job_name,
+ job_status=_statuses_to_str(job_status),
+ timeout=timeout)
+
+
+# After the timeout, the cluster will stop if autostop is set, and our check
+# should be more than the timeout. To address this, we extend the timeout by
+# _BUMP_UP_SECONDS before exiting.
+BUMP_UP_SECONDS = 35
+
+DEFAULT_CMD_TIMEOUT = 15 * 60
+
+
+class Test(NamedTuple):
+ name: str
+ # Each command is executed serially. If any failed, the remaining commands
+ # are not run and the test is treated as failed.
+ commands: List[str]
+ teardown: Optional[str] = None
+ # Timeout for each command in seconds.
+ timeout: int = DEFAULT_CMD_TIMEOUT
+ # Environment variables to set for each command.
+ env: Dict[str, str] = None
+
+ def echo(self, message: str):
+ # pytest's xdist plugin captures stdout; print to stderr so that the
+ # logs are streaming while the tests are running.
+ prefix = f'[{self.name}]'
+ message = f'{prefix} {message}'
+ message = message.replace('\n', f'\n{prefix} ')
+ print(message, file=sys.stderr, flush=True)
+
+
+def get_timeout(generic_cloud: str,
+ override_timeout: int = DEFAULT_CMD_TIMEOUT):
+ timeouts = {'fluidstack': 60 * 60} # file_mounts
+ return timeouts.get(generic_cloud, override_timeout)
+
+
+def get_cluster_name() -> str:
+ """Returns a user-unique cluster name for each test_().
+
+ Must be called from each test_().
+ """
+ caller_func_name = inspect.stack()[1][3]
+ test_name = caller_func_name.replace('_', '-').replace('test-', 't-')
+ test_name = common_utils.make_cluster_name_on_cloud(test_name,
+ 24,
+ add_user_hash=False)
+ return f'{test_name}-{test_id}'
+
+
+def terminate_gcp_replica(name: str, zone: str, replica_id: int) -> str:
+ cluster_name = serve.generate_replica_cluster_name(name, replica_id)
+ query_cmd = (f'gcloud compute instances list --filter='
+ f'"(labels.ray-cluster-name:{cluster_name})" '
+ f'--zones={zone} --format="value(name)"')
+ return (f'gcloud compute instances delete --zone={zone}'
+ f' --quiet $({query_cmd})')
+
+
+def run_one_test(test: Test) -> Tuple[int, str, str]:
+ # Fail fast if `sky` CLI somehow errors out.
+ subprocess.run(['sky', 'status'], stdout=subprocess.DEVNULL, check=True)
+ log_to_stdout = os.environ.get('LOG_TO_STDOUT', None)
+ if log_to_stdout:
+ write = test.echo
+ flush = lambda: None
+ subprocess_out = sys.stderr
+ test.echo(f'Test started. Log to stdout')
+ else:
+ log_file = tempfile.NamedTemporaryFile('a',
+ prefix=f'{test.name}-',
+ suffix='.log',
+ delete=False)
+ write = log_file.write
+ flush = log_file.flush
+ subprocess_out = log_file
+ test.echo(f'Test started. Log: less {log_file.name}')
+
+ env_dict = os.environ.copy()
+ if test.env:
+ env_dict.update(test.env)
+ for command in test.commands:
+ write(f'+ {command}\n')
+ flush()
+ proc = subprocess.Popen(
+ command,
+ stdout=subprocess_out,
+ stderr=subprocess.STDOUT,
+ shell=True,
+ executable='/bin/bash',
+ env=env_dict,
+ )
+ try:
+ proc.wait(timeout=test.timeout)
+ except subprocess.TimeoutExpired as e:
+ flush()
+ test.echo(f'Timeout after {test.timeout} seconds.')
+ test.echo(str(e))
+ write(f'Timeout after {test.timeout} seconds.\n')
+ flush()
+ # Kill the current process.
+ proc.terminate()
+ proc.returncode = 1 # None if we don't set it.
+ break
+
+ if proc.returncode:
+ break
+
+ style = colorama.Style
+ fore = colorama.Fore
+ outcome = (f'{fore.RED}Failed{style.RESET_ALL}'
+ if proc.returncode else f'{fore.GREEN}Passed{style.RESET_ALL}')
+ reason = f'\nReason: {command}' if proc.returncode else ''
+ msg = (f'{outcome}.'
+ f'{reason}')
+ if log_to_stdout:
+ test.echo(msg)
+ else:
+ msg += f'\nLog: less {log_file.name}\n'
+ test.echo(msg)
+ write(msg)
+
+ if (proc.returncode == 0 or
+ pytest.terminate_on_failure) and test.teardown is not None:
+ subprocess_utils.run(
+ test.teardown,
+ stdout=subprocess_out,
+ stderr=subprocess.STDOUT,
+ timeout=10 * 60, # 10 mins
+ shell=True,
+ )
+
+ if proc.returncode:
+ if log_to_stdout:
+ raise Exception(f'test failed')
+ else:
+ raise Exception(f'test failed: less {log_file.name}')
+
+
+def get_aws_region_for_quota_failover() -> Optional[str]:
+ candidate_regions = AWS.regions_with_offering(instance_type='p3.16xlarge',
+ accelerators=None,
+ use_spot=True,
+ region=None,
+ zone=None)
+ original_resources = sky.Resources(cloud=sky.AWS(),
+ instance_type='p3.16xlarge',
+ use_spot=True)
+
+ # Filter the regions with proxy command in ~/.sky/config.yaml.
+ filtered_regions = original_resources.get_valid_regions_for_launchable()
+ candidate_regions = [
+ region for region in candidate_regions
+ if region.name in filtered_regions
+ ]
+
+ for region in candidate_regions:
+ resources = original_resources.copy(region=region.name)
+ if not AWS.check_quota_available(resources):
+ return region.name
+
+ return None
+
+
+def get_gcp_region_for_quota_failover() -> Optional[str]:
+
+ candidate_regions = GCP.regions_with_offering(instance_type=None,
+ accelerators={'A100-80GB': 1},
+ use_spot=True,
+ region=None,
+ zone=None)
+
+ original_resources = sky.Resources(cloud=sky.GCP(),
+ instance_type='a2-ultragpu-1g',
+ accelerators={'A100-80GB': 1},
+ use_spot=True)
+
+ # Filter the regions with proxy command in ~/.sky/config.yaml.
+ filtered_regions = original_resources.get_valid_regions_for_launchable()
+ candidate_regions = [
+ region for region in candidate_regions
+ if region.name in filtered_regions
+ ]
+
+ for region in candidate_regions:
+ if not GCP.check_quota_available(
+ original_resources.copy(region=region.name)):
+ return region.name
+
+ return None
+
+
+VALIDATE_LAUNCH_OUTPUT = (
+ # Validate the output of the job submission:
+ # ⚙️ Launching on Kubernetes.
+ # Pod is up.
+ # ✓ Cluster launched: test. View logs at: ~/sky_logs/sky-2024-10-07-19-44-18-177288/provision.log
+ # ✓ Setup Detached.
+ # ⚙️ Job submitted, ID: 1.
+ # ├── Waiting for task resources on 1 node.
+ # └── Job started. Streaming logs... (Ctrl-C to exit log streaming; job will not be killed)
+ # (setup pid=1277) running setup
+ # (min, pid=1277) # conda environments:
+ # (min, pid=1277) #
+ # (min, pid=1277) base * /opt/conda
+ # (min, pid=1277)
+ # (min, pid=1277) task run finish
+ # ✓ Job finished (status: SUCCEEDED).
+ #
+ # Job ID: 1
+ # 📋 Useful Commands
+ # ├── To cancel the job: sky cancel test 1
+ # ├── To stream job logs: sky logs test 1
+ # └── To view job queue: sky queue test
+ #
+ # Cluster name: test
+ # ├── To log into the head VM: ssh test
+ # ├── To submit a job: sky exec test yaml_file
+ # ├── To stop the cluster: sky stop test
+ # └── To teardown the cluster: sky down test
+ 'echo "$s" && echo "==Validating launching==" && '
+ 'echo "$s" | grep -A 1 "Launching on" | grep "is up." && '
+ 'echo "$s" && echo "==Validating setup output==" && '
+ 'echo "$s" | grep -A 1 "Setup detached" | grep "Job submitted" && '
+ 'echo "==Validating running output hints==" && echo "$s" | '
+ 'grep -A 1 "Job submitted, ID:" | '
+ 'grep "Waiting for task resources on " && '
+ 'echo "==Validating task setup/run output starting==" && echo "$s" | '
+ 'grep -A 1 "Job started. Streaming logs..." | grep "(setup" | '
+ 'grep "running setup" && '
+ 'echo "$s" | grep -A 1 "(setup" | grep "(min, pid=" && '
+ 'echo "==Validating task output ending==" && '
+ 'echo "$s" | grep -A 1 "task run finish" | '
+ 'grep "Job finished (status: SUCCEEDED)" && '
+ 'echo "==Validating task output ending 2==" && '
+ 'echo "$s" | grep -A 5 "Job finished (status: SUCCEEDED)" | '
+ 'grep "Job ID:" && '
+ 'echo "$s" | grep -A 1 "Useful Commands" | grep "Job ID:"')
diff --git a/tests/smoke_tests/test_api_server.py b/tests/smoke_tests/test_api_server.py
new file mode 100644
index 00000000000..c07379b6edd
--- /dev/null
+++ b/tests/smoke_tests/test_api_server.py
@@ -0,0 +1,96 @@
+from typing import List
+
+from smoke_tests import smoke_tests_utils
+
+from sky.skylet import constants
+
+
+# ---------- Test multi-tenant ----------
+def test_multi_tenant(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ user_1 = 'abcdef12'
+ user_1_name = 'user1'
+ user_2 = 'abcdef13'
+ user_2_name = 'user2'
+
+ def set_user(user_id: str, user_name: str,
+ commands: List[str]) -> List[str]:
+ return [
+ f'export {constants.USER_ID_ENV_VAR}="{user_id}"; '
+ f'export {constants.USER_ENV_VAR}="{user_name}"; ' + cmd
+ for cmd in commands
+ ]
+
+ stop_test_cmds = [
+ 'echo "==== Test multi-tenant cluster stop ===="',
+ *set_user(
+ user_2,
+ user_2_name,
+ [
+ f'sky stop -y -a',
+ # -a should only stop clusters from the current user.
+ f's=$(sky status -u {name}-1) && echo "$s" && echo "$s" | grep {user_1_name} | grep UP',
+ f's=$(sky status -u {name}-2) && echo "$s" && echo "$s" | grep {user_2_name} | grep STOPPED',
+ # Explicit cluster name should stop the cluster.
+ f'sky stop -y {name}-1',
+ # Stopping cluster should not change the ownership of the cluster.
+ f's=$(sky status) && echo "$s" && echo "$s" | grep {name}-1 && exit 1 || true',
+ f'sky status {name}-1 | grep STOPPED',
+ # Both clusters should be stopped.
+ f'sky status -u | grep {name}-1 | grep STOPPED',
+ f'sky status -u | grep {name}-2 | grep STOPPED',
+ ]),
+ ]
+ if generic_cloud == 'kubernetes':
+ # Skip the stop test for Kubernetes, as stopping is not supported.
+ stop_test_cmds = []
+
+ test = Test(
+ 'test_multi_tenant',
+ [
+ 'echo "==== Test multi-tenant job on single cluster ===="',
+ *set_user(user_1, user_1_name, [
+ f'sky launch -y -c {name}-1 --cloud {generic_cloud} --cpus 2+ -n job-1 tests/test_yamls/minimal.yaml',
+ f's=$(sky queue {name}-1) && echo "$s" && echo "$s" | grep job-1 | grep SUCCEEDED | awk \'{{print $1}}\' | grep 1',
+ f's=$(sky queue -u {name}-1) && echo "$s" && echo "$s" | grep {user_1_name} | grep job-1 | grep SUCCEEDED',
+ ]),
+ *set_user(user_2, user_2_name, [
+ f'sky exec {name}-1 -n job-2 \'echo "hello" && exit 1\'',
+ f's=$(sky queue {name}-1) && echo "$s" && echo "$s" | grep job-2 | grep FAILED | awk \'{{print $1}}\' | grep 2',
+ f's=$(sky queue {name}-1) && echo "$s" && echo "$s" | grep job-1 && exit 1 || true',
+ f's=$(sky queue {name}-1 -u) && echo "$s" && echo "$s" | grep {user_2_name} | grep job-2 | grep FAILED',
+ f's=$(sky queue {name}-1 -u) && echo "$s" && echo "$s" | grep {user_1_name} | grep job-1 | grep SUCCEEDED',
+ ]),
+ 'echo "==== Test clusters from different users ===="',
+ *set_user(
+ user_2,
+ user_2_name,
+ [
+ f'sky launch -y -c {name}-2 --cloud {generic_cloud} --cpus 2+ -n job-3 tests/test_yamls/minimal.yaml',
+ f's=$(sky status {name}-2) && echo "$s" && echo "$s" | grep UP',
+ # sky status should not show other clusters from other users.
+ f's=$(sky status) && echo "$s" && echo "$s" | grep {name}-1 && exit 1 || true',
+ # Explicit cluster name should show the cluster.
+ f's=$(sky status {name}-1) && echo "$s" && echo "$s" | grep UP',
+ f's=$(sky status -u) && echo "$s" && echo "$s" | grep {user_2_name} | grep {name}-2 | grep UP',
+ f's=$(sky status -u) && echo "$s" && echo "$s" | grep {user_1_name} | grep {name}-1 | grep UP',
+ ]),
+ *stop_test_cmds,
+ 'echo "==== Test multi-tenant cluster down ===="',
+ *set_user(
+ user_2,
+ user_2_name,
+ [
+ f'sky down -y -a',
+ # STOPPED or UP based on whether we run the stop_test_cmds.
+ f'sky status -u | grep {name}-1 | grep "STOPPED\|UP"',
+ # Current user's clusters should be down'ed.
+ f'sky status -u | grep {name}-2 && exit 1 || true',
+ # Explicit cluster name should delete the cluster.
+ f'sky down -y {name}-1',
+ f'sky status | grep {name}-1 && exit 1 || true',
+ ]),
+ ],
+ f'sky down -y {name}-1 {name}-2',
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_basic.py b/tests/smoke_tests/test_basic.py
new file mode 100644
index 00000000000..9b8cad2f77a
--- /dev/null
+++ b/tests/smoke_tests/test_basic.py
@@ -0,0 +1,607 @@
+# Smoke tests for SkyPilot for basic functionality
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_basic.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_basic.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_basic.py::test_minimal
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_basic.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_basic.py --generic-cloud aws
+
+import pathlib
+import subprocess
+import tempfile
+import textwrap
+import time
+
+import pytest
+from smoke_tests import smoke_tests_utils
+
+import sky
+from sky.skylet import events
+from sky.utils import common_utils
+
+
+# ---------- Dry run: 2 Tasks in a chain. ----------
+@pytest.mark.no_fluidstack #requires GCP and AWS set up
+def test_example_app():
+ test = smoke_tests_utils.Test(
+ 'example_app',
+ ['python examples/example_app.py'],
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- A minimal task ----------
+def test_minimal(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'minimal',
+ [
+ f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --cpus 2+ tests/test_yamls/minimal.yaml) && {smoke_tests_utils.VALIDATE_LAUNCH_OUTPUT}',
+ # Output validation done.
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} --status | grep "Job 1: SUCCEEDED"', # Equivalent.
+ # Test launch output again on existing cluster
+ f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} tests/test_yamls/minimal.yaml) && {smoke_tests_utils.VALIDATE_LAUNCH_OUTPUT}',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
+ # Check the logs downloading
+ # TODO(zhwu): Fix the logs downloading.
+ # f'log_path=$(sky logs {name} 1 --sync-down | grep "Job 1 logs:" | sed -E "s/^.*Job 1 logs: (.*)\\x1b\\[0m/\\1/g") && echo "$log_path" && test -f $log_path/run.log',
+ # Ensure the raylet process has the correct file descriptor limit.
+ f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ # Install jq for the next test.
+ f'sky exec {name} \'sudo apt-get update && sudo apt-get install -y jq\'',
+ # Check the cluster info
+ f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cluster_name | grep {name}\'',
+ f'sky logs {name} 5 --status', # Ensure the job succeeded.
+ f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cloud | grep -i {generic_cloud}\'',
+ f'sky logs {name} 6 --status', # Ensure the job succeeded.
+ # Test '-c' for exec
+ f'sky exec -c {name} echo',
+ f'sky logs {name} 7 --status',
+ f'sky exec echo -c {name}',
+ f'sky logs {name} 8 --status',
+ f'sky exec -c {name} echo hi test',
+ f'sky logs {name} 9 | grep "hi test"',
+ f'sky exec {name} && exit 1 || true',
+ f'sky exec -c {name} && exit 1 || true',
+ ],
+ f'sky down -y {name}',
+ smoke_tests_utils.get_timeout(generic_cloud),
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Test fast launch ----------
+def test_launch_fast(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+
+ test = smoke_tests_utils.Test(
+ 'test_launch_fast',
+ [
+ # First launch to create the cluster
+ f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast tests/test_yamls/minimal.yaml) && {smoke_tests_utils.VALIDATE_LAUNCH_OUTPUT}',
+ f'sky logs {name} 1 --status',
+
+ # Second launch to test fast launch - should not reprovision
+ f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast tests/test_yamls/minimal.yaml) && '
+ ' echo "$s" && '
+ # Validate that cluster was not re-launched.
+ '! echo "$s" | grep -A 1 "Launching on" | grep "is up." && '
+ # Validate that setup was not re-run.
+ '! echo "$s" | grep -A 1 "Running setup on" | grep "running setup" && '
+ # Validate that the task ran and finished.
+ 'echo "$s" | grep -A 1 "task run finish" | grep "Job finished (status: SUCCEEDED)"',
+ f'sky logs {name} 2 --status',
+ f'sky status -r {name} | grep UP',
+ ],
+ f'sky down -y {name}',
+ timeout=smoke_tests_utils.get_timeout(generic_cloud),
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# See cloud exclusion explanations in test_autostop
+@pytest.mark.no_fluidstack
+@pytest.mark.no_lambda_cloud
+@pytest.mark.no_ibm
+@pytest.mark.no_kubernetes
+def test_launch_fast_with_autostop(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure
+ # the VM is stopped.
+ autostop_timeout = 600 if generic_cloud == 'azure' else 250
+ test = smoke_tests_utils.Test(
+ 'test_launch_fast_with_autostop',
+ [
+ # First launch to create the cluster with a short autostop
+ f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast -i 1 tests/test_yamls/minimal.yaml) && {smoke_tests_utils.VALIDATE_LAUNCH_OUTPUT}',
+ f'sky logs {name} 1 --status',
+ f'sky status -r {name} | grep UP',
+
+ # Ensure cluster is stopped
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=autostop_timeout),
+ # Even the cluster is stopped, cloud platform may take a while to
+ # delete the VM.
+ f'sleep {smoke_tests_utils.BUMP_UP_SECONDS}',
+ # Launch again. Do full output validation - we expect the cluster to re-launch
+ f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast -i 1 tests/test_yamls/minimal.yaml) && {smoke_tests_utils.VALIDATE_LAUNCH_OUTPUT}',
+ f'sky logs {name} 2 --status',
+ f'sky status -r {name} | grep UP',
+ ],
+ f'sky down -y {name}',
+ timeout=smoke_tests_utils.get_timeout(generic_cloud) + autostop_timeout,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ------------ Test stale job ------------
+@pytest.mark.no_fluidstack # FluidStack does not support stopping instances in SkyPilot implementation
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support stopping instances
+@pytest.mark.no_kubernetes # Kubernetes does not support stopping instances
+def test_stale_job(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'stale_job',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} "echo hi"',
+ f'sky exec {name} -d "echo start; sleep 10000"',
+ f'sky stop {name} -y',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=100),
+ f'sky start {name} -y',
+ f'sky logs {name} 1 --status',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_aws_stale_job_manual_restart():
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, sky.AWS.max_cluster_name_length())
+ region = 'us-east-2'
+ test = smoke_tests_utils.Test(
+ 'aws_stale_job_manual_restart',
+ [
+ f'sky launch -y -c {name} --cloud aws --region {region} "echo hi"',
+ f'sky exec {name} -d "echo start; sleep 10000"',
+ # Stop the cluster manually.
+ f'id=`aws ec2 describe-instances --region {region} --filters '
+ f'Name=tag:ray-cluster-name,Values={name_on_cloud} '
+ f'--query Reservations[].Instances[].InstanceId '
+ '--output text`; '
+ f'aws ec2 stop-instances --region {region} '
+ '--instance-ids $id',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=40),
+ f'sky launch -c {name} -y "echo hi"',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 3 --status',
+ # Ensure the skylet updated the stale job status.
+ smoke_tests_utils.
+ get_cmd_wait_until_job_status_contains_without_matching_job(
+ cluster_name=name,
+ job_status=[sky.JobStatus.FAILED_DRIVER],
+ timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS),
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_stale_job_manual_restart():
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, sky.GCP.max_cluster_name_length())
+ zone = 'us-west2-a'
+ query_cmd = (f'gcloud compute instances list --filter='
+ f'"(labels.ray-cluster-name={name_on_cloud})" '
+ f'--zones={zone} --format="value(name)"')
+ stop_cmd = (f'gcloud compute instances stop --zone={zone}'
+ f' --quiet $({query_cmd})')
+ test = smoke_tests_utils.Test(
+ 'gcp_stale_job_manual_restart',
+ [
+ f'sky launch -y -c {name} --cloud gcp --zone {zone} "echo hi"',
+ f'sky exec {name} -d "echo start; sleep 10000"',
+ # Stop the cluster manually.
+ stop_cmd,
+ 'sleep 40',
+ f'sky launch -c {name} -y "echo hi"',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 3 --status',
+ # Ensure the skylet updated the stale job status.
+ smoke_tests_utils.
+ get_cmd_wait_until_job_status_contains_without_matching_job(
+ cluster_name=name,
+ job_status=[sky.JobStatus.FAILED_DRIVER],
+ timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS)
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Check Sky's environment variables; workdir. ----------
+@pytest.mark.no_fluidstack # Requires amazon S3
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+def test_env_check(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ total_timeout_minutes = 25 if generic_cloud == 'azure' else 15
+ test = smoke_tests_utils.Test(
+ 'env_check',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --detach-setup examples/env_check.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ # Test --detach-setup with only setup.
+ f'sky launch -y -c {name} --detach-setup tests/test_yamls/test_only_setup.yaml',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 2 | grep "hello world"',
+ ],
+ f'sky down -y {name}',
+ timeout=total_timeout_minutes * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- CLI logs ----------
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet. Run test_scp_logs instead.
+def test_cli_logs(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ num_nodes = 2
+ if generic_cloud == 'kubernetes':
+ # Kubernetes does not support multi-node
+ num_nodes = 1
+ timestamp = time.time()
+ test = smoke_tests_utils.Test('cli_logs', [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --num-nodes {num_nodes} "echo {timestamp} 1"',
+ f'sky exec {name} "echo {timestamp} 2"',
+ f'sky exec {name} "echo {timestamp} 3"',
+ f'sky exec {name} "echo {timestamp} 4"',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 4 --sync-down',
+ f'sky logs {name} * --sync-down',
+ f'sky logs {name} 1 | grep "{timestamp} 1"',
+ f'sky logs {name} | grep "{timestamp} 4"',
+ ], f'sky down -y {name}')
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.scp
+def test_scp_logs():
+ name = smoke_tests_utils.get_cluster_name()
+ timestamp = time.time()
+ test = smoke_tests_utils.Test(
+ 'SCP_cli_logs',
+ [
+ f'sky launch -y -c {name} {smoke_tests_utils.SCP_TYPE} "echo {timestamp} 1"',
+ f'sky exec {name} "echo {timestamp} 2"',
+ f'sky exec {name} "echo {timestamp} 3"',
+ f'sky exec {name} "echo {timestamp} 4"',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 4 --sync-down',
+ f'sky logs {name} * --sync-down',
+ f'sky logs {name} 1 | grep "{timestamp} 1"',
+ f'sky logs {name} | grep "{timestamp} 4"',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ------- Testing the core API --------
+# Most of the core APIs have been tested in the CLI tests.
+# These tests are for testing the return value of the APIs not fully used in CLI.
+
+
+@pytest.mark.gcp
+def test_core_api_sky_launch_exec():
+ name = smoke_tests_utils.get_cluster_name()
+ task = sky.Task(run="whoami")
+ task.set_resources(sky.Resources(cloud=sky.GCP()))
+ job_id, handle = sky.get(sky.launch(task, cluster_name=name))
+ assert job_id == 1
+ assert handle is not None
+ assert handle.cluster_name == name
+ assert handle.launched_resources.cloud.is_same_cloud(sky.GCP())
+ job_id_exec, handle_exec = sky.get(sky.exec(task, cluster_name=name))
+ assert job_id_exec == 2
+ assert handle_exec is not None
+ assert handle_exec.cluster_name == name
+ assert handle_exec.launched_resources.cloud.is_same_cloud(sky.GCP())
+ # For dummy task (i.e. task.run is None), the job won't be submitted.
+ dummy_task = sky.Task()
+ job_id_dummy, _ = sky.get(sky.exec(dummy_task, cluster_name=name))
+ assert job_id_dummy is None
+ sky.get(sky.down(name))
+
+
+# The sky launch CLI has some additional checks to make sure the cluster is up/
+# restarted. However, the core API doesn't have these; make sure it still works
+def test_core_api_sky_launch_fast(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ cloud = sky.utils.registry.CLOUD_REGISTRY.from_str(generic_cloud)
+ try:
+ task = sky.Task(run="whoami").set_resources(sky.Resources(cloud=cloud))
+ sky.launch(task,
+ cluster_name=name,
+ idle_minutes_to_autostop=1,
+ fast=True)
+ # Sleep to let the cluster autostop
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=120)
+ # Run it again - should work with fast=True
+ sky.launch(task,
+ cluster_name=name,
+ idle_minutes_to_autostop=1,
+ fast=True)
+ finally:
+ sky.down(name)
+
+
+# ---------- Testing YAML Specs ----------
+# Our sky storage requires credentials to check the bucket existance when
+# loading a task from the yaml file, so we cannot make it a unit test.
+class TestYamlSpecs:
+ # TODO(zhwu): Add test for `to_yaml_config` for the Storage object.
+ # We should not use `examples/storage_demo.yaml` here, since it requires
+ # users to ensure bucket names to not exist and/or be unique.
+ _TEST_YAML_PATHS = [
+ 'examples/minimal.yaml', 'examples/managed_job.yaml',
+ 'examples/using_file_mounts.yaml', 'examples/resnet_app.yaml',
+ 'examples/multi_hostname.yaml'
+ ]
+
+ def _is_dict_subset(self, d1, d2):
+ """Check if d1 is the subset of d2."""
+ for k, v in d1.items():
+ if k not in d2:
+ if isinstance(v, list) or isinstance(v, dict):
+ assert len(v) == 0, (k, v)
+ else:
+ assert False, (k, v)
+ elif isinstance(v, dict):
+ assert isinstance(d2[k], dict), (k, v, d2)
+ self._is_dict_subset(v, d2[k])
+ elif isinstance(v, str):
+ if k == 'accelerators':
+ resources = sky.Resources()
+ resources._set_accelerators(v, None)
+ assert resources.accelerators == d2[k], (k, v, d2)
+ else:
+ assert v.lower() == d2[k].lower(), (k, v, d2[k])
+ else:
+ assert v == d2[k], (k, v, d2[k])
+
+ def _check_equivalent(self, yaml_path):
+ """Check if the yaml is equivalent after load and dump again."""
+ origin_task_config = common_utils.read_yaml(yaml_path)
+
+ task = sky.Task.from_yaml(yaml_path)
+ new_task_config = task.to_yaml_config()
+ # d1 <= d2
+ print(origin_task_config, new_task_config)
+ self._is_dict_subset(origin_task_config, new_task_config)
+
+ def test_load_dump_yaml_config_equivalent(self):
+ """Test if the yaml config is equivalent after load and dump again."""
+ pathlib.Path('~/datasets').expanduser().mkdir(exist_ok=True)
+ pathlib.Path('~/tmpfile').expanduser().touch()
+ pathlib.Path('~/.ssh').expanduser().mkdir(exist_ok=True)
+ pathlib.Path('~/.ssh/id_rsa.pub').expanduser().touch()
+ pathlib.Path('~/tmp-workdir').expanduser().mkdir(exist_ok=True)
+ pathlib.Path('~/Downloads/tpu').expanduser().mkdir(parents=True,
+ exist_ok=True)
+ for yaml_path in self._TEST_YAML_PATHS:
+ self._check_equivalent(yaml_path)
+
+
+# ---------- Testing Multiple Accelerators ----------
+@pytest.mark.no_fluidstack # Fluidstack does not support K80 gpus for now
+@pytest.mark.no_paperspace # Paperspace does not support K80 gpus
+def test_multiple_accelerators_ordered():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multiple-accelerators-ordered',
+ [
+ f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_ordered.yaml | grep "Using user-specified accelerators list"',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs
+@pytest.mark.no_paperspace # Paperspace does not support T4 GPUs
+def test_multiple_accelerators_ordered_with_default():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multiple-accelerators-ordered',
+ [
+ f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_ordered_with_default.yaml | grep "Using user-specified accelerators list"',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status {name} | grep Spot',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs
+@pytest.mark.no_paperspace # Paperspace does not support T4 GPUs
+def test_multiple_accelerators_unordered():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multiple-accelerators-unordered',
+ [
+ f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_unordered.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs
+@pytest.mark.no_paperspace # Paperspace does not support T4 GPUs
+def test_multiple_accelerators_unordered_with_default():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multiple-accelerators-unordered-with-default',
+ [
+ f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_unordered_with_default.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status {name} | grep Spot',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Requires other clouds to be enabled
+def test_multiple_resources():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multiple-resources',
+ [
+ f'sky launch -y -c {name} tests/test_yamls/test_multiple_resources.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Sky Benchmark ----------
+@pytest.mark.no_fluidstack # Requires other clouds to be enabled
+@pytest.mark.no_paperspace # Requires other clouds to be enabled
+@pytest.mark.no_kubernetes
+@pytest.mark.aws # SkyBenchmark requires S3 access
+def test_sky_bench(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'sky-bench',
+ [
+ f'sky bench launch -y -b {name} --cloud {generic_cloud} -i0 tests/test_yamls/minimal.yaml',
+ 'sleep 120',
+ f'sky bench show {name} | grep sky-bench-{name} | grep FINISHED',
+ ],
+ f'sky bench down {name} -y; sky bench delete {name} -y',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+def test_kubernetes_context_failover():
+ """Test if the kubernetes context failover works.
+
+ This test requires two kubernetes clusters:
+ - kind-skypilot: the local cluster with mock labels for 8 H100 GPUs.
+ - another accessible cluster: with enough CPUs
+ To start the first cluster, run:
+ sky local up
+ # Add mock label for accelerator
+ kubectl label node --overwrite skypilot-control-plane skypilot.co/accelerator=h100 --context kind-skypilot
+ # Get the token for the cluster in context kind-skypilot
+ TOKEN=$(kubectl config view --minify --context kind-skypilot -o jsonpath=\'{.users[0].user.token}\')
+ # Get the API URL for the cluster in context kind-skypilot
+ API_URL=$(kubectl config view --minify --context kind-skypilot -o jsonpath=\'{.clusters[0].cluster.server}\')
+ # Add mock capacity for GPU
+ curl --header "Content-Type: application/json-patch+json" --header "Authorization: Bearer $TOKEN" --request PATCH --data \'[{"op": "add", "path": "/status/capacity/nvidia.com~1gpu", "value": "8"}]\' "$API_URL/api/v1/nodes/skypilot-control-plane/status"
+ # Add a new namespace to test the handling of namespaces
+ kubectl create namespace test-namespace --context kind-skypilot
+ # Set the namespace to test-namespace
+ kubectl config set-context kind-skypilot --namespace=test-namespace --context kind-skypilot
+ """
+ # Get context that is not kind-skypilot
+ contexts = subprocess.check_output('kubectl config get-contexts -o name',
+ shell=True).decode('utf-8').split('\n')
+ context = [context for context in contexts if context != 'kind-skypilot'][0]
+ config = textwrap.dedent(f"""\
+ kubernetes:
+ allowed_contexts:
+ - kind-skypilot
+ - {context}
+ """)
+ with tempfile.NamedTemporaryFile(delete=True) as f:
+ f.write(config.encode('utf-8'))
+ f.flush()
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'kubernetes-context-failover',
+ [
+ # Check if kind-skypilot is provisioned with H100 annotations already
+ 'NODE_INFO=$(kubectl get nodes -o yaml --context kind-skypilot) && '
+ 'echo "$NODE_INFO" | grep nvidia.com/gpu | grep 8 && '
+ 'echo "$NODE_INFO" | grep skypilot.co/accelerator | grep h100 || '
+ '{ echo "kind-skypilot does not exist '
+ 'or does not have mock labels for GPUs. Check the instructions in '
+ 'tests/test_smoke.py::test_kubernetes_context_failover." && exit 1; }',
+ # Check namespace for kind-skypilot is test-namespace
+ 'kubectl get namespaces --context kind-skypilot | grep test-namespace || '
+ '{ echo "Should set the namespace to test-namespace for kind-skypilot. Check the instructions in '
+ 'tests/test_smoke.py::test_kubernetes_context_failover." && exit 1; }',
+ 'sky show-gpus --cloud kubernetes --region kind-skypilot | grep H100 | grep "1, 2, 3, 4, 5, 6, 7, 8"',
+ # Get contexts and set current context to the other cluster that is not kind-skypilot
+ f'kubectl config use-context {context}',
+ # H100 should not in the current context
+ '! sky show-gpus --cloud kubernetes | grep H100',
+ f'sky launch -y -c {name}-1 --cpus 1 echo hi',
+ f'sky logs {name}-1 --status',
+ # It should be launched not on kind-skypilot
+ f'sky status -v {name}-1 | grep "{context}"',
+ # Test failure for launching H100 on other cluster
+ f'sky launch -y -c {name}-2 --gpus H100 --cpus 1 --cloud kubernetes --region {context} echo hi && exit 1 || true',
+ # Test failover
+ f'sky launch -y -c {name}-3 --gpus H100 --cpus 1 --cloud kubernetes echo hi',
+ f'sky logs {name}-3 --status',
+ # Test pods
+ f'kubectl get pods --context kind-skypilot | grep "{name}-3"',
+ # It should be launched on kind-skypilot
+ f'sky status -v {name}-3 | grep "kind-skypilot"',
+ # Should be 7 free GPUs
+ f'sky show-gpus --cloud kubernetes --region kind-skypilot | grep H100 | grep " 7"',
+ # Remove the line with "kind-skypilot"
+ f'sed -i "/kind-skypilot/d" {f.name}',
+ # Should still be able to exec and launch on existing cluster
+ f'sky exec {name}-3 "echo hi"',
+ f'sky logs {name}-3 --status',
+ f'sky status -r {name}-3 | grep UP',
+ f'sky launch -c {name}-3 --gpus h100 echo hi',
+ f'sky logs {name}-3 --status',
+ f'sky status -r {name}-3 | grep UP',
+ ],
+ f'sky down -y {name}-1 {name}-3',
+ env={'SKYPILOT_CONFIG': f.name},
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_cluster_job.py b/tests/smoke_tests/test_cluster_job.py
new file mode 100644
index 00000000000..24eb95687ca
--- /dev/null
+++ b/tests/smoke_tests/test_cluster_job.py
@@ -0,0 +1,1694 @@
+# Smoke tests for SkyPilot for sky launched cluster and cluster job
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_cluster_job.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_cluster_job.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_cluster_job.py::test_job_queue
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_cluster_job.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_cluster_job.py --generic-cloud aws
+
+import pathlib
+import tempfile
+import textwrap
+from typing import Dict
+
+import jinja2
+import pytest
+from smoke_tests import smoke_tests_utils
+
+import sky
+from sky import AWS
+from sky import Azure
+from sky import GCP
+from sky.skylet import constants
+from sky.utils import common_utils
+from sky.utils import resources_utils
+
+
+# ---------- Job Queue. ----------
+@pytest.mark.no_fluidstack # FluidStack DC has low availability of T4 GPUs
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not have T4 gpus
+@pytest.mark.no_ibm # IBM Cloud does not have T4 gpus. run test_ibm_job_queue instead
+@pytest.mark.no_scp # SCP does not have T4 gpus. Run test_scp_job_queue instead
+@pytest.mark.no_paperspace # Paperspace does not have T4 gpus.
+@pytest.mark.no_oci # OCI does not have T4 gpus
+@pytest.mark.parametrize('accelerator', [{'do': 'H100'}])
+def test_job_queue(generic_cloud: str, accelerator: Dict[str, str]):
+ accelerator = accelerator.get(generic_cloud, 'T4')
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'job_queue',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/job_queue/cluster.yaml',
+ f'sky exec {name} -n {name}-1 -d --gpus {accelerator}:0.5 examples/job_queue/job.yaml',
+ f'sky exec {name} -n {name}-2 -d --gpus {accelerator}:0.5 examples/job_queue/job.yaml',
+ f'sky exec {name} -n {name}-3 -d --gpus {accelerator}:0.5 examples/job_queue/job.yaml',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep RUNNING',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep RUNNING',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep PENDING',
+ f'sky cancel -y {name} 2',
+ 'sleep 5',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep RUNNING',
+ f'sky cancel -y {name} 3',
+ f'sky exec {name} --gpus {accelerator}:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky exec {name} --gpus {accelerator}:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky logs {name} 4 --status',
+ f'sky logs {name} 5 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Job Queue with Docker. ----------
+@pytest.mark.no_fluidstack # FluidStack does not support docker for now
+@pytest.mark.no_lambda_cloud # Doesn't support Lambda Cloud for now
+@pytest.mark.no_ibm # Doesn't support IBM Cloud for now
+@pytest.mark.no_paperspace # Paperspace doesn't have T4 GPUs
+@pytest.mark.no_scp # Doesn't support SCP for now
+@pytest.mark.no_oci # Doesn't support OCI for now
+@pytest.mark.no_kubernetes # Doesn't support Kubernetes for now
+@pytest.mark.parametrize(
+ 'image_id',
+ [
+ 'docker:nvidia/cuda:11.8.0-devel-ubuntu18.04',
+ 'docker:ubuntu:18.04',
+ # Test latest image with python 3.11 installed by default.
+ 'docker:continuumio/miniconda3:24.1.2-0',
+ # Test python>=3.12 where SkyPilot should automatically create a separate
+ # conda env for runtime with python 3.10.
+ 'docker:continuumio/miniconda3:latest',
+ # Axolotl image is a good example custom image that has its conda path
+ # set in PATH with dockerfile and uses python>=3.12. It could test:
+ # 1. we handle the env var set in dockerfile correctly
+ # 2. python>=3.12 works with SkyPilot runtime.
+ 'docker:winglian/axolotl:main-latest'
+ ])
+def test_job_queue_with_docker(generic_cloud: str, image_id: str,
+ accelerator: Dict[str, str]):
+ accelerator = accelerator.get(generic_cloud, 'T4')
+ name = smoke_tests_utils.get_cluster_name() + image_id[len('docker:'):][:4]
+ total_timeout_minutes = 40 if generic_cloud == 'azure' else 15
+ time_to_sleep = 300 if generic_cloud == 'azure' else 180
+ test = smoke_tests_utils.Test(
+ 'job_queue_with_docker',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} --image-id {image_id} examples/job_queue/cluster_docker.yaml',
+ f'sky exec {name} -n {name}-1 -d --gpus {accelerator}:0.5 --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml',
+ f'sky exec {name} -n {name}-2 -d --gpus {accelerator}:0.5 --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml',
+ f'sky exec {name} -n {name}-3 -d --gpus {accelerator}:0.5 --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep RUNNING',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep RUNNING',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep PENDING',
+ f'sky cancel -y {name} 2',
+ 'sleep 5',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep RUNNING',
+ f'sky cancel -y {name} 3',
+ # Make sure the GPU is still visible to the container.
+ f'sky exec {name} --image-id {image_id} nvidia-smi | grep -i "{accelerator}"',
+ f'sky logs {name} 4 --status',
+ f'sky stop -y {name}',
+ # Make sure the job status preserve after stop and start the
+ # cluster. This is also a test for the docker container to be
+ # preserved after stop and start.
+ f'sky start -y {name}',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep FAILED',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep CANCELLED',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep CANCELLED',
+ f'sky exec {name} --gpus {accelerator}:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky exec {name} --gpus {accelerator}:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ # Make sure it is still visible after an stop & start cycle.
+ f'sky exec {name} --image-id {image_id} nvidia-smi | grep "Tesla T4"',
+ f'sky logs {name} 7 --status'
+ ],
+ f'sky down -y {name}',
+ timeout=total_timeout_minutes * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.lambda_cloud
+def test_lambda_job_queue():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'lambda_job_queue',
+ [
+ f'sky launch -y -c {name} {smoke_tests_utils.LAMBDA_TYPE} examples/job_queue/cluster.yaml',
+ f'sky exec {name} -n {name}-1 --gpus A10:0.5 -d examples/job_queue/job.yaml',
+ f'sky exec {name} -n {name}-2 --gpus A10:0.5 -d examples/job_queue/job.yaml',
+ f'sky exec {name} -n {name}-3 --gpus A10:0.5 -d examples/job_queue/job.yaml',
+ f'sky queue {name} | grep {name}-1 | grep RUNNING',
+ f'sky queue {name} | grep {name}-2 | grep RUNNING',
+ f'sky queue {name} | grep {name}-3 | grep PENDING',
+ f'sky cancel -y {name} 2',
+ 'sleep 5',
+ f'sky queue {name} | grep {name}-3 | grep RUNNING',
+ f'sky cancel -y {name} 3',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.ibm
+def test_ibm_job_queue():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'ibm_job_queue',
+ [
+ f'sky launch -y -c {name} --cloud ibm --gpus v100',
+ f'sky exec {name} -n {name}-1 --cloud ibm -d examples/job_queue/job_ibm.yaml',
+ f'sky exec {name} -n {name}-2 --cloud ibm -d examples/job_queue/job_ibm.yaml',
+ f'sky exec {name} -n {name}-3 --cloud ibm -d examples/job_queue/job_ibm.yaml',
+ f'sky queue {name} | grep {name}-1 | grep RUNNING',
+ f'sky queue {name} | grep {name}-2 | grep RUNNING',
+ f'sky queue {name} | grep {name}-3 | grep PENDING',
+ f'sky cancel -y {name} 2',
+ 'sleep 5',
+ f'sky queue {name} | grep {name}-3 | grep RUNNING',
+ f'sky cancel -y {name} 3',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.scp
+def test_scp_job_queue():
+ name = smoke_tests_utils.get_cluster_name()
+ num_of_gpu_launch = 1
+ num_of_gpu_exec = 0.5
+ test = smoke_tests_utils.Test(
+ 'SCP_job_queue',
+ [
+ f'sky launch -y -c {name} {smoke_tests_utils.SCP_TYPE} {smoke_tests_utils.SCP_GPU_V100}:{num_of_gpu_launch} examples/job_queue/cluster.yaml',
+ f'sky exec {name} -n {name}-1 {smoke_tests_utils.SCP_GPU_V100}:{num_of_gpu_exec} -d examples/job_queue/job.yaml',
+ f'sky exec {name} -n {name}-2 {smoke_tests_utils.SCP_GPU_V100}:{num_of_gpu_exec} -d examples/job_queue/job.yaml',
+ f'sky exec {name} -n {name}-3 {smoke_tests_utils.SCP_GPU_V100}:{num_of_gpu_exec} -d examples/job_queue/job.yaml',
+ f'sky queue {name} | grep {name}-1 | grep RUNNING',
+ f'sky queue {name} | grep {name}-2 | grep RUNNING',
+ f'sky queue {name} | grep {name}-3 | grep PENDING',
+ f'sky cancel -y {name} 2',
+ 'sleep 5',
+ f'sky queue {name} | grep {name}-3 | grep RUNNING',
+ f'sky cancel -y {name} 3',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # FluidStack DC has low availability of T4 GPUs
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not have T4 gpus
+@pytest.mark.no_ibm # IBM Cloud does not have T4 gpus. run test_ibm_job_queue_multinode instead
+@pytest.mark.no_paperspace # Paperspace does not have T4 gpus.
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+@pytest.mark.no_oci # OCI Cloud does not have T4 gpus.
+@pytest.mark.no_kubernetes # Kubernetes not support num_nodes > 1 yet
+@pytest.mark.parametrize('accelerator', [{'do': 'H100'}])
+def test_job_queue_multinode(generic_cloud: str, accelerator: Dict[str, str]):
+ accelerator = accelerator.get(generic_cloud, 'T4')
+ name = smoke_tests_utils.get_cluster_name()
+ total_timeout_minutes = 30 if generic_cloud == 'azure' else 15
+ test = smoke_tests_utils.Test(
+ 'job_queue_multinode',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/job_queue/cluster_multinode.yaml',
+ f'sky exec {name} -n {name}-1 -d --gpus {accelerator}:0.5 examples/job_queue/job_multinode.yaml',
+ f'sky exec {name} -n {name}-2 -d --gpus {accelerator}:0.5 examples/job_queue/job_multinode.yaml',
+ f'sky launch -c {name} -n {name}-3 --detach-setup -d --gpus {accelerator}:0.5 examples/job_queue/job_multinode.yaml',
+ f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-1 | grep RUNNING)',
+ f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-2 | grep RUNNING)',
+ f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-3 | grep PENDING)',
+ 'sleep 90',
+ f'sky cancel -y {name} 1',
+ 'sleep 5',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep SETTING_UP',
+ f'sky cancel -y {name} 1 2 3',
+ f'sky launch -c {name} -n {name}-4 --detach-setup -d --gpus {accelerator} examples/job_queue/job_multinode.yaml',
+ # Test the job status is correctly set to SETTING_UP, during the setup is running,
+ # and the job can be cancelled during the setup.
+ 'sleep 5',
+ f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-4 | grep SETTING_UP)',
+ f'sky cancel -y {name} 4',
+ f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-4 | grep CANCELLED)',
+ f'sky exec {name} --gpus {accelerator}:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky exec {name} --gpus {accelerator}:0.2 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky exec {name} --gpus {accelerator}:1 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ f'sky logs {name} 7 --status',
+ ],
+ f'sky down -y {name}',
+ timeout=total_timeout_minutes * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # No FluidStack VM has 8 CPUs
+@pytest.mark.no_lambda_cloud # No Lambda Cloud VM has 8 CPUs
+def test_large_job_queue(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'large_job_queue',
+ [
+ f'sky launch -y -c {name} --cpus 8 --cloud {generic_cloud}',
+ f'for i in `seq 1 75`; do sky exec {name} -n {name}-$i -d "echo $i; sleep 100000000"; done',
+ f'sky cancel -y {name} 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16',
+ 'sleep 90',
+
+ # Each job takes 0.5 CPU and the default VM has 8 CPUs, so there should be 8 / 0.5 = 16 jobs running.
+ # The first 16 jobs are canceled, so there should be 75 - 32 = 43 jobs PENDING.
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep PENDING | wc -l | grep 43',
+ # Make sure the jobs are scheduled in FIFO order
+ *[
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep CANCELLED'
+ for i in range(1, 17)
+ ],
+ *[
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep RUNNING'
+ for i in range(17, 33)
+ ],
+ *[
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep PENDING'
+ for i in range(33, 75)
+ ],
+ f'sky cancel -y {name} 33 35 37 39 17 18 19',
+ *[
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep CANCELLED'
+ for i in range(33, 40, 2)
+ ],
+ 'sleep 10',
+ *[
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep RUNNING'
+ for i in [34, 36, 38]
+ ],
+ ],
+ f'sky down -y {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # No FluidStack VM has 8 CPUs
+@pytest.mark.no_lambda_cloud # No Lambda Cloud VM has 8 CPUs
+def test_fast_large_job_queue(generic_cloud: str):
+ # This is to test the jobs can be scheduled quickly when there are many jobs in the queue.
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'fast_large_job_queue',
+ [
+ f'sky launch -y -c {name} --cpus 8 --cloud {generic_cloud}',
+ f'for i in `seq 1 32`; do sky exec {name} -n {name}-$i -d "echo $i"; done',
+ 'sleep 60',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep SUCCEEDED | wc -l | grep 32',
+ ],
+ f'sky down -y {name}',
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.ibm
+def test_ibm_job_queue_multinode():
+ name = smoke_tests_utils.get_cluster_name()
+ task_file = 'examples/job_queue/job_multinode_ibm.yaml'
+ test = smoke_tests_utils.Test(
+ 'ibm_job_queue_multinode',
+ [
+ f'sky launch -y -c {name} --cloud ibm --gpus v100 --num-nodes 2',
+ f'sky exec {name} -n {name}-1 -d {task_file}',
+ f'sky exec {name} -n {name}-2 -d {task_file}',
+ f'sky launch -y -c {name} -n {name}-3 --detach-setup -d {task_file}',
+ f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-1 | grep RUNNING)',
+ f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-2 | grep RUNNING)',
+ f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-3 | grep SETTING_UP)',
+ 'sleep 90',
+ f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-3 | grep PENDING)',
+ f'sky cancel -y {name} 1',
+ 'sleep 5',
+ f'sky queue {name} | grep {name}-3 | grep RUNNING',
+ f'sky cancel -y {name} 1 2 3',
+ f'sky launch -c {name} -n {name}-4 --detach-setup -d {task_file}',
+ # Test the job status is correctly set to SETTING_UP, during the setup is running,
+ # and the job can be cancelled during the setup.
+ f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-4 | grep SETTING_UP)',
+ f'sky cancel -y {name} 4',
+ f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-4 | grep CANCELLED)',
+ f'sky exec {name} --gpus v100:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky exec {name} --gpus v100:0.2 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky exec {name} --gpus v100:1 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ f'sky logs {name} 7 --status',
+ ],
+ f'sky down -y {name}',
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Docker with preinstalled package. ----------
+@pytest.mark.no_fluidstack # Doesn't support Fluidstack for now
+@pytest.mark.no_lambda_cloud # Doesn't support Lambda Cloud for now
+@pytest.mark.no_ibm # Doesn't support IBM Cloud for now
+@pytest.mark.no_scp # Doesn't support SCP for now
+@pytest.mark.no_oci # Doesn't support OCI for now
+@pytest.mark.no_kubernetes # Doesn't support Kubernetes for now
+# TODO(zhwu): we should fix this for kubernetes
+def test_docker_preinstalled_package(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'docker_with_preinstalled_package',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --image-id docker:nginx',
+ f'sky exec {name} "nginx -V"',
+ f'sky logs {name} 1 --status',
+ f'sky exec {name} whoami | grep root',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Submitting multiple tasks to the same cluster. ----------
+@pytest.mark.no_fluidstack # FluidStack DC has low availability of T4 GPUs
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not have T4 gpus
+@pytest.mark.no_paperspace # Paperspace does not have T4 gpus
+@pytest.mark.no_ibm # IBM Cloud does not have T4 gpus
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+@pytest.mark.no_oci # OCI Cloud does not have T4 gpus
+def test_multi_echo(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multi_echo',
+ [
+ f'python examples/multi_echo.py {name} {generic_cloud}',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
+ 'sleep 10',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
+ 'sleep 30',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
+ 'sleep 30',
+ # Make sure that our job scheduler is fast enough to have at least
+ # 10 RUNNING jobs in parallel.
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "RUNNING" | wc -l | awk \'{{if ($1 < 20) exit 1}}\'',
+ 'sleep 30',
+ f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
+ # This is to make sure we can finish job 32 before the test timeout.
+ f'until sky logs {name} 32 --status; do echo "Waiting for job 32 to finish..."; sleep 1; done',
+ ] +
+ # Ensure jobs succeeded.
+ [
+ smoke_tests_utils.
+ get_cmd_wait_until_job_status_contains_matching_job_id(
+ cluster_name=name,
+ job_id=i + 1,
+ job_status=[sky.JobStatus.SUCCEEDED],
+ timeout=120) for i in range(32)
+ ] +
+ # Ensure monitor/autoscaler didn't crash on the 'assert not
+ # unfulfilled' error. If process not found, grep->ssh returns 1.
+ [f'ssh {name} \'ps aux | grep "[/]"monitor.py\''],
+ f'sky down -y {name}',
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Task: 1 node training. ----------
+@pytest.mark.no_fluidstack # Fluidstack does not have T4 gpus for now
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus
+@pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA
+@pytest.mark.no_scp # SCP does not have V100 (16GB) GPUs. Run test_scp_huggingface instead.
+@pytest.mark.parametrize('accelerator', [{'do': 'H100'}])
+def test_huggingface(generic_cloud: str, accelerator: Dict[str, str]):
+ accelerator = accelerator.get(generic_cloud, 'T4')
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'huggingface_glue_imdb_app',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/huggingface_glue_imdb_app.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} --gpus {accelerator} examples/huggingface_glue_imdb_app.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.lambda_cloud
+def test_lambda_huggingface(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'lambda_huggingface_glue_imdb_app',
+ [
+ f'sky launch -y -c {name} {smoke_tests_utils.LAMBDA_TYPE} examples/huggingface_glue_imdb_app.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} {smoke_tests_utils.LAMBDA_TYPE} examples/huggingface_glue_imdb_app.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.scp
+def test_scp_huggingface(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ num_of_gpu_launch = 1
+ test = smoke_tests_utils.Test(
+ 'SCP_huggingface_glue_imdb_app',
+ [
+ f'sky launch -y -c {name} {smoke_tests_utils.SCP_TYPE} {smoke_tests_utils.SCP_GPU_V100}:{num_of_gpu_launch} examples/huggingface_glue_imdb_app.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} {smoke_tests_utils.SCP_TYPE} {smoke_tests_utils.SCP_GPU_V100}:{num_of_gpu_launch} examples/huggingface_glue_imdb_app.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Inferentia. ----------
+@pytest.mark.aws
+def test_inferentia():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test_inferentia',
+ [
+ f'sky launch -y -c {name} -t inf2.xlarge -- echo hi',
+ f'sky exec {name} --gpus Inferentia:1 echo hi',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- TPU. ----------
+@pytest.mark.gcp
+@pytest.mark.tpu
+def test_tpu():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'tpu_app',
+ [
+ f'sky launch -y -c {name} examples/tpu/tpu_app.yaml',
+ f'sky logs {name} 1', # Ensure the job finished.
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky launch -y -c {name} examples/tpu/tpu_app.yaml | grep "TPU .* already exists"', # Ensure sky launch won't create another TPU.
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60, # can take >20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- TPU VM. ----------
+@pytest.mark.gcp
+@pytest.mark.tpu
+def test_tpu_vm():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'tpu_vm_app',
+ [
+ f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml',
+ f'sky logs {name} 1', # Ensure the job finished.
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky stop -y {name}',
+ f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', # Ensure the cluster is STOPPED.
+ # Use retry: guard against transient errors observed for
+ # just-stopped TPU VMs (#962).
+ f'sky start --retry-until-up -y {name}',
+ f'sky exec {name} examples/tpu/tpuvm_mnist.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ f'sky stop -y {name}',
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60, # can take 30 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- TPU VM Pod. ----------
+@pytest.mark.gcp
+@pytest.mark.tpu
+def test_tpu_vm_pod():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'tpu_pod',
+ [
+ f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --gpus tpu-v2-32 --use-spot --zone europe-west4-a',
+ f'sky logs {name} 1', # Ensure the job finished.
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60, # can take 30 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- TPU Pod Slice on GKE. ----------
+@pytest.mark.tpu
+@pytest.mark.kubernetes
+def test_tpu_pod_slice_gke():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'tpu_pod_slice_gke',
+ [
+ f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --cloud kubernetes --gpus tpu-v5-lite-podslice',
+ f'sky logs {name} 1', # Ensure the job finished.
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} "conda activate flax; python -c \'import jax; print(jax.devices()[0].platform);\' | grep tpu || exit 1;"', # Ensure TPU is reachable.
+ f'sky logs {name} 2 --status'
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60, # can take 30 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Simple apps. ----------
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+def test_multi_hostname(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ total_timeout_minutes = 25 if generic_cloud == 'azure' else 15
+ test = smoke_tests_utils.Test(
+ 'multi_hostname',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} examples/multi_hostname.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky logs {name} 1 | grep "My hostname:" | wc -l | grep 2', # Ensure there are 2 hosts.
+ f'sky exec {name} examples/multi_hostname.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ timeout=smoke_tests_utils.get_timeout(generic_cloud,
+ total_timeout_minutes * 60),
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+def test_multi_node_failure(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'multi_node_failure',
+ [
+ # TODO(zhwu): we use multi-thread to run the commands in setup
+ # commands in parallel, which makes it impossible to fail fast
+ # when one of the nodes fails. We should fix this in the future.
+ # The --detach-setup version can fail fast, as the setup is
+ # submitted to the remote machine, which does not use multi-thread.
+ # Refer to the comment in `subprocess_utils.run_in_parallel`.
+ # f'sky launch -y -c {name} --cloud {generic_cloud} tests/test_yamls/failed_worker_setup.yaml && exit 1', # Ensure the job setup failed.
+ f'sky launch -y -c {name} --cloud {generic_cloud} --detach-setup tests/test_yamls/failed_worker_setup.yaml',
+ f'sky logs {name} 1 --status | grep FAILED_SETUP', # Ensure the job setup failed.
+ f'sky exec {name} tests/test_yamls/failed_worker_run.yaml',
+ f'sky logs {name} 2 --status | grep FAILED', # Ensure the job failed.
+ f'sky logs {name} 2 | grep "My hostname:" | wc -l | grep 2', # Ensure there 2 of the hosts printed their hostname.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Web apps with custom ports on GCP. ----------
+@pytest.mark.gcp
+def test_gcp_http_server_with_custom_ports():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_http_server_with_custom_ports',
+ [
+ f'sky launch -y -d -c {name} --cloud gcp examples/http_server_with_custom_ports/task.yaml',
+ f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
+ # Retry a few times to avoid flakiness in ports being open.
+ f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Web apps with custom ports on AWS. ----------
+@pytest.mark.aws
+def test_aws_http_server_with_custom_ports():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_http_server_with_custom_ports',
+ [
+ f'sky launch -y -d -c {name} --cloud aws examples/http_server_with_custom_ports/task.yaml',
+ f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
+ # Retry a few times to avoid flakiness in ports being open.
+ f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi'
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Web apps with custom ports on Azure. ----------
+@pytest.mark.azure
+def test_azure_http_server_with_custom_ports():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'azure_http_server_with_custom_ports',
+ [
+ f'sky launch -y -d -c {name} --cloud azure examples/http_server_with_custom_ports/task.yaml',
+ f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
+ # Retry a few times to avoid flakiness in ports being open.
+ f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi'
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Web apps with custom ports on Kubernetes. ----------
+@pytest.mark.kubernetes
+def test_kubernetes_http_server_with_custom_ports():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'kubernetes_http_server_with_custom_ports',
+ [
+ f'sky launch -y -d -c {name} --cloud kubernetes examples/http_server_with_custom_ports/task.yaml',
+ f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
+ # Retry a few times to avoid flakiness in ports being open.
+ f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 100); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 5; done; if [ "$success" = false ]; then exit 1; fi'
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Web apps with custom ports on Paperspace. ----------
+@pytest.mark.paperspace
+def test_paperspace_http_server_with_custom_ports():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'paperspace_http_server_with_custom_ports',
+ [
+ f'sky launch -y -d -c {name} --cloud paperspace examples/http_server_with_custom_ports/task.yaml',
+ f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
+ # Retry a few times to avoid flakiness in ports being open.
+ f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Web apps with custom ports on RunPod. ----------
+@pytest.mark.runpod
+def test_runpod_http_server_with_custom_ports():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'runpod_http_server_with_custom_ports',
+ [
+ f'sky launch -y -d -c {name} --cloud runpod examples/http_server_with_custom_ports/task.yaml',
+ f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
+ # Retry a few times to avoid flakiness in ports being open.
+ f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Labels from task on AWS (instance_tags) ----------
+@pytest.mark.aws
+def test_task_labels_aws():
+ name = smoke_tests_utils.get_cluster_name()
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_labels.yaml.j2').read_text()
+ template = jinja2.Template(template_str)
+ content = template.render(cloud='aws', region='us-east-1')
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test = smoke_tests_utils.Test(
+ 'task_labels_aws',
+ [
+ f'sky launch -y -c {name} {file_path}',
+ # Verify with aws cli that the tags are set.
+ 'aws ec2 describe-instances '
+ '--query "Reservations[*].Instances[*].InstanceId" '
+ '--filters "Name=instance-state-name,Values=running" '
+ f'--filters "Name=tag:skypilot-cluster-name,Values={name}*" '
+ '--filters "Name=tag:inlinelabel1,Values=inlinevalue1" '
+ '--filters "Name=tag:inlinelabel2,Values=inlinevalue2" '
+ '--region us-east-1 --output text',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Labels from task on GCP (labels) ----------
+@pytest.mark.gcp
+def test_task_labels_gcp():
+ name = smoke_tests_utils.get_cluster_name()
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_labels.yaml.j2').read_text()
+ template = jinja2.Template(template_str)
+ content = template.render(cloud='gcp')
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test = smoke_tests_utils.Test(
+ 'task_labels_gcp',
+ [
+ f'sky launch -y -c {name} {file_path}',
+ # Verify with gcloud cli that the tags are set
+ f'gcloud compute instances list --filter="name~\'^{name}\' AND '
+ 'labels.inlinelabel1=\'inlinevalue1\' AND '
+ 'labels.inlinelabel2=\'inlinevalue2\'" '
+ '--format="value(name)" | grep .',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Labels from task on Kubernetes (labels) ----------
+@pytest.mark.kubernetes
+def test_task_labels_kubernetes():
+ name = smoke_tests_utils.get_cluster_name()
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_labels.yaml.j2').read_text()
+ template = jinja2.Template(template_str)
+ content = template.render(cloud='kubernetes')
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test = smoke_tests_utils.Test(
+ 'task_labels_kubernetes',
+ [
+ f'sky launch -y -c {name} {file_path}',
+ # Verify with kubectl that the labels are set.
+ 'kubectl get pods '
+ '--selector inlinelabel1=inlinevalue1 '
+ '--selector inlinelabel2=inlinevalue2 '
+ '-o jsonpath=\'{.items[*].metadata.name}\' | '
+ f'grep \'^{name}\''
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Pod Annotations on Kubernetes ----------
+@pytest.mark.kubernetes
+def test_add_pod_annotations_for_autodown_with_launch():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'add_pod_annotations_for_autodown_with_launch',
+ [
+ # Launch Kubernetes cluster with two nodes, each being head node and worker node.
+ # Autodown is set.
+ f'sky launch -y -c {name} -i 10 --down --num-nodes 2 --cpus=1 --cloud kubernetes',
+ # Get names of the pods containing cluster name.
+ f'pod_1=$(kubectl get pods -o name | grep {name} | sed -n 1p)',
+ f'pod_2=$(kubectl get pods -o name | grep {name} | sed -n 2p)',
+ # Describe the first pod and check for annotations.
+ 'kubectl describe pod $pod_1 | grep -q skypilot.co/autodown',
+ 'kubectl describe pod $pod_1 | grep -q skypilot.co/idle_minutes_to_autostop',
+ # Describe the second pod and check for annotations.
+ 'kubectl describe pod $pod_2 | grep -q skypilot.co/autodown',
+ 'kubectl describe pod $pod_2 | grep -q skypilot.co/idle_minutes_to_autostop'
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+def test_add_and_remove_pod_annotations_with_autostop():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'add_and_remove_pod_annotations_with_autostop',
+ [
+ # Launch Kubernetes cluster with two nodes, each being head node and worker node.
+ f'sky launch -y -c {name} --num-nodes 2 --cpus=1 --cloud kubernetes',
+ # Set autodown on the cluster with 'autostop' command.
+ f'sky autostop -y {name} -i 20 --down',
+ # Get names of the pods containing cluster name.
+ f'pod_1=$(kubectl get pods -o name | grep {name} | sed -n 1p)',
+ f'pod_2=$(kubectl get pods -o name | grep {name} | sed -n 2p)',
+ # Describe the first pod and check for annotations.
+ 'kubectl describe pod $pod_1 | grep -q skypilot.co/autodown',
+ 'kubectl describe pod $pod_1 | grep -q skypilot.co/idle_minutes_to_autostop',
+ # Describe the second pod and check for annotations.
+ 'kubectl describe pod $pod_2 | grep -q skypilot.co/autodown',
+ 'kubectl describe pod $pod_2 | grep -q skypilot.co/idle_minutes_to_autostop',
+ # Cancel the set autodown to remove the annotations from the pods.
+ f'sky autostop -y {name} --cancel',
+ # Describe the first pod and check if annotations are removed.
+ '! kubectl describe pod $pod_1 | grep -q skypilot.co/autodown',
+ '! kubectl describe pod $pod_1 | grep -q skypilot.co/idle_minutes_to_autostop',
+ # Describe the second pod and check if annotations are removed.
+ '! kubectl describe pod $pod_2 | grep -q skypilot.co/autodown',
+ '! kubectl describe pod $pod_2 | grep -q skypilot.co/idle_minutes_to_autostop',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Container logs from task on Kubernetes ----------
+@pytest.mark.kubernetes
+def test_container_logs_multinode_kubernetes():
+ name = smoke_tests_utils.get_cluster_name()
+ task_yaml = 'tests/test_yamls/test_k8s_logs.yaml'
+ head_logs = ('kubectl get pods '
+ f' | grep {name} | grep head | '
+ " awk '{print $1}' | xargs -I {} kubectl logs {}")
+ worker_logs = ('kubectl get pods '
+ f' | grep {name} | grep worker |'
+ " awk '{print $1}' | xargs -I {} kubectl logs {}")
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test = smoke_tests_utils.Test(
+ 'container_logs_multinode_kubernetes',
+ [
+ f'sky launch -y -c {name} {task_yaml} --num-nodes 2',
+ f'{head_logs} | wc -l | grep 9',
+ f'{worker_logs} | wc -l | grep 9',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+def test_container_logs_two_jobs_kubernetes():
+ name = smoke_tests_utils.get_cluster_name()
+ task_yaml = 'tests/test_yamls/test_k8s_logs.yaml'
+ pod_logs = ('kubectl get pods '
+ f' | grep {name} | grep head |'
+ " awk '{print $1}' | xargs -I {} kubectl logs {}")
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test = smoke_tests_utils.Test(
+ 'test_container_logs_two_jobs_kubernetes',
+ [
+ f'sky launch -y -c {name} {task_yaml}',
+ f'{pod_logs} | wc -l | grep 9',
+ f'sky launch -y -c {name} {task_yaml}',
+ f'{pod_logs} | wc -l | grep 18',
+ f'{pod_logs} | grep 1 | wc -l | grep 2',
+ f'{pod_logs} | grep 2 | wc -l | grep 2',
+ f'{pod_logs} | grep 3 | wc -l | grep 2',
+ f'{pod_logs} | grep 4 | wc -l | grep 2',
+ f'{pod_logs} | grep 5 | wc -l | grep 2',
+ f'{pod_logs} | grep 6 | wc -l | grep 2',
+ f'{pod_logs} | grep 7 | wc -l | grep 2',
+ f'{pod_logs} | grep 8 | wc -l | grep 2',
+ f'{pod_logs} | grep 9 | wc -l | grep 2',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+def test_container_logs_two_simultaneous_jobs_kubernetes():
+ name = smoke_tests_utils.get_cluster_name()
+ task_yaml = 'tests/test_yamls/test_k8s_logs.yaml '
+ pod_logs = ('kubectl get pods '
+ f' | grep {name} | grep head |'
+ " awk '{print $1}' | xargs -I {} kubectl logs {}")
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test = smoke_tests_utils.Test(
+ 'test_container_logs_two_simultaneous_jobs_kubernetes',
+ [
+ f'sky launch -y -c {name}',
+ f'sky exec -c {name} -d {task_yaml}',
+ f'sky exec -c {name} -d {task_yaml}',
+ 'sleep 30',
+ f'{pod_logs} | wc -l | grep 18',
+ f'{pod_logs} | grep 1 | wc -l | grep 2',
+ f'{pod_logs} | grep 2 | wc -l | grep 2',
+ f'{pod_logs} | grep 3 | wc -l | grep 2',
+ f'{pod_logs} | grep 4 | wc -l | grep 2',
+ f'{pod_logs} | grep 5 | wc -l | grep 2',
+ f'{pod_logs} | grep 6 | wc -l | grep 2',
+ f'{pod_logs} | grep 7 | wc -l | grep 2',
+ f'{pod_logs} | grep 8 | wc -l | grep 2',
+ f'{pod_logs} | grep 9 | wc -l | grep 2',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Task: n=2 nodes with setups. ----------
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus
+@pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+@pytest.mark.no_dos # DO does not have V100 gpus
+@pytest.mark.skip(
+ reason=
+ 'The resnet_distributed_tf_app is flaky, due to it failing to detect GPUs.')
+def test_distributed_tf(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'resnet_distributed_tf_app',
+ [
+ # NOTE: running it twice will hang (sometimes?) - an app-level bug.
+ f'python examples/resnet_distributed_tf_app.py {name} {generic_cloud}',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ timeout=25 * 60, # 25 mins (it takes around ~19 mins)
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing GCP start and stop instances ----------
+@pytest.mark.gcp
+def test_gcp_start_stop():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp-start-stop',
+ [
+ f'sky launch -y -c {name} examples/gcp_start_stop.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} examples/gcp_start_stop.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit.
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ f'sky stop -y {name}',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=40),
+ f'sky start -y {name} -i 1',
+ f'sky exec {name} examples/gcp_start_stop.yaml',
+ f'sky logs {name} 4 --status', # Ensure the job succeeded.
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[
+ sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT
+ ],
+ timeout=200),
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing Azure start and stop instances ----------
+@pytest.mark.azure
+def test_azure_start_stop():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'azure-start-stop',
+ [
+ f'sky launch -y -c {name} examples/azure_start_stop.yaml',
+ f'sky exec {name} examples/azure_start_stop.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit.
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ f'sky stop -y {name}',
+ f'sky start -y {name} -i 1',
+ f'sky exec {name} examples/azure_start_stop.yaml',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[
+ sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT
+ ],
+ timeout=280) +
+ f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}',
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60, # 30 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing Autostopping ----------
+@pytest.mark.no_fluidstack # FluidStack does not support stopping in SkyPilot implementation
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support stopping instances
+@pytest.mark.no_ibm # FIX(IBM) sporadically fails, as restarted workers stay uninitialized indefinitely
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+@pytest.mark.no_kubernetes # Kubernetes does not autostop yet
+def test_autostop(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure
+ # the VM is stopped.
+ autostop_timeout = 600 if generic_cloud == 'azure' else 250
+ # Launching and starting Azure clusters can take a long time too. e.g., restart
+ # a stopped Azure cluster can take 7m. So we set the total timeout to 70m.
+ total_timeout_minutes = 70 if generic_cloud == 'azure' else 20
+ test = smoke_tests_utils.Test(
+ 'autostop',
+ [
+ f'sky launch -y -d -c {name} --num-nodes 2 --cloud {generic_cloud} tests/test_yamls/minimal.yaml',
+ f'sky autostop -y {name} -i 1',
+
+ # Ensure autostop is set.
+ f'sky status | grep {name} | grep "1m"',
+
+ # Ensure the cluster is not stopped early.
+ 'sleep 40',
+ f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP',
+
+ # Ensure the cluster is STOPPED.
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=autostop_timeout),
+
+ # Ensure the cluster is UP and the autostop setting is reset ('-').
+ f'sky start -y {name}',
+ f'sky status | grep {name} | grep -E "UP\s+-"',
+
+ # Ensure the job succeeded.
+ f'sky exec {name} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 2 --status',
+
+ # Test restarting the idleness timer via reset:
+ f'sky autostop -y {name} -i 1', # Idleness starts counting.
+ 'sleep 40', # Almost reached the threshold.
+ f'sky autostop -y {name} -i 1', # Should restart the timer.
+ 'sleep 40',
+ f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=autostop_timeout),
+
+ # Test restarting the idleness timer via exec:
+ f'sky start -y {name}',
+ f'sky status | grep {name} | grep -E "UP\s+-"',
+ f'sky autostop -y {name} -i 1', # Idleness starts counting.
+ 'sleep 45', # Almost reached the threshold.
+ f'sky exec {name} echo hi', # Should restart the timer.
+ 'sleep 45',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=autostop_timeout + smoke_tests_utils._BUMP_UP_SECONDS),
+ ],
+ f'sky down -y {name}',
+ timeout=total_timeout_minutes * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing Autodowning ----------
+@pytest.mark.no_fluidstack # FluidStack does not support stopping in SkyPilot implementation
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet. Run test_scp_autodown instead.
+def test_autodown(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ # Azure takes ~ 13m30s (810s) to autodown a VM, so here we use 900 to ensure
+ # the VM is terminated.
+ autodown_timeout = 900 if generic_cloud == 'azure' else 240
+ total_timeout_minutes = 90 if generic_cloud == 'azure' else 20
+ test = smoke_tests_utils.Test(
+ 'autodown',
+ [
+ f'sky launch -y -d -c {name} --num-nodes 2 --cloud {generic_cloud} tests/test_yamls/minimal.yaml',
+ f'sky autostop -y {name} --down -i 1',
+ # Ensure autostop is set.
+ f'sky status | grep {name} | grep "1m (down)"',
+ # Ensure the cluster is not terminated early.
+ 'sleep 40',
+ f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP',
+ # Ensure the cluster is terminated.
+ f'sleep {autodown_timeout}',
+ f's=$(SKYPILOT_DEBUG=0 sky status {name} --refresh) && echo "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
+ f'sky launch -y -d -c {name} --cloud {generic_cloud} --num-nodes 2 --down tests/test_yamls/minimal.yaml',
+ f'sky status | grep {name} | grep UP', # Ensure the cluster is UP.
+ f'sky exec {name} --cloud {generic_cloud} tests/test_yamls/minimal.yaml',
+ f'sky status | grep {name} | grep "1m (down)"',
+ f'sleep {autodown_timeout}',
+ # Ensure the cluster is terminated.
+ f's=$(SKYPILOT_DEBUG=0 sky status {name} --refresh) && echo "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
+ f'sky launch -y -d -c {name} --cloud {generic_cloud} --num-nodes 2 --down tests/test_yamls/minimal.yaml',
+ f'sky autostop -y {name} --cancel',
+ f'sleep {autodown_timeout}',
+ # Ensure the cluster is still UP.
+ f's=$(SKYPILOT_DEBUG=0 sky status {name} --refresh) && echo "$s" && echo "$s" | grep {name} | grep UP',
+ ],
+ f'sky down -y {name}',
+ timeout=total_timeout_minutes * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.scp
+def test_scp_autodown():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'SCP_autodown',
+ [
+ f'sky launch -y -d -c {name} {smoke_tests_utils.SCP_TYPE} tests/test_yamls/minimal.yaml',
+ f'sky autostop -y {name} --down -i 1',
+ # Ensure autostop is set.
+ f'sky status | grep {name} | grep "1m (down)"',
+ # Ensure the cluster is not terminated early.
+ 'sleep 45',
+ f'sky status --refresh | grep {name} | grep UP',
+ # Ensure the cluster is terminated.
+ 'sleep 200',
+ f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
+ f'sky launch -y -d -c {name} {smoke_tests_utils.SCP_TYPE} --down tests/test_yamls/minimal.yaml',
+ f'sky status | grep {name} | grep UP', # Ensure the cluster is UP.
+ f'sky exec {name} {smoke_tests_utils.SCP_TYPE} tests/test_yamls/minimal.yaml',
+ f'sky status | grep {name} | grep "1m (down)"',
+ 'sleep 200',
+ # Ensure the cluster is terminated.
+ f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
+ f'sky launch -y -d -c {name} {smoke_tests_utils.SCP_TYPE} --down tests/test_yamls/minimal.yaml',
+ f'sky autostop -y {name} --cancel',
+ 'sleep 200',
+ # Ensure the cluster is still UP.
+ f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && echo "$s" | grep {name} | grep UP',
+ ],
+ f'sky down -y {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+def _get_cancel_task_with_cloud(name, cloud, timeout=15 * 60):
+ test = smoke_tests_utils.Test(
+ f'{cloud}-cancel-task',
+ [
+ f'sky launch -c {name} examples/resnet_app.yaml --cloud {cloud} -y -d',
+ # Wait the job to be scheduled and finished setup.
+ f'until sky queue {name} | grep "RUNNING"; do sleep 10; done',
+ # Wait the setup and initialize before the GPU process starts.
+ 'sleep 120',
+ f'sky exec {name} "nvidia-smi | grep python"',
+ f'sky logs {name} 2 --status || {{ sky logs {name} --no-follow 1 && exit 1; }}', # Ensure the job succeeded.
+ f'sky cancel -y {name} 1',
+ 'sleep 60',
+ # check if the python job is gone.
+ f'sky exec {name} "! nvidia-smi | grep python"',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ timeout=timeout,
+ )
+ return test
+
+
+# ---------- Testing `sky cancel` ----------
+@pytest.mark.aws
+def test_cancel_aws():
+ name = smoke_tests_utils.get_cluster_name()
+ test = _get_cancel_task_with_cloud(name, 'aws')
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_cancel_gcp():
+ name = smoke_tests_utils.get_cluster_name()
+ test = _get_cancel_task_with_cloud(name, 'gcp')
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_cancel_azure():
+ name = smoke_tests_utils.get_cluster_name()
+ test = _get_cancel_task_with_cloud(name, 'azure', timeout=30 * 60)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack does not support V100 gpus for now
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus
+@pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA
+@pytest.mark.no_paperspace # Paperspace has `gnome-shell` on nvidia-smi
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
+@pytest.mark.parametrize('accelerator', [{'do': 'H100'}])
+def test_cancel_pytorch(generic_cloud: str, accelerator: Dict[str, str]):
+ accelerator = accelerator.get(generic_cloud, 'T4')
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'cancel-pytorch',
+ [
+ f'sky launch -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/resnet_distributed_torch.yaml -y -d',
+ # Wait the GPU process to start.
+ 'sleep 90',
+ f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep python) || '
+ # When run inside container/k8s, nvidia-smi cannot show process ids.
+ # See https://github.com/NVIDIA/nvidia-docker/issues/179
+ # To work around, we check if GPU utilization is greater than 0.
+ f'[ \$(nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits) -gt 0 ]"',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ f'sky cancel -y {name} 1',
+ 'sleep 60',
+ f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep \'No running process\') || '
+ # Ensure Xorg is the only process running.
+ '[ \$(nvidia-smi | grep -A 10 Processes | grep -A 10 === | grep -v Xorg) -eq 2 ]"',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# can't use `_get_cancel_task_with_cloud()`, as command `nvidia-smi`
+# requires a CUDA public image, which IBM doesn't offer
+@pytest.mark.ibm
+def test_cancel_ibm():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'ibm-cancel-task',
+ [
+ f'sky launch -y -c {name} --cloud ibm examples/minimal.yaml',
+ f'sky exec {name} -n {name}-1 -d "while true; do echo \'Hello SkyPilot\'; sleep 2; done"',
+ 'sleep 20',
+ f'sky queue {name} | grep {name}-1 | grep RUNNING',
+ f'sky cancel -y {name} 2',
+ f'sleep 5',
+ f'sky queue {name} | grep {name}-1 | grep CANCELLED',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing use-spot option ----------
+@pytest.mark.no_fluidstack # FluidStack does not support spot instances
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
+@pytest.mark.no_paperspace # Paperspace does not support spot instances
+@pytest.mark.no_ibm # IBM Cloud does not support spot instances
+@pytest.mark.no_scp # SCP does not support spot instances
+@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
+@pytest.mark.no_do
+def test_use_spot(generic_cloud: str):
+ """Test use-spot and sky exec."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'use-spot',
+ [
+ f'sky launch -c {name} --cloud {generic_cloud} tests/test_yamls/minimal.yaml --use-spot -y',
+ f'sky logs {name} 1 --status',
+ f'sky exec {name} echo hi',
+ f'sky logs {name} 2 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_spot_instance_verification():
+ """Test Azure spot instance provisioning with explicit verification.
+ This test verifies that when --use-spot is specified for Azure:
+ 1. The cluster launches successfully
+ 2. The instances are actually provisioned as spot instances
+ """
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'azure-spot-verification',
+ [
+ f'sky launch -c {name} --cloud azure tests/test_yamls/minimal.yaml --use-spot -y',
+ f'sky logs {name} 1 --status', f'TARGET_VM_NAME="{name}"; '
+ 'VM_INFO=$(az vm list --query "[?contains(name, \'$TARGET_VM_NAME\')].{Name:name, ResourceGroup:resourceGroup}" -o tsv); '
+ '[[ -z "$VM_INFO" ]] && exit 1; '
+ 'FULL_VM_NAME=$(echo "$VM_INFO" | awk \'{print $1}\'); '
+ 'RESOURCE_GROUP=$(echo "$VM_INFO" | awk \'{print $2}\'); '
+ 'VM_DETAILS=$(az vm list --resource-group "$RESOURCE_GROUP" '
+ '--query "[?name==\'$FULL_VM_NAME\'].{Name:name, Location:location, Priority:priority}" -o table); '
+ '[[ -z "$VM_DETAILS" ]] && exit 1; '
+ 'echo "VM Details:"; echo "$VM_DETAILS"; '
+ 'echo "$VM_DETAILS" | grep -qw "Spot" && exit 0 || exit 1'
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_stop_gcp_spot():
+ """Test GCP spot can be stopped, autostopped, restarted."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'stop_gcp_spot',
+ [
+ f'sky launch -c {name} --cloud gcp --use-spot --cpus 2+ -y -- touch myfile',
+ # stop should go through:
+ f'sky stop {name} -y',
+ f'sky start {name} -y',
+ f'sky exec {name} -- ls myfile',
+ f'sky logs {name} 2 --status',
+ f'sky autostop {name} -i0 -y',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=90),
+ f'sky start {name} -y',
+ f'sky exec {name} -- ls myfile',
+ f'sky logs {name} 3 --status',
+ # -i option at launch should go through:
+ f'sky launch -c {name} -i0 -y',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=120),
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing env ----------
+def test_inline_env(generic_cloud: str):
+ """Test env"""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test-inline-env',
+ [
+ f'sky launch -c {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
+ 'sleep 20',
+ f'sky logs {name} 1 --status',
+ f'sky exec {name} --env TEST_ENV2="success" "([[ ! -z \\"\$TEST_ENV2\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
+ f'sky logs {name} 2 --status',
+ ],
+ f'sky down -y {name}',
+ smoke_tests_utils.get_timeout(generic_cloud),
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing env file ----------
+def test_inline_env_file(generic_cloud: str):
+ """Test env"""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test-inline-env-file',
+ [
+ f'sky launch -c {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
+ f'sky logs {name} 1 --status',
+ f'sky exec {name} --env-file examples/sample_dotenv "([[ ! -z \\"\$TEST_ENV2\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
+ f'sky logs {name} 2 --status',
+ ],
+ f'sky down -y {name}',
+ smoke_tests_utils.get_timeout(generic_cloud),
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing custom image ----------
+@pytest.mark.aws
+def test_aws_custom_image():
+ """Test AWS custom image"""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test-aws-custom-image',
+ [
+ f'sky launch -c {name} --retry-until-up -y tests/test_yamls/test_custom_image.yaml --cloud aws --region us-east-2 --image-id ami-062ddd90fb6f8267a', # Nvidia image
+ f'sky logs {name} 1 --status',
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+@pytest.mark.parametrize(
+ 'image_id',
+ [
+ 'docker:nvidia/cuda:11.8.0-devel-ubuntu18.04',
+ 'docker:ubuntu:18.04',
+ # Test latest image with python 3.11 installed by default.
+ 'docker:continuumio/miniconda3:24.1.2-0',
+ # Test python>=3.12 where SkyPilot should automatically create a separate
+ # conda env for runtime with python 3.10.
+ 'docker:continuumio/miniconda3:latest',
+ ])
+def test_kubernetes_custom_image(image_id):
+ """Test Kubernetes custom image"""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test-kubernetes-custom-image',
+ [
+ f'sky launch -c {name} --retry-until-up -y tests/test_yamls/test_custom_image.yaml --cloud kubernetes --image-id {image_id} --region None --gpus T4:1',
+ f'sky logs {name} 1 --status',
+ # Try exec to run again and check if the logs are printed
+ f'sky exec {name} tests/test_yamls/test_custom_image.yaml --cloud kubernetes --image-id {image_id} --region None --gpus T4:1 | grep "Hello 100"',
+ # Make sure ssh is working with custom username
+ f'ssh {name} echo hi | grep hi',
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_start_stop_two_nodes():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'azure-start-stop-two-nodes',
+ [
+ f'sky launch --num-nodes=2 -y -c {name} examples/azure_start_stop.yaml',
+ f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky stop -y {name}',
+ f'sky start -y {name} -i 1',
+ f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[
+ sky.ClusterStatus.INIT, sky.ClusterStatus.STOPPED
+ ],
+ timeout=200 + smoke_tests_utils.BUMP_UP_SECONDS) +
+ f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}'
+ ],
+ f'sky down -y {name}',
+ timeout=30 * 60, # 30 mins (it takes around ~23 mins)
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing env for disk tier ----------
+@pytest.mark.aws
+def test_aws_disk_tier():
+
+ def _get_aws_query_command(region, instance_id, field, expected):
+ return (f'aws ec2 describe-volumes --region {region} '
+ f'--filters Name=attachment.instance-id,Values={instance_id} '
+ f'--query Volumes[*].{field} | grep {expected} ; ')
+
+ for disk_tier in list(resources_utils.DiskTier):
+ specs = AWS._get_disk_specs(disk_tier)
+ name = smoke_tests_utils.get_cluster_name() + '-' + disk_tier.value
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, sky.AWS.max_cluster_name_length())
+ region = 'us-east-2'
+ test = smoke_tests_utils.Test(
+ 'aws-disk-tier-' + disk_tier.value,
+ [
+ f'sky launch -y -c {name} --cloud aws --region {region} '
+ f'--disk-tier {disk_tier.value} echo "hello sky"',
+ f'id=`aws ec2 describe-instances --region {region} --filters '
+ f'Name=tag:ray-cluster-name,Values={name_on_cloud} --query '
+ f'Reservations[].Instances[].InstanceId --output text`; ' +
+ _get_aws_query_command(region, '$id', 'VolumeType',
+ specs['disk_tier']) +
+ ('' if specs['disk_tier']
+ == 'standard' else _get_aws_query_command(
+ region, '$id', 'Iops', specs['disk_iops'])) +
+ ('' if specs['disk_tier'] != 'gp3' else _get_aws_query_command(
+ region, '$id', 'Throughput', specs['disk_throughput'])),
+ ],
+ f'sky down -y {name}',
+ timeout=10 * 60, # 10 mins (it takes around ~6 mins)
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_disk_tier():
+ for disk_tier in list(resources_utils.DiskTier):
+ disk_types = [GCP._get_disk_type(disk_tier)]
+ name = smoke_tests_utils.get_cluster_name() + '-' + disk_tier.value
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, sky.GCP.max_cluster_name_length())
+ region = 'us-west2'
+ instance_type_options = ['']
+ if disk_tier == resources_utils.DiskTier.BEST:
+ # Ultra disk tier requires n2 instance types to have more than 64 CPUs.
+ # If using default instance type, it will only enable the high disk tier.
+ disk_types = [
+ GCP._get_disk_type(resources_utils.DiskTier.HIGH),
+ GCP._get_disk_type(resources_utils.DiskTier.ULTRA),
+ ]
+ instance_type_options = ['', '--instance-type n2-standard-64']
+ for disk_type, instance_type_option in zip(disk_types,
+ instance_type_options):
+ test = smoke_tests_utils.Test(
+ 'gcp-disk-tier-' + disk_tier.value,
+ [
+ f'sky launch -y -c {name} --cloud gcp --region {region} '
+ f'--disk-tier {disk_tier.value} {instance_type_option} ',
+ f'name=`gcloud compute instances list --filter='
+ f'"labels.ray-cluster-name:{name_on_cloud}" '
+ '--format="value(name)"`; '
+ f'gcloud compute disks list --filter="name=$name" '
+ f'--format="value(type)" | grep {disk_type} '
+ ],
+ f'sky down -y {name}',
+ timeout=6 * 60, # 6 mins (it takes around ~3 mins)
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_disk_tier():
+ for disk_tier in list(resources_utils.DiskTier):
+ if disk_tier == resources_utils.DiskTier.HIGH or disk_tier == resources_utils.DiskTier.ULTRA:
+ # Azure does not support high and ultra disk tier.
+ continue
+ type = Azure._get_disk_type(disk_tier)
+ name = smoke_tests_utils.get_cluster_name() + '-' + disk_tier.value
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, sky.Azure.max_cluster_name_length())
+ region = 'westus2'
+ test = smoke_tests_utils.Test(
+ 'azure-disk-tier-' + disk_tier.value,
+ [
+ f'sky launch -y -c {name} --cloud azure --region {region} '
+ f'--disk-tier {disk_tier.value} echo "hello sky"',
+ f'az resource list --tag ray-cluster-name={name_on_cloud} --query '
+ f'"[?type==\'Microsoft.Compute/disks\'].sku.name" '
+ f'--output tsv | grep {type}'
+ ],
+ f'sky down -y {name}',
+ timeout=20 * 60, # 20 mins (it takes around ~12 mins)
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_best_tier_failover():
+ type = Azure._get_disk_type(resources_utils.DiskTier.LOW)
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, sky.Azure.max_cluster_name_length())
+ region = 'westus2'
+ test = smoke_tests_utils.Test(
+ 'azure-best-tier-failover',
+ [
+ f'sky launch -y -c {name} --cloud azure --region {region} '
+ f'--disk-tier best --instance-type Standard_D8_v5 echo "hello sky"',
+ f'az resource list --tag ray-cluster-name={name_on_cloud} --query '
+ f'"[?type==\'Microsoft.Compute/disks\'].sku.name" '
+ f'--output tsv | grep {type}',
+ ],
+ f'sky down -y {name}',
+ timeout=20 * 60, # 20 mins (it takes around ~12 mins)
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ------ Testing Zero Quota Failover ------
+@pytest.mark.aws
+def test_aws_zero_quota_failover():
+
+ name = smoke_tests_utils.get_cluster_name()
+ region = smoke_tests_utils.get_aws_region_for_quota_failover()
+
+ if not region:
+ pytest.xfail(
+ 'Unable to test zero quota failover optimization — quotas '
+ 'for EC2 P3 instances were found on all AWS regions. Is this '
+ 'expected for your account?')
+ return
+
+ test = smoke_tests_utils.Test(
+ 'aws-zero-quota-failover',
+ [
+ f'sky launch -y -c {name} --cloud aws --region {region} --gpus V100:8 --use-spot | grep "Found no quota"',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_zero_quota_failover():
+
+ name = smoke_tests_utils.get_cluster_name()
+ region = smoke_tests_utils.get_gcp_region_for_quota_failover()
+
+ if not region:
+ pytest.xfail(
+ 'Unable to test zero quota failover optimization — quotas '
+ 'for A100-80GB GPUs were found on all GCP regions. Is this '
+ 'expected for your account?')
+ return
+
+ test = smoke_tests_utils.Test(
+ 'gcp-zero-quota-failover',
+ [
+ f'sky launch -y -c {name} --cloud gcp --region {region} --gpus A100-80GB:1 --use-spot | grep "Found no quota"',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+def test_long_setup_run_script(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ with tempfile.NamedTemporaryFile('w', prefix='sky_app_',
+ suffix='.yaml') as f:
+ f.write(
+ textwrap.dedent(""" \
+ setup: |
+ echo "start long setup"
+ """))
+ for i in range(1024 * 200):
+ f.write(f' echo {i}\n')
+ f.write(' echo "end long setup"\n')
+ f.write(
+ textwrap.dedent(""" \
+ run: |
+ echo "run"
+ """))
+ for i in range(1024 * 200):
+ f.write(f' echo {i}\n')
+ f.write(' echo "end run"\n')
+ f.flush()
+
+ test = smoke_tests_utils.Test(
+ 'long-setup-run-script',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} --cpus 2+ {f.name}',
+ f'sky exec {name} "echo hello"',
+ f'sky exec {name} {f.name}',
+ f'sky logs {name} --status 1',
+ f'sky logs {name} --status 2',
+ f'sky logs {name} --status 3',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_images.py b/tests/smoke_tests/test_images.py
new file mode 100644
index 00000000000..800279d9161
--- /dev/null
+++ b/tests/smoke_tests/test_images.py
@@ -0,0 +1,407 @@
+# Smoke tests for SkyPilot for image functionality
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_images.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_images.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_images.py::test_aws_images
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_images.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_images.py --generic-cloud aws
+
+import pytest
+from smoke_tests import smoke_tests_utils
+
+import sky
+
+
+# ---------- Test the image ----------
+@pytest.mark.aws
+def test_aws_images():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_images',
+ [
+ f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-1804 examples/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky launch -c {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml && exit 1 || true',
+ f'sky launch -y -c {name} examples/minimal.yaml',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .cloud | grep -i aws\'',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_images():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_images',
+ [
+ f'sky launch -y -c {name} --image-id skypilot:gpu-debian-10 --cloud gcp tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky launch -c {name} --image-id skypilot:cpu-debian-10 --cloud gcp tests/test_yamls/minimal.yaml && exit 1 || true',
+ f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .cloud | grep -i gcp\'',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_images():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'azure_images',
+ [
+ f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-2204 --cloud azure tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky launch -c {name} --image-id skypilot:v1-ubuntu-2004 --cloud azure tests/test_yamls/minimal.yaml && exit 1 || true',
+ f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .cloud | grep -i azure\'',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_aws_image_id_dict():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_image_id_dict',
+ [
+ # Use image id dict.
+ f'sky launch -y -c {name} examples/per_region_images.yaml',
+ f'sky exec {name} examples/per_region_images.yaml',
+ f'sky exec {name} "ls ~"',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_image_id_dict():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_image_id_dict',
+ [
+ # Use image id dict.
+ f'sky launch -y -c {name} tests/test_yamls/gcp_per_region_images.yaml',
+ f'sky exec {name} tests/test_yamls/gcp_per_region_images.yaml',
+ f'sky exec {name} "ls ~"',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_aws_image_id_dict_region():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_image_id_dict_region',
+ [
+ # YAML has
+ # image_id:
+ # us-west-2: skypilot:gpu-ubuntu-1804
+ # us-east-2: skypilot:gpu-ubuntu-2004
+ # Use region to filter image_id dict.
+ f'sky launch -y -c {name} --region us-east-1 examples/per_region_images.yaml && exit 1 || true',
+ f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
+ f'sky launch -y -c {name} --region us-east-2 examples/per_region_images.yaml',
+ # Should success because the image id match for the region.
+ f'sky launch -c {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
+ f'sky exec {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
+ f'sky exec {name} --image-id skypilot:gpu-ubuntu-1804 examples/minimal.yaml && exit 1 || true',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 --status',
+ f'sky status -v | grep {name} | grep us-east-2', # Ensure the region is correct.
+ # Ensure exec works.
+ f'sky exec {name} --region us-east-2 examples/per_region_images.yaml',
+ f'sky exec {name} examples/per_region_images.yaml',
+ f'sky exec {name} --cloud aws --region us-east-2 "ls ~"',
+ f'sky exec {name} "ls ~"',
+ f'sky logs {name} 4 --status',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ f'sky logs {name} 7 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_image_id_dict_region():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_image_id_dict_region',
+ [
+ # Use region to filter image_id dict.
+ f'sky launch -y -c {name} --region us-east1 tests/test_yamls/gcp_per_region_images.yaml && exit 1 || true',
+ f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
+ f'sky launch -y -c {name} --region us-west3 tests/test_yamls/gcp_per_region_images.yaml',
+ # Should success because the image id match for the region.
+ f'sky launch -c {name} --cloud gcp --image-id projects/ubuntu-os-cloud/global/images/ubuntu-1804-bionic-v20230112 tests/test_yamls/minimal.yaml',
+ f'sky exec {name} --cloud gcp --image-id projects/ubuntu-os-cloud/global/images/ubuntu-1804-bionic-v20230112 tests/test_yamls/minimal.yaml',
+ f'sky exec {name} --cloud gcp --image-id skypilot:cpu-debian-10 tests/test_yamls/minimal.yaml && exit 1 || true',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 --status',
+ f'sky status -v | grep {name} | grep us-west3', # Ensure the region is correct.
+ # Ensure exec works.
+ f'sky exec {name} --region us-west3 tests/test_yamls/gcp_per_region_images.yaml',
+ f'sky exec {name} tests/test_yamls/gcp_per_region_images.yaml',
+ f'sky exec {name} --cloud gcp --region us-west3 "ls ~"',
+ f'sky exec {name} "ls ~"',
+ f'sky logs {name} 4 --status',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ f'sky logs {name} 7 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_aws_image_id_dict_zone():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_image_id_dict_zone',
+ [
+ # YAML has
+ # image_id:
+ # us-west-2: skypilot:gpu-ubuntu-1804
+ # us-east-2: skypilot:gpu-ubuntu-2004
+ # Use zone to filter image_id dict.
+ f'sky launch -y -c {name} --zone us-east-1b examples/per_region_images.yaml && exit 1 || true',
+ f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
+ f'sky launch -y -c {name} --zone us-east-2a examples/per_region_images.yaml',
+ # Should success because the image id match for the zone.
+ f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
+ f'sky exec {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
+ # Fail due to image id mismatch.
+ f'sky exec {name} --image-id skypilot:gpu-ubuntu-1804 examples/minimal.yaml && exit 1 || true',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 --status',
+ f'sky status -v | grep {name} | grep us-east-2a', # Ensure the zone is correct.
+ # Ensure exec works.
+ f'sky exec {name} --zone us-east-2a examples/per_region_images.yaml',
+ f'sky exec {name} examples/per_region_images.yaml',
+ f'sky exec {name} --cloud aws --region us-east-2 "ls ~"',
+ f'sky exec {name} "ls ~"',
+ f'sky logs {name} 4 --status',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ f'sky logs {name} 7 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_image_id_dict_zone():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_image_id_dict_zone',
+ [
+ # Use zone to filter image_id dict.
+ f'sky launch -y -c {name} --zone us-east1-a tests/test_yamls/gcp_per_region_images.yaml && exit 1 || true',
+ f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
+ f'sky launch -y -c {name} --zone us-central1-a tests/test_yamls/gcp_per_region_images.yaml',
+ # Should success because the image id match for the zone.
+ f'sky launch -y -c {name} --cloud gcp --image-id skypilot:cpu-debian-10 tests/test_yamls/minimal.yaml',
+ f'sky exec {name} --cloud gcp --image-id skypilot:cpu-debian-10 tests/test_yamls/minimal.yaml',
+ # Fail due to image id mismatch.
+ f'sky exec {name} --cloud gcp --image-id skypilot:gpu-debian-10 tests/test_yamls/minimal.yaml && exit 1 || true',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} 3 --status',
+ f'sky status -v | grep {name} | grep us-central1', # Ensure the zone is correct.
+ # Ensure exec works.
+ f'sky exec {name} --cloud gcp --zone us-central1-a tests/test_yamls/gcp_per_region_images.yaml',
+ f'sky exec {name} tests/test_yamls/gcp_per_region_images.yaml',
+ f'sky exec {name} --cloud gcp --region us-central1 "ls ~"',
+ f'sky exec {name} "ls ~"',
+ f'sky logs {name} 4 --status',
+ f'sky logs {name} 5 --status',
+ f'sky logs {name} 6 --status',
+ f'sky logs {name} 7 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_clone_disk_aws():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'clone_disk_aws',
+ [
+ f'sky launch -y -c {name} --cloud aws --region us-east-2 --retry-until-up "echo hello > ~/user_file.txt"',
+ f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true',
+ f'sky stop {name} -y',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=60),
+ # Wait for EC2 instance to be in stopped state.
+ # TODO: event based wait.
+ 'sleep 60',
+ f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"',
+ f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"',
+ f'sky logs {name}-clone 1 --status',
+ f'sky logs {name}-clone-2 1 --status',
+ ],
+ f'sky down -y {name} {name}-clone {name}-clone-2',
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_clone_disk_gcp():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'clone_disk_gcp',
+ [
+ f'sky launch -y -c {name} --cloud gcp --zone us-east1-b --retry-until-up "echo hello > ~/user_file.txt"',
+ f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true',
+ f'sky stop {name} -y',
+ f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud gcp --zone us-central1-a "cat ~/user_file.txt | grep hello"',
+ f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud gcp --zone us-east1-b "cat ~/user_file.txt | grep hello"',
+ f'sky logs {name}-clone 1 --status',
+ f'sky logs {name}-clone-2 1 --status',
+ ],
+ f'sky down -y {name} {name}-clone {name}-clone-2',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_mig():
+ name = smoke_tests_utils.get_cluster_name()
+ region = 'us-central1'
+ test = smoke_tests_utils.Test(
+ 'gcp_mig',
+ [
+ f'sky launch -y -c {name} --gpus t4 --num-nodes 2 --image-id skypilot:gpu-debian-10 --cloud gcp --region {region} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 2 --status',
+ f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
+ # Check MIG exists.
+ f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"',
+ f'sky autostop -i 0 --down -y {name}',
+ smoke_tests_utils.get_cmd_wait_until_cluster_is_not_found(
+ cluster_name=name, timeout=120),
+ f'gcloud compute instance-templates list | grep "sky-it-{name}"',
+ # Launch again with the same region. The original instance template
+ # should be removed.
+ f'sky launch -y -c {name} --gpus L4 --num-nodes 2 --region {region} nvidia-smi',
+ f'sky logs {name} 1 | grep "L4"',
+ f'sky down -y {name}',
+ f'gcloud compute instance-templates list | grep "sky-it-{name}" && exit 1 || true',
+ ],
+ f'sky down -y {name}',
+ env={'SKYPILOT_CONFIG': 'tests/test_yamls/use_mig_config.yaml'})
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_force_enable_external_ips():
+ name = smoke_tests_utils.get_cluster_name()
+ test_commands = [
+ f'sky launch -y -c {name} --cloud gcp --cpus 2 tests/test_yamls/minimal.yaml',
+ # Check network of vm is "default"
+ (f'gcloud compute instances list --filter=name~"{name}" --format='
+ '"value(networkInterfaces.network)" | grep "networks/default"'),
+ # Check External NAT in network access configs, corresponds to external ip
+ (f'gcloud compute instances list --filter=name~"{name}" --format='
+ '"value(networkInterfaces.accessConfigs[0].name)" | grep "External NAT"'
+ ),
+ f'sky down -y {name}',
+ ]
+ skypilot_config = 'tests/test_yamls/force_enable_external_ips_config.yaml'
+ test = smoke_tests_utils.Test('gcp_force_enable_external_ips',
+ test_commands,
+ f'sky down -y {name}',
+ env={'SKYPILOT_CONFIG': skypilot_config})
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_image_no_conda():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'image_no_conda',
+ [
+ # Use image id dict.
+ f'sky launch -y -c {name} --region us-east-2 examples/per_region_images.yaml',
+ f'sky logs {name} 1 --status',
+ f'sky stop {name} -y',
+ f'sky start {name} -y',
+ f'sky exec {name} examples/per_region_images.yaml',
+ f'sky logs {name} 2 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # FluidStack does not support stopping instances in SkyPilot implementation
+@pytest.mark.no_kubernetes # Kubernetes does not support stopping instances
+def test_custom_default_conda_env(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test('custom_default_conda_env', [
+ f'sky launch -c {name} -y --cloud {generic_cloud} tests/test_yamls/test_custom_default_conda_env.yaml',
+ f'sky status -r {name} | grep "UP"',
+ f'sky logs {name} 1 --status',
+ f'sky logs {name} 1 --no-follow | grep -E "myenv\\s+\\*"',
+ f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml',
+ f'sky logs {name} 2 --status',
+ f'sky autostop -y -i 0 {name}',
+ smoke_tests_utils.get_cmd_wait_until_cluster_status_contains(
+ cluster_name=name,
+ cluster_status=[sky.ClusterStatus.STOPPED],
+ timeout=80),
+ f'sky start -y {name}',
+ f'sky logs {name} 2 --no-follow | grep -E "myenv\\s+\\*"',
+ f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml',
+ f'sky logs {name} 3 --status',
+ ], f'sky down -y {name}')
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_managed_job.py b/tests/smoke_tests/test_managed_job.py
new file mode 100644
index 00000000000..88e22a4758f
--- /dev/null
+++ b/tests/smoke_tests/test_managed_job.py
@@ -0,0 +1,874 @@
+# Smoke tests for SkyPilot for managed jobs
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_managed_job.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_managed_job.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_managed_job.py::test_managed_jobs
+#
+# Only run managed job tests
+# > pytest tests/smoke_tests/test_managed_job.py --managed-jobs
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_managed_job.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_managed_job.py --generic-cloud aws
+
+import pathlib
+import re
+import tempfile
+import time
+
+import pytest
+from smoke_tests import smoke_tests_utils
+from smoke_tests.test_mount_and_storage import TestStorageWithCredentials
+
+import sky
+from sky import jobs
+from sky.data import storage as storage_lib
+from sky.skylet import constants
+from sky.utils import common_utils
+
+
+# ---------- Testing managed job ----------
+# TODO(zhwu): make the jobs controller on GCP, to avoid parallel test issues
+# when the controller being on Azure, which takes a long time for launching
+# step.
+@pytest.mark.managed_jobs
+def test_managed_jobs(generic_cloud: str):
+ """Test the managed jobs yaml."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'managed-jobs',
+ [
+ f'sky jobs launch -n {name}-1 --cloud {generic_cloud} examples/managed_job.yaml -y -d',
+ f'sky jobs launch -n {name}-2 --cloud {generic_cloud} examples/managed_job.yaml -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-1',
+ job_status=[
+ sky.ManagedJobStatus.PENDING,
+ sky.ManagedJobStatus.SUBMITTED,
+ sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING
+ ],
+ timeout=60),
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-2',
+ job_status=[
+ sky.ManagedJobStatus.PENDING,
+ sky.ManagedJobStatus.SUBMITTED,
+ sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING
+ ],
+ timeout=60),
+ f'sky jobs cancel -y -n {name}-1',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-1',
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=230),
+ # Test the functionality for logging.
+ f's=$(sky jobs logs -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "start counting"',
+ f's=$(sky jobs logs --controller -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "Cluster launched:"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "RUNNING\|SUCCEEDED"',
+ ],
+ # TODO(zhwu): Change to f'sky jobs cancel -y -n {name}-1 -n {name}-2' when
+ # canceling multiple job names is supported.
+ f'sky jobs cancel -y -n {name}-1; sky jobs cancel -y -n {name}-2',
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack #fluidstack does not support spot instances
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
+@pytest.mark.no_ibm # IBM Cloud does not support spot instances
+@pytest.mark.no_scp # SCP does not support spot instances
+@pytest.mark.no_paperspace # Paperspace does not support spot instances
+@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
+@pytest.mark.no_do # DO does not support spot instances
+@pytest.mark.managed_jobs
+def test_job_pipeline(generic_cloud: str):
+ """Test a job pipeline."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'job_pipeline',
+ [
+ f'sky jobs launch -n {name} tests/test_yamls/pipeline.yaml -y -d',
+ # Need to wait for setup and job initialization.
+ 'sleep 30',
+ f'{smoke_tests_utils.GET_JOB_QUEUE}| grep {name} | head -n1 | grep "STARTING\|RUNNING"',
+ # `grep -A 4 {name}` finds the job with {name} and the 4 lines
+ # after it, i.e. the 4 tasks within the job.
+ # `sed -n 2p` gets the second line of the 4 lines, i.e. the first
+ # task within the job.
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "STARTING\|RUNNING"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "PENDING"',
+ f'sky jobs cancel -y -n {name}',
+ 'sleep 5',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "CANCELLING\|CANCELLED"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "CANCELLING\|CANCELLED"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLING\|CANCELLED"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLING\|CANCELLED"',
+ 'sleep 200',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "CANCELLED"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "CANCELLED"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"',
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack #fluidstack does not support spot instances
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
+@pytest.mark.no_ibm # IBM Cloud does not support spot instances
+@pytest.mark.no_scp # SCP does not support spot instances
+@pytest.mark.no_paperspace # Paperspace does not support spot instances
+@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
+@pytest.mark.no_do # DO does not support spot instances
+@pytest.mark.managed_jobs
+def test_managed_jobs_failed_setup(generic_cloud: str):
+ """Test managed job with failed setup."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_failed_setup',
+ [
+ f'sky jobs launch -n {name} --cloud {generic_cloud} -y -d tests/test_yamls/failed_setup.yaml',
+ # Make sure the job failed quickly.
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.FAILED_SETUP],
+ timeout=330 + smoke_tests_utils.BUMP_UP_SECONDS),
+ ],
+ f'sky jobs cancel -y -n {name}',
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack #fluidstack does not support spot instances
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
+@pytest.mark.no_ibm # IBM Cloud does not support spot instances
+@pytest.mark.no_scp # SCP does not support spot instances
+@pytest.mark.no_paperspace # Paperspace does not support spot instances
+@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
+@pytest.mark.managed_jobs
+def test_managed_jobs_pipeline_failed_setup(generic_cloud: str):
+ """Test managed job with failed setup for a pipeline."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_pipeline_failed_setup',
+ [
+ f'sky jobs launch -n {name} -y -d tests/test_yamls/failed_setup_pipeline.yaml',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.FAILED_SETUP],
+ timeout=600),
+ # Make sure the job failed quickly.
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"',
+ # Task 0 should be SUCCEEDED.
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "SUCCEEDED"',
+ # Task 1 should be FAILED_SETUP.
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "FAILED_SETUP"',
+ # Task 2 should be CANCELLED.
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"',
+ # Task 3 should be CANCELLED.
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing managed job recovery ----------
+
+
+@pytest.mark.aws
+@pytest.mark.managed_jobs
+def test_managed_jobs_recovery_aws(aws_config_region):
+ """Test managed job recovery."""
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ region = aws_config_region
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_recovery_aws',
+ [
+ f'sky jobs launch --cloud aws --region {region} --use-spot -n {name} "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=600),
+ f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
+ # Terminate the cluster manually.
+ (f'aws ec2 terminate-instances --region {region} --instance-ids $('
+ f'aws ec2 describe-instances --region {region} '
+ f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}* '
+ f'--query Reservations[].Instances[].InstanceId '
+ '--output text)'),
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=200),
+ f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | grep "$RUN_ID"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.managed_jobs
+def test_managed_jobs_recovery_gcp():
+ """Test managed job recovery."""
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ zone = 'us-east4-b'
+ query_cmd = (
+ f'gcloud compute instances list --filter='
+ # `:` means prefix match.
+ f'"(labels.ray-cluster-name:{name_on_cloud})" '
+ f'--zones={zone} --format="value(name)"')
+ terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
+ f' --quiet $({query_cmd})')
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_recovery_gcp',
+ [
+ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --cpus 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=300),
+ f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
+ # Terminate the cluster manually.
+ terminate_cmd,
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=200),
+ f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+@pytest.mark.managed_jobs
+def test_managed_jobs_pipeline_recovery_aws(aws_config_region):
+ """Test managed job recovery for a pipeline."""
+ name = smoke_tests_utils.get_cluster_name()
+ user_hash = common_utils.get_user_hash()
+ user_hash = user_hash[:common_utils.USER_HASH_LENGTH_IN_CLUSTER_NAME]
+ region = aws_config_region
+ if region != 'us-east-2':
+ pytest.skip('Only run spot pipeline recovery test in us-east-2')
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_pipeline_recovery_aws',
+ [
+ f'sky jobs launch -n {name} tests/test_yamls/pipeline_aws.yaml -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=400),
+ f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
+ f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids',
+ # Terminate the cluster manually.
+ # The `cat ...| rev` is to retrieve the job_id from the
+ # SKYPILOT_TASK_ID, which gets the second to last field
+ # separated by `-`.
+ (
+ f'MANAGED_JOB_ID=`cat /tmp/{name}-run-id | rev | '
+ 'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`;'
+ f'aws ec2 terminate-instances --region {region} --instance-ids $('
+ f'aws ec2 describe-instances --region {region} '
+ # TODO(zhwu): fix the name for spot cluster.
+ '--filters Name=tag:ray-cluster-name,Values=*-${MANAGED_JOB_ID}'
+ f'-{user_hash} '
+ f'--query Reservations[].Instances[].InstanceId '
+ '--output text)'),
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=200),
+ f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"',
+ f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new',
+ f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new',
+ f'cat /tmp/{name}-run-ids | sed -n 2p | grep `cat /tmp/{name}-run-id`',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.managed_jobs
+def test_managed_jobs_pipeline_recovery_gcp():
+ """Test managed job recovery for a pipeline."""
+ name = smoke_tests_utils.get_cluster_name()
+ zone = 'us-east4-b'
+ user_hash = common_utils.get_user_hash()
+ user_hash = user_hash[:common_utils.USER_HASH_LENGTH_IN_CLUSTER_NAME]
+ query_cmd = (
+ 'gcloud compute instances list --filter='
+ f'"(labels.ray-cluster-name:*-${{MANAGED_JOB_ID}}-{user_hash})" '
+ f'--zones={zone} --format="value(name)"')
+ terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
+ f' --quiet $({query_cmd})')
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_pipeline_recovery_gcp',
+ [
+ f'sky jobs launch -n {name} tests/test_yamls/pipeline_gcp.yaml -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=400),
+ f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
+ f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids',
+ # Terminate the cluster manually.
+ # The `cat ...| rev` is to retrieve the job_id from the
+ # SKYPILOT_TASK_ID, which gets the second to last field
+ # separated by `-`.
+ (f'MANAGED_JOB_ID=`cat /tmp/{name}-run-id | rev | '
+ f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'),
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=200),
+ f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"',
+ f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new',
+ f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new',
+ f'cat /tmp/{name}-run-ids | sed -n 2p | grep `cat /tmp/{name}-run-id`',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
+@pytest.mark.no_ibm # IBM Cloud does not support spot instances
+@pytest.mark.no_scp # SCP does not support spot instances
+@pytest.mark.no_paperspace # Paperspace does not support spot instances
+@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
+@pytest.mark.no_do # DO does not have spot instances
+@pytest.mark.managed_jobs
+def test_managed_jobs_recovery_default_resources(generic_cloud: str):
+ """Test managed job recovery for default resources."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'managed-spot-recovery-default-resources',
+ [
+ f'sky jobs launch -n {name} --cloud {generic_cloud} --use-spot "sleep 30 && sudo shutdown now && sleep 1000" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[
+ sky.ManagedJobStatus.RUNNING,
+ sky.ManagedJobStatus.RECOVERING
+ ],
+ timeout=360),
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+@pytest.mark.managed_jobs
+def test_managed_jobs_recovery_multi_node_aws(aws_config_region):
+ """Test managed job recovery."""
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ region = aws_config_region
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_recovery_multi_node_aws',
+ [
+ f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=450),
+ f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
+ # Terminate the worker manually.
+ (f'aws ec2 terminate-instances --region {region} --instance-ids $('
+ f'aws ec2 describe-instances --region {region} '
+ f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}* '
+ 'Name=tag:ray-node-type,Values=worker '
+ f'--query Reservations[].Instances[].InstanceId '
+ '--output text)'),
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=560),
+ f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.managed_jobs
+def test_managed_jobs_recovery_multi_node_gcp():
+ """Test managed job recovery."""
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ zone = 'us-west2-a'
+ # Use ':' to match as the cluster name will contain the suffix with job id
+ query_cmd = (
+ f'gcloud compute instances list --filter='
+ f'"(labels.ray-cluster-name:{name_on_cloud} AND '
+ f'labels.ray-node-type=worker)" --zones={zone} --format="value(name)"')
+ terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
+ f' --quiet $({query_cmd})')
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_recovery_multi_node_gcp',
+ [
+ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=400),
+ f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
+ # Terminate the worker manually.
+ terminate_cmd,
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=560),
+ f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ timeout=25 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+@pytest.mark.managed_jobs
+def test_managed_jobs_cancellation_aws(aws_config_region):
+ name = smoke_tests_utils.get_cluster_name()
+ name_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ name_2_on_cloud = common_utils.make_cluster_name_on_cloud(
+ f'{name}-2', jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ name_3_on_cloud = common_utils.make_cluster_name_on_cloud(
+ f'{name}-3', jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ region = aws_config_region
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_cancellation_aws',
+ [
+ # Test cancellation during spot cluster being launched.
+ f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot "sleep 1000" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[
+ sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING
+ ],
+ timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS),
+ f'sky jobs cancel -y -n {name}',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=120 + smoke_tests_utils.BUMP_UP_SECONDS),
+ (f's=$(aws ec2 describe-instances --region {region} '
+ f'--filters "Name=tag:ray-cluster-name,Values={name_on_cloud}-*" '
+ '--query "Reservations[].Instances[].State[].Name" '
+ '--output text) && echo "$s" && echo; [[ -z "$s" ]] || [[ "$s" = "terminated" ]] || [[ "$s" = "shutting-down" ]]'
+ ),
+ # Test cancelling the spot cluster during spot job being setup.
+ f'sky jobs launch --cloud aws --region {region} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d',
+ # The job is set up in the cluster, will shown as RUNNING.
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-2',
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS),
+ f'sky jobs cancel -y -n {name}-2',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-2',
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=120 + smoke_tests_utils.BUMP_UP_SECONDS),
+ (f's=$(aws ec2 describe-instances --region {region} '
+ f'--filters "Name=tag:ray-cluster-name,Values={name_2_on_cloud}-*" '
+ '--query "Reservations[].Instances[].State[].Name" '
+ '--output text) && echo "$s" && echo; [[ -z "$s" ]] || [[ "$s" = "terminated" ]] || [[ "$s" = "shutting-down" ]]'
+ ),
+ # Test cancellation during spot job is recovering.
+ f'sky jobs launch --cloud aws --region {region} -n {name}-3 --use-spot "sleep 1000" -y -d',
+ # The job is running in the cluster, will shown as RUNNING.
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-3',
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # Terminate the cluster manually.
+ (f'aws ec2 terminate-instances --region {region} --instance-ids $('
+ f'aws ec2 describe-instances --region {region} '
+ f'--filters "Name=tag:ray-cluster-name,Values={name_3_on_cloud}-*" '
+ f'--query "Reservations[].Instances[].InstanceId" '
+ '--output text)'),
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"',
+ f'sky jobs cancel -y -n {name}-3',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-3',
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=120 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because
+ # there can be multiple VM with the same name due to the recovery.
+ (f's=$(aws ec2 describe-instances --region {region} '
+ f'--filters "Name=tag:ray-cluster-name,Values={name_3_on_cloud}-*" '
+ '--query "Reservations[].Instances[].State[].Name" '
+ '--output text) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "pending|running|stopped|stopping"'
+ ),
+ ],
+ timeout=25 * 60)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.managed_jobs
+def test_managed_jobs_cancellation_gcp():
+ name = smoke_tests_utils.get_cluster_name()
+ name_3 = f'{name}-3'
+ name_3_on_cloud = common_utils.make_cluster_name_on_cloud(
+ name_3, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
+ zone = 'us-west3-b'
+ query_state_cmd = (
+ 'gcloud compute instances list '
+ f'--filter="(labels.ray-cluster-name:{name_3_on_cloud})" '
+ '--format="value(status)"')
+ query_cmd = (f'gcloud compute instances list --filter='
+ f'"(labels.ray-cluster-name:{name_3_on_cloud})" '
+ f'--zones={zone} --format="value(name)"')
+ terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
+ f' --quiet $({query_cmd})')
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_cancellation_gcp',
+ [
+ # Test cancellation during spot cluster being launched.
+ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot "sleep 1000" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.STARTING],
+ timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS),
+ f'sky jobs cancel -y -n {name}',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=120 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # Test cancelling the spot cluster during spot job being setup.
+ f'sky jobs launch --cloud gcp --zone {zone} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d',
+ # The job is set up in the cluster, will shown as RUNNING.
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-2',
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS),
+ f'sky jobs cancel -y -n {name}-2',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-2',
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=120 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # Test cancellation during spot job is recovering.
+ f'sky jobs launch --cloud gcp --zone {zone} -n {name}-3 --use-spot "sleep 1000" -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-3',
+ job_status=[sky.ManagedJobStatus.RUNNING],
+ timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # Terminate the cluster manually.
+ terminate_cmd,
+ smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'),
+ f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"',
+ f'sky jobs cancel -y -n {name}-3',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=f'{name}-3',
+ job_status=[sky.ManagedJobStatus.CANCELLED],
+ timeout=120 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because
+ # there can be multiple VM with the same name due to the recovery.
+ (f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"'
+ ),
+ ],
+ timeout=25 * 60)
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing storage for managed job ----------
+@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
+@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
+@pytest.mark.no_ibm # IBM Cloud does not support spot instances
+@pytest.mark.no_paperspace # Paperspace does not support spot instances
+@pytest.mark.no_scp # SCP does not support spot instances
+@pytest.mark.no_do # DO does not support spot instances
+@pytest.mark.managed_jobs
+def test_managed_jobs_storage(generic_cloud: str):
+ """Test storage with managed job"""
+ name = smoke_tests_utils.get_cluster_name()
+ yaml_str = pathlib.Path(
+ 'examples/managed_job_with_storage.yaml').read_text()
+ timestamp = int(time.time())
+ storage_name = f'sky-test-{timestamp}'
+ output_storage_name = f'sky-test-output-{timestamp}'
+
+ # Also perform region testing for bucket creation to validate if buckets are
+ # created in the correct region and correctly mounted in managed jobs.
+ # However, we inject this testing only for AWS and GCP since they are the
+ # supported object storage providers in SkyPilot.
+ region_flag = ''
+ region_validation_cmd = 'true'
+ use_spot = ' --use-spot'
+ if generic_cloud == 'aws':
+ region = 'eu-central-1'
+ region_flag = f' --region {region}'
+ region_cmd = TestStorageWithCredentials.cli_region_cmd(
+ storage_lib.StoreType.S3, bucket_name=storage_name)
+ region_validation_cmd = f'{region_cmd} | grep {region}'
+ s3_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
+ storage_lib.StoreType.S3, output_storage_name, 'output.txt')
+ output_check_cmd = f'{s3_check_file_count} | grep 1'
+ elif generic_cloud == 'gcp':
+ region = 'us-west2'
+ region_flag = f' --region {region}'
+ region_cmd = TestStorageWithCredentials.cli_region_cmd(
+ storage_lib.StoreType.GCS, bucket_name=storage_name)
+ region_validation_cmd = f'{region_cmd} | grep {region}'
+ gcs_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
+ storage_lib.StoreType.GCS, output_storage_name, 'output.txt')
+ output_check_cmd = f'{gcs_check_file_count} | grep 1'
+ elif generic_cloud == 'azure':
+ region = 'westus2'
+ region_flag = f' --region {region}'
+ storage_account_name = (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(region))
+ region_cmd = TestStorageWithCredentials.cli_region_cmd(
+ storage_lib.StoreType.AZURE,
+ storage_account_name=storage_account_name)
+ region_validation_cmd = f'{region_cmd} | grep {region}'
+ az_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
+ storage_lib.StoreType.AZURE,
+ output_storage_name,
+ 'output.txt',
+ storage_account_name=storage_account_name)
+ output_check_cmd = f'{az_check_file_count} | grep 1'
+ elif generic_cloud == 'kubernetes':
+ # With Kubernetes, we don't know which object storage provider is used.
+ # Check both S3 and GCS if bucket exists in either.
+ s3_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
+ storage_lib.StoreType.S3, output_storage_name, 'output.txt')
+ s3_output_check_cmd = f'{s3_check_file_count} | grep 1'
+ gcs_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
+ storage_lib.StoreType.GCS, output_storage_name, 'output.txt')
+ gcs_output_check_cmd = f'{gcs_check_file_count} | grep 1'
+ output_check_cmd = f'{s3_output_check_cmd} || {gcs_output_check_cmd}'
+ use_spot = ' --no-use-spot'
+
+ yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name)
+ yaml_str = yaml_str.replace('sky-output-bucket', output_storage_name)
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(yaml_str)
+ f.flush()
+ file_path = f.name
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_storage',
+ [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky jobs launch -n {name}{use_spot} --cloud {generic_cloud}{region_flag} {file_path} -y',
+ region_validation_cmd, # Check if the bucket is created in the correct region
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.SUCCEEDED],
+ timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # Wait for the job to be cleaned up.
+ 'sleep 30',
+ f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]',
+ # Check if file was written to the mounted output bucket
+ output_check_cmd
+ ],
+ (f'sky jobs cancel -y -n {name}'
+ f'; sky storage delete {output_storage_name} -y || true'),
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_managed_jobs_intermediate_storage(generic_cloud: str):
+ """Test storage with managed job"""
+ name = smoke_tests_utils.get_cluster_name()
+ yaml_str = pathlib.Path(
+ 'examples/managed_job_with_storage.yaml').read_text()
+ timestamp = int(time.time())
+ storage_name = f'sky-test-{timestamp}'
+ output_storage_name = f'sky-test-output-{timestamp}'
+
+ yaml_str_user_config = pathlib.Path(
+ 'tests/test_yamls/use_intermediate_bucket_config.yaml').read_text()
+ intermediate_storage_name = f'intermediate-smoke-test-{timestamp}'
+
+ yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name)
+ yaml_str = yaml_str.replace('sky-output-bucket', output_storage_name)
+ yaml_str_user_config = re.sub(r'bucket-jobs-[\w\d]+',
+ intermediate_storage_name,
+ yaml_str_user_config)
+
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_user_config:
+ f_user_config.write(yaml_str_user_config)
+ f_user_config.flush()
+ user_config_path = f_user_config.name
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_task:
+ f_task.write(yaml_str)
+ f_task.flush()
+ file_path = f_task.name
+
+ test = smoke_tests_utils.Test(
+ 'managed_jobs_intermediate_storage',
+ [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ # Verify command fails with correct error - run only once
+ f'err=$(sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y 2>&1); ret=$?; echo "$err" ; [ $ret -eq 0 ] || ! echo "$err" | grep "StorageBucketCreateError: Jobs bucket \'{intermediate_storage_name}\' does not exist. Please check jobs.bucket configuration in your SkyPilot config." > /dev/null && exit 1 || exit 0',
+ f'aws s3api create-bucket --bucket {intermediate_storage_name}',
+ f'sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y',
+ # fail because the bucket does not exist
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.SUCCEEDED],
+ timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # check intermediate bucket exists, it won't be deletd if its user specific
+ f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{intermediate_storage_name}\')].Name" --output text | wc -l) -eq 1 ]',
+ ],
+ (f'sky jobs cancel -y -n {name}'
+ f'; aws s3 rb s3://{intermediate_storage_name} --force'
+ f'; sky storage delete {output_storage_name} -y || true'),
+ env={'SKYPILOT_CONFIG': user_config_path},
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing spot TPU ----------
+@pytest.mark.gcp
+@pytest.mark.managed_jobs
+@pytest.mark.tpu
+def test_managed_jobs_tpu():
+ """Test managed job on TPU."""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test-spot-tpu',
+ [
+ f'sky jobs launch -n {name} --use-spot examples/tpu/tpuvm_mnist.yaml -y -d',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.STARTING],
+ timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS),
+ # TPU takes a while to launch
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[
+ sky.ManagedJobStatus.RUNNING, sky.ManagedJobStatus.SUCCEEDED
+ ],
+ timeout=900 + smoke_tests_utils.BUMP_UP_SECONDS),
+ ],
+ f'sky jobs cancel -y -n {name}',
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing env for managed jobs ----------
+@pytest.mark.managed_jobs
+def test_managed_jobs_inline_env(generic_cloud: str):
+ """Test managed jobs env"""
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test-managed-jobs-inline-env',
+ [
+ f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "echo "\\$TEST_ENV"; ([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[sky.ManagedJobStatus.SUCCEEDED],
+ timeout=20 + smoke_tests_utils.BUMP_UP_SECONDS),
+ f'JOB_ROW=$(sky jobs queue | grep {name} | head -n1) && '
+ f'echo "$JOB_ROW" && echo "$JOB_ROW" | grep "SUCCEEDED" && '
+ f'JOB_ID=$(echo "$JOB_ROW" | awk \'{{print $1}}\') && '
+ f'echo "JOB_ID=$JOB_ID" && '
+ # Test that logs are still available after the job finishes.
+ 'unset SKYPILOT_DEBUG; s=$(sky jobs logs $JOB_ID --refresh) && echo "$s" && echo "$s" | grep "hello world" && '
+ # Make sure we skip the unnecessary logs.
+ 'echo "$s" | head -n1 | grep "Waiting for"',
+ ],
+ f'sky jobs cancel -y -n {name}',
+ # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py
new file mode 100644
index 00000000000..e1eabc70523
--- /dev/null
+++ b/tests/smoke_tests/test_mount_and_storage.py
@@ -0,0 +1,1663 @@
+# Smoke tests for SkyPilot for mounting storage
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_mount_and_storage.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_mount_and_storage.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_mount_and_storage.py::test_file_mounts
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_mount_and_storage.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_mount_and_storage.py --generic-cloud aws
+
+import json
+import os
+import pathlib
+import shlex
+import shutil
+import subprocess
+import tempfile
+import time
+from typing import Dict, Optional, TextIO
+import urllib.parse
+import uuid
+
+import jinja2
+import pytest
+from smoke_tests import smoke_tests_utils
+
+import sky
+from sky import global_user_state
+from sky import skypilot_config
+from sky.adaptors import azure
+from sky.adaptors import cloudflare
+from sky.adaptors import ibm
+from sky.data import data_utils
+from sky.data import storage as storage_lib
+from sky.data.data_utils import Rclone
+from sky.skylet import constants
+
+
+# ---------- file_mounts ----------
+@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet. Run test_scp_file_mounts instead.
+def test_file_mounts(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ extra_flags = ''
+ if generic_cloud in 'kubernetes':
+ # Kubernetes does not support multi-node
+ # NOTE: This test will fail if you have a Kubernetes cluster running on
+ # arm64 (e.g., Apple Silicon) since goofys does not work on arm64.
+ extra_flags = '--num-nodes 1'
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {name} --cloud {generic_cloud} {extra_flags} examples/using_file_mounts.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ]
+ test = smoke_tests_utils.Test(
+ 'using_file_mounts',
+ test_commands,
+ f'sky down -y {name}',
+ smoke_tests_utils.get_timeout(generic_cloud, 20 * 60), # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.scp
+def test_scp_file_mounts():
+ name = smoke_tests_utils.get_cluster_name()
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {name} {smoke_tests_utils.SCP_TYPE} --num-nodes 1 examples/using_file_mounts.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ]
+ test = smoke_tests_utils.Test(
+ 'SCP_using_file_mounts',
+ test_commands,
+ f'sky down -y {name}',
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.oci # For OCI object storage mounts and file mounts.
+def test_oci_mounts():
+ name = smoke_tests_utils.get_cluster_name()
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {name} --cloud oci --num-nodes 2 examples/oci/oci-mounts.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ ]
+ test = smoke_tests_utils.Test(
+ 'oci_mounts',
+ test_commands,
+ f'sky down -y {name}',
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Requires GCP to be enabled
+def test_using_file_mounts_with_env_vars(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ storage_name = TestStorageWithCredentials.generate_bucket_name()
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ (f'sky launch -y -c {name} --cpus 2+ --cloud {generic_cloud} '
+ 'examples/using_file_mounts_with_env_vars.yaml '
+ f'--env MY_BUCKET={storage_name}'),
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ # Override with --env:
+ (f'sky launch -y -c {name}-2 --cpus 2+ --cloud {generic_cloud} '
+ 'examples/using_file_mounts_with_env_vars.yaml '
+ f'--env MY_BUCKET={storage_name} '
+ '--env MY_LOCAL_PATH=tmpfile'),
+ f'sky logs {name}-2 1 --status', # Ensure the job succeeded.
+ ]
+ test = smoke_tests_utils.Test(
+ 'using_file_mounts_with_env_vars',
+ test_commands,
+ (f'sky down -y {name} {name}-2',
+ f'sky storage delete -y {storage_name} {storage_name}-2'),
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- storage ----------
+def _storage_mounts_commands_generator(f: TextIO, cluster_name: str,
+ storage_name: str, ls_hello_command: str,
+ cloud: str, only_mount: bool):
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_storage_mounting.yaml.j2').read_text()
+ template = jinja2.Template(template_str)
+
+ # Set mount flags based on cloud provider
+ include_s3_mount = cloud in ['aws', 'kubernetes']
+ include_gcs_mount = cloud in ['gcp', 'kubernetes']
+ include_azure_mount = cloud == 'azure'
+
+ content = template.render(
+ storage_name=storage_name,
+ cloud=cloud,
+ only_mount=only_mount,
+ include_s3_mount=include_s3_mount,
+ include_gcs_mount=include_gcs_mount,
+ include_azure_mount=include_azure_mount,
+ )
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {cluster_name} --cloud {cloud} {file_path}',
+ f'sky logs {cluster_name} 1 --status', # Ensure job succeeded.
+ ls_hello_command,
+ f'sky stop -y {cluster_name}',
+ f'sky start -y {cluster_name}',
+ # Check if hello.txt from mounting bucket exists after restart in
+ # the mounted directory
+ f'sky exec {cluster_name} -- "set -ex; ls /mount_private_mount/hello.txt"',
+ ]
+ clean_command = f'sky down -y {cluster_name}; sky storage delete -y {storage_name}'
+ return test_commands, clean_command
+
+
+@pytest.mark.aws
+def test_aws_storage_mounts_with_stop():
+ name = smoke_tests_utils.get_cluster_name()
+ cloud = 'aws'
+ storage_name = f'sky-test-{int(time.time())}'
+ ls_hello_command = f'aws s3 ls {storage_name}/hello.txt'
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test_commands, clean_command = _storage_mounts_commands_generator(
+ f, name, storage_name, ls_hello_command, cloud, False)
+ test = smoke_tests_utils.Test(
+ 'aws_storage_mounts',
+ test_commands,
+ clean_command,
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+def test_aws_storage_mounts_with_stop_only_mount():
+ name = smoke_tests_utils.get_cluster_name()
+ cloud = 'aws'
+ storage_name = f'sky-test-{int(time.time())}'
+ ls_hello_command = f'aws s3 ls {storage_name}/hello.txt'
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test_commands, clean_command = _storage_mounts_commands_generator(
+ f, name, storage_name, ls_hello_command, cloud, True)
+ test = smoke_tests_utils.Test(
+ 'aws_storage_mounts_only_mount',
+ test_commands,
+ clean_command,
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_storage_mounts_with_stop():
+ name = smoke_tests_utils.get_cluster_name()
+ cloud = 'gcp'
+ storage_name = f'sky-test-{int(time.time())}'
+ ls_hello_command = f'gsutil ls gs://{storage_name}/hello.txt'
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test_commands, clean_command = _storage_mounts_commands_generator(
+ f, name, storage_name, ls_hello_command, cloud, False)
+ test = smoke_tests_utils.Test(
+ 'gcp_storage_mounts',
+ test_commands,
+ clean_command,
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_storage_mounts_with_stop():
+ name = smoke_tests_utils.get_cluster_name()
+ cloud = 'azure'
+ storage_name = f'sky-test-{int(time.time())}'
+ default_region = 'eastus'
+ storage_account_name = (storage_lib.AzureBlobStore.
+ get_default_storage_account_name(default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ # if the file does not exist, az storage blob list returns '[]'
+ ls_hello_command = (f'output=$(az storage blob list -c {storage_name} '
+ f'--account-name {storage_account_name} '
+ f'--account-key {storage_account_key} '
+ f'--prefix hello.txt) '
+ f'[ "$output" = "[]" ] && exit 1 || exit 0')
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test_commands, clean_command = _storage_mounts_commands_generator(
+ f, name, storage_name, ls_hello_command, cloud, False)
+ test = smoke_tests_utils.Test(
+ 'azure_storage_mounts',
+ test_commands,
+ clean_command,
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+def test_kubernetes_storage_mounts():
+ # Tests bucket mounting on k8s, assuming S3 is configured.
+ # This test will fail if run on non x86_64 architecture, since goofys is
+ # built for x86_64 only.
+ name = smoke_tests_utils.get_cluster_name()
+ storage_name = f'sky-test-{int(time.time())}'
+ ls_hello_command = (f'aws s3 ls {storage_name}/hello.txt || '
+ f'gsutil ls gs://{storage_name}/hello.txt')
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ test_commands, clean_command = _storage_mounts_commands_generator(
+ f, name, storage_name, ls_hello_command, 'kubernetes', False)
+ test = smoke_tests_utils.Test(
+ 'kubernetes_storage_mounts',
+ test_commands,
+ clean_command,
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+def test_kubernetes_context_switch():
+ name = smoke_tests_utils.get_cluster_name()
+ new_context = f'sky-test-context-{int(time.time())}'
+ new_namespace = f'sky-test-namespace-{int(time.time())}'
+
+ test_commands = [
+ # Launch a cluster and run a simple task
+ f'sky launch -y -c {name} --cloud kubernetes "echo Hello from original context"',
+ f'sky logs {name} 1 --status', # Ensure job succeeded
+
+ # Get current context details and save to a file for later use in cleanup
+ 'CURRENT_CONTEXT=$(kubectl config current-context); '
+ 'echo "$CURRENT_CONTEXT" > /tmp/sky_test_current_context; '
+ 'CURRENT_CLUSTER=$(kubectl config view -o jsonpath="{.contexts[?(@.name==\\"$CURRENT_CONTEXT\\")].context.cluster}"); '
+ 'CURRENT_USER=$(kubectl config view -o jsonpath="{.contexts[?(@.name==\\"$CURRENT_CONTEXT\\")].context.user}"); '
+
+ # Create a new context with a different name and namespace
+ f'kubectl config set-context {new_context} --cluster="$CURRENT_CLUSTER" --user="$CURRENT_USER" --namespace={new_namespace}',
+
+ # Create the new namespace if it doesn't exist
+ f'kubectl create namespace {new_namespace} --dry-run=client -o yaml | kubectl apply -f -',
+
+ # Set the new context as active
+ f'kubectl config use-context {new_context}',
+
+ # Verify the new context is active
+ f'[ "$(kubectl config current-context)" = "{new_context}" ] || exit 1',
+
+ # Try to run sky exec on the original cluster (should still work)
+ f'sky exec {name} "echo Success: sky exec works after context switch"',
+
+ # Test sky queue
+ f'sky queue {name}',
+
+ # Test SSH access
+ f'ssh {name} whoami',
+ ]
+
+ cleanup_commands = (
+ f'kubectl delete namespace {new_namespace}; '
+ f'kubectl config delete-context {new_context}; '
+ 'kubectl config use-context $(cat /tmp/sky_test_current_context); '
+ 'rm /tmp/sky_test_current_context; '
+ f'sky down -y {name}')
+
+ test = smoke_tests_utils.Test(
+ 'kubernetes_context_switch',
+ test_commands,
+ cleanup_commands,
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO (zhwu): These tests may fail as it can require access cloud credentials,
+# even though the API server is running remotely. We should fix this.
+@pytest.mark.parametrize(
+ 'image_id',
+ [
+ 'docker:nvidia/cuda:11.8.0-devel-ubuntu18.04',
+ 'docker:ubuntu:18.04',
+ # Test image with python 3.11 installed by default.
+ 'docker:continuumio/miniconda3:24.1.2-0',
+ # Test python>=3.12 where SkyPilot should automatically create a separate
+ # conda env for runtime with python 3.10.
+ 'docker:continuumio/miniconda3:latest',
+ ])
+def test_docker_storage_mounts(generic_cloud: str, image_id: str):
+ # Tests bucket mounting on docker container
+ name = smoke_tests_utils.get_cluster_name()
+ timestamp = str(time.time()).replace('.', '')
+ storage_name = f'sky-test-{timestamp}'
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_storage_mounting.yaml.j2').read_text()
+ template = jinja2.Template(template_str)
+ # ubuntu 18.04 does not support fuse3, and blobfuse2 depends on fuse3.
+ azure_mount_unsupported_ubuntu_version = '18.04'
+ # Commands to verify bucket upload. We need to check all three
+ # storage types because the optimizer may pick any of them.
+ s3_command = f'aws s3 ls {storage_name}/hello.txt'
+ gsutil_command = f'gsutil ls gs://{storage_name}/hello.txt'
+ azure_blob_command = TestStorageWithCredentials.cli_ls_cmd(
+ storage_lib.StoreType.AZURE, storage_name, suffix='hello.txt')
+ if azure_mount_unsupported_ubuntu_version in image_id:
+ # The store for mount_private_mount is not specified in the template.
+ # If we're running on Azure, the private mount will be created on
+ # azure blob. That will not be supported on the ubuntu 18.04 image
+ # and thus fail. For other clouds, the private mount on other
+ # storage types (GCS/S3) should succeed.
+ include_private_mount = False if generic_cloud == 'azure' else True
+ content = template.render(storage_name=storage_name,
+ include_azure_mount=False,
+ include_private_mount=include_private_mount)
+ else:
+ content = template.render(storage_name=storage_name,)
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {name} --cloud {generic_cloud} --image-id {image_id} {file_path}',
+ f'sky logs {name} 1 --status', # Ensure job succeeded.
+ # Check AWS, GCP, or Azure storage mount.
+ f'sky exec {name} -- "{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV}; {s3_command} || {gsutil_command} || {azure_blob_command}"',
+ f'sky logs {name} 2 --status', # Ensure the bucket check succeeded.
+ ]
+ test = smoke_tests_utils.Test(
+ 'docker_storage_mounts',
+ test_commands,
+ f'sky down -y {name}; sky storage delete -y {storage_name}',
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.cloudflare
+def test_cloudflare_storage_mounts(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ storage_name = f'sky-test-{int(time.time())}'
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_r2_storage_mounting.yaml').read_text()
+ template = jinja2.Template(template_str)
+ content = template.render(storage_name=storage_name)
+ endpoint_url = cloudflare.create_endpoint()
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {name} --cloud {generic_cloud} {file_path}',
+ f'sky logs {name} 1 --status', # Ensure job succeeded.
+ f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls s3://{storage_name}/hello.txt --endpoint {endpoint_url} --profile=r2'
+ ]
+
+ test = smoke_tests_utils.Test(
+ 'cloudflare_storage_mounts',
+ test_commands,
+ f'sky down -y {name}; sky storage delete -y {storage_name}',
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.ibm
+def test_ibm_storage_mounts():
+ name = smoke_tests_utils.get_cluster_name()
+ storage_name = f'sky-test-{int(time.time())}'
+ bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name(
+ storage_name, Rclone.RcloneClouds.IBM)
+ template_str = pathlib.Path(
+ 'tests/test_yamls/test_ibm_cos_storage_mounting.yaml').read_text()
+ template = jinja2.Template(template_str)
+ content = template.render(storage_name=storage_name)
+ with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
+ f.write(content)
+ f.flush()
+ file_path = f.name
+ test_commands = [
+ *smoke_tests_utils.STORAGE_SETUP_COMMANDS,
+ f'sky launch -y -c {name} --cloud ibm {file_path}',
+ f'sky logs {name} 1 --status', # Ensure job succeeded.
+ f'rclone ls {bucket_rclone_profile}:{storage_name}/hello.txt',
+ ]
+ test = smoke_tests_utils.Test(
+ 'ibm_storage_mounts',
+ test_commands,
+ f'sky down -y {name}; sky storage delete -y {storage_name}',
+ timeout=20 * 60, # 20 mins
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Testing Storage ----------
+# These tests are essentially unit tests for Storage, but they require
+# credentials and network connection. Thus, they are included with smoke tests.
+# Since these tests require cloud credentials to verify bucket operations,
+# they should not be run when the API server is remote and the user does not
+# have any credentials locally.
+# TODO(romilb): In the future, we should figure out a way to ship these tests
+# to the API server and run them there. Maybe these tests can be packaged as a
+# SkyPilot task run on a remote cluster launched via the API server.
+@pytest.mark.local
+class TestStorageWithCredentials:
+ """Storage tests which require credentials and network connection"""
+
+ AWS_INVALID_NAMES = [
+ 'ab', # less than 3 characters
+ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
+ # more than 63 characters
+ 'Abcdef', # contains an uppercase letter
+ 'abc def', # contains a space
+ 'abc..def', # two adjacent periods
+ '192.168.5.4', # formatted as an IP address
+ 'xn--bucket', # starts with 'xn--' prefix
+ 'bucket-s3alias', # ends with '-s3alias' suffix
+ 'bucket--ol-s3', # ends with '--ol-s3' suffix
+ '.abc', # starts with a dot
+ 'abc.', # ends with a dot
+ '-abc', # starts with a hyphen
+ 'abc-', # ends with a hyphen
+ ]
+
+ GCS_INVALID_NAMES = [
+ 'ab', # less than 3 characters
+ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
+ # more than 63 characters (without dots)
+ 'Abcdef', # contains an uppercase letter
+ 'abc def', # contains a space
+ 'abc..def', # two adjacent periods
+ 'abc_.def.ghi.jklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1'
+ # More than 63 characters between dots
+ 'abc_.def.ghi.jklmnopqrstuvwxyzabcdefghijklmnopqfghijklmnopqrstuvw' * 5,
+ # more than 222 characters (with dots)
+ '192.168.5.4', # formatted as an IP address
+ 'googbucket', # starts with 'goog' prefix
+ 'googlebucket', # contains 'google'
+ 'g00glebucket', # variant of 'google'
+ 'go0glebucket', # variant of 'google'
+ 'g0oglebucket', # variant of 'google'
+ '.abc', # starts with a dot
+ 'abc.', # ends with a dot
+ '_abc', # starts with an underscore
+ 'abc_', # ends with an underscore
+ ]
+
+ AZURE_INVALID_NAMES = [
+ 'ab', # less than 3 characters
+ # more than 63 characters
+ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
+ 'Abcdef', # contains an uppercase letter
+ '.abc', # starts with a non-letter(dot)
+ 'a--bc', # contains consecutive hyphens
+ ]
+
+ IBM_INVALID_NAMES = [
+ 'ab', # less than 3 characters
+ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
+ # more than 63 characters
+ 'Abcdef', # contains an uppercase letter
+ 'abc def', # contains a space
+ 'abc..def', # two adjacent periods
+ '192.168.5.4', # formatted as an IP address
+ 'xn--bucket', # starts with 'xn--' prefix
+ '.abc', # starts with a dot
+ 'abc.', # ends with a dot
+ '-abc', # starts with a hyphen
+ 'abc-', # ends with a hyphen
+ 'a.-bc', # contains the sequence '.-'
+ 'a-.bc', # contains the sequence '-.'
+ 'a&bc' # contains special characters
+ 'ab^c' # contains special characters
+ ]
+ GITIGNORE_SYNC_TEST_DIR_STRUCTURE = {
+ 'double_asterisk': {
+ 'double_asterisk_excluded': None,
+ 'double_asterisk_excluded_dir': {
+ 'dir_excluded': None,
+ },
+ },
+ 'double_asterisk_parent': {
+ 'parent': {
+ 'also_excluded.txt': None,
+ 'child': {
+ 'double_asterisk_parent_child_excluded.txt': None,
+ },
+ 'double_asterisk_parent_excluded.txt': None,
+ },
+ },
+ 'excluded.log': None,
+ 'excluded_dir': {
+ 'excluded.txt': None,
+ 'nested_excluded': {
+ 'excluded': None,
+ },
+ },
+ 'exp-1': {
+ 'be_excluded': None,
+ },
+ 'exp-2': {
+ 'be_excluded': None,
+ },
+ 'front_slash_excluded': None,
+ 'included.log': None,
+ 'included.txt': None,
+ 'include_dir': {
+ 'excluded.log': None,
+ 'included.log': None,
+ },
+ 'nested_double_asterisk': {
+ 'one': {
+ 'also_exclude.txt': None,
+ },
+ 'two': {
+ 'also_exclude.txt': None,
+ },
+ },
+ 'nested_wildcard_dir': {
+ 'monday': {
+ 'also_exclude.txt': None,
+ },
+ 'tuesday': {
+ 'also_exclude.txt': None,
+ },
+ },
+ 'no_slash_excluded': None,
+ 'no_slash_tests': {
+ 'no_slash_excluded': {
+ 'also_excluded.txt': None,
+ },
+ },
+ 'question_mark': {
+ 'excluded1.txt': None,
+ 'excluded@.txt': None,
+ },
+ 'square_bracket': {
+ 'excluded1.txt': None,
+ },
+ 'square_bracket_alpha': {
+ 'excludedz.txt': None,
+ },
+ 'square_bracket_excla': {
+ 'excluded2.txt': None,
+ 'excluded@.txt': None,
+ },
+ 'square_bracket_single': {
+ 'excluded0.txt': None,
+ },
+ }
+
+ @staticmethod
+ def create_dir_structure(base_path, structure):
+ # creates a given file STRUCTURE in BASE_PATH
+ for name, substructure in structure.items():
+ path = os.path.join(base_path, name)
+ if substructure is None:
+ # Create a file
+ open(path, 'a', encoding='utf-8').close()
+ else:
+ # Create a subdirectory
+ os.mkdir(path)
+ TestStorageWithCredentials.create_dir_structure(
+ path, substructure)
+
+ @staticmethod
+ def cli_delete_cmd(store_type,
+ bucket_name,
+ storage_account_name: str = None):
+ if store_type == storage_lib.StoreType.S3:
+ url = f's3://{bucket_name}'
+ return f'aws s3 rb {url} --force'
+ if store_type == storage_lib.StoreType.GCS:
+ url = f'gs://{bucket_name}'
+ gsutil_alias, alias_gen = data_utils.get_gsutil_command()
+ return f'{alias_gen}; {gsutil_alias} rm -r {url}'
+ if store_type == storage_lib.StoreType.AZURE:
+ default_region = 'eastus'
+ storage_account_name = (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(
+ default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ return ('az storage container delete '
+ f'--account-name {storage_account_name} '
+ f'--account-key {storage_account_key} '
+ f'--name {bucket_name}')
+ if store_type == storage_lib.StoreType.R2:
+ endpoint_url = cloudflare.create_endpoint()
+ url = f's3://{bucket_name}'
+ return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 rb {url} --force --endpoint {endpoint_url} --profile=r2'
+ if store_type == storage_lib.StoreType.IBM:
+ bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name(
+ bucket_name, Rclone.RcloneClouds.IBM)
+ return f'rclone purge {bucket_rclone_profile}:{bucket_name} && rclone config delete {bucket_rclone_profile}'
+
+ @classmethod
+ def list_all_files(cls, store_type, bucket_name):
+ cmd = cls.cli_ls_cmd(store_type, bucket_name, recursive=True)
+ if store_type == storage_lib.StoreType.GCS:
+ try:
+ out = subprocess.check_output(cmd,
+ shell=True,
+ stderr=subprocess.PIPE)
+ files = [line[5:] for line in out.decode('utf-8').splitlines()]
+ except subprocess.CalledProcessError as e:
+ error_output = e.stderr.decode('utf-8')
+ if "One or more URLs matched no objects" in error_output:
+ files = []
+ else:
+ raise
+ elif store_type == storage_lib.StoreType.AZURE:
+ out = subprocess.check_output(cmd, shell=True)
+ try:
+ blobs = json.loads(out.decode('utf-8'))
+ files = [blob['name'] for blob in blobs]
+ except json.JSONDecodeError:
+ files = []
+ elif store_type == storage_lib.StoreType.IBM:
+ # rclone ls format: " 1234 path/to/file"
+ out = subprocess.check_output(cmd, shell=True)
+ files = []
+ for line in out.decode('utf-8').splitlines():
+ # Skip empty lines
+ if not line.strip():
+ continue
+ # Split by whitespace and get the file path (last column)
+ parts = line.strip().split(
+ None, 1) # Split into max 2 parts (size and path)
+ if len(parts) == 2:
+ files.append(parts[1])
+ else:
+ out = subprocess.check_output(cmd, shell=True)
+ files = [
+ line.split()[-1] for line in out.decode('utf-8').splitlines()
+ ]
+ return files
+
+ @staticmethod
+ def cli_ls_cmd(store_type, bucket_name, suffix='', recursive=False):
+ if store_type == storage_lib.StoreType.S3:
+ if suffix:
+ url = f's3://{bucket_name}/{suffix}'
+ else:
+ url = f's3://{bucket_name}'
+ cmd = f'aws s3 ls {url}'
+ if recursive:
+ cmd += ' --recursive'
+ return cmd
+ if store_type == storage_lib.StoreType.GCS:
+ if suffix:
+ url = f'gs://{bucket_name}/{suffix}'
+ else:
+ url = f'gs://{bucket_name}'
+ if recursive:
+ url = f'"{url}/**"'
+ return f'gsutil ls {url}'
+ if store_type == storage_lib.StoreType.AZURE:
+ # azure isrecursive by default
+ default_region = 'eastus'
+ config_storage_account = skypilot_config.get_nested(
+ ('azure', 'storage_account'), None)
+ storage_account_name = config_storage_account if (
+ config_storage_account is not None) else (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(
+ default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ list_cmd = ('az storage blob list '
+ f'--container-name {bucket_name} '
+ f'--prefix {shlex.quote(suffix)} '
+ f'--account-name {storage_account_name} '
+ f'--account-key {storage_account_key}')
+ return list_cmd
+ if store_type == storage_lib.StoreType.R2:
+ endpoint_url = cloudflare.create_endpoint()
+ if suffix:
+ url = f's3://{bucket_name}/{suffix}'
+ else:
+ url = f's3://{bucket_name}'
+ recursive_flag = '--recursive' if recursive else ''
+ return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2 {recursive_flag}'
+ if store_type == storage_lib.StoreType.IBM:
+ # rclone ls is recursive by default
+ bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name(
+ bucket_name, Rclone.RcloneClouds.IBM)
+ return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}'
+
+ @staticmethod
+ def cli_region_cmd(store_type, bucket_name=None, storage_account_name=None):
+ if store_type == storage_lib.StoreType.S3:
+ assert bucket_name is not None
+ return ('aws s3api get-bucket-location '
+ f'--bucket {bucket_name} --output text')
+ elif store_type == storage_lib.StoreType.GCS:
+ assert bucket_name is not None
+ return (f'gsutil ls -L -b gs://{bucket_name}/ | '
+ 'grep "Location constraint" | '
+ 'awk \'{print tolower($NF)}\'')
+ elif store_type == storage_lib.StoreType.AZURE:
+ # For Azure Blob Storage, the location of the containers are
+ # determined by the location of storage accounts.
+ assert storage_account_name is not None
+ return (f'az storage account show --name {storage_account_name} '
+ '--query "primaryLocation" --output tsv')
+ else:
+ raise NotImplementedError(f'Region command not implemented for '
+ f'{store_type}')
+
+ @staticmethod
+ def cli_count_name_in_bucket(store_type,
+ bucket_name,
+ file_name,
+ suffix='',
+ storage_account_name=None):
+ if store_type == storage_lib.StoreType.S3:
+ if suffix:
+ return f'aws s3api list-objects --bucket "{bucket_name}" --prefix {suffix} --query "length(Contents[?contains(Key,\'{file_name}\')].Key)"'
+ else:
+ return f'aws s3api list-objects --bucket "{bucket_name}" --query "length(Contents[?contains(Key,\'{file_name}\')].Key)"'
+ elif store_type == storage_lib.StoreType.GCS:
+ if suffix:
+ return f'gsutil ls -r gs://{bucket_name}/{suffix} | grep "{file_name}" | wc -l'
+ else:
+ return f'gsutil ls -r gs://{bucket_name} | grep "{file_name}" | wc -l'
+ elif store_type == storage_lib.StoreType.AZURE:
+ if storage_account_name is None:
+ default_region = 'eastus'
+ storage_account_name = (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(
+ default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ return ('az storage blob list '
+ f'--container-name {bucket_name} '
+ f'--prefix {shlex.quote(suffix)} '
+ f'--account-name {storage_account_name} '
+ f'--account-key {storage_account_key} | '
+ f'grep {file_name} | '
+ 'wc -l')
+ elif store_type == storage_lib.StoreType.R2:
+ endpoint_url = cloudflare.create_endpoint()
+ if suffix:
+ return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3api list-objects --bucket "{bucket_name}" --prefix {suffix} --query "length(Contents[?contains(Key,\'{file_name}\')].Key)" --endpoint {endpoint_url} --profile=r2'
+ else:
+ return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3api list-objects --bucket "{bucket_name}" --query "length(Contents[?contains(Key,\'{file_name}\')].Key)" --endpoint {endpoint_url} --profile=r2'
+
+ @staticmethod
+ def cli_count_file_in_bucket(store_type, bucket_name):
+ if store_type == storage_lib.StoreType.S3:
+ return f'aws s3 ls s3://{bucket_name} --recursive | wc -l'
+ elif store_type == storage_lib.StoreType.GCS:
+ return f'gsutil ls -r gs://{bucket_name}/** | wc -l'
+ elif store_type == storage_lib.StoreType.AZURE:
+ default_region = 'eastus'
+ storage_account_name = (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(
+ default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ return ('az storage blob list '
+ f'--container-name {bucket_name} '
+ f'--account-name {storage_account_name} '
+ f'--account-key {storage_account_key} | '
+ 'grep \\"name\\": | '
+ 'wc -l')
+ elif store_type == storage_lib.StoreType.R2:
+ endpoint_url = cloudflare.create_endpoint()
+ return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls s3://{bucket_name} --recursive --endpoint {endpoint_url} --profile=r2 | wc -l'
+
+ @pytest.fixture
+ def tmp_source(self, tmp_path):
+ # Creates a temporary directory with a file in it
+ tmp_dir = tmp_path / 'tmp-source'
+ tmp_dir.mkdir()
+ tmp_file = tmp_dir / 'tmp-file'
+ tmp_file.write_text('test')
+ circle_link = tmp_dir / 'circle-link'
+ circle_link.symlink_to(tmp_dir, target_is_directory=True)
+ yield str(tmp_dir)
+
+ @pytest.fixture
+ def tmp_sub_path(self):
+ tmp_dir1 = uuid.uuid4().hex[:8]
+ tmp_dir2 = uuid.uuid4().hex[:8]
+ yield "/".join([tmp_dir1, tmp_dir2])
+
+ @staticmethod
+ def generate_bucket_name():
+ # Creates a temporary bucket name
+ # time.time() returns varying precision on different systems, so we
+ # replace the decimal point and use whatever precision we can get.
+ timestamp = str(time.time()).replace('.', '')
+ return f'sky-test-{timestamp}'
+
+ @pytest.fixture
+ def tmp_bucket_name(self):
+ yield self.generate_bucket_name()
+
+ @staticmethod
+ def yield_storage_object(
+ name: Optional[str] = None,
+ source: Optional[storage_lib.Path] = None,
+ stores: Optional[Dict[storage_lib.StoreType,
+ storage_lib.AbstractStore]] = None,
+ persistent: Optional[bool] = True,
+ mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT,
+ _bucket_sub_path: Optional[str] = None):
+ # Creates a temporary storage object. Stores must be added in the test.
+ storage_obj = storage_lib.Storage(name=name,
+ source=source,
+ stores=stores,
+ persistent=persistent,
+ mode=mode,
+ _bucket_sub_path=_bucket_sub_path)
+ storage_obj.construct()
+ yield storage_obj
+ handle = global_user_state.get_handle_from_storage_name(
+ storage_obj.name)
+ if handle:
+ # If handle exists, delete manually
+ # TODO(romilb): This is potentially risky - if the delete method has
+ # bugs, this can cause resource leaks. Ideally we should manually
+ # eject storage from global_user_state and delete the bucket using
+ # boto3 directly.
+ storage_obj.delete()
+
+ @pytest.fixture
+ def tmp_scratch_storage_obj(self, tmp_bucket_name):
+ # Creates a storage object with no source to create a scratch storage.
+ # Stores must be added in the test.
+ yield from self.yield_storage_object(name=tmp_bucket_name)
+
+ @pytest.fixture
+ def tmp_multiple_scratch_storage_obj(self):
+ # Creates a list of 5 storage objects with no source to create
+ # multiple scratch storages.
+ # Stores for each object in the list must be added in the test.
+ storage_mult_obj = []
+ for _ in range(5):
+ timestamp = str(time.time()).replace('.', '')
+ store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}')
+ store_obj.construct()
+ storage_mult_obj.append(store_obj)
+ yield storage_mult_obj
+ for storage_obj in storage_mult_obj:
+ handle = global_user_state.get_handle_from_storage_name(
+ storage_obj.name)
+ if handle:
+ # If handle exists, delete manually
+ # TODO(romilb): This is potentially risky - if the delete method has
+ # bugs, this can cause resource leaks. Ideally we should manually
+ # eject storage from global_user_state and delete the bucket using
+ # boto3 directly.
+ storage_obj.delete()
+
+ @pytest.fixture
+ def tmp_multiple_custom_source_storage_obj(self):
+ # Creates a list of storage objects with custom source names to
+ # create multiple scratch storages.
+ # Stores for each object in the list must be added in the test.
+ custom_source_names = ['"path With Spaces"', 'path With Spaces']
+ storage_mult_obj = []
+ for name in custom_source_names:
+ src_path = os.path.expanduser(f'~/{name}')
+ pathlib.Path(src_path).expanduser().mkdir(exist_ok=True)
+ timestamp = str(time.time()).replace('.', '')
+ store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}',
+ source=src_path)
+ store_obj.construct()
+ storage_mult_obj.append(store_obj)
+ yield storage_mult_obj
+ for storage_obj in storage_mult_obj:
+ handle = global_user_state.get_handle_from_storage_name(
+ storage_obj.name)
+ if handle:
+ storage_obj.delete()
+
+ @pytest.fixture
+ def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source):
+ # Creates a temporary storage object. Stores must be added in the test.
+ yield from self.yield_storage_object(name=tmp_bucket_name,
+ source=tmp_source)
+
+ @pytest.fixture
+ def tmp_local_storage_obj_with_sub_path(self, tmp_bucket_name, tmp_source,
+ tmp_sub_path):
+ # Creates a temporary storage object with sub. Stores must be added in the test.
+ list_source = [tmp_source, tmp_source + '/tmp-file']
+ yield from self.yield_storage_object(name=tmp_bucket_name,
+ source=list_source,
+ _bucket_sub_path=tmp_sub_path)
+
+ @pytest.fixture
+ def tmp_local_list_storage_obj(self, tmp_bucket_name, tmp_source):
+ # Creates a temp storage object which uses a list of paths as source.
+ # Stores must be added in the test. After upload, the bucket should
+ # have two files - /tmp-file and /tmp-source/tmp-file
+ list_source = [tmp_source, tmp_source + '/tmp-file']
+ yield from self.yield_storage_object(name=tmp_bucket_name,
+ source=list_source)
+
+ @pytest.fixture
+ def tmp_bulk_del_storage_obj(self, tmp_bucket_name):
+ # Creates a temporary storage object for testing bulk deletion.
+ # Stores must be added in the test.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ subprocess.check_output(f'mkdir -p {tmpdir}/folder{{000..255}}',
+ shell=True)
+ subprocess.check_output(f'touch {tmpdir}/test{{000..255}}.txt',
+ shell=True)
+ subprocess.check_output(
+ f'touch {tmpdir}/folder{{000..255}}/test.txt', shell=True)
+ yield from self.yield_storage_object(name=tmp_bucket_name,
+ source=tmpdir)
+
+ @pytest.fixture
+ def tmp_copy_mnt_existing_storage_obj(self, tmp_scratch_storage_obj):
+ # Creates a copy mount storage which reuses an existing storage object.
+ tmp_scratch_storage_obj.add_store(storage_lib.StoreType.S3)
+ storage_name = tmp_scratch_storage_obj.name
+
+ # Try to initialize another storage with the storage object created
+ # above, but now in COPY mode. This should succeed.
+ yield from self.yield_storage_object(name=storage_name,
+ mode=storage_lib.StorageMode.COPY)
+
+ @pytest.fixture
+ def tmp_gitignore_storage_obj(self, tmp_bucket_name, gitignore_structure):
+ # Creates a temporary storage object for testing .gitignore filter.
+ # GITIGINORE_STRUCTURE is representing a file structure in a dictionary
+ # format. Created storage object will contain the file structure along
+ # with .gitignore and .git/info/exclude files to test exclude filter.
+ # Stores must be added in the test.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Creates file structure to be uploaded in the Storage
+ self.create_dir_structure(tmpdir, gitignore_structure)
+
+ # Create .gitignore and list files/dirs to be excluded in it
+ skypilot_path = os.path.dirname(os.path.dirname(sky.__file__))
+ temp_path = f'{tmpdir}/.gitignore'
+ file_path = os.path.join(skypilot_path, 'tests/gitignore_test')
+ shutil.copyfile(file_path, temp_path)
+
+ # Create .git/info/exclude and list files/dirs to be excluded in it
+ temp_path = f'{tmpdir}/.git/info/'
+ os.makedirs(temp_path)
+ temp_exclude_path = os.path.join(temp_path, 'exclude')
+ file_path = os.path.join(skypilot_path,
+ 'tests/git_info_exclude_test')
+ shutil.copyfile(file_path, temp_exclude_path)
+
+ # Create sky Storage with the files created
+ yield from self.yield_storage_object(
+ name=tmp_bucket_name,
+ source=tmpdir,
+ mode=storage_lib.StorageMode.COPY)
+
+ @pytest.fixture
+ def tmp_awscli_bucket(self, tmp_bucket_name):
+ # Creates a temporary bucket using awscli
+ bucket_uri = f's3://{tmp_bucket_name}'
+ subprocess.check_call(['aws', 's3', 'mb', bucket_uri])
+ yield tmp_bucket_name, bucket_uri
+ subprocess.check_call(['aws', 's3', 'rb', bucket_uri, '--force'])
+
+ @pytest.fixture
+ def tmp_gsutil_bucket(self, tmp_bucket_name):
+ # Creates a temporary bucket using gsutil
+ bucket_uri = f'gs://{tmp_bucket_name}'
+ subprocess.check_call(['gsutil', 'mb', bucket_uri])
+ yield tmp_bucket_name, bucket_uri
+ subprocess.check_call(['gsutil', 'rm', '-r', bucket_uri])
+
+ @pytest.fixture
+ def tmp_az_bucket(self, tmp_bucket_name):
+ # Creates a temporary bucket using gsutil
+ default_region = 'eastus'
+ storage_account_name = (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(
+ default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ bucket_uri = data_utils.AZURE_CONTAINER_URL.format(
+ storage_account_name=storage_account_name,
+ container_name=tmp_bucket_name)
+ subprocess.check_call([
+ 'az', 'storage', 'container', 'create', '--name',
+ f'{tmp_bucket_name}', '--account-name', f'{storage_account_name}',
+ '--account-key', f'{storage_account_key}'
+ ])
+ yield tmp_bucket_name, bucket_uri
+ subprocess.check_call([
+ 'az', 'storage', 'container', 'delete', '--name',
+ f'{tmp_bucket_name}', '--account-name', f'{storage_account_name}',
+ '--account-key', f'{storage_account_key}'
+ ])
+
+ @pytest.fixture
+ def tmp_awscli_bucket_r2(self, tmp_bucket_name):
+ # Creates a temporary bucket using awscli
+ endpoint_url = cloudflare.create_endpoint()
+ bucket_uri = f's3://{tmp_bucket_name}'
+ subprocess.check_call(
+ f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 mb {bucket_uri} --endpoint {endpoint_url} --profile=r2',
+ shell=True)
+ yield tmp_bucket_name, bucket_uri
+ subprocess.check_call(
+ f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 rb {bucket_uri} --force --endpoint {endpoint_url} --profile=r2',
+ shell=True)
+
+ @pytest.fixture
+ def tmp_ibm_cos_bucket(self, tmp_bucket_name):
+ # Creates a temporary bucket using IBM COS API
+ storage_obj = storage_lib.IBMCosStore(source="", name=tmp_bucket_name)
+ yield tmp_bucket_name
+ storage_obj.delete()
+
+ @pytest.fixture
+ def tmp_public_storage_obj(self, request):
+ # Initializes a storage object with a public bucket
+ storage_obj = storage_lib.Storage(source=request.param)
+ storage_obj.construct()
+ yield storage_obj
+ # This does not require any deletion logic because it is a public bucket
+ # and should not get added to global_user_state.
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('store_type', [
+ storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
+ ])
+ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj,
+ store_type):
+ # Creates a new bucket with a local source, uploads files to it
+ # and deletes it.
+ tmp_local_storage_obj.add_store(store_type)
+
+ # Run sky storage ls to check if storage object exists in the output
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_local_storage_obj.name in out.decode('utf-8')
+
+ # Run sky storage delete to delete the storage object
+ subprocess.check_output(
+ ['sky', 'storage', 'delete', tmp_local_storage_obj.name, '--yes'])
+
+ # Run sky storage ls to check if storage object is deleted
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_local_storage_obj.name not in out.decode('utf-8')
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('store_type', [
+ pytest.param(storage_lib.StoreType.S3, marks=pytest.mark.aws),
+ pytest.param(storage_lib.StoreType.GCS, marks=pytest.mark.gcp),
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
+ ])
+ def test_bucket_sub_path(self, tmp_local_storage_obj_with_sub_path,
+ store_type):
+ # Creates a new bucket with a local source, uploads files to it
+ # and deletes it.
+ tmp_local_storage_obj_with_sub_path.add_store(store_type)
+
+ # Check files under bucket and filter by prefix
+ files = self.list_all_files(store_type,
+ tmp_local_storage_obj_with_sub_path.name)
+ assert len(files) > 0
+ if store_type == storage_lib.StoreType.GCS:
+ assert all([
+ file.startswith(
+ tmp_local_storage_obj_with_sub_path.name + '/' +
+ tmp_local_storage_obj_with_sub_path._bucket_sub_path)
+ for file in files
+ ])
+ else:
+ assert all([
+ file.startswith(
+ tmp_local_storage_obj_with_sub_path._bucket_sub_path)
+ for file in files
+ ])
+
+ # Check bucket is empty, all files under sub directory should be deleted
+ store = tmp_local_storage_obj_with_sub_path.stores[store_type]
+ store.is_sky_managed = False
+ if store_type == storage_lib.StoreType.AZURE:
+ azure.assign_storage_account_iam_role(
+ storage_account_name=store.storage_account_name,
+ resource_group_name=store.resource_group_name)
+ store.delete()
+ files = self.list_all_files(store_type,
+ tmp_local_storage_obj_with_sub_path.name)
+ assert len(files) == 0
+
+ # Now, delete the entire bucket
+ store.is_sky_managed = True
+ tmp_local_storage_obj_with_sub_path.delete()
+
+ # Run sky storage ls to check if storage object is deleted
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_local_storage_obj_with_sub_path.name not in out.decode(
+ 'utf-8')
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.xdist_group('multiple_bucket_deletion')
+ @pytest.mark.parametrize('store_type', [
+ storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm)
+ ])
+ def test_multiple_buckets_creation_and_deletion(
+ self, tmp_multiple_scratch_storage_obj, store_type):
+ # Creates multiple new buckets(5 buckets) with a local source
+ # and deletes them.
+ storage_obj_name = []
+ for store_obj in tmp_multiple_scratch_storage_obj:
+ store_obj.add_store(store_type)
+ storage_obj_name.append(store_obj.name)
+
+ # Run sky storage ls to check if all storage objects exists in the
+ # output filtered by store type
+ out_all = subprocess.check_output(['sky', 'storage', 'ls'])
+ out = [
+ item.split()[0]
+ for item in out_all.decode('utf-8').splitlines()
+ if store_type.value in item
+ ]
+ assert all([item in out for item in storage_obj_name])
+
+ # Run sky storage delete all to delete all storage objects
+ delete_cmd = ['sky', 'storage', 'delete', '--yes']
+ delete_cmd += storage_obj_name
+ subprocess.check_output(delete_cmd)
+
+ # Run sky storage ls to check if all storage objects filtered by store
+ # type are deleted
+ out_all = subprocess.check_output(['sky', 'storage', 'ls'])
+ out = [
+ item.split()[0]
+ for item in out_all.decode('utf-8').splitlines()
+ if store_type.value in item
+ ]
+ assert all([item not in out for item in storage_obj_name])
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('store_type', [
+ storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
+ ])
+ def test_upload_source_with_spaces(self, store_type,
+ tmp_multiple_custom_source_storage_obj):
+ # Creates two buckets with specified local sources
+ # with spaces in the name
+ storage_obj_names = []
+ for storage_obj in tmp_multiple_custom_source_storage_obj:
+ storage_obj.add_store(store_type)
+ storage_obj_names.append(storage_obj.name)
+
+ # Run sky storage ls to check if all storage objects exists in the
+ # output filtered by store type
+ out_all = subprocess.check_output(['sky', 'storage', 'ls'])
+ out = [
+ item.split()[0]
+ for item in out_all.decode('utf-8').splitlines()
+ if store_type.value in item
+ ]
+ assert all([item in out for item in storage_obj_names])
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('store_type', [
+ storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
+ ])
+ def test_bucket_external_deletion(self, tmp_scratch_storage_obj,
+ store_type):
+ # Creates a bucket, deletes it externally using cloud cli commands
+ # and then tries to delete it using sky storage delete.
+ tmp_scratch_storage_obj.add_store(store_type)
+
+ # Run sky storage ls to check if storage object exists in the output
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_scratch_storage_obj.name in out.decode('utf-8')
+
+ # Delete bucket externally
+ cmd = self.cli_delete_cmd(store_type, tmp_scratch_storage_obj.name)
+ subprocess.check_output(cmd, shell=True)
+
+ # Run sky storage delete to delete the storage object
+ out = subprocess.check_output(
+ ['sky', 'storage', 'delete', tmp_scratch_storage_obj.name, '--yes'])
+ # Make sure bucket was not created during deletion (see issue #1322)
+ assert 'created' not in out.decode('utf-8').lower()
+
+ # Run sky storage ls to check if storage object is deleted
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_scratch_storage_obj.name not in out.decode('utf-8')
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('store_type', [
+ storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
+ ])
+ def test_bucket_bulk_deletion(self, store_type, tmp_bulk_del_storage_obj):
+ # Creates a temp folder with over 256 files and folders, upload
+ # files and folders to a new bucket, then delete bucket.
+ tmp_bulk_del_storage_obj.add_store(store_type)
+
+ subprocess.check_output([
+ 'sky', 'storage', 'delete', tmp_bulk_del_storage_obj.name, '--yes'
+ ])
+
+ output = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_bulk_del_storage_obj.name not in output.decode('utf-8')
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize(
+ 'tmp_public_storage_obj, store_type',
+ [('s3://tcga-2-open', storage_lib.StoreType.S3),
+ ('s3://digitalcorpora', storage_lib.StoreType.S3),
+ ('gs://gcp-public-data-sentinel-2', storage_lib.StoreType.GCS),
+ pytest.param(
+ 'https://azureopendatastorage.blob.core.windows.net/nyctlc',
+ storage_lib.StoreType.AZURE,
+ marks=pytest.mark.azure)],
+ indirect=['tmp_public_storage_obj'])
+ def test_public_bucket(self, tmp_public_storage_obj, store_type):
+ # Creates a new bucket with a public source and verifies that it is not
+ # added to global_user_state.
+ tmp_public_storage_obj.add_store(store_type)
+
+ # Run sky storage ls to check if storage object exists in the output
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert tmp_public_storage_obj.name not in out.decode('utf-8')
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize(
+ 'nonexist_bucket_url',
+ [
+ 's3://{random_name}',
+ 'gs://{random_name}',
+ pytest.param(
+ 'https://{account_name}.blob.core.windows.net/{random_name}', # pylint: disable=line-too-long
+ marks=pytest.mark.azure),
+ pytest.param('cos://us-east/{random_name}', marks=pytest.mark.ibm),
+ pytest.param('r2://{random_name}', marks=pytest.mark.cloudflare)
+ ])
+ def test_nonexistent_bucket(self, nonexist_bucket_url):
+ # Attempts to create fetch a stroage with a non-existent source.
+ # Generate a random bucket name and verify it doesn't exist:
+ retry_count = 0
+ while True:
+ nonexist_bucket_name = str(uuid.uuid4())
+ if nonexist_bucket_url.startswith('s3'):
+ command = f'aws s3api head-bucket --bucket {nonexist_bucket_name}'
+ expected_output = '404'
+ elif nonexist_bucket_url.startswith('gs'):
+ command = f'gsutil ls {nonexist_bucket_url.format(random_name=nonexist_bucket_name)}'
+ expected_output = 'BucketNotFoundException'
+ elif nonexist_bucket_url.startswith('https'):
+ default_region = 'eastus'
+ storage_account_name = (
+ storage_lib.AzureBlobStore.get_default_storage_account_name(
+ default_region))
+ storage_account_key = data_utils.get_az_storage_account_key(
+ storage_account_name)
+ command = f'az storage container exists --account-name {storage_account_name} --account-key {storage_account_key} --name {nonexist_bucket_name}'
+ expected_output = '"exists": false'
+ elif nonexist_bucket_url.startswith('r2'):
+ endpoint_url = cloudflare.create_endpoint()
+ command = f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3api head-bucket --bucket {nonexist_bucket_name} --endpoint {endpoint_url} --profile=r2'
+ expected_output = '404'
+ elif nonexist_bucket_url.startswith('cos'):
+ # Using API calls, since using rclone requires a profile's name
+ try:
+ expected_output = command = "echo" # avoid unrelated exception in case of failure.
+ bucket_name = urllib.parse.urlsplit(
+ nonexist_bucket_url.format(
+ random_name=nonexist_bucket_name)).path.strip('/')
+ client = ibm.get_cos_client('us-east')
+ client.head_bucket(Bucket=bucket_name)
+ except ibm.ibm_botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] == '404':
+ # success
+ return
+ else:
+ raise ValueError('Unsupported bucket type '
+ f'{nonexist_bucket_url}')
+
+ # Check if bucket exists using the cli:
+ try:
+ out = subprocess.check_output(command,
+ stderr=subprocess.STDOUT,
+ shell=True)
+ except subprocess.CalledProcessError as e:
+ out = e.output
+ out = out.decode('utf-8')
+ if expected_output in out:
+ break
+ else:
+ retry_count += 1
+ if retry_count > 3:
+ raise RuntimeError('Unable to find a nonexistent bucket '
+ 'to use. This is higly unlikely - '
+ 'check if the tests are correct.')
+
+ with pytest.raises(sky.exceptions.StorageBucketGetError,
+ match='Attempted to use a non-existent'):
+ if nonexist_bucket_url.startswith('https'):
+ storage_obj = storage_lib.Storage(
+ source=nonexist_bucket_url.format(
+ account_name=storage_account_name,
+ random_name=nonexist_bucket_name))
+ else:
+ storage_obj = storage_lib.Storage(
+ source=nonexist_bucket_url.format(
+ random_name=nonexist_bucket_name))
+ storage_obj.construct()
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize(
+ 'private_bucket',
+ [
+ f's3://imagenet',
+ f'gs://imagenet',
+ pytest.param('https://smoketestprivate.blob.core.windows.net/test',
+ marks=pytest.mark.azure), # pylint: disable=line-too-long
+ pytest.param('cos://us-east/bucket1', marks=pytest.mark.ibm)
+ ])
+ def test_private_bucket(self, private_bucket):
+ # Attempts to access private buckets not belonging to the user.
+ # These buckets are known to be private, but may need to be updated if
+ # they are removed by their owners.
+ store_type = urllib.parse.urlsplit(private_bucket).scheme
+ if store_type == 'https' or store_type == 'cos':
+ private_bucket_name = urllib.parse.urlsplit(
+ private_bucket).path.strip('/')
+ else:
+ private_bucket_name = urllib.parse.urlsplit(private_bucket).netloc
+ with pytest.raises(
+ sky.exceptions.StorageBucketGetError,
+ match=storage_lib._BUCKET_FAIL_TO_CONNECT_MESSAGE.format(
+ name=private_bucket_name)):
+ storage_obj = storage_lib.Storage(source=private_bucket)
+ storage_obj.construct()
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('ext_bucket_fixture, store_type',
+ [('tmp_awscli_bucket', storage_lib.StoreType.S3),
+ ('tmp_gsutil_bucket', storage_lib.StoreType.GCS),
+ pytest.param('tmp_az_bucket',
+ storage_lib.StoreType.AZURE,
+ marks=pytest.mark.azure),
+ pytest.param('tmp_ibm_cos_bucket',
+ storage_lib.StoreType.IBM,
+ marks=pytest.mark.ibm),
+ pytest.param('tmp_awscli_bucket_r2',
+ storage_lib.StoreType.R2,
+ marks=pytest.mark.cloudflare)])
+ def test_upload_to_existing_bucket(self, ext_bucket_fixture, request,
+ tmp_source, store_type):
+ # Tries uploading existing files to newly created bucket (outside of
+ # sky) and verifies that files are written.
+ bucket_name, _ = request.getfixturevalue(ext_bucket_fixture)
+ storage_obj = storage_lib.Storage(name=bucket_name, source=tmp_source)
+ storage_obj.construct()
+ storage_obj.add_store(store_type)
+
+ # Check if tmp_source/tmp-file exists in the bucket using aws cli
+ out = subprocess.check_output(self.cli_ls_cmd(store_type, bucket_name),
+ shell=True)
+ assert 'tmp-file' in out.decode('utf-8'), \
+ 'File not found in bucket - output was : {}'.format(out.decode
+ ('utf-8'))
+
+ # Check symlinks - symlinks don't get copied by sky storage
+ assert (pathlib.Path(tmp_source) / 'circle-link').is_symlink(), (
+ 'circle-link was not found in the upload source - '
+ 'are the test fixtures correct?')
+ assert 'circle-link' not in out.decode('utf-8'), (
+ 'Symlink found in bucket - ls output was : {}'.format(
+ out.decode('utf-8')))
+
+ # Run sky storage ls to check if storage object exists in the output.
+ # It should not exist because the bucket was created externally.
+ out = subprocess.check_output(['sky', 'storage', 'ls'])
+ assert storage_obj.name not in out.decode('utf-8')
+
+ @pytest.mark.no_fluidstack
+ def test_copy_mount_existing_storage(self,
+ tmp_copy_mnt_existing_storage_obj):
+ # Creates a bucket with no source in MOUNT mode (empty bucket), and
+ # then tries to load the same storage in COPY mode.
+ tmp_copy_mnt_existing_storage_obj.add_store(storage_lib.StoreType.S3)
+ storage_name = tmp_copy_mnt_existing_storage_obj.name
+
+ # Check `sky storage ls` to ensure storage object exists
+ out = subprocess.check_output(['sky', 'storage', 'ls']).decode('utf-8')
+ assert storage_name in out, f'Storage {storage_name} not found in sky storage ls.'
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('store_type', [
+ storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
+ pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
+ pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
+ pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
+ ])
+ def test_list_source(self, tmp_local_list_storage_obj, store_type):
+ # Uses a list in the source field to specify a file and a directory to
+ # be uploaded to the storage object.
+ tmp_local_list_storage_obj.add_store(store_type)
+
+ # Check if tmp-file exists in the bucket root using cli
+ out = subprocess.check_output(self.cli_ls_cmd(
+ store_type, tmp_local_list_storage_obj.name),
+ shell=True)
+ assert 'tmp-file' in out.decode('utf-8'), \
+ 'File not found in bucket - output was : {}'.format(out.decode
+ ('utf-8'))
+
+ # Check if tmp-file exists in the bucket/tmp-source using cli
+ out = subprocess.check_output(self.cli_ls_cmd(
+ store_type, tmp_local_list_storage_obj.name, 'tmp-source/'),
+ shell=True)
+ assert 'tmp-file' in out.decode('utf-8'), \
+ 'File not found in bucket - output was : {}'.format(out.decode
+ ('utf-8'))
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('invalid_name_list, store_type',
+ [(AWS_INVALID_NAMES, storage_lib.StoreType.S3),
+ (GCS_INVALID_NAMES, storage_lib.StoreType.GCS),
+ pytest.param(AZURE_INVALID_NAMES,
+ storage_lib.StoreType.AZURE,
+ marks=pytest.mark.azure),
+ pytest.param(IBM_INVALID_NAMES,
+ storage_lib.StoreType.IBM,
+ marks=pytest.mark.ibm),
+ pytest.param(AWS_INVALID_NAMES,
+ storage_lib.StoreType.R2,
+ marks=pytest.mark.cloudflare)])
+ def test_invalid_names(self, invalid_name_list, store_type):
+ # Uses a list in the source field to specify a file and a directory to
+ # be uploaded to the storage object.
+ for name in invalid_name_list:
+ with pytest.raises(sky.exceptions.StorageNameError):
+ storage_obj = storage_lib.Storage(name=name)
+ storage_obj.construct()
+ storage_obj.add_store(store_type)
+
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize(
+ 'gitignore_structure, store_type',
+ [(GITIGNORE_SYNC_TEST_DIR_STRUCTURE, storage_lib.StoreType.S3),
+ (GITIGNORE_SYNC_TEST_DIR_STRUCTURE, storage_lib.StoreType.GCS),
+ (GITIGNORE_SYNC_TEST_DIR_STRUCTURE, storage_lib.StoreType.AZURE),
+ pytest.param(GITIGNORE_SYNC_TEST_DIR_STRUCTURE,
+ storage_lib.StoreType.R2,
+ marks=pytest.mark.cloudflare)])
+ def test_excluded_file_cloud_storage_upload_copy(self, gitignore_structure,
+ store_type,
+ tmp_gitignore_storage_obj):
+ # tests if files included in .gitignore and .git/info/exclude are
+ # excluded from being transferred to Storage
+
+ tmp_gitignore_storage_obj.add_store(store_type)
+
+ upload_file_name = 'included'
+ # Count the number of files with the given file name
+ up_cmd = self.cli_count_name_in_bucket(store_type, \
+ tmp_gitignore_storage_obj.name, file_name=upload_file_name)
+ git_exclude_cmd = self.cli_count_name_in_bucket(store_type, \
+ tmp_gitignore_storage_obj.name, file_name='.git')
+ cnt_num_file_cmd = self.cli_count_file_in_bucket(
+ store_type, tmp_gitignore_storage_obj.name)
+
+ up_output = subprocess.check_output(up_cmd, shell=True)
+ git_exclude_output = subprocess.check_output(git_exclude_cmd,
+ shell=True)
+ cnt_output = subprocess.check_output(cnt_num_file_cmd, shell=True)
+
+ assert '3' in up_output.decode('utf-8'), \
+ 'Files to be included are not completely uploaded.'
+ # 1 is read as .gitignore is uploaded
+ assert '1' in git_exclude_output.decode('utf-8'), \
+ '.git directory should not be uploaded.'
+ # 4 files include .gitignore, included.log, included.txt, include_dir/included.log
+ assert '4' in cnt_output.decode('utf-8'), \
+ 'Some items listed in .gitignore and .git/info/exclude are not excluded.'
+
+ @pytest.mark.parametrize('ext_bucket_fixture, store_type',
+ [('tmp_awscli_bucket', storage_lib.StoreType.S3),
+ ('tmp_gsutil_bucket', storage_lib.StoreType.GCS),
+ pytest.param('tmp_awscli_bucket_r2',
+ storage_lib.StoreType.R2,
+ marks=pytest.mark.cloudflare)])
+ def test_externally_created_bucket_mount_without_source(
+ self, ext_bucket_fixture, request, store_type):
+ # Non-sky managed buckets(buckets created outside of Skypilot CLI)
+ # are allowed to be MOUNTed by specifying the URI of the bucket to
+ # source field only. When it is attempted by specifying the name of
+ # the bucket only, it should error out.
+ #
+ # TODO(doyoung): Add test for IBM COS. Currently, this is blocked
+ # as rclone used to interact with IBM COS does not support feature to
+ # create a bucket, and the ibmcloud CLI is not supported in Skypilot.
+ # Either of the feature is necessary to simulate an external bucket
+ # creation for IBM COS.
+ # https://github.com/skypilot-org/skypilot/pull/1966/files#r1253439837
+
+ ext_bucket_name, ext_bucket_uri = request.getfixturevalue(
+ ext_bucket_fixture)
+ # invalid spec
+ with pytest.raises(sky.exceptions.StorageSpecError) as e:
+ storage_obj = storage_lib.Storage(
+ name=ext_bucket_name, mode=storage_lib.StorageMode.MOUNT)
+ storage_obj.construct()
+ storage_obj.add_store(store_type)
+
+ assert 'Attempted to mount a non-sky managed bucket' in str(e)
+
+ # valid spec
+ storage_obj = storage_lib.Storage(source=ext_bucket_uri,
+ mode=storage_lib.StorageMode.MOUNT)
+ storage_obj.construct()
+ handle = global_user_state.get_handle_from_storage_name(
+ storage_obj.name)
+ if handle:
+ storage_obj.delete()
+
+ @pytest.mark.aws
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('region', [
+ 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-south-1',
+ 'ap-southeast-1', 'ap-southeast-2', 'eu-central-1', 'eu-north-1',
+ 'eu-west-1', 'eu-west-2', 'eu-west-3', 'sa-east-1', 'us-east-1',
+ 'us-east-2', 'us-west-1', 'us-west-2'
+ ])
+ def test_aws_regions(self, tmp_local_storage_obj, region):
+ # This tests creation and upload to bucket in all AWS s3 regions
+ # To test full functionality, use test_managed_jobs_storage above.
+ store_type = storage_lib.StoreType.S3
+ tmp_local_storage_obj.add_store(store_type, region=region)
+ bucket_name = tmp_local_storage_obj.name
+
+ # Confirm that the bucket was created in the correct region
+ region_cmd = self.cli_region_cmd(store_type, bucket_name=bucket_name)
+ out = subprocess.check_output(region_cmd, shell=True)
+ output = out.decode('utf-8')
+ expected_output_region = region
+ if region == 'us-east-1':
+ expected_output_region = 'None' # us-east-1 is the default region
+ assert expected_output_region in out.decode('utf-8'), (
+ f'Bucket was not found in region {region} - '
+ f'output of {region_cmd} was: {output}')
+
+ # Check if tmp_source/tmp-file exists in the bucket using cli
+ ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
+ out = subprocess.check_output(ls_cmd, shell=True)
+ output = out.decode('utf-8')
+ assert 'tmp-file' in output, (
+ f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')
+
+ @pytest.mark.gcp
+ @pytest.mark.no_fluidstack
+ @pytest.mark.parametrize('region', [
+ 'northamerica-northeast1', 'northamerica-northeast2', 'us-central1',
+ 'us-east1', 'us-east4', 'us-east5', 'us-south1', 'us-west1', 'us-west2',
+ 'us-west3', 'us-west4', 'southamerica-east1', 'southamerica-west1',
+ 'europe-central2', 'europe-north1', 'europe-southwest1', 'europe-west1',
+ 'europe-west2', 'europe-west3', 'europe-west4', 'europe-west6',
+ 'europe-west8', 'europe-west9', 'europe-west10', 'europe-west12',
+ 'asia-east1', 'asia-east2', 'asia-northeast1', 'asia-northeast2',
+ 'asia-northeast3', 'asia-southeast1', 'asia-south1', 'asia-south2',
+ 'asia-southeast2', 'me-central1', 'me-central2', 'me-west1',
+ 'australia-southeast1', 'australia-southeast2', 'africa-south1'
+ ])
+ def test_gcs_regions(self, tmp_local_storage_obj, region):
+ # This tests creation and upload to bucket in all GCS regions
+ # To test full functionality, use test_managed_jobs_storage above.
+ store_type = storage_lib.StoreType.GCS
+ tmp_local_storage_obj.add_store(store_type, region=region)
+ bucket_name = tmp_local_storage_obj.name
+
+ # Confirm that the bucket was created in the correct region
+ region_cmd = self.cli_region_cmd(store_type, bucket_name=bucket_name)
+ out = subprocess.check_output(region_cmd, shell=True)
+ output = out.decode('utf-8')
+ assert region in out.decode('utf-8'), (
+ f'Bucket was not found in region {region} - '
+ f'output of {region_cmd} was: {output}')
+
+ # Check if tmp_source/tmp-file exists in the bucket using cli
+ ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
+ out = subprocess.check_output(ls_cmd, shell=True)
+ output = out.decode('utf-8')
+ assert 'tmp-file' in output, (
+ f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')
diff --git a/tests/smoke_tests/test_quick_tests_core.py b/tests/smoke_tests/test_quick_tests_core.py
new file mode 100644
index 00000000000..48df4ef9a2b
--- /dev/null
+++ b/tests/smoke_tests/test_quick_tests_core.py
@@ -0,0 +1,47 @@
+# Smoke tests for SkyPilot required before merging
+# If the change includes an interface modification or touches the core API,
+# the reviewer could decide it’s necessary to trigger a pre-merge test and
+# leave a comment /quicktest-core will then trigger this test.
+#
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_quick_tests_core.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_quick_tests_core.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_quick_tests_core.py::test_yaml_launch_and_mount
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_quick_tests_core.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_quick_tests_core.py --generic-cloud aws
+
+from smoke_tests import smoke_tests_utils
+
+import sky
+
+
+def test_yaml_launch_and_mount(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'test_yaml_launch_and_mount',
+ [
+ f'sky launch -y -c {name} tests/test_yamls/minimal_test_quick_tests_core.yaml',
+ smoke_tests_utils.
+ get_cmd_wait_until_job_status_contains_matching_job_id(
+ cluster_name=name,
+ job_id=1,
+ job_status=[sky.JobStatus.SUCCEEDED],
+ timeout=2 * 60),
+ ],
+ f'sky down -y {name}',
+ timeout=5 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_region_and_zone.py b/tests/smoke_tests/test_region_and_zone.py
new file mode 100644
index 00000000000..617088f24f3
--- /dev/null
+++ b/tests/smoke_tests/test_region_and_zone.py
@@ -0,0 +1,220 @@
+# Smoke tests for SkyPilot for reg
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_region_and_zone.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_region_and_zone.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_region_and_zone.py::test_aws_region
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_region_and_zone.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_region_and_zone.py --generic-cloud aws
+
+import tempfile
+import textwrap
+
+import pytest
+from smoke_tests import smoke_tests_utils
+
+import sky
+from sky import skypilot_config
+from sky.skylet import constants
+
+
+# ---------- Test region ----------
+@pytest.mark.aws
+def test_aws_region():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_region',
+ [
+ f'sky launch -y -c {name} --region us-east-2 examples/minimal.yaml',
+ f'sky exec {name} examples/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep us-east-2', # Ensure the region is correct.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .region | grep us-east-2\'',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ # A user program should not access SkyPilot runtime env python by default.
+ f'sky exec {name} \'which python | grep {constants.SKY_REMOTE_PYTHON_ENV_NAME} && exit 1 || true\'',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO(zhwu): Fix this test -- the jump-{name} cluster cannot be found by ssh,
+# when API server is hosted remotely, since we don't add jump-{name} to the ssh
+# config on the API server.
+@pytest.mark.aws
+def test_aws_with_ssh_proxy_command():
+ name = smoke_tests_utils.get_cluster_name()
+ api_server = skypilot_config.get_nested(('api_server', 'endpoint'), None)
+ with tempfile.NamedTemporaryFile(mode='w') as f:
+ f.write(
+ textwrap.dedent(f"""\
+ aws:
+ ssh_proxy_command: ssh -W %h:%p -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null jump-{name}
+ """))
+ if api_server is not None:
+ f.write(
+ textwrap.dedent(f"""\
+ api_server:
+ endpoint: {api_server}
+ """))
+ f.flush()
+ test = smoke_tests_utils.Test(
+ 'aws_with_ssh_proxy_command',
+ [
+ f'sky launch -y -c jump-{name} --cloud aws --cpus 2 --region us-east-1',
+ # Use jump config
+ f'export SKYPILOT_CONFIG={f.name}; '
+ f'sky launch -y -c {name} --cloud aws --cpus 2 --region us-east-1 echo hi',
+ f'sky logs {name} 1 --status',
+ f'export SKYPILOT_CONFIG={f.name}; sky exec {name} echo hi',
+ f'sky logs {name} 2 --status',
+ # Start a small job to make sure the controller is created.
+ f'sky jobs launch -n {name}-0 --cloud aws --cpus 2 --use-spot -y echo hi',
+ # Wait other tests to create the job controller first, so that
+ # the job controller is not launched with proxy command.
+ smoke_tests_utils.
+ get_cmd_wait_until_cluster_status_contains_wildcard(
+ cluster_name_wildcard='sky-jobs-controller-*',
+ cluster_status=[sky.ClusterStatus.UP],
+ timeout=300),
+ f'export SKYPILOT_CONFIG={f.name}; sky jobs launch -n {name} --cpus 2 --cloud aws --region us-east-1 -yd echo hi',
+ smoke_tests_utils.
+ get_cmd_wait_until_managed_job_status_contains_matching_job_name(
+ job_name=name,
+ job_status=[
+ sky.ManagedJobStatus.SUCCEEDED,
+ sky.ManagedJobStatus.RUNNING,
+ sky.ManagedJobStatus.STARTING
+ ],
+ timeout=300),
+ ],
+ f'sky down -y {name} jump-{name}; sky jobs cancel -y -n {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_region_and_service_account():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_region',
+ [
+ f'sky launch -y -c {name} --region us-central1 --cloud gcp tests/test_yamls/minimal.yaml',
+ f'sky exec {name} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky exec {name} \'curl -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity?format=standard&audience=gcp"\'',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep us-central1', # Ensure the region is correct.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .region | grep us-central1\'',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ # A user program should not access SkyPilot runtime env python by default.
+ f'sky exec {name} \'which python | grep {constants.SKY_REMOTE_PYTHON_ENV_NAME} && exit 1 || true\'',
+ f'sky logs {name} 4 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.ibm
+def test_ibm_region():
+ name = smoke_tests_utils.get_cluster_name()
+ region = 'eu-de'
+ test = smoke_tests_utils.Test(
+ 'region',
+ [
+ f'sky launch -y -c {name} --cloud ibm --region {region} examples/minimal.yaml',
+ f'sky exec {name} --cloud ibm examples/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep {region}', # Ensure the region is correct.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+def test_azure_region():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'azure_region',
+ [
+ f'sky launch -y -c {name} --region eastus2 --cloud azure tests/test_yamls/minimal.yaml',
+ f'sky exec {name} tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep eastus2', # Ensure the region is correct.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .region | grep eastus2\'',
+ f'sky logs {name} 2 --status', # Ensure the job succeeded.
+ f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .zone | grep null\'',
+ f'sky logs {name} 3 --status', # Ensure the job succeeded.
+ # A user program should not access SkyPilot runtime env python by default.
+ f'sky exec {name} \'which python | grep {constants.SKY_REMOTE_PYTHON_ENV_NAME} && exit 1 || true\'',
+ f'sky logs {name} 4 --status', # Ensure the job succeeded.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# ---------- Test zone ----------
+@pytest.mark.aws
+def test_aws_zone():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'aws_zone',
+ [
+ f'sky launch -y -c {name} examples/minimal.yaml --zone us-east-2b',
+ f'sky exec {name} examples/minimal.yaml --zone us-east-2b',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep us-east-2b', # Ensure the zone is correct.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.ibm
+def test_ibm_zone():
+ name = smoke_tests_utils.get_cluster_name()
+ zone = 'eu-de-2'
+ test = smoke_tests_utils.Test(
+ 'zone',
+ [
+ f'sky launch -y -c {name} --cloud ibm examples/minimal.yaml --zone {zone}',
+ f'sky exec {name} --cloud ibm examples/minimal.yaml --zone {zone}',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep {zone}', # Ensure the zone is correct.
+ ],
+ f'sky down -y {name} {name}-2 {name}-3',
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+def test_gcp_zone():
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'gcp_zone',
+ [
+ f'sky launch -y -c {name} --zone us-central1-a --cloud gcp tests/test_yamls/minimal.yaml',
+ f'sky exec {name} --zone us-central1-a --cloud gcp tests/test_yamls/minimal.yaml',
+ f'sky logs {name} 1 --status', # Ensure the job succeeded.
+ f'sky status -v | grep {name} | grep us-central1-a', # Ensure the zone is correct.
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/smoke_tests/test_sky_serve.py b/tests/smoke_tests/test_sky_serve.py
new file mode 100644
index 00000000000..11375558e2e
--- /dev/null
+++ b/tests/smoke_tests/test_sky_serve.py
@@ -0,0 +1,800 @@
+# Smoke tests for SkyPilot for sky serve
+# Default options are set in pyproject.toml
+# Example usage:
+# Run all tests except for AWS and Lambda Cloud
+# > pytest tests/smoke_tests/test_sky_serve.py
+#
+# Terminate failed clusters after test finishes
+# > pytest tests/smoke_tests/test_sky_serve.py --terminate-on-failure
+#
+# Re-run last failed tests
+# > pytest --lf
+#
+# Run one of the smoke tests
+# > pytest tests/smoke_tests/test_sky_serve.py::test_skyserve_gcp_http
+#
+# Only run sky serve tests
+# > pytest tests/smoke_tests/test_sky_server.py --sky-serve
+#
+# Only run test for AWS + generic tests
+# > pytest tests/smoke_tests/test_sky_serve.py --aws
+#
+# Change cloud for generic tests to aws
+# > pytest tests/smoke_tests/test_sky_serve.py --generic-cloud aws
+
+import inspect
+import json
+import shlex
+from typing import Dict, List, Tuple
+
+import pytest
+from smoke_tests import smoke_tests_utils
+
+from sky import serve
+from sky.utils import common_utils
+
+# ---------- Testing skyserve ----------
+
+
+def _get_service_name() -> str:
+ """Returns a user-unique service name for each test_skyserve_().
+
+ Must be called from each test_skyserve_().
+ """
+ caller_func_name = inspect.stack()[1][3]
+ test_name = caller_func_name.replace('_', '-').replace('test-', 't-')
+ test_name = test_name.replace('skyserve-', 'ss-')
+ test_name = common_utils.make_cluster_name_on_cloud(test_name, 24)
+ return f'{test_name}-{smoke_tests_utils.test_id}'
+
+
+# We check the output of the skyserve service to see if it is ready. Output of
+# `REPLICAS` is in the form of `1/2` where the first number is the number of
+# ready replicas and the second number is the number of total replicas. We
+# grep such format to ensure that the service is ready, and early exit if any
+# failure detected. In the end we sleep for
+# serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS to make sure load balancer have
+# enough time to sync with the controller and get all ready replica IPs.
+_SERVE_WAIT_UNTIL_READY = (
+ '{{ while true; do'
+ ' s=$(sky serve status {name}); echo "$s";'
+ ' echo "$s" | grep -q "{replica_num}/{replica_num}" && break;'
+ ' echo "$s" | grep -q "FAILED" && exit 1;'
+ ' sleep 10;'
+ ' done; }}; echo "Got service status $s";'
+ f'sleep {serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS + 2};')
+_IP_REGEX = r'([0-9]{1,3}\.){3}[0-9]{1,3}'
+_AWK_ALL_LINES_BELOW_REPLICAS = r'/Replicas/{flag=1; next} flag'
+_SERVICE_LAUNCHING_STATUS_REGEX = 'PROVISIONING\|STARTING'
+# Since we don't allow terminate the service if the controller is INIT,
+# which is common for simultaneous pytest, we need to wait until the
+# controller is UP before we can terminate the service.
+# The teardown command has a 10-mins timeout, so we don't need to do
+# the timeout here. See implementation of run_one_test() for details.
+_TEARDOWN_SERVICE = (
+ '(for i in `seq 1 20`; do'
+ ' s=$(sky serve down -y {name});'
+ ' echo "Trying to terminate {name}";'
+ ' echo "$s";'
+ ' echo "$s" | grep -q "scheduled to be terminated\|No service to terminate" && break;'
+ ' sleep 10;'
+ ' [ $i -eq 20 ] && echo "Failed to terminate service {name}";'
+ 'done)')
+
+_SERVE_ENDPOINT_WAIT = (
+ 'export ORIGIN_SKYPILOT_DEBUG=$SKYPILOT_DEBUG; export SKYPILOT_DEBUG=0; '
+ 'endpoint=$(sky serve status --endpoint {name}); '
+ 'until ! echo "$endpoint" | grep "Controller is initializing"; '
+ 'do echo "Waiting for serve endpoint to be ready..."; '
+ 'sleep 5; endpoint=$(sky serve status --endpoint {name}); done; '
+ 'export SKYPILOT_DEBUG=$ORIGIN_SKYPILOT_DEBUG; echo "$endpoint"')
+
+_SERVE_STATUS_WAIT = ('s=$(sky serve status {name}); '
+ 'until ! echo "$s" | grep "Controller is initializing."; '
+ 'do echo "Waiting for serve status to be ready..."; '
+ 'sleep 5; s=$(sky serve status {name}); done; echo "$s"')
+
+
+def _get_replica_ip(name: str, replica_id: int) -> str:
+ return (f'ip{replica_id}=$(echo "$s" | '
+ f'awk "{_AWK_ALL_LINES_BELOW_REPLICAS}" | '
+ f'grep -E "{name}\s+{replica_id}" | '
+ f'grep -Eo "{_IP_REGEX}")')
+
+
+def _get_skyserve_http_test(name: str, cloud: str,
+ timeout_minutes: int) -> smoke_tests_utils.Test:
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-{cloud.replace("_", "-")}',
+ [
+ f'sky serve up -n {name} -y tests/skyserve/http/{cloud}.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'curl http://$endpoint | grep "Hi, SkyPilot here"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=timeout_minutes * 60,
+ )
+ return test
+
+
+def _check_replica_in_status(name: str, check_tuples: List[Tuple[int, bool,
+ str]]) -> str:
+ """Check replicas' status and count in sky serve status
+
+ We will check vCPU=2, as all our tests use vCPU=2.
+
+ Args:
+ name: the name of the service
+ check_tuples: A list of replica property to check. Each tuple is
+ (count, is_spot, status)
+ """
+ check_cmd = ''
+ for check_tuple in check_tuples:
+ count, is_spot, status = check_tuple
+ resource_str = ''
+ if status not in ['PENDING', 'SHUTTING_DOWN'
+ ] and not status.startswith('FAILED'):
+ spot_str = ''
+ if is_spot:
+ spot_str = '\[Spot\]'
+ resource_str = f'({spot_str}vCPU=2)'
+ check_cmd += (f' echo "$s" | grep "{resource_str}" | '
+ f'grep "{status}" | wc -l | grep {count} || exit 1;')
+ return (f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s"; ' + check_cmd)
+
+
+def _check_service_version(service_name: str, version: str) -> str:
+ # Grep the lines before 'Service Replicas' and check if the service version
+ # is correct.
+ return (f'echo "$s" | grep -B1000 "Service Replicas" | '
+ f'grep -E "{service_name}\s+{version}" || exit 1; ')
+
+
+@pytest.mark.gcp
+@pytest.mark.serve
+def test_skyserve_gcp_http():
+ """Test skyserve on GCP"""
+ name = _get_service_name()
+ test = _get_skyserve_http_test(name, 'gcp', 20)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.aws
+@pytest.mark.serve
+def test_skyserve_aws_http():
+ """Test skyserve on AWS"""
+ name = _get_service_name()
+ test = _get_skyserve_http_test(name, 'aws', 20)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.azure
+@pytest.mark.serve
+def test_skyserve_azure_http():
+ """Test skyserve on Azure"""
+ name = _get_service_name()
+ test = _get_skyserve_http_test(name, 'azure', 30)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.kubernetes
+@pytest.mark.serve
+def test_skyserve_kubernetes_http():
+ """Test skyserve on Kubernetes"""
+ name = _get_service_name()
+ test = _get_skyserve_http_test(name, 'kubernetes', 30)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.oci
+@pytest.mark.serve
+def test_skyserve_oci_http():
+ """Test skyserve on OCI"""
+ name = _get_service_name()
+ test = _get_skyserve_http_test(name, 'oci', 20)
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack does not support T4 gpus for now
+@pytest.mark.parametrize('accelerator', [{'do': 'H100'}])
+@pytest.mark.serve
+def test_skyserve_llm(generic_cloud: str, accelerator: Dict[str, str]):
+ """Test skyserve with real LLM usecase"""
+ accelerator = accelerator.get(generic_cloud, 'T4')
+ name = _get_service_name()
+
+ def generate_llm_test_command(prompt: str, expected_output: str) -> str:
+ prompt = shlex.quote(prompt)
+ expected_output = shlex.quote(expected_output)
+ return (
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'python tests/skyserve/llm/get_response.py --endpoint $endpoint '
+ f'--prompt {prompt} | grep {expected_output}')
+
+ with open('tests/skyserve/llm/prompt_output.json', 'r',
+ encoding='utf-8') as f:
+ prompt2output = json.load(f)
+
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-llm',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} --gpus {accelerator} -y tests/skyserve/llm/service.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ *[
+ generate_llm_test_command(prompt, output)
+ for prompt, output in prompt2output.items()
+ ],
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=40 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.serve
+def test_skyserve_spot_recovery():
+ name = _get_service_name()
+ zone = 'us-central1-a'
+
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-spot-recovery-gcp',
+ [
+ f'sky serve up -n {name} -y tests/skyserve/spot/recovery.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
+ smoke_tests_utils.terminate_gcp_replica(name, zone, 1),
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
+@pytest.mark.serve
+@pytest.mark.no_kubernetes
+@pytest.mark.no_do
+def test_skyserve_base_ondemand_fallback(generic_cloud: str):
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-base-ondemand-fallback',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/spot/base_ondemand_fallback.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ _check_replica_in_status(name, [(1, True, 'READY'),
+ (1, False, 'READY')]),
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.serve
+def test_skyserve_dynamic_ondemand_fallback():
+ name = _get_service_name()
+ zone = 'us-central1-a'
+
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-dynamic-ondemand-fallback',
+ [
+ f'sky serve up -n {name} --cloud gcp -y tests/skyserve/spot/dynamic_ondemand_fallback.yaml',
+ f'sleep 40',
+ # 2 on-demand (provisioning) + 2 Spot (provisioning).
+ f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s";'
+ 'echo "$s" | grep -q "0/4" || exit 1',
+ # Wait for the provisioning starts
+ f'sleep 40',
+ _check_replica_in_status(name, [
+ (2, True, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY'),
+ (2, False, _SERVICE_LAUNCHING_STATUS_REGEX + '\|SHUTTING_DOWN')
+ ]),
+
+ # Wait until 2 spot instances are ready.
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ _check_replica_in_status(name, [(2, True, 'READY'),
+ (0, False, '')]),
+ smoke_tests_utils.terminate_gcp_replica(name, zone, 1),
+ f'sleep 40',
+ # 1 on-demand (provisioning) + 1 Spot (ready) + 1 spot (provisioning).
+ f'{_SERVE_STATUS_WAIT.format(name=name)}; '
+ 'echo "$s" | grep -q "1/3"',
+ _check_replica_in_status(
+ name, [(1, True, 'READY'),
+ (1, True, _SERVICE_LAUNCHING_STATUS_REGEX),
+ (1, False, _SERVICE_LAUNCHING_STATUS_REGEX)]),
+
+ # Wait until 2 spot instances are ready.
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ _check_replica_in_status(name, [(2, True, 'READY'),
+ (0, False, '')]),
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
+@pytest.mark.no_fluidstack
+@pytest.mark.no_do # DO does not support `--cpus 2`
+@pytest.mark.serve
+def test_skyserve_user_bug_restart(generic_cloud: str):
+ """Tests that we restart the service after user bug."""
+ # TODO(zhwu): this behavior needs some rethinking.
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-user-bug-restart',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/restart/user_bug.yaml',
+ f's=$(sky serve status {name}); echo "$s";'
+ 'until echo "$s" | grep -A 100 "Service Replicas" | grep "SHUTTING_DOWN"; '
+ 'do echo "Waiting for first service to be SHUTTING DOWN..."; '
+ f'sleep 5; s=$(sky serve status {name}); echo "$s"; done; ',
+ f's=$(sky serve status {name}); echo "$s";'
+ 'until echo "$s" | grep -A 100 "Service Replicas" | grep "FAILED"; '
+ 'do echo "Waiting for first service to be FAILED..."; '
+ f'sleep 2; s=$(sky serve status {name}); echo "$s"; done; echo "$s"; '
+ + _check_replica_in_status(name, [(1, True, 'FAILED')]) +
+ # User bug failure will cause no further scaling.
+ f'echo "$s" | grep -A 100 "Service Replicas" | grep "{name}" | wc -l | grep 1; '
+ f'echo "$s" | grep -B 100 "NO_REPLICA" | grep "0/0"',
+ f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/auto_restart.yaml',
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'until curl --connect-timeout 10 --max-time 10 http://$endpoint | grep "Hi, SkyPilot here"; do sleep 1; done; sleep 2; '
+ + _check_replica_in_status(name, [(1, False, 'READY'),
+ (1, False, 'FAILED')]),
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.serve
+@pytest.mark.no_kubernetes # Replicas on k8s may be running on the same node and have the same public IP
+def test_skyserve_load_balancer(generic_cloud: str):
+ """Test skyserve load balancer round-robin policy"""
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-load-balancer',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/load_balancer/service.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=3),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ f'{_SERVE_STATUS_WAIT.format(name=name)}; '
+ f'{_get_replica_ip(name, 1)}; '
+ f'{_get_replica_ip(name, 2)}; {_get_replica_ip(name, 3)}; '
+ 'python tests/skyserve/load_balancer/test_round_robin.py '
+ '--endpoint $endpoint --replica-num 3 --replica-ips $ip1 $ip2 $ip3',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.gcp
+@pytest.mark.serve
+@pytest.mark.no_kubernetes
+def test_skyserve_auto_restart():
+ """Test skyserve with auto restart"""
+ name = _get_service_name()
+ zone = 'us-central1-a'
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-auto-restart',
+ [
+ # TODO(tian): we can dynamically generate YAML from template to
+ # avoid maintaining too many YAML files
+ f'sky serve up -n {name} -y tests/skyserve/auto_restart.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
+ # sleep for 20 seconds (initial delay) to make sure it will
+ # be restarted
+ f'sleep 20',
+ smoke_tests_utils.terminate_gcp_replica(name, zone, 1),
+ # Wait for consecutive failure timeout passed.
+ # If the cluster is not using spot, it won't check the cluster status
+ # on the cloud (since manual shutdown is not a common behavior and such
+ # queries takes a lot of time). Instead, we think continuous 3 min probe
+ # failure is not a temporary problem but indeed a failure.
+ 'sleep 180',
+ # We cannot use _SERVE_WAIT_UNTIL_READY; there will be a intermediate time
+ # that the output of `sky serve status` shows FAILED and this status will
+ # cause _SERVE_WAIT_UNTIL_READY to early quit.
+ '(while true; do'
+ f' output=$(sky serve status {name});'
+ ' echo "$output" | grep -q "1/1" && break;'
+ ' sleep 10;'
+ f'done); sleep {serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS};',
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.serve
+def test_skyserve_cancel(generic_cloud: str):
+ """Test skyserve with cancel"""
+ name = _get_service_name()
+
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-cancel',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/cancel/cancel.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; python3 '
+ 'tests/skyserve/cancel/send_cancel_request.py '
+ '--endpoint $endpoint | grep "Request was cancelled"',
+ f's=$(sky serve logs {name} 1 --no-follow); '
+ 'until ! echo "$s" | grep "Please wait for the controller to be"; '
+ 'do echo "Waiting for serve logs"; sleep 10; '
+ f's=$(sky serve logs {name} 1 --no-follow); done; '
+ 'echo "$s"; echo "$s" | grep "Client disconnected, stopping computation"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.serve
+def test_skyserve_streaming(generic_cloud: str):
+ """Test skyserve with streaming"""
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-streaming',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/streaming/streaming.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'python3 tests/skyserve/streaming/send_streaming_request.py '
+ '--endpoint $endpoint | grep "Streaming test passed"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.serve
+def test_skyserve_readiness_timeout_fail(generic_cloud: str):
+ """Test skyserve with large readiness probe latency, expected to fail"""
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-readiness-timeout-fail',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task.yaml',
+ # None of the readiness probe will pass, so the service will be
+ # terminated after the initial delay.
+ f's=$(sky serve status {name}); '
+ f'until echo "$s" | grep "FAILED_INITIAL_DELAY"; do '
+ 'echo "Waiting for replica to be failed..."; sleep 5; '
+ f's=$(sky serve status {name}); echo "$s"; done;',
+ 'sleep 60',
+ f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s" | grep "{name}" | grep "FAILED_INITIAL_DELAY" | wc -l | grep 1;'
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.serve
+def test_skyserve_large_readiness_timeout(generic_cloud: str):
+ """Test skyserve with customized large readiness timeout"""
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-large-readiness-timeout',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task_large_timeout.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
+@pytest.mark.no_fluidstack
+@pytest.mark.no_do # DO does not support `--cpus 2`
+@pytest.mark.serve
+def test_skyserve_update(generic_cloud: str):
+ """Test skyserve with update"""
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-update',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
+ f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/new.yaml',
+ # sleep before update is registered.
+ 'sleep 20',
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done;'
+ # Make sure the traffic is not mixed
+ 'curl http://$endpoint | grep "Hi, new SkyPilot here"',
+ # The latest 2 version should be READY and the older versions should be shutting down
+ (_check_replica_in_status(name, [(2, False, 'READY'),
+ (2, False, 'SHUTTING_DOWN')]) +
+ _check_service_version(name, "2")),
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
+@pytest.mark.no_fluidstack
+@pytest.mark.no_do # DO does not support `--cpus 2`
+@pytest.mark.serve
+def test_skyserve_rolling_update(generic_cloud: str):
+ """Test skyserve with rolling update"""
+ name = _get_service_name()
+ single_new_replica = _check_replica_in_status(
+ name, [(2, False, 'READY'), (1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
+ (1, False, 'SHUTTING_DOWN')])
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-rolling-update',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
+ f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/new.yaml',
+ # Make sure the traffic is mixed across two versions, the replicas
+ # with even id will sleep 60 seconds before being ready, so we
+ # should be able to get observe the period that the traffic is mixed
+ # across two versions.
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done; sleep 2; '
+ # The latest version should have one READY and the one of the older versions should be shutting down
+ f'{single_new_replica} {_check_service_version(name, "1,2")} '
+ # Check the output from the old version, immediately after the
+ # output from the new version appears. This is guaranteed by the
+ # round robin load balancing policy.
+ # TODO(zhwu): we should have a more generalized way for checking the
+ # mixed version of replicas to avoid depending on the specific
+ # round robin load balancing policy.
+ 'curl http://$endpoint | grep "Hi, SkyPilot here"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack
+@pytest.mark.serve
+def test_skyserve_fast_update(generic_cloud: str):
+ """Test skyserve with fast update (Increment version of old replicas)"""
+ name = _get_service_name()
+
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-fast-update',
+ [
+ f'sky serve up -n {name} -y --cloud {generic_cloud} tests/skyserve/update/bump_version_before.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
+ f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/bump_version_after.yaml',
+ # sleep to wait for update to be registered.
+ 'sleep 40',
+ # 2 on-deamnd (ready) + 1 on-demand (provisioning).
+ (
+ _check_replica_in_status(
+ name, [(2, False, 'READY'),
+ (1, False, _SERVICE_LAUNCHING_STATUS_REGEX)]) +
+ # Fast update will directly have the latest version ready.
+ _check_service_version(name, "2")),
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=3) +
+ _check_service_version(name, "2"),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
+ # Test rolling update
+ f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/bump_version_before.yaml',
+ # sleep to wait for update to be registered.
+ 'sleep 25',
+ # 2 on-deamnd (ready) + 1 on-demand (shutting down).
+ _check_replica_in_status(name, [(2, False, 'READY'),
+ (1, False, 'SHUTTING_DOWN')]),
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
+ _check_service_version(name, "3"),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.serve
+def test_skyserve_update_autoscale(generic_cloud: str):
+ """Test skyserve update with autoscale"""
+ name = _get_service_name()
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-update-autoscale',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/num_min_two.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
+ _check_service_version(name, "1"),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'curl http://$endpoint | grep "Hi, SkyPilot here"',
+ f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/num_min_one.yaml',
+ # sleep before update is registered.
+ 'sleep 20',
+ # Timeout will be triggered when update fails.
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1) +
+ _check_service_version(name, "2"),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'curl http://$endpoint | grep "Hi, SkyPilot here!"',
+ # Rolling Update
+ f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/num_min_two.yaml',
+ # sleep before update is registered.
+ 'sleep 20',
+ # Timeout will be triggered when update fails.
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
+ _check_service_version(name, "3"),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'curl http://$endpoint | grep "Hi, SkyPilot here!"',
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=30 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+@pytest.mark.no_fluidstack # Spot instances are note supported by Fluidstack
+@pytest.mark.serve
+@pytest.mark.no_kubernetes # Spot instances are not supported in Kubernetes
+@pytest.mark.no_do # Spot instances not on DO
+@pytest.mark.parametrize('mode', ['rolling', 'blue_green'])
+def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str):
+ """Test skyserve with update that changes autoscaler"""
+ name = f'{_get_service_name()}-{mode}'
+
+ wait_until_no_pending = (
+ f's=$(sky serve status {name}); echo "$s"; '
+ 'until ! echo "$s" | grep PENDING; do '
+ ' echo "Waiting for replica to be out of pending..."; '
+ f' sleep 5; s=$(sky serve status {name}); '
+ ' echo "$s"; '
+ 'done')
+ four_spot_up_cmd = _check_replica_in_status(name, [(4, True, 'READY')])
+ update_check = [f'until ({four_spot_up_cmd}); do sleep 5; done; sleep 15;']
+ if mode == 'rolling':
+ # Check rolling update, it will terminate one of the old on-demand
+ # instances, once there are 4 spot instance ready.
+ update_check += [
+ _check_replica_in_status(
+ name, [(1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
+ (1, False, 'SHUTTING_DOWN'), (1, False, 'READY')]) +
+ _check_service_version(name, "1,2"),
+ ]
+ else:
+ # Check blue green update, it will keep both old on-demand instances
+ # running, once there are 4 spot instance ready.
+ update_check += [
+ _check_replica_in_status(
+ name, [(1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
+ (2, False, 'READY')]) +
+ _check_service_version(name, "1"),
+ ]
+ test = smoke_tests_utils.Test(
+ f'test-skyserve-new-autoscaler-update-{mode}',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/new_autoscaler_before.yaml',
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
+ _check_service_version(name, "1"),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 's=$(curl http://$endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"',
+ f'sky serve update {name} --cloud {generic_cloud} --mode {mode} -y tests/skyserve/update/new_autoscaler_after.yaml',
+ # Wait for update to be registered
+ f'sleep 90',
+ wait_until_no_pending,
+ _check_replica_in_status(
+ name, [(4, True, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY'),
+ (1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
+ (2, False, 'READY')]),
+ *update_check,
+ _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=5),
+ f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
+ 'curl http://$endpoint | grep "Hi, SkyPilot here"',
+ _check_replica_in_status(name, [(4, True, 'READY'),
+ (1, False, 'READY')]),
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
+@pytest.mark.no_fluidstack
+@pytest.mark.no_do # DO does not support `--cpus 2`
+@pytest.mark.serve
+def test_skyserve_failures(generic_cloud: str):
+ """Test replica failure statuses"""
+ name = _get_service_name()
+
+ test = smoke_tests_utils.Test(
+ 'test-skyserve-failures',
+ [
+ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/failures/initial_delay.yaml',
+ f's=$(sky serve status {name}); '
+ f'until echo "$s" | grep "FAILED_INITIAL_DELAY"; do '
+ 'echo "Waiting for replica to be failed..."; sleep 5; '
+ f's=$(sky serve status {name}); echo "$s"; done;',
+ 'sleep 60',
+ f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s" | grep "{name}" | grep "FAILED_INITIAL_DELAY" | wc -l | grep 2; '
+ # Make sure no new replicas are started for early failure.
+ f'echo "$s" | grep -A 100 "Service Replicas" | grep "{name}" | wc -l | grep 2;',
+ f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/failures/probing.yaml',
+ f's=$(sky serve status {name}); '
+ # Wait for replica to be ready.
+ f'until echo "$s" | grep "READY"; do '
+ 'echo "Waiting for replica to be failed..."; sleep 5; '
+ f's=$(sky serve status {name}); echo "$s"; done;',
+ # Wait for replica to change to FAILED_PROBING
+ f's=$(sky serve status {name}); '
+ f'until echo "$s" | grep "FAILED_PROBING"; do '
+ 'echo "Waiting for replica to be failed..."; sleep 5; '
+ f's=$(sky serve status {name}); echo "$s"; done',
+ # Wait for the PENDING replica to appear.
+ 'sleep 10',
+ # Wait until the replica is out of PENDING.
+ f's=$(sky serve status {name}); '
+ f'until ! echo "$s" | grep "PENDING" && ! echo "$s" | grep "Please wait for the controller to be ready."; do '
+ 'echo "Waiting for replica to be out of pending..."; sleep 5; '
+ f's=$(sky serve status {name}); echo "$s"; done; ' +
+ _check_replica_in_status(name, [
+ (1, False, 'FAILED_PROBING'),
+ (1, False, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY')
+ ]),
+ # TODO(zhwu): add test for FAILED_PROVISION
+ ],
+ _TEARDOWN_SERVICE.format(name=name),
+ timeout=20 * 60,
+ )
+ smoke_tests_utils.run_one_test(test)
+
+
+# TODO(Ziming, Tian): Add tests for autoscaling.
+
+
+# ------- Testing user dependencies --------
+def test_user_dependencies(generic_cloud: str):
+ name = smoke_tests_utils.get_cluster_name()
+ test = smoke_tests_utils.Test(
+ 'user-dependencies',
+ [
+ f'sky launch -y -c {name} --cloud {generic_cloud} "pip install ray>2.11; ray start --head"',
+ f'sky logs {name} 1 --status',
+ f'sky exec {name} "echo hi"',
+ f'sky logs {name} 2 --status',
+ f'sky status -r {name} | grep UP',
+ f'sky exec {name} "echo bye"',
+ f'sky logs {name} 3 --status',
+ f'sky launch -c {name} tests/test_yamls/different_default_conda_env.yaml',
+ f'sky logs {name} 4 --status',
+ # Launch again to test the default env does not affect SkyPilot
+ # runtime setup
+ f'sky launch -c {name} "python --version 2>&1 | grep \'Python 3.6\' || exit 1"',
+ f'sky logs {name} 5 --status',
+ ],
+ f'sky down -y {name}',
+ )
+ smoke_tests_utils.run_one_test(test)
diff --git a/tests/test_config.py b/tests/test_config.py
index 7ed212a58a5..e607861cb65 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -10,6 +10,7 @@
import sky
from sky import skypilot_config
from sky.api.requests import payloads
+import sky.exceptions
from sky.skylet import constants
from sky.utils import common_utils
from sky.utils import config_utils
@@ -97,6 +98,45 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None:
"""))
+def _create_invalid_config_yaml_file(task_file_path: pathlib.Path) -> None:
+ task_file_path.write_text(
+ textwrap.dedent("""\
+ experimental:
+ config_overrides:
+ kubernetes:
+ pod_config:
+ metadata:
+ labels:
+ test-key: test-value
+ annotations:
+ abc: def
+ spec:
+ containers:
+ - name:
+ imagePullSecrets:
+ - name: my-secret-2
+
+ setup: echo 'Setting up...'
+ run: echo 'Running...'
+ """))
+
+
+def test_nested_config(monkeypatch) -> None:
+ """Test that the nested config works."""
+ config = skypilot_config.Config()
+ config.set_nested(('aws', 'ssh_proxy_command'), 'value')
+ assert config == {'aws': {'ssh_proxy_command': 'value'}}
+
+ assert config.get_nested(('admin_policy',), 'default') == 'default'
+ config.set_nested(('aws', 'use_internal_ips'), True)
+ assert config == {
+ 'aws': {
+ 'ssh_proxy_command': 'value',
+ 'use_internal_ips': True
+ }
+ }
+
+
def test_no_config(monkeypatch) -> None:
"""Test that the config is not loaded if the config file does not exist."""
monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', '/tmp/does_not_exist')
@@ -317,6 +357,28 @@ def test_k8s_config_with_override(monkeypatch, tmp_path,
assert cluster_pod_config['spec']['runtimeClassName'] == 'nvidia'
+def test_k8s_config_with_invalid_config(monkeypatch, tmp_path,
+ enable_all_clouds) -> None:
+ config_path = tmp_path / 'config.yaml'
+ _create_config_file(config_path)
+ monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path)
+
+ _reload_config()
+ task_path = tmp_path / 'task.yaml'
+ _create_invalid_config_yaml_file(task_path)
+ task = sky.Task.from_yaml(task_path)
+
+ # Test Kubernetes pod_config invalid
+ cluster_name = 'test_k8s_config_with_invalid_config'
+ task.set_resources_override({'cloud': sky.Kubernetes()})
+ exception_occurred = False
+ try:
+ sky.launch(task, cluster_name=cluster_name, dryrun=True)
+ except sky.exceptions.ResourcesUnavailableError:
+ exception_occurred = True
+ assert exception_occurred
+
+
def test_gcp_config_with_override(monkeypatch, tmp_path,
enable_all_clouds) -> None:
config_path = tmp_path / 'config.yaml'
diff --git a/tests/test_smoke.py b/tests/test_smoke.py
index 618cab5ed33..b4f3a97bf70 100644
--- a/tests/test_smoke.py
+++ b/tests/test_smoke.py
@@ -25,6288 +25,11 @@
# Change cloud for generic tests to aws
# > pytest tests/test_smoke.py --generic-cloud aws
-import enum
-import inspect
-import json
-import os
-import pathlib
-import shlex
-import shutil
-import subprocess
-import sys
-import tempfile
-import textwrap
-import time
-from typing import Dict, List, NamedTuple, Optional, TextIO, Tuple
-import urllib.parse
-import uuid
-
-import colorama
-import jinja2
-import pytest
-
-import sky
-from sky import global_user_state
-from sky import jobs
-from sky import serve
-from sky import skypilot_config
-from sky.adaptors import cloudflare
-from sky.adaptors import ibm
-from sky.clouds import AWS
-from sky.clouds import Azure
-from sky.clouds import GCP
-from sky.data import data_utils
-from sky.data import storage as storage_lib
-from sky.data.data_utils import Rclone
-from sky.skylet import constants
-from sky.skylet import events
-from sky.utils import common_utils
-from sky.utils import resources_utils
-from sky.utils import subprocess_utils
-
-# To avoid the second smoke test reusing the cluster launched in the first
-# smoke test. Also required for test_managed_jobs_recovery to make sure the
-# manual termination with aws ec2 does not accidentally terminate other clusters
-# for for the different managed jobs launch with the same job name but a
-# different job id.
-test_id = str(uuid.uuid4())[-2:]
-
-LAMBDA_TYPE = '--cloud lambda --gpus A10'
-FLUIDSTACK_TYPE = '--cloud fluidstack --gpus RTXA4000'
-
-SCP_TYPE = '--cloud scp'
-SCP_GPU_V100 = '--gpus V100-32GB'
-
-STORAGE_SETUP_COMMANDS = [
- 'touch ~/tmpfile', 'mkdir -p ~/tmp-workdir',
- 'touch ~/tmp-workdir/tmp\ file', 'touch ~/tmp-workdir/tmp\ file2',
- 'touch ~/tmp-workdir/foo',
- '[ ! -e ~/tmp-workdir/circle-link ] && ln -s ~/tmp-workdir/ ~/tmp-workdir/circle-link || true',
- 'touch ~/.ssh/id_rsa.pub'
-]
-
-# Get the job queue, and print it once on its own, then print it again to
-# use with grep by the caller.
-_GET_JOB_QUEUE = 's=$(sky jobs queue); echo "$s"; echo "$s"'
-# Wait for a job to be not in RUNNING state. Used to check for RECOVERING.
-_JOB_WAIT_NOT_RUNNING = (
- 's=$(sky jobs queue);'
- 'until ! echo "$s" | grep "{job_name}" | grep "RUNNING"; do '
- 'sleep 10; s=$(sky jobs queue);'
- 'echo "Waiting for job to stop RUNNING"; echo "$s"; done')
-
-# Cluster functions
-_ALL_JOB_STATUSES = "|".join([status.value for status in sky.JobStatus])
-_ALL_CLUSTER_STATUSES = "|".join([status.value for status in sky.ClusterStatus])
-_ALL_MANAGED_JOB_STATUSES = "|".join(
- [status.value for status in sky.ManagedJobStatus])
-
-
-def _statuses_to_str(statuses: List[enum.Enum]):
- """Convert a list of enums to a string with all the values separated by |."""
- assert len(statuses) > 0, 'statuses must not be empty'
- if len(statuses) > 1:
- return '(' + '|'.join([status.value for status in statuses]) + ')'
- else:
- return statuses[0].value
-
-
-_WAIT_UNTIL_CLUSTER_STATUS_CONTAINS = (
- # A while loop to wait until the cluster status
- # becomes certain status, with timeout.
- 'start_time=$SECONDS; '
- 'while true; do '
- 'if (( $SECONDS - $start_time > {timeout} )); then '
- ' echo "Timeout after {timeout} seconds waiting for cluster status \'{cluster_status}\'"; exit 1; '
- 'fi; '
- 'current_status=$(sky status {cluster_name} --refresh | '
- 'awk "/^{cluster_name}/ '
- '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_CLUSTER_STATUSES +
- ')$/) print \$i}}"); '
- 'if [[ "$current_status" =~ {cluster_status} ]]; '
- 'then echo "Target cluster status {cluster_status} reached."; break; fi; '
- 'echo "Waiting for cluster status to become {cluster_status}, current status: $current_status"; '
- 'sleep 10; '
- 'done')
-
-
-def _get_cmd_wait_until_cluster_status_contains(
- cluster_name: str, cluster_status: List[sky.ClusterStatus],
- timeout: int):
- return _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.format(
- cluster_name=cluster_name,
- cluster_status=_statuses_to_str(cluster_status),
- timeout=timeout)
-
-
-def _get_cmd_wait_until_cluster_status_contains_wildcard(
- cluster_name_wildcard: str, cluster_status: List[sky.ClusterStatus],
- timeout: int):
- wait_cmd = _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.replace(
- 'sky status {cluster_name}',
- 'sky status "{cluster_name}"').replace('awk "/^{cluster_name}/',
- 'awk "/^{cluster_name_awk}/')
- return wait_cmd.format(cluster_name=cluster_name_wildcard,
- cluster_name_awk=cluster_name_wildcard.replace(
- '*', '.*'),
- cluster_status=_statuses_to_str(cluster_status),
- timeout=timeout)
-
-
-_WAIT_UNTIL_CLUSTER_IS_NOT_FOUND = (
- # A while loop to wait until the cluster is not found or timeout
- 'start_time=$SECONDS; '
- 'while true; do '
- 'if (( $SECONDS - $start_time > {timeout} )); then '
- ' echo "Timeout after {timeout} seconds waiting for cluster to be removed"; exit 1; '
- 'fi; '
- 'if sky status -r {cluster_name}; sky status {cluster_name} | grep "{cluster_name} not found"; then '
- ' echo "Cluster {cluster_name} successfully removed."; break; '
- 'fi; '
- 'echo "Waiting for cluster {cluster_name} to be removed..."; '
- 'sleep 10; '
- 'done')
-
-
-def _get_cmd_wait_until_cluster_is_not_found(cluster_name: str, timeout: int):
- return _WAIT_UNTIL_CLUSTER_IS_NOT_FOUND.format(cluster_name=cluster_name,
- timeout=timeout)
-
-
-_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID = (
- # A while loop to wait until the job status
- # contains certain status, with timeout.
- 'start_time=$SECONDS; '
- 'while true; do '
- 'if (( $SECONDS - $start_time > {timeout} )); then '
- ' echo "Timeout after {timeout} seconds waiting for job status \'{job_status}\'"; exit 1; '
- 'fi; '
- 'current_status=$(sky queue {cluster_name} | '
- 'awk "\\$1 == \\"{job_id}\\" '
- '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_JOB_STATUSES +
- ')$/) print \$i}}"); '
- 'found=0; ' # Initialize found variable outside the loop
- 'while read -r line; do ' # Read line by line
- ' if [[ "$line" =~ {job_status} ]]; then ' # Check each line
- ' echo "Target job status {job_status} reached."; '
- ' found=1; '
- ' break; ' # Break inner loop
- ' fi; '
- 'done <<< "$current_status"; '
- 'if [ "$found" -eq 1 ]; then break; fi; ' # Break outer loop if match found
- 'echo "Waiting for job status to contain {job_status}, current status: $current_status"; '
- 'sleep 10; '
- 'done')
-
-_WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace(
- 'awk "\\$1 == \\"{job_id}\\"', 'awk "')
-
-_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace(
- 'awk "\\$1 == \\"{job_id}\\"', 'awk "\\$2 == \\"{job_name}\\"')
-
-
-def _get_cmd_wait_until_job_status_contains_matching_job_id(
- cluster_name: str, job_id: str, job_status: List[sky.JobStatus],
- timeout: int):
- return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.format(
- cluster_name=cluster_name,
- job_id=job_id,
- job_status=_statuses_to_str(job_status),
- timeout=timeout)
-
-
-def _get_cmd_wait_until_job_status_contains_without_matching_job(
- cluster_name: str, job_status: List[sky.JobStatus], timeout: int):
- return _WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB.format(
- cluster_name=cluster_name,
- job_status=_statuses_to_str(job_status),
- timeout=timeout)
-
-
-def _get_cmd_wait_until_job_status_contains_matching_job_name(
- cluster_name: str, job_name: str, job_status: List[sky.JobStatus],
- timeout: int):
- return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format(
- cluster_name=cluster_name,
- job_name=job_name,
- job_status=_statuses_to_str(job_status),
- timeout=timeout)
-
-
-# Managed job functions
-
-_WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.replace(
- 'sky queue {cluster_name}', 'sky jobs queue').replace(
- 'awk "\\$2 == \\"{job_name}\\"',
- 'awk "\\$2 == \\"{job_name}\\" || \\$3 == \\"{job_name}\\"').replace(
- _ALL_JOB_STATUSES, _ALL_MANAGED_JOB_STATUSES)
-
-
-def _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name: str, job_status: List[sky.JobStatus], timeout: int):
- return _WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format(
- job_name=job_name,
- job_status=_statuses_to_str(job_status),
- timeout=timeout)
-
-
-# After the timeout, the cluster will stop if autostop is set, and our check
-# should be more than the timeout. To address this, we extend the timeout by
-# _BUMP_UP_SECONDS before exiting.
-_BUMP_UP_SECONDS = 35
-
-DEFAULT_CMD_TIMEOUT = 15 * 60
-
-
-class Test(NamedTuple):
- name: str
- # Each command is executed serially. If any failed, the remaining commands
- # are not run and the test is treated as failed.
- commands: List[str]
- teardown: Optional[str] = None
- # Timeout for each command in seconds.
- timeout: int = DEFAULT_CMD_TIMEOUT
- # Environment variables to set for each command.
- env: Dict[str, str] = None
-
- def echo(self, message: str):
- # pytest's xdist plugin captures stdout; print to stderr so that the
- # logs are streaming while the tests are running.
- prefix = f'[{self.name}]'
- message = f'{prefix} {message}'
- message = message.replace('\n', f'\n{prefix} ')
- print(message, file=sys.stderr, flush=True)
-
-
-def _get_timeout(generic_cloud: str,
- override_timeout: int = DEFAULT_CMD_TIMEOUT):
- timeouts = {'fluidstack': 60 * 60} # file_mounts
- return timeouts.get(generic_cloud, override_timeout)
-
-
-def _get_cluster_name() -> str:
- """Returns a user-unique cluster name for each test_().
-
- Must be called from each test_().
- """
- caller_func_name = inspect.stack()[1][3]
- test_name = caller_func_name.replace('_', '-').replace('test-', 't-')
- test_name = common_utils.make_cluster_name_on_cloud(test_name,
- 24,
- add_user_hash=False)
- return f'{test_name}-{test_id}'
-
-
-def _terminate_gcp_replica(name: str, zone: str, replica_id: int) -> str:
- cluster_name = serve.generate_replica_cluster_name(name, replica_id)
- query_cmd = (f'gcloud compute instances list --filter='
- f'"(labels.ray-cluster-name:{cluster_name})" '
- f'--zones={zone} --format="value(name)"')
- return (f'gcloud compute instances delete --zone={zone}'
- f' --quiet $({query_cmd})')
-
-
-def run_one_test(test: Test) -> Tuple[int, str, str]:
- # Fail fast if `sky` CLI somehow errors out.
- subprocess.run(['sky', 'status'], stdout=subprocess.DEVNULL, check=True)
- log_file = tempfile.NamedTemporaryFile('a',
- prefix=f'{test.name}-',
- suffix='.log',
- delete=False)
- test.echo(f'Test started. Log: less {log_file.name}')
- env_dict = os.environ.copy()
- if test.env:
- env_dict.update(test.env)
- for command in test.commands:
- log_file.write(f'+ {command}\n')
- log_file.flush()
- proc = subprocess.Popen(
- command,
- stdout=log_file,
- stderr=subprocess.STDOUT,
- shell=True,
- executable='/bin/bash',
- env=env_dict,
- )
- try:
- proc.wait(timeout=test.timeout)
- except subprocess.TimeoutExpired as e:
- log_file.flush()
- test.echo(f'Timeout after {test.timeout} seconds.')
- test.echo(str(e))
- log_file.write(f'Timeout after {test.timeout} seconds.\n')
- log_file.flush()
- # Kill the current process.
- proc.terminate()
- proc.returncode = 1 # None if we don't set it.
- break
-
- if proc.returncode:
- break
-
- style = colorama.Style
- fore = colorama.Fore
- outcome = (f'{fore.RED}Failed{style.RESET_ALL}'
- if proc.returncode else f'{fore.GREEN}Passed{style.RESET_ALL}')
- reason = f'\nReason: {command}' if proc.returncode else ''
- msg = (f'{outcome}.'
- f'{reason}'
- f'\nLog: less {log_file.name}\n')
- test.echo(msg)
- log_file.write(msg)
- if (proc.returncode == 0 or
- pytest.terminate_on_failure) and test.teardown is not None:
- subprocess_utils.run(
- test.teardown,
- stdout=log_file,
- stderr=subprocess.STDOUT,
- timeout=10 * 60, # 10 mins
- shell=True,
- )
-
- if proc.returncode:
- raise Exception(f'test failed: less {log_file.name}')
-
-
-def get_aws_region_for_quota_failover() -> Optional[str]:
- candidate_regions = AWS.regions_with_offering(instance_type='p3.16xlarge',
- accelerators=None,
- use_spot=True,
- region=None,
- zone=None)
- original_resources = sky.Resources(cloud=sky.AWS(),
- instance_type='p3.16xlarge',
- use_spot=True)
-
- # Filter the regions with proxy command in ~/.sky/config.yaml.
- filtered_regions = original_resources.get_valid_regions_for_launchable()
- candidate_regions = [
- region for region in candidate_regions
- if region.name in filtered_regions
- ]
-
- for region in candidate_regions:
- resources = original_resources.copy(region=region.name)
- if not AWS.check_quota_available(resources):
- return region.name
-
- return None
-
-
-def get_gcp_region_for_quota_failover() -> Optional[str]:
-
- candidate_regions = GCP.regions_with_offering(instance_type=None,
- accelerators={'A100-80GB': 1},
- use_spot=True,
- region=None,
- zone=None)
-
- original_resources = sky.Resources(cloud=sky.GCP(),
- instance_type='a2-ultragpu-1g',
- accelerators={'A100-80GB': 1},
- use_spot=True)
-
- # Filter the regions with proxy command in ~/.sky/config.yaml.
- filtered_regions = original_resources.get_valid_regions_for_launchable()
- candidate_regions = [
- region for region in candidate_regions
- if region.name in filtered_regions
- ]
-
- for region in candidate_regions:
- if not GCP.check_quota_available(
- original_resources.copy(region=region.name)):
- return region.name
-
- return None
-
-
-# ---------- Dry run: 2 Tasks in a chain. ----------
-@pytest.mark.no_fluidstack #requires GCP and AWS set up
-def test_example_app():
- test = Test(
- 'example_app',
- ['python examples/example_app.py'],
- )
- run_one_test(test)
-
-
-_VALIDATE_LAUNCH_OUTPUT = (
- # Validate the output of the job submission:
- # ⚙️ Launching on Kubernetes.
- # Pod is up.
- # ✓ Cluster launched: test. View logs at: ~/sky_logs/sky-2024-10-07-19-44-18-177288/provision.log
- # ✓ Setup Detached.
- # ⚙️ Job submitted, ID: 1.
- # ├── Waiting for task resources on 1 node.
- # └── Job started. Streaming logs... (Ctrl-C to exit log streaming; job will not be killed)
- # (setup pid=1277) running setup
- # (min, pid=1277) # conda environments:
- # (min, pid=1277) #
- # (min, pid=1277) base * /opt/conda
- # (min, pid=1277)
- # (min, pid=1277) task run finish
- # ✓ Job finished (status: SUCCEEDED).
- #
- # Job ID: 1
- # 📋 Useful Commands
- # ├── To cancel the job: sky cancel test 1
- # ├── To stream job logs: sky logs test 1
- # └── To view job queue: sky queue test
- #
- # Cluster name: test
- # ├── To log into the head VM: ssh test
- # ├── To submit a job: sky exec test yaml_file
- # ├── To stop the cluster: sky stop test
- # └── To teardown the cluster: sky down test
- 'echo "$s" && echo "==Validating launching==" && '
- 'echo "$s" | grep -A 1 "Launching on" | grep "is up." && '
- 'echo "$s" && echo "==Validating setup output==" && '
- 'echo "$s" | grep -A 1 "Setup detached" | grep "Job submitted" && '
- 'echo "==Validating running output hints==" && echo "$s" | '
- 'grep -A 1 "Job submitted, ID:" | '
- 'grep "Waiting for task resources on " && '
- 'echo "==Validating task setup/run output starting==" && echo "$s" | '
- 'grep -A 1 "Job started. Streaming logs..." | grep "(setup" | '
- 'grep "running setup" && '
- 'echo "$s" | grep -A 1 "(setup" | grep "(min, pid=" && '
- 'echo "==Validating task output ending==" && '
- 'echo "$s" | grep -A 1 "task run finish" | '
- 'grep "Job finished (status: SUCCEEDED)" && '
- 'echo "==Validating task output ending 2==" && '
- 'echo "$s" | grep -A 5 "Job finished (status: SUCCEEDED)" | '
- 'grep "Job ID:" && '
- 'echo "$s" | grep -A 1 "Useful Commands" | grep "Job ID:"')
-
-
-# ---------- A minimal task ----------
-def test_minimal(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'minimal',
- [
- f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --cpus 2+ tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}',
- # Output validation done.
- f'sky logs {name} 1 --status',
- f'sky logs {name} --status | grep "Job 1: SUCCEEDED"', # Equivalent.
- # Test launch output again on existing cluster
- f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}',
- f'sky logs {name} 2 --status',
- f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
- # Check the logs downloading
- # TODO(zhwu): Fix the logs downloading.
- # f'log_path=$(sky logs {name} 1 --sync-down | grep "Job 1 logs:" | sed -E "s/^.*Job 1 logs: (.*)\\x1b\\[0m/\\1/g") && echo "$log_path" && test -f $log_path/run.log',
- # Ensure the raylet process has the correct file descriptor limit.
- f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- # Install jq for the next test.
- f'sky exec {name} \'sudo apt-get update && sudo apt-get install -y jq\'',
- # Check the cluster info
- f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cluster_name | grep {name}\'',
- f'sky logs {name} 5 --status', # Ensure the job succeeded.
- f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cloud | grep -i {generic_cloud}\'',
- f'sky logs {name} 6 --status', # Ensure the job succeeded.
- # Test '-c' for exec
- f'sky exec -c {name} echo',
- f'sky logs {name} 7 --status',
- f'sky exec echo -c {name}',
- f'sky logs {name} 8 --status',
- f'sky exec -c {name} echo hi test',
- f'sky logs {name} 9 | grep "hi test"',
- f'sky exec {name} && exit 1 || true',
- f'sky exec -c {name} && exit 1 || true',
- ],
- f'sky down -y {name}',
- _get_timeout(generic_cloud),
- )
- run_one_test(test)
-
-
-# ---------- Test fast launch ----------
-def test_launch_fast(generic_cloud: str):
- name = _get_cluster_name()
-
- test = Test(
- 'test_launch_fast',
- [
- # First launch to create the cluster
- f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}',
- f'sky logs {name} 1 --status',
-
- # Second launch to test fast launch - should not reprovision
- f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast tests/test_yamls/minimal.yaml) && '
- ' echo "$s" && '
- # Validate that cluster was not re-launched.
- '! echo "$s" | grep -A 1 "Launching on" | grep "is up." && '
- # Validate that setup was not re-run.
- '! echo "$s" | grep -A 1 "Running setup on" | grep "running setup" && '
- # Validate that the task ran and finished.
- 'echo "$s" | grep -A 1 "task run finish" | grep "Job finished (status: SUCCEEDED)"',
- f'sky logs {name} 2 --status',
- f'sky status -r {name} | grep UP',
- ],
- f'sky down -y {name}',
- timeout=_get_timeout(generic_cloud),
- )
- run_one_test(test)
-
-
-# ---------- Test multi-tenant ----------
-def test_multi_tenant(generic_cloud: str):
- name = _get_cluster_name()
- user_1 = 'abcdef12'
- user_1_name = 'user1'
- user_2 = 'abcdef13'
- user_2_name = 'user2'
-
- def set_user(user_id: str, user_name: str,
- commands: List[str]) -> List[str]:
- return [
- f'export {constants.USER_ID_ENV_VAR}="{user_id}"; '
- f'export {constants.USER_ENV_VAR}="{user_name}"; ' + cmd
- for cmd in commands
- ]
-
- stop_test_cmds = [
- 'echo "==== Test multi-tenant cluster stop ===="',
- *set_user(
- user_2,
- user_2_name,
- [
- f'sky stop -y -a',
- # -a should only stop clusters from the current user.
- f's=$(sky status -u {name}-1) && echo "$s" && echo "$s" | grep {user_1_name} | grep UP',
- f's=$(sky status -u {name}-2) && echo "$s" && echo "$s" | grep {user_2_name} | grep STOPPED',
- # Explicit cluster name should stop the cluster.
- f'sky stop -y {name}-1',
- # Stopping cluster should not change the ownership of the cluster.
- f's=$(sky status) && echo "$s" && echo "$s" | grep {name}-1 && exit 1 || true',
- f'sky status {name}-1 | grep STOPPED',
- # Both clusters should be stopped.
- f'sky status -u | grep {name}-1 | grep STOPPED',
- f'sky status -u | grep {name}-2 | grep STOPPED',
- ]),
- ]
- if generic_cloud == 'kubernetes':
- # Skip the stop test for Kubernetes, as stopping is not supported.
- stop_test_cmds = []
-
- test = Test(
- 'test_multi_tenant',
- [
- 'echo "==== Test multi-tenant job on single cluster ===="',
- *set_user(user_1, user_1_name, [
- f'sky launch -y -c {name}-1 --cloud {generic_cloud} --cpus 2+ -n job-1 tests/test_yamls/minimal.yaml',
- f's=$(sky queue {name}-1) && echo "$s" && echo "$s" | grep job-1 | grep SUCCEEDED | awk \'{{print $1}}\' | grep 1',
- f's=$(sky queue -u {name}-1) && echo "$s" && echo "$s" | grep {user_1_name} | grep job-1 | grep SUCCEEDED',
- ]),
- *set_user(user_2, user_2_name, [
- f'sky exec {name}-1 -n job-2 \'echo "hello" && exit 1\'',
- f's=$(sky queue {name}-1) && echo "$s" && echo "$s" | grep job-2 | grep FAILED | awk \'{{print $1}}\' | grep 2',
- f's=$(sky queue {name}-1) && echo "$s" && echo "$s" | grep job-1 && exit 1 || true',
- f's=$(sky queue {name}-1 -u) && echo "$s" && echo "$s" | grep {user_2_name} | grep job-2 | grep FAILED',
- f's=$(sky queue {name}-1 -u) && echo "$s" && echo "$s" | grep {user_1_name} | grep job-1 | grep SUCCEEDED',
- ]),
- 'echo "==== Test clusters from different users ===="',
- *set_user(
- user_2,
- user_2_name,
- [
- f'sky launch -y -c {name}-2 --cloud {generic_cloud} --cpus 2+ -n job-3 tests/test_yamls/minimal.yaml',
- f's=$(sky status {name}-2) && echo "$s" && echo "$s" | grep UP',
- # sky status should not show other clusters from other users.
- f's=$(sky status) && echo "$s" && echo "$s" | grep {name}-1 && exit 1 || true',
- # Explicit cluster name should show the cluster.
- f's=$(sky status {name}-1) && echo "$s" && echo "$s" | grep UP',
- f's=$(sky status -u) && echo "$s" && echo "$s" | grep {user_2_name} | grep {name}-2 | grep UP',
- f's=$(sky status -u) && echo "$s" && echo "$s" | grep {user_1_name} | grep {name}-1 | grep UP',
- ]),
- *stop_test_cmds,
- 'echo "==== Test multi-tenant cluster down ===="',
- *set_user(
- user_2,
- user_2_name,
- [
- f'sky down -y -a',
- # STOPPED or UP based on whether we run the stop_test_cmds.
- f'sky status -u | grep {name}-1 | grep "STOPPED\|UP"',
- # Current user's clusters should be down'ed.
- f'sky status -u | grep {name}-2 && exit 1 || true',
- # Explicit cluster name should delete the cluster.
- f'sky down -y {name}-1',
- f'sky status | grep {name}-1 && exit 1 || true',
- ]),
- ],
- f'sky down -y {name}-1 {name}-2',
- )
- run_one_test(test)
-
-
-# See cloud exclusion explanations in test_autostop
-@pytest.mark.no_fluidstack
-@pytest.mark.no_lambda_cloud
-@pytest.mark.no_ibm
-@pytest.mark.no_kubernetes
-def test_launch_fast_with_autostop(generic_cloud: str):
- name = _get_cluster_name()
- # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure
- # the VM is stopped.
- autostop_timeout = 600 if generic_cloud == 'azure' else 250
- test = Test(
- 'test_launch_fast_with_autostop',
- [
- # First launch to create the cluster with a short autostop
- f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}',
- f'sky logs {name} 1 --status',
- f'sky status -r {name} | grep UP',
-
- # Ensure cluster is stopped
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=autostop_timeout),
- # Even the cluster is stopped, cloud platform may take a while to
- # delete the VM.
- f'sleep {_BUMP_UP_SECONDS}',
- # Launch again. Do full output validation - we expect the cluster to re-launch
- f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}',
- f'sky logs {name} 2 --status',
- f'sky status -r {name} | grep UP',
- ],
- f'sky down -y {name}',
- timeout=_get_timeout(generic_cloud) + autostop_timeout,
- )
- run_one_test(test)
-
-
-# ---------- Test region ----------
-@pytest.mark.aws
-def test_aws_region():
- name = _get_cluster_name()
- test = Test(
- 'aws_region',
- [
- f'sky launch -y -c {name} --region us-east-2 examples/minimal.yaml',
- f'sky exec {name} examples/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep us-east-2', # Ensure the region is correct.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .region | grep us-east-2\'',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- # A user program should not access SkyPilot runtime env python by default.
- f'sky exec {name} \'which python | grep {constants.SKY_REMOTE_PYTHON_ENV_NAME} && exit 1 || true\'',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# TODO(zhwu): Fix this test -- the jump-{name} cluster cannot be found by ssh,
-# when API server is hosted remotely, since we don't add jump-{name} to the ssh
-# config on the API server.
-@pytest.mark.aws
-def test_aws_with_ssh_proxy_command():
- name = _get_cluster_name()
- api_server = skypilot_config.get_nested(('api_server', 'endpoint'), None)
- with tempfile.NamedTemporaryFile(mode='w') as f:
- f.write(
- textwrap.dedent(f"""\
- aws:
- ssh_proxy_command: ssh -W %h:%p -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null jump-{name}
- """))
- if api_server is not None:
- f.write(
- textwrap.dedent(f"""\
- api_server:
- endpoint: {api_server}
- """))
- f.flush()
-
- test = Test(
- 'aws_with_ssh_proxy_command',
- [
- f'sky launch -y -c jump-{name} --cloud aws --cpus 2 --region us-east-1',
- # Use jump config
- f'export SKYPILOT_CONFIG={f.name}; '
- f'sky launch -y -c {name} --cloud aws --cpus 2 --region us-east-1 echo hi',
- f'sky logs {name} 1 --status',
- f'export SKYPILOT_CONFIG={f.name}; sky exec {name} echo hi',
- f'sky logs {name} 2 --status',
- # Start a small job to make sure the controller is created.
- f'sky jobs launch -n {name}-0 --cloud aws --cpus 2 --use-spot -y echo hi',
- # Wait other tests to create the job controller first, so that
- # the job controller is not launched with proxy command.
- _get_cmd_wait_until_cluster_status_contains_wildcard(
- cluster_name_wildcard='sky-jobs-controller-*',
- cluster_status=[sky.ClusterStatus.UP],
- timeout=300),
- f'export SKYPILOT_CONFIG={f.name}; sky jobs launch -n {name} --cpus 2 --cloud aws --region us-east-1 -yd echo hi',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[
- sky.ManagedJobStatus.SUCCEEDED,
- sky.ManagedJobStatus.RUNNING,
- sky.ManagedJobStatus.STARTING
- ],
- timeout=300),
- ],
- f'sky down -y {name} jump-{name}; sky jobs cancel -y -n {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_region_and_service_account():
- name = _get_cluster_name()
- test = Test(
- 'gcp_region',
- [
- f'sky launch -y -c {name} --region us-central1 --cloud gcp tests/test_yamls/minimal.yaml',
- f'sky exec {name} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} \'curl -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity?format=standard&audience=gcp"\'',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep us-central1', # Ensure the region is correct.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .region | grep us-central1\'',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- # A user program should not access SkyPilot runtime env python by default.
- f'sky exec {name} \'which python | grep {constants.SKY_REMOTE_PYTHON_ENV_NAME} && exit 1 || true\'',
- f'sky logs {name} 4 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.ibm
-def test_ibm_region():
- name = _get_cluster_name()
- region = 'eu-de'
- test = Test(
- 'region',
- [
- f'sky launch -y -c {name} --cloud ibm --region {region} examples/minimal.yaml',
- f'sky exec {name} --cloud ibm examples/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep {region}', # Ensure the region is correct.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_azure_region():
- name = _get_cluster_name()
- test = Test(
- 'azure_region',
- [
- f'sky launch -y -c {name} --region eastus2 --cloud azure tests/test_yamls/minimal.yaml',
- f'sky exec {name} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep eastus2', # Ensure the region is correct.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .region | grep eastus2\'',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .zone | grep null\'',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- # A user program should not access SkyPilot runtime env python by default.
- f'sky exec {name} \'which python | grep {constants.SKY_REMOTE_PYTHON_ENV_NAME} && exit 1 || true\'',
- f'sky logs {name} 4 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Test zone ----------
-@pytest.mark.aws
-def test_aws_zone():
- name = _get_cluster_name()
- test = Test(
- 'aws_zone',
- [
- f'sky launch -y -c {name} examples/minimal.yaml --zone us-east-2b',
- f'sky exec {name} examples/minimal.yaml --zone us-east-2b',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep us-east-2b', # Ensure the zone is correct.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.ibm
-def test_ibm_zone():
- name = _get_cluster_name()
- zone = 'eu-de-2'
- test = Test(
- 'zone',
- [
- f'sky launch -y -c {name} --cloud ibm examples/minimal.yaml --zone {zone}',
- f'sky exec {name} --cloud ibm examples/minimal.yaml --zone {zone}',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep {zone}', # Ensure the zone is correct.
- ],
- f'sky down -y {name} {name}-2 {name}-3',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_zone():
- name = _get_cluster_name()
- test = Test(
- 'gcp_zone',
- [
- f'sky launch -y -c {name} --zone us-central1-a --cloud gcp tests/test_yamls/minimal.yaml',
- f'sky exec {name} --zone us-central1-a --cloud gcp tests/test_yamls/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status -v | grep {name} | grep us-central1-a', # Ensure the zone is correct.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Test the image ----------
-@pytest.mark.aws
-def test_aws_images():
- name = _get_cluster_name()
- test = Test(
- 'aws_images',
- [
- f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-1804 examples/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky launch -c {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml && exit 1 || true',
- f'sky launch -y -c {name} examples/minimal.yaml',
- f'sky logs {name} 2 --status',
- f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .cloud | grep -i aws\'',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_images():
- name = _get_cluster_name()
- test = Test(
- 'gcp_images',
- [
- f'sky launch -y -c {name} --image-id skypilot:gpu-debian-10 --cloud gcp tests/test_yamls/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky launch -c {name} --image-id skypilot:cpu-debian-10 --cloud gcp tests/test_yamls/minimal.yaml && exit 1 || true',
- f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 2 --status',
- f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .cloud | grep -i gcp\'',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_azure_images():
- name = _get_cluster_name()
- test = Test(
- 'azure_images',
- [
- f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-2204 --cloud azure tests/test_yamls/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky launch -c {name} --image-id skypilot:v1-ubuntu-2004 --cloud azure tests/test_yamls/minimal.yaml && exit 1 || true',
- f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 2 --status',
- f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
- f'sky exec {name} \'echo $SKYPILOT_CLUSTER_INFO | jq .cloud | grep -i azure\'',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_aws_image_id_dict():
- name = _get_cluster_name()
- test = Test(
- 'aws_image_id_dict',
- [
- # Use image id dict.
- f'sky launch -y -c {name} examples/per_region_images.yaml',
- f'sky exec {name} examples/per_region_images.yaml',
- f'sky exec {name} "ls ~"',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_image_id_dict():
- name = _get_cluster_name()
- test = Test(
- 'gcp_image_id_dict',
- [
- # Use image id dict.
- f'sky launch -y -c {name} tests/test_yamls/gcp_per_region_images.yaml',
- f'sky exec {name} tests/test_yamls/gcp_per_region_images.yaml',
- f'sky exec {name} "ls ~"',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_aws_image_id_dict_region():
- name = _get_cluster_name()
- test = Test(
- 'aws_image_id_dict_region',
- [
- # YAML has
- # image_id:
- # us-west-2: skypilot:gpu-ubuntu-1804
- # us-east-2: skypilot:gpu-ubuntu-2004
- # Use region to filter image_id dict.
- f'sky launch -y -c {name} --region us-east-1 examples/per_region_images.yaml && exit 1 || true',
- f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
- f'sky launch -y -c {name} --region us-east-2 examples/per_region_images.yaml',
- # Should success because the image id match for the region.
- f'sky launch -c {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
- f'sky exec {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
- f'sky exec {name} --image-id skypilot:gpu-ubuntu-1804 examples/minimal.yaml && exit 1 || true',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 --status',
- f'sky status -v | grep {name} | grep us-east-2', # Ensure the region is correct.
- # Ensure exec works.
- f'sky exec {name} --region us-east-2 examples/per_region_images.yaml',
- f'sky exec {name} examples/per_region_images.yaml',
- f'sky exec {name} --cloud aws --region us-east-2 "ls ~"',
- f'sky exec {name} "ls ~"',
- f'sky logs {name} 4 --status',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- f'sky logs {name} 7 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_image_id_dict_region():
- name = _get_cluster_name()
- test = Test(
- 'gcp_image_id_dict_region',
- [
- # Use region to filter image_id dict.
- f'sky launch -y -c {name} --region us-east1 tests/test_yamls/gcp_per_region_images.yaml && exit 1 || true',
- f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
- f'sky launch -y -c {name} --region us-west3 tests/test_yamls/gcp_per_region_images.yaml',
- # Should success because the image id match for the region.
- f'sky launch -c {name} --cloud gcp --image-id projects/ubuntu-os-cloud/global/images/ubuntu-1804-bionic-v20230112 tests/test_yamls/minimal.yaml',
- f'sky exec {name} --cloud gcp --image-id projects/ubuntu-os-cloud/global/images/ubuntu-1804-bionic-v20230112 tests/test_yamls/minimal.yaml',
- f'sky exec {name} --cloud gcp --image-id skypilot:cpu-debian-10 tests/test_yamls/minimal.yaml && exit 1 || true',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 --status',
- f'sky status -v | grep {name} | grep us-west3', # Ensure the region is correct.
- # Ensure exec works.
- f'sky exec {name} --region us-west3 tests/test_yamls/gcp_per_region_images.yaml',
- f'sky exec {name} tests/test_yamls/gcp_per_region_images.yaml',
- f'sky exec {name} --cloud gcp --region us-west3 "ls ~"',
- f'sky exec {name} "ls ~"',
- f'sky logs {name} 4 --status',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- f'sky logs {name} 7 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_aws_image_id_dict_zone():
- name = _get_cluster_name()
- test = Test(
- 'aws_image_id_dict_zone',
- [
- # YAML has
- # image_id:
- # us-west-2: skypilot:gpu-ubuntu-1804
- # us-east-2: skypilot:gpu-ubuntu-2004
- # Use zone to filter image_id dict.
- f'sky launch -y -c {name} --zone us-east-1b examples/per_region_images.yaml && exit 1 || true',
- f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
- f'sky launch -y -c {name} --zone us-east-2a examples/per_region_images.yaml',
- # Should success because the image id match for the zone.
- f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
- f'sky exec {name} --image-id skypilot:gpu-ubuntu-2004 examples/minimal.yaml',
- # Fail due to image id mismatch.
- f'sky exec {name} --image-id skypilot:gpu-ubuntu-1804 examples/minimal.yaml && exit 1 || true',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 --status',
- f'sky status -v | grep {name} | grep us-east-2a', # Ensure the zone is correct.
- # Ensure exec works.
- f'sky exec {name} --zone us-east-2a examples/per_region_images.yaml',
- f'sky exec {name} examples/per_region_images.yaml',
- f'sky exec {name} --cloud aws --region us-east-2 "ls ~"',
- f'sky exec {name} "ls ~"',
- f'sky logs {name} 4 --status',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- f'sky logs {name} 7 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_image_id_dict_zone():
- name = _get_cluster_name()
- test = Test(
- 'gcp_image_id_dict_zone',
- [
- # Use zone to filter image_id dict.
- f'sky launch -y -c {name} --zone us-east1-a tests/test_yamls/gcp_per_region_images.yaml && exit 1 || true',
- f'sky status | grep {name} && exit 1 || true', # Ensure the cluster is not created.
- f'sky launch -y -c {name} --zone us-central1-a tests/test_yamls/gcp_per_region_images.yaml',
- # Should success because the image id match for the zone.
- f'sky launch -y -c {name} --cloud gcp --image-id skypilot:cpu-debian-10 tests/test_yamls/minimal.yaml',
- f'sky exec {name} --cloud gcp --image-id skypilot:cpu-debian-10 tests/test_yamls/minimal.yaml',
- # Fail due to image id mismatch.
- f'sky exec {name} --cloud gcp --image-id skypilot:gpu-debian-10 tests/test_yamls/minimal.yaml && exit 1 || true',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 --status',
- f'sky status -v | grep {name} | grep us-central1', # Ensure the zone is correct.
- # Ensure exec works.
- f'sky exec {name} --cloud gcp --zone us-central1-a tests/test_yamls/gcp_per_region_images.yaml',
- f'sky exec {name} tests/test_yamls/gcp_per_region_images.yaml',
- f'sky exec {name} --cloud gcp --region us-central1 "ls ~"',
- f'sky exec {name} "ls ~"',
- f'sky logs {name} 4 --status',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- f'sky logs {name} 7 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_clone_disk_aws():
- name = _get_cluster_name()
- test = Test(
- 'clone_disk_aws',
- [
- f'sky launch -y -c {name} --cloud aws --region us-east-2 --retry-until-up "echo hello > ~/user_file.txt"',
- f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true',
- f'sky stop {name} -y',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=60),
- # Wait for EC2 instance to be in stopped state.
- # TODO: event based wait.
- 'sleep 60',
- f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"',
- f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"',
- f'sky logs {name}-clone 1 --status',
- f'sky logs {name}-clone-2 1 --status',
- ],
- f'sky down -y {name} {name}-clone {name}-clone-2',
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_clone_disk_gcp():
- name = _get_cluster_name()
- test = Test(
- 'clone_disk_gcp',
- [
- f'sky launch -y -c {name} --cloud gcp --zone us-east1-b --retry-until-up "echo hello > ~/user_file.txt"',
- f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true',
- f'sky stop {name} -y',
- f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud gcp --zone us-central1-a "cat ~/user_file.txt | grep hello"',
- f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud gcp --zone us-east1-b "cat ~/user_file.txt | grep hello"',
- f'sky logs {name}-clone 1 --status',
- f'sky logs {name}-clone-2 1 --status',
- ],
- f'sky down -y {name} {name}-clone {name}-clone-2',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_mig():
- name = _get_cluster_name()
- region = 'us-central1'
- test = Test(
- 'gcp_mig',
- [
- f'sky launch -y -c {name} --gpus t4 --num-nodes 2 --image-id skypilot:gpu-debian-10 --cloud gcp --region {region} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 2 --status',
- f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
- # Check MIG exists.
- f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"',
- f'sky autostop -i 0 --down -y {name}',
- _get_cmd_wait_until_cluster_is_not_found(cluster_name=name,
- timeout=120),
- f'gcloud compute instance-templates list | grep "sky-it-{name}"',
- # Launch again with the same region. The original instance template
- # should be removed.
- f'sky launch -y -c {name} --gpus L4 --num-nodes 2 --region {region} nvidia-smi',
- f'sky logs {name} 1 | grep "L4"',
- f'sky down -y {name}',
- f'gcloud compute instance-templates list | grep "sky-it-{name}" && exit 1 || true',
- ],
- f'sky down -y {name}',
- env={'SKYPILOT_CONFIG': 'tests/test_yamls/use_mig_config.yaml'})
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_force_enable_external_ips():
- name = _get_cluster_name()
- test_commands = [
- f'sky launch -y -c {name} --cloud gcp --cpus 2 tests/test_yamls/minimal.yaml',
- # Check network of vm is "default"
- (f'gcloud compute instances list --filter=name~"{name}" --format='
- '"value(networkInterfaces.network)" | grep "networks/default"'),
- # Check External NAT in network access configs, corresponds to external ip
- (f'gcloud compute instances list --filter=name~"{name}" --format='
- '"value(networkInterfaces.accessConfigs[0].name)" | grep "External NAT"'
- ),
- f'sky down -y {name}',
- ]
- skypilot_config = 'tests/test_yamls/force_enable_external_ips_config.yaml'
- test = Test('gcp_force_enable_external_ips',
- test_commands,
- f'sky down -y {name}',
- env={'SKYPILOT_CONFIG': skypilot_config})
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_image_no_conda():
- name = _get_cluster_name()
- test = Test(
- 'image_no_conda',
- [
- # Use image id dict.
- f'sky launch -y -c {name} --region us-east-2 examples/per_region_images.yaml',
- f'sky logs {name} 1 --status',
- f'sky stop {name} -y',
- f'sky start {name} -y',
- f'sky exec {name} examples/per_region_images.yaml',
- f'sky logs {name} 2 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # FluidStack does not support stopping instances in SkyPilot implementation
-@pytest.mark.no_kubernetes # Kubernetes does not support stopping instances
-def test_custom_default_conda_env(generic_cloud: str):
- name = _get_cluster_name()
- test = Test('custom_default_conda_env', [
- f'sky launch -c {name} -y --cloud {generic_cloud} tests/test_yamls/test_custom_default_conda_env.yaml',
- f'sky status -r {name} | grep "UP"',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 1 --no-follow | grep -E "myenv\\s+\\*"',
- f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml',
- f'sky logs {name} 2 --status',
- f'sky autostop -y -i 0 {name}',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=80),
- f'sky start -y {name}',
- f'sky logs {name} 2 --no-follow | grep -E "myenv\\s+\\*"',
- f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml',
- f'sky logs {name} 3 --status',
- ], f'sky down -y {name}')
- run_one_test(test)
-
-
-# ------------ Test stale job ------------
-@pytest.mark.no_fluidstack # FluidStack does not support stopping instances in SkyPilot implementation
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support stopping instances
-@pytest.mark.no_kubernetes # Kubernetes does not support stopping instances
-def test_stale_job(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'stale_job',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} "echo hi"',
- f'sky exec {name} -d "echo start; sleep 10000"',
- f'sky stop {name} -y',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=100),
- f'sky start {name} -y',
- f'sky logs {name} 1 --status',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_aws_stale_job_manual_restart():
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, sky.AWS.max_cluster_name_length())
- region = 'us-east-2'
- test = Test(
- 'aws_stale_job_manual_restart',
- [
- f'sky launch -y -c {name} --cloud aws --region {region} "echo hi"',
- f'sky exec {name} -d "echo start; sleep 10000"',
- # Stop the cluster manually.
- f'id=`aws ec2 describe-instances --region {region} --filters '
- f'Name=tag:ray-cluster-name,Values={name_on_cloud} '
- f'--query Reservations[].Instances[].InstanceId '
- '--output text`; '
- f'aws ec2 stop-instances --region {region} '
- '--instance-ids $id',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=40),
- f'sky launch -c {name} -y "echo hi"',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 3 --status',
- # Ensure the skylet updated the stale job status.
- _get_cmd_wait_until_job_status_contains_without_matching_job(
- cluster_name=name,
- job_status=[sky.JobStatus.FAILED_DRIVER],
- timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS),
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_stale_job_manual_restart():
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, sky.GCP.max_cluster_name_length())
- zone = 'us-west2-a'
- query_cmd = (f'gcloud compute instances list --filter='
- f'"(labels.ray-cluster-name={name_on_cloud})" '
- f'--zones={zone} --format="value(name)"')
- stop_cmd = (f'gcloud compute instances stop --zone={zone}'
- f' --quiet $({query_cmd})')
- test = Test(
- 'gcp_stale_job_manual_restart',
- [
- f'sky launch -y -c {name} --cloud gcp --zone {zone} "echo hi"',
- f'sky exec {name} -d "echo start; sleep 10000"',
- # Stop the cluster manually.
- stop_cmd,
- 'sleep 40',
- f'sky launch -c {name} -y "echo hi"',
- f'sky logs {name} 1 --status',
- f'sky logs {name} 3 --status',
- # Ensure the skylet updated the stale job status.
- _get_cmd_wait_until_job_status_contains_without_matching_job(
- cluster_name=name,
- job_status=[sky.JobStatus.FAILED_DRIVER],
- timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS)
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Check Sky's environment variables; workdir. ----------
-@pytest.mark.no_fluidstack # Requires amazon S3
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-def test_env_check(generic_cloud: str):
- name = _get_cluster_name()
- total_timeout_minutes = 25 if generic_cloud == 'azure' else 15
- test = Test(
- 'env_check',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} --detach-setup examples/env_check.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- # Test --detach-setup with only setup.
- f'sky launch -y -c {name} --detach-setup tests/test_yamls/test_only_setup.yaml',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 2 | grep "hello world"',
- ],
- f'sky down -y {name}',
- timeout=total_timeout_minutes * 60,
- )
- run_one_test(test)
-
-
-# ---------- file_mounts ----------
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet. Run test_scp_file_mounts instead.
-def test_file_mounts(generic_cloud: str):
- name = _get_cluster_name()
- extra_flags = ''
- if generic_cloud in 'kubernetes':
- # Kubernetes does not support multi-node
- # NOTE: This test will fail if you have a Kubernetes cluster running on
- # arm64 (e.g., Apple Silicon) since goofys does not work on arm64.
- extra_flags = '--num-nodes 1'
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- f'sky launch -y -c {name} --cloud {generic_cloud} {extra_flags} examples/using_file_mounts.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ]
- test = Test(
- 'using_file_mounts',
- test_commands,
- f'sky down -y {name}',
- _get_timeout(generic_cloud, 20 * 60), # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.scp
-def test_scp_file_mounts():
- name = _get_cluster_name()
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- f'sky launch -y -c {name} {SCP_TYPE} --num-nodes 1 examples/using_file_mounts.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ]
- test = Test(
- 'SCP_using_file_mounts',
- test_commands,
- f'sky down -y {name}',
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Requires GCP to be enabled
-def test_using_file_mounts_with_env_vars(generic_cloud: str):
- name = _get_cluster_name()
- storage_name = TestStorageWithCredentials.generate_bucket_name()
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- (f'sky launch -y -c {name} --cpus 2+ --cloud {generic_cloud} '
- 'examples/using_file_mounts_with_env_vars.yaml '
- f'--env MY_BUCKET={storage_name}'),
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- # Override with --env:
- (f'sky launch -y -c {name}-2 --cpus 2+ --cloud {generic_cloud} '
- 'examples/using_file_mounts_with_env_vars.yaml '
- f'--env MY_BUCKET={storage_name} '
- '--env MY_LOCAL_PATH=tmpfile'),
- f'sky logs {name}-2 1 --status', # Ensure the job succeeded.
- ]
- test = Test(
- 'using_file_mounts_with_env_vars',
- test_commands,
- (f'sky down -y {name} {name}-2',
- f'sky storage delete -y {storage_name} {storage_name}-2'),
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-# ---------- storage ----------
-
-
-def _storage_mounts_commands_generator(f: TextIO, cluster_name: str,
- storage_name: str, ls_hello_command: str,
- cloud: str, only_mount: bool):
- template_str = pathlib.Path(
- 'tests/test_yamls/test_storage_mounting.yaml.j2').read_text()
- template = jinja2.Template(template_str)
-
- # Set mount flags based on cloud provider
- include_s3_mount = cloud in ['aws', 'kubernetes']
- include_gcs_mount = cloud in ['gcp', 'kubernetes']
- include_azure_mount = cloud == 'azure'
-
- content = template.render(
- storage_name=storage_name,
- cloud=cloud,
- only_mount=only_mount,
- include_s3_mount=include_s3_mount,
- include_gcs_mount=include_gcs_mount,
- include_azure_mount=include_azure_mount,
- )
- f.write(content)
- f.flush()
- file_path = f.name
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- f'sky launch -y -c {cluster_name} --cloud {cloud} {file_path}',
- f'sky logs {cluster_name} 1 --status', # Ensure job succeeded.
- ls_hello_command,
- f'sky stop -y {cluster_name}',
- f'sky start -y {cluster_name}',
- # Check if hello.txt from mounting bucket exists after restart in
- # the mounted directory
- f'sky exec {cluster_name} -- "set -ex; ls /mount_private_mount/hello.txt"',
- ]
- clean_command = f'sky down -y {cluster_name}; sky storage delete -y {storage_name}'
- return test_commands, clean_command
-
-
-@pytest.mark.aws
-def test_aws_storage_mounts_with_stop():
- name = _get_cluster_name()
- cloud = 'aws'
- storage_name = f'sky-test-{int(time.time())}'
- ls_hello_command = f'aws s3 ls {storage_name}/hello.txt'
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test_commands, clean_command = _storage_mounts_commands_generator(
- f, name, storage_name, ls_hello_command, cloud, False)
- test = Test(
- 'aws_storage_mounts',
- test_commands,
- clean_command,
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-def test_aws_storage_mounts_with_stop_only_mount():
- name = _get_cluster_name()
- cloud = 'aws'
- storage_name = f'sky-test-{int(time.time())}'
- ls_hello_command = f'aws s3 ls {storage_name}/hello.txt'
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test_commands, clean_command = _storage_mounts_commands_generator(
- f, name, storage_name, ls_hello_command, cloud, True)
- test = Test(
- 'aws_storage_mounts_only_mount',
- test_commands,
- clean_command,
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_storage_mounts_with_stop():
- name = _get_cluster_name()
- cloud = 'gcp'
- storage_name = f'sky-test-{int(time.time())}'
- ls_hello_command = f'gsutil ls gs://{storage_name}/hello.txt'
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test_commands, clean_command = _storage_mounts_commands_generator(
- f, name, storage_name, ls_hello_command, cloud, False)
- test = Test(
- 'gcp_storage_mounts',
- test_commands,
- clean_command,
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_azure_storage_mounts_with_stop():
- name = _get_cluster_name()
- cloud = 'azure'
- storage_name = f'sky-test-{int(time.time())}'
- default_region = 'eastus'
- storage_account_name = (storage_lib.AzureBlobStore.
- get_default_storage_account_name(default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- # if the file does not exist, az storage blob list returns '[]'
- ls_hello_command = (f'output=$(az storage blob list -c {storage_name} '
- f'--account-name {storage_account_name} '
- f'--account-key {storage_account_key} '
- f'--prefix hello.txt) '
- f'[ "$output" = "[]" ] && exit 1 || exit 0')
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test_commands, clean_command = _storage_mounts_commands_generator(
- f, name, storage_name, ls_hello_command, cloud, False)
- test = Test(
- 'azure_storage_mounts',
- test_commands,
- clean_command,
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-def test_kubernetes_storage_mounts():
- # Tests bucket mounting on k8s, assuming S3 is configured.
- # This test will fail if run on non x86_64 architecture, since goofys is
- # built for x86_64 only.
- name = _get_cluster_name()
- storage_name = f'sky-test-{int(time.time())}'
- ls_hello_command = (f'aws s3 ls {storage_name}/hello.txt || '
- f'gsutil ls gs://{storage_name}/hello.txt')
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test_commands, clean_command = _storage_mounts_commands_generator(
- f, name, storage_name, ls_hello_command, 'kubernetes', False)
- test = Test(
- 'kubernetes_storage_mounts',
- test_commands,
- clean_command,
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-def test_kubernetes_context_switch():
- name = _get_cluster_name()
- new_context = f'sky-test-context-{int(time.time())}'
- new_namespace = f'sky-test-namespace-{int(time.time())}'
-
- test_commands = [
- # Launch a cluster and run a simple task
- f'sky launch -y -c {name} --cloud kubernetes "echo Hello from original context"',
- f'sky logs {name} 1 --status', # Ensure job succeeded
-
- # Get current context details and save to a file for later use in cleanup
- 'CURRENT_CONTEXT=$(kubectl config current-context); '
- 'echo "$CURRENT_CONTEXT" > /tmp/sky_test_current_context; '
- 'CURRENT_CLUSTER=$(kubectl config view -o jsonpath="{.contexts[?(@.name==\\"$CURRENT_CONTEXT\\")].context.cluster}"); '
- 'CURRENT_USER=$(kubectl config view -o jsonpath="{.contexts[?(@.name==\\"$CURRENT_CONTEXT\\")].context.user}"); '
-
- # Create a new context with a different name and namespace
- f'kubectl config set-context {new_context} --cluster="$CURRENT_CLUSTER" --user="$CURRENT_USER" --namespace={new_namespace}',
-
- # Create the new namespace if it doesn't exist
- f'kubectl create namespace {new_namespace} --dry-run=client -o yaml | kubectl apply -f -',
-
- # Set the new context as active
- f'kubectl config use-context {new_context}',
-
- # Verify the new context is active
- f'[ "$(kubectl config current-context)" = "{new_context}" ] || exit 1',
-
- # Try to run sky exec on the original cluster (should still work)
- f'sky exec {name} "echo Success: sky exec works after context switch"',
-
- # Test sky queue
- f'sky queue {name}',
-
- # Test SSH access
- f'ssh {name} whoami',
- ]
-
- cleanup_commands = (
- f'kubectl delete namespace {new_namespace}; '
- f'kubectl config delete-context {new_context}; '
- 'kubectl config use-context $(cat /tmp/sky_test_current_context); '
- 'rm /tmp/sky_test_current_context; '
- f'sky down -y {name}')
-
- test = Test(
- 'kubernetes_context_switch',
- test_commands,
- cleanup_commands,
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-# TODO (zhwu): These tests may fail as it can require access cloud credentials,
-# even though the API server is running remotely. We should fix this.
-@pytest.mark.parametrize(
- 'image_id',
- [
- 'docker:nvidia/cuda:11.8.0-devel-ubuntu18.04',
- 'docker:ubuntu:18.04',
- # Test image with python 3.11 installed by default.
- 'docker:continuumio/miniconda3:24.1.2-0',
- # Test python>=3.12 where SkyPilot should automatically create a separate
- # conda env for runtime with python 3.10.
- 'docker:continuumio/miniconda3:latest',
- ])
-def test_docker_storage_mounts(generic_cloud: str, image_id: str):
- # Tests bucket mounting on docker container
- name = _get_cluster_name()
- timestamp = str(time.time()).replace('.', '')
- storage_name = f'sky-test-{timestamp}'
- template_str = pathlib.Path(
- 'tests/test_yamls/test_storage_mounting.yaml.j2').read_text()
- template = jinja2.Template(template_str)
- # ubuntu 18.04 does not support fuse3, and blobfuse2 depends on fuse3.
- azure_mount_unsupported_ubuntu_version = '18.04'
- # Commands to verify bucket upload. We need to check all three
- # storage types because the optimizer may pick any of them.
- s3_command = f'aws s3 ls {storage_name}/hello.txt'
- gsutil_command = f'gsutil ls gs://{storage_name}/hello.txt'
- azure_blob_command = TestStorageWithCredentials.cli_ls_cmd(
- storage_lib.StoreType.AZURE, storage_name, suffix='hello.txt')
-
- # Set mount flags based on cloud provider
- include_s3_mount = generic_cloud in ['aws', 'kubernetes']
- include_gcs_mount = generic_cloud == 'gcp'
- include_azure_mount = generic_cloud == 'azure'
- include_private_mount = True # Default to True
-
- if azure_mount_unsupported_ubuntu_version in image_id:
- # The store for mount_private_mount is not specified in the template.
- # If we're running on Azure, the private mount will be created on
- # azure blob. That will not be supported on the ubuntu 18.04 image
- # and thus fail. For other clouds, the private mount on other
- # storage types (GCS/S3) should succeed.
- include_private_mount = False if generic_cloud == 'azure' else True
- content = template.render(storage_name=storage_name,
- include_azure_mount=False,
- include_s3_mount=include_s3_mount,
- include_gcs_mount=include_gcs_mount,
- include_private_mount=include_private_mount)
- else:
- content = template.render(storage_name=storage_name,
- include_azure_mount=include_azure_mount,
- include_s3_mount=include_s3_mount,
- include_gcs_mount=include_gcs_mount,
- include_private_mount=include_private_mount)
-
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(content)
- f.flush()
- file_path = f.name
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- f'sky launch -y -c {name} --cloud {generic_cloud} --image-id {image_id} {file_path}',
- f'sky logs {name} 1 --status', # Ensure job succeeded.
- # Check AWS, GCP, or Azure storage mount.
- f'sky exec {name} -- "{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV}; {s3_command} || {gsutil_command} || {azure_blob_command}"',
- f'sky logs {name} 2 --status', # Ensure the bucket check succeeded.
- ]
- test = Test(
- 'docker_storage_mounts',
- test_commands,
- f'sky down -y {name}; sky storage delete -y {storage_name}',
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.cloudflare
-def test_cloudflare_storage_mounts(generic_cloud: str):
- name = _get_cluster_name()
- storage_name = f'sky-test-{int(time.time())}'
- template_str = pathlib.Path(
- 'tests/test_yamls/test_r2_storage_mounting.yaml').read_text()
- template = jinja2.Template(template_str)
- content = template.render(storage_name=storage_name)
- endpoint_url = cloudflare.create_endpoint()
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(content)
- f.flush()
- file_path = f.name
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- f'sky launch -y -c {name} --cloud {generic_cloud} {file_path}',
- f'sky logs {name} 1 --status', # Ensure job succeeded.
- f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls s3://{storage_name}/hello.txt --endpoint {endpoint_url} --profile=r2'
- ]
-
- test = Test(
- 'cloudflare_storage_mounts',
- test_commands,
- f'sky down -y {name}; sky storage delete -y {storage_name}',
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-@pytest.mark.ibm
-def test_ibm_storage_mounts():
- name = _get_cluster_name()
- storage_name = f'sky-test-{int(time.time())}'
- bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name(
- storage_name, Rclone.RcloneClouds.IBM)
- template_str = pathlib.Path(
- 'tests/test_yamls/test_ibm_cos_storage_mounting.yaml').read_text()
- template = jinja2.Template(template_str)
- content = template.render(storage_name=storage_name)
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(content)
- f.flush()
- file_path = f.name
- test_commands = [
- *STORAGE_SETUP_COMMANDS,
- f'sky launch -y -c {name} --cloud ibm {file_path}',
- f'sky logs {name} 1 --status', # Ensure job succeeded.
- f'rclone ls {bucket_rclone_profile}:{storage_name}/hello.txt',
- ]
- test = Test(
- 'ibm_storage_mounts',
- test_commands,
- f'sky down -y {name}; sky storage delete -y {storage_name}',
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-# ---------- CLI logs ----------
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet. Run test_scp_logs instead.
-def test_cli_logs(generic_cloud: str):
- name = _get_cluster_name()
- num_nodes = 2
- if generic_cloud == 'kubernetes':
- # Kubernetes does not support multi-node
- num_nodes = 1
- timestamp = time.time()
- test = Test('cli_logs', [
- f'sky launch -y -c {name} --cloud {generic_cloud} --num-nodes {num_nodes} "echo {timestamp} 1"',
- f'sky exec {name} "echo {timestamp} 2"',
- f'sky exec {name} "echo {timestamp} 3"',
- f'sky exec {name} "echo {timestamp} 4"',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 4 --sync-down',
- f'sky logs {name} * --sync-down',
- f'sky logs {name} 1 | grep "{timestamp} 1"',
- f'sky logs {name} | grep "{timestamp} 4"',
- ], f'sky down -y {name}')
- run_one_test(test)
-
-
-@pytest.mark.scp
-def test_scp_logs():
- name = _get_cluster_name()
- timestamp = time.time()
- test = Test(
- 'SCP_cli_logs',
- [
- f'sky launch -y -c {name} {SCP_TYPE} "echo {timestamp} 1"',
- f'sky exec {name} "echo {timestamp} 2"',
- f'sky exec {name} "echo {timestamp} 3"',
- f'sky exec {name} "echo {timestamp} 4"',
- f'sky logs {name} 2 --status',
- f'sky logs {name} 3 4 --sync-down',
- f'sky logs {name} * --sync-down',
- f'sky logs {name} 1 | grep "{timestamp} 1"',
- f'sky logs {name} | grep "{timestamp} 4"',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Job Queue. ----------
-@pytest.mark.no_fluidstack # FluidStack DC has low availability of T4 GPUs
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not have T4 gpus
-@pytest.mark.no_ibm # IBM Cloud does not have T4 gpus. run test_ibm_job_queue instead
-@pytest.mark.no_scp # SCP does not have T4 gpus. Run test_scp_job_queue instead
-@pytest.mark.no_paperspace # Paperspace does not have T4 gpus.
-@pytest.mark.no_oci # OCI does not have T4 gpus
-def test_job_queue(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'job_queue',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} examples/job_queue/cluster.yaml',
- f'sky exec {name} -n {name}-1 -d examples/job_queue/job.yaml',
- f'sky exec {name} -n {name}-2 -d examples/job_queue/job.yaml',
- f'sky exec {name} -n {name}-3 -d examples/job_queue/job.yaml',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep RUNNING',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep RUNNING',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep PENDING',
- f'sky cancel -y {name} 2',
- 'sleep 5',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep RUNNING',
- f'sky cancel -y {name} 3',
- f'sky exec {name} --gpus T4:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky exec {name} --gpus T4:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky logs {name} 4 --status',
- f'sky logs {name} 5 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Job Queue with Docker. ----------
-@pytest.mark.no_fluidstack # FluidStack does not support docker for now
-@pytest.mark.no_lambda_cloud # Doesn't support Lambda Cloud for now
-@pytest.mark.no_ibm # Doesn't support IBM Cloud for now
-@pytest.mark.no_paperspace # Paperspace doesn't have T4 GPUs
-@pytest.mark.no_scp # Doesn't support SCP for now
-@pytest.mark.no_oci # Doesn't support OCI for now
-@pytest.mark.no_kubernetes # Doesn't support Kubernetes for now
-@pytest.mark.parametrize(
- 'image_id',
- [
- 'docker:nvidia/cuda:11.8.0-devel-ubuntu18.04',
- 'docker:ubuntu:18.04',
- # Test latest image with python 3.11 installed by default.
- 'docker:continuumio/miniconda3:24.1.2-0',
- # Test python>=3.12 where SkyPilot should automatically create a separate
- # conda env for runtime with python 3.10.
- 'docker:continuumio/miniconda3:latest',
- # Axolotl image is a good example custom image that has its conda path
- # set in PATH with dockerfile and uses python>=3.12. It could test:
- # 1. we handle the env var set in dockerfile correctly
- # 2. python>=3.12 works with SkyPilot runtime.
- 'docker:winglian/axolotl:main-latest'
- ])
-def test_job_queue_with_docker(generic_cloud: str, image_id: str):
- name = _get_cluster_name() + image_id[len('docker:'):][:4]
- total_timeout_minutes = 40 if generic_cloud == 'azure' else 15
- time_to_sleep = 300 if generic_cloud == 'azure' else 180
- test = Test(
- 'job_queue_with_docker',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} --image-id {image_id} examples/job_queue/cluster_docker.yaml',
- f'sky exec {name} -n {name}-1 -d --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml',
- f'sky exec {name} -n {name}-2 -d --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml',
- f'sky exec {name} -n {name}-3 -d --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep RUNNING',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep RUNNING',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep PENDING',
- f'sky cancel -y {name} 2',
- 'sleep 5',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep RUNNING',
- f'sky cancel -y {name} 3',
- # Make sure the GPU is still visible to the container.
- f'sky exec {name} --image-id {image_id} nvidia-smi | grep "Tesla T4"',
- f'sky logs {name} 4 --status',
- f'sky stop -y {name}',
- # Make sure the job status preserve after stop and start the
- # cluster. This is also a test for the docker container to be
- # preserved after stop and start.
- f'sky start -y {name}',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep FAILED',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep CANCELLED',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep CANCELLED',
- f'sky exec {name} --gpus T4:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky exec {name} --gpus T4:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- # Make sure it is still visible after an stop & start cycle.
- f'sky exec {name} --image-id {image_id} nvidia-smi | grep "Tesla T4"',
- f'sky logs {name} 7 --status'
- ],
- f'sky down -y {name}',
- timeout=total_timeout_minutes * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.lambda_cloud
-def test_lambda_job_queue():
- name = _get_cluster_name()
- test = Test(
- 'lambda_job_queue',
- [
- f'sky launch -y -c {name} {LAMBDA_TYPE} examples/job_queue/cluster.yaml',
- f'sky exec {name} -n {name}-1 --gpus A10:0.5 -d examples/job_queue/job.yaml',
- f'sky exec {name} -n {name}-2 --gpus A10:0.5 -d examples/job_queue/job.yaml',
- f'sky exec {name} -n {name}-3 --gpus A10:0.5 -d examples/job_queue/job.yaml',
- f'sky queue {name} | grep {name}-1 | grep RUNNING',
- f'sky queue {name} | grep {name}-2 | grep RUNNING',
- f'sky queue {name} | grep {name}-3 | grep PENDING',
- f'sky cancel -y {name} 2',
- 'sleep 5',
- f'sky queue {name} | grep {name}-3 | grep RUNNING',
- f'sky cancel -y {name} 3',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.ibm
-def test_ibm_job_queue():
- name = _get_cluster_name()
- test = Test(
- 'ibm_job_queue',
- [
- f'sky launch -y -c {name} --cloud ibm --gpus v100',
- f'sky exec {name} -n {name}-1 --cloud ibm -d examples/job_queue/job_ibm.yaml',
- f'sky exec {name} -n {name}-2 --cloud ibm -d examples/job_queue/job_ibm.yaml',
- f'sky exec {name} -n {name}-3 --cloud ibm -d examples/job_queue/job_ibm.yaml',
- f'sky queue {name} | grep {name}-1 | grep RUNNING',
- f'sky queue {name} | grep {name}-2 | grep RUNNING',
- f'sky queue {name} | grep {name}-3 | grep PENDING',
- f'sky cancel -y {name} 2',
- 'sleep 5',
- f'sky queue {name} | grep {name}-3 | grep RUNNING',
- f'sky cancel -y {name} 3',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.scp
-def test_scp_job_queue():
- name = _get_cluster_name()
- num_of_gpu_launch = 1
- num_of_gpu_exec = 0.5
- test = Test(
- 'SCP_job_queue',
- [
- f'sky launch -y -c {name} {SCP_TYPE} {SCP_GPU_V100}:{num_of_gpu_launch} examples/job_queue/cluster.yaml',
- f'sky exec {name} -n {name}-1 {SCP_GPU_V100}:{num_of_gpu_exec} -d examples/job_queue/job.yaml',
- f'sky exec {name} -n {name}-2 {SCP_GPU_V100}:{num_of_gpu_exec} -d examples/job_queue/job.yaml',
- f'sky exec {name} -n {name}-3 {SCP_GPU_V100}:{num_of_gpu_exec} -d examples/job_queue/job.yaml',
- f'sky queue {name} | grep {name}-1 | grep RUNNING',
- f'sky queue {name} | grep {name}-2 | grep RUNNING',
- f'sky queue {name} | grep {name}-3 | grep PENDING',
- f'sky cancel -y {name} 2',
- 'sleep 5',
- f'sky queue {name} | grep {name}-3 | grep RUNNING',
- f'sky cancel -y {name} 3',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # FluidStack DC has low availability of T4 GPUs
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not have T4 gpus
-@pytest.mark.no_ibm # IBM Cloud does not have T4 gpus. run test_ibm_job_queue_multinode instead
-@pytest.mark.no_paperspace # Paperspace does not have T4 gpus.
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-@pytest.mark.no_oci # OCI Cloud does not have T4 gpus.
-@pytest.mark.no_kubernetes # Kubernetes not support num_nodes > 1 yet
-def test_job_queue_multinode(generic_cloud: str):
- name = _get_cluster_name()
- total_timeout_minutes = 30 if generic_cloud == 'azure' else 15
- test = Test(
- 'job_queue_multinode',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} examples/job_queue/cluster_multinode.yaml',
- f'sky exec {name} -n {name}-1 -d examples/job_queue/job_multinode.yaml',
- f'sky exec {name} -n {name}-2 -d examples/job_queue/job_multinode.yaml',
- f'sky launch -c {name} -n {name}-3 --detach-setup -d examples/job_queue/job_multinode.yaml',
- f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-1 | grep RUNNING)',
- f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-2 | grep RUNNING)',
- f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-3 | grep PENDING)',
- 'sleep 90',
- f'sky cancel -y {name} 1',
- 'sleep 5',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep SETTING_UP',
- f'sky cancel -y {name} 1 2 3',
- f'sky launch -c {name} -n {name}-4 --detach-setup -d examples/job_queue/job_multinode.yaml',
- # Test the job status is correctly set to SETTING_UP, during the setup is running,
- # and the job can be cancelled during the setup.
- 'sleep 5',
- f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-4 | grep SETTING_UP)',
- f'sky cancel -y {name} 4',
- f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-4 | grep CANCELLED)',
- f'sky exec {name} --gpus T4:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky exec {name} --gpus T4:0.2 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky exec {name} --gpus T4:1 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- f'sky logs {name} 7 --status',
- ],
- f'sky down -y {name}',
- timeout=total_timeout_minutes * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # No FluidStack VM has 8 CPUs
-@pytest.mark.no_lambda_cloud # No Lambda Cloud VM has 8 CPUs
-def test_large_job_queue(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'large_job_queue',
- [
- f'sky launch -y -c {name} --cpus 8 --cloud {generic_cloud}',
- f'for i in `seq 1 75`; do sky exec {name} -n {name}-$i -d "echo $i; sleep 100000000"; done',
- f'sky cancel -y {name} 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16',
- 'sleep 90',
-
- # Each job takes 0.5 CPU and the default VM has 8 CPUs, so there should be 8 / 0.5 = 16 jobs running.
- # The first 16 jobs are canceled, so there should be 75 - 32 = 43 jobs PENDING.
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep PENDING | wc -l | grep 43',
- # Make sure the jobs are scheduled in FIFO order
- *[
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep CANCELLED'
- for i in range(1, 17)
- ],
- *[
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep RUNNING'
- for i in range(17, 33)
- ],
- *[
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep PENDING'
- for i in range(33, 75)
- ],
- f'sky cancel -y {name} 33 35 37 39 17 18 19',
- *[
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep CANCELLED'
- for i in range(33, 40, 2)
- ],
- 'sleep 10',
- *[
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-{i} | grep RUNNING'
- for i in [34, 36, 38]
- ],
- ],
- f'sky down -y {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # No FluidStack VM has 8 CPUs
-@pytest.mark.no_lambda_cloud # No Lambda Cloud VM has 8 CPUs
-def test_fast_large_job_queue(generic_cloud: str):
- # This is to test the jobs can be scheduled quickly when there are many jobs in the queue.
- name = _get_cluster_name()
- test = Test(
- 'fast_large_job_queue',
- [
- f'sky launch -y -c {name} --cpus 8 --cloud {generic_cloud}',
- f'for i in `seq 1 32`; do sky exec {name} -n {name}-$i -d "echo $i"; done',
- 'sleep 60',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep SUCCEEDED | wc -l | grep 32',
- ],
- f'sky down -y {name}',
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.ibm
-def test_ibm_job_queue_multinode():
- name = _get_cluster_name()
- task_file = 'examples/job_queue/job_multinode_ibm.yaml'
- test = Test(
- 'ibm_job_queue_multinode',
- [
- f'sky launch -y -c {name} --cloud ibm --gpus v100 --num-nodes 2',
- f'sky exec {name} -n {name}-1 -d {task_file}',
- f'sky exec {name} -n {name}-2 -d {task_file}',
- f'sky launch -y -c {name} -n {name}-3 --detach-setup -d {task_file}',
- f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-1 | grep RUNNING)',
- f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-2 | grep RUNNING)',
- f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-3 | grep SETTING_UP)',
- 'sleep 90',
- f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-3 | grep PENDING)',
- f'sky cancel -y {name} 1',
- 'sleep 5',
- f'sky queue {name} | grep {name}-3 | grep RUNNING',
- f'sky cancel -y {name} 1 2 3',
- f'sky launch -c {name} -n {name}-4 --detach-setup -d {task_file}',
- # Test the job status is correctly set to SETTING_UP, during the setup is running,
- # and the job can be cancelled during the setup.
- f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-4 | grep SETTING_UP)',
- f'sky cancel -y {name} 4',
- f's=$(sky queue {name}) && printf "$s" && (echo "$s" | grep {name}-4 | grep CANCELLED)',
- f'sky exec {name} --gpus v100:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky exec {name} --gpus v100:0.2 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky exec {name} --gpus v100:1 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
- f'sky logs {name} 5 --status',
- f'sky logs {name} 6 --status',
- f'sky logs {name} 7 --status',
- ],
- f'sky down -y {name}',
- timeout=20 * 60, # 20 mins
- )
- run_one_test(test)
-
-
-# ---------- Docker with preinstalled package. ----------
-@pytest.mark.no_fluidstack # Doesn't support Fluidstack for now
-@pytest.mark.no_lambda_cloud # Doesn't support Lambda Cloud for now
-@pytest.mark.no_ibm # Doesn't support IBM Cloud for now
-@pytest.mark.no_scp # Doesn't support SCP for now
-@pytest.mark.no_oci # Doesn't support OCI for now
-@pytest.mark.no_kubernetes # Doesn't support Kubernetes for now
-# TODO(zhwu): we should fix this for kubernetes
-def test_docker_preinstalled_package(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'docker_with_preinstalled_package',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} --image-id docker:nginx',
- f'sky exec {name} "nginx -V"',
- f'sky logs {name} 1 --status',
- f'sky exec {name} whoami | grep root',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Submitting multiple tasks to the same cluster. ----------
-@pytest.mark.no_fluidstack # FluidStack DC has low availability of T4 GPUs
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not have T4 gpus
-@pytest.mark.no_paperspace # Paperspace does not have T4 gpus
-@pytest.mark.no_ibm # IBM Cloud does not have T4 gpus
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-@pytest.mark.no_oci # OCI Cloud does not have T4 gpus
-def test_multi_echo(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'multi_echo',
- [
- f'python examples/multi_echo.py {name} {generic_cloud}',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
- 'sleep 10',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
- 'sleep 30',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
- 'sleep 30',
- # Make sure that our job scheduler is fast enough to have at least
- # 10 RUNNING jobs in parallel.
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "RUNNING" | wc -l | awk \'{{if ($1 < 20) exit 1}}\'',
- 'sleep 30',
- f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep "FAILED" && exit 1 || true',
- # This is to make sure we can finish job 32 before the test timeout.
- f'until sky logs {name} 32 --status; do echo "Waiting for job 32 to finish..."; sleep 1; done',
- ] +
- # Ensure jobs succeeded.
- [
- _get_cmd_wait_until_job_status_contains_matching_job_id(
- cluster_name=name,
- job_id=i + 1,
- job_status=[sky.JobStatus.SUCCEEDED],
- timeout=120) for i in range(32)
- ] +
- # Ensure monitor/autoscaler didn't crash on the 'assert not
- # unfulfilled' error. If process not found, grep->ssh returns 1.
- [f'ssh {name} \'ps aux | grep "[/]"monitor.py\''],
- f'sky down -y {name}',
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# ---------- Task: 1 node training. ----------
-@pytest.mark.no_fluidstack # Fluidstack does not have T4 gpus for now
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus
-@pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA
-@pytest.mark.no_scp # SCP does not have V100 (16GB) GPUs. Run test_scp_huggingface instead.
-def test_huggingface(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'huggingface_glue_imdb_app',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} examples/huggingface_glue_imdb_app.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} examples/huggingface_glue_imdb_app.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.lambda_cloud
-def test_lambda_huggingface(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'lambda_huggingface_glue_imdb_app',
- [
- f'sky launch -y -c {name} {LAMBDA_TYPE} examples/huggingface_glue_imdb_app.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} {LAMBDA_TYPE} examples/huggingface_glue_imdb_app.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.scp
-def test_scp_huggingface(generic_cloud: str):
- name = _get_cluster_name()
- num_of_gpu_launch = 1
- test = Test(
- 'SCP_huggingface_glue_imdb_app',
- [
- f'sky launch -y -c {name} {SCP_TYPE} {SCP_GPU_V100}:{num_of_gpu_launch} examples/huggingface_glue_imdb_app.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} {SCP_TYPE} {SCP_GPU_V100}:{num_of_gpu_launch} examples/huggingface_glue_imdb_app.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Inferentia. ----------
-@pytest.mark.aws
-def test_inferentia():
- name = _get_cluster_name()
- test = Test(
- 'test_inferentia',
- [
- f'sky launch -y -c {name} -t inf2.xlarge -- echo hi',
- f'sky exec {name} --gpus Inferentia:1 echo hi',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- TPU. ----------
-@pytest.mark.gcp
-@pytest.mark.tpu
-def test_tpu():
- name = _get_cluster_name()
- test = Test(
- 'tpu_app',
- [
- f'sky launch -y -c {name} examples/tpu/tpu_app.yaml',
- f'sky logs {name} 1', # Ensure the job finished.
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky launch -y -c {name} examples/tpu/tpu_app.yaml | grep "TPU .* already exists"', # Ensure sky launch won't create another TPU.
- ],
- f'sky down -y {name}',
- timeout=30 * 60, # can take >20 mins
- )
- run_one_test(test)
-
-
-# ---------- TPU VM. ----------
-@pytest.mark.gcp
-@pytest.mark.tpu
-def test_tpu_vm():
- name = _get_cluster_name()
- test = Test(
- 'tpu_vm_app',
- [
- f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml',
- f'sky logs {name} 1', # Ensure the job finished.
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky stop -y {name}',
- f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', # Ensure the cluster is STOPPED.
- # Use retry: guard against transient errors observed for
- # just-stopped TPU VMs (#962).
- f'sky start --retry-until-up -y {name}',
- f'sky exec {name} examples/tpu/tpuvm_mnist.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- f'sky stop -y {name}',
- ],
- f'sky down -y {name}',
- timeout=30 * 60, # can take 30 mins
- )
- run_one_test(test)
-
-
-# ---------- TPU VM Pod. ----------
-@pytest.mark.gcp
-@pytest.mark.tpu
-def test_tpu_vm_pod():
- name = _get_cluster_name()
- test = Test(
- 'tpu_pod',
- [
- f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --gpus tpu-v2-32 --use-spot --zone europe-west4-a',
- f'sky logs {name} 1', # Ensure the job finished.
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- timeout=30 * 60, # can take 30 mins
- )
- run_one_test(test)
-
-
-# ---------- TPU Pod Slice on GKE. ----------
-@pytest.mark.tpu
-@pytest.mark.kubernetes
-def test_tpu_pod_slice_gke():
- name = _get_cluster_name()
- test = Test(
- 'tpu_pod_slice_gke',
- [
- f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --cloud kubernetes --gpus tpu-v5-lite-podslice',
- f'sky logs {name} 1', # Ensure the job finished.
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} "conda activate flax; python -c \'import jax; print(jax.devices()[0].platform);\' | grep tpu || exit 1;"', # Ensure TPU is reachable.
- f'sky logs {name} 2 --status'
- ],
- f'sky down -y {name}',
- timeout=30 * 60, # can take 30 mins
- )
- run_one_test(test)
-
-
-# ---------- Simple apps. ----------
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-def test_multi_hostname(generic_cloud: str):
- name = _get_cluster_name()
- total_timeout_minutes = 25 if generic_cloud == 'azure' else 15
- test = Test(
- 'multi_hostname',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} examples/multi_hostname.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky logs {name} 1 | grep "My hostname:" | wc -l | grep 2', # Ensure there are 2 hosts.
- f'sky exec {name} examples/multi_hostname.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- timeout=_get_timeout(generic_cloud, total_timeout_minutes * 60),
- )
- run_one_test(test)
-
-
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-def test_multi_node_failure(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'multi_node_failure',
- [
- # TODO(zhwu): we use multi-thread to run the commands in setup
- # commands in parallel, which makes it impossible to fail fast
- # when one of the nodes fails. We should fix this in the future.
- # The --detach-setup version can fail fast, as the setup is
- # submitted to the remote machine, which does not use multi-thread.
- # Refer to the comment in `subprocess_utils.run_in_parallel`.
- # f'sky launch -y -c {name} --cloud {generic_cloud} tests/test_yamls/failed_worker_setup.yaml && exit 1', # Ensure the job setup failed.
- f'sky launch -y -c {name} --cloud {generic_cloud} --detach-setup tests/test_yamls/failed_worker_setup.yaml',
- f'sky logs {name} 1 --status | grep FAILED_SETUP', # Ensure the job setup failed.
- f'sky exec {name} tests/test_yamls/failed_worker_run.yaml',
- f'sky logs {name} 2 --status | grep FAILED', # Ensure the job failed.
- f'sky logs {name} 2 | grep "My hostname:" | wc -l | grep 2', # Ensure there 2 of the hosts printed their hostname.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Web apps with custom ports on GCP. ----------
-@pytest.mark.gcp
-def test_gcp_http_server_with_custom_ports():
- name = _get_cluster_name()
- test = Test(
- 'gcp_http_server_with_custom_ports',
- [
- f'sky launch -y -d -c {name} --cloud gcp examples/http_server_with_custom_ports/task.yaml',
- f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
- # Retry a few times to avoid flakiness in ports being open.
- f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Web apps with custom ports on AWS. ----------
-@pytest.mark.aws
-def test_aws_http_server_with_custom_ports():
- name = _get_cluster_name()
- test = Test(
- 'aws_http_server_with_custom_ports',
- [
- f'sky launch -y -d -c {name} --cloud aws examples/http_server_with_custom_ports/task.yaml',
- f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
- # Retry a few times to avoid flakiness in ports being open.
- f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi'
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Web apps with custom ports on Azure. ----------
-@pytest.mark.azure
-def test_azure_http_server_with_custom_ports():
- name = _get_cluster_name()
- test = Test(
- 'azure_http_server_with_custom_ports',
- [
- f'sky launch -y -d -c {name} --cloud azure examples/http_server_with_custom_ports/task.yaml',
- f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
- # Retry a few times to avoid flakiness in ports being open.
- f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi'
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Web apps with custom ports on Kubernetes. ----------
-@pytest.mark.kubernetes
-def test_kubernetes_http_server_with_custom_ports():
- name = _get_cluster_name()
- test = Test(
- 'kubernetes_http_server_with_custom_ports',
- [
- f'sky launch -y -d -c {name} --cloud kubernetes examples/http_server_with_custom_ports/task.yaml',
- f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
- # Retry a few times to avoid flakiness in ports being open.
- f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 100); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 5; done; if [ "$success" = false ]; then exit 1; fi'
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Web apps with custom ports on Paperspace. ----------
-@pytest.mark.paperspace
-def test_paperspace_http_server_with_custom_ports():
- name = _get_cluster_name()
- test = Test(
- 'paperspace_http_server_with_custom_ports',
- [
- f'sky launch -y -d -c {name} --cloud paperspace examples/http_server_with_custom_ports/task.yaml',
- f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
- # Retry a few times to avoid flakiness in ports being open.
- f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Web apps with custom ports on RunPod. ----------
-@pytest.mark.runpod
-def test_runpod_http_server_with_custom_ports():
- name = _get_cluster_name()
- test = Test(
- 'runpod_http_server_with_custom_ports',
- [
- f'sky launch -y -d -c {name} --cloud runpod examples/http_server_with_custom_ports/task.yaml',
- f'until SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}; do sleep 10; done',
- # Retry a few times to avoid flakiness in ports being open.
- f'ip=$(SKYPILOT_DEBUG=0 sky status --endpoint 33828 {name}); success=false; for i in $(seq 1 5); do if curl $ip | grep "This is a demo HTML page. "; then success=true; break; fi; sleep 10; done; if [ "$success" = false ]; then exit 1; fi',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Labels from task on AWS (instance_tags) ----------
-@pytest.mark.aws
-def test_task_labels_aws():
- name = _get_cluster_name()
- template_str = pathlib.Path(
- 'tests/test_yamls/test_labels.yaml.j2').read_text()
- template = jinja2.Template(template_str)
- content = template.render(cloud='aws', region='us-east-1')
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(content)
- f.flush()
- file_path = f.name
- test = Test(
- 'task_labels_aws',
- [
- f'sky launch -y -c {name} {file_path}',
- # Verify with aws cli that the tags are set.
- 'aws ec2 describe-instances '
- '--query "Reservations[*].Instances[*].InstanceId" '
- '--filters "Name=instance-state-name,Values=running" '
- f'--filters "Name=tag:skypilot-cluster-name,Values={name}*" '
- '--filters "Name=tag:inlinelabel1,Values=inlinevalue1" '
- '--filters "Name=tag:inlinelabel2,Values=inlinevalue2" '
- '--region us-east-1 --output text',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Labels from task on GCP (labels) ----------
-@pytest.mark.gcp
-def test_task_labels_gcp():
- name = _get_cluster_name()
- template_str = pathlib.Path(
- 'tests/test_yamls/test_labels.yaml.j2').read_text()
- template = jinja2.Template(template_str)
- content = template.render(cloud='gcp')
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(content)
- f.flush()
- file_path = f.name
- test = Test(
- 'task_labels_gcp',
- [
- f'sky launch -y -c {name} {file_path}',
- # Verify with gcloud cli that the tags are set
- f'gcloud compute instances list --filter="name~\'^{name}\' AND '
- 'labels.inlinelabel1=\'inlinevalue1\' AND '
- 'labels.inlinelabel2=\'inlinevalue2\'" '
- '--format="value(name)" | grep .',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Labels from task on Kubernetes (labels) ----------
-@pytest.mark.kubernetes
-def test_task_labels_kubernetes():
- name = _get_cluster_name()
- template_str = pathlib.Path(
- 'tests/test_yamls/test_labels.yaml.j2').read_text()
- template = jinja2.Template(template_str)
- content = template.render(cloud='kubernetes')
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(content)
- f.flush()
- file_path = f.name
- test = Test(
- 'task_labels_kubernetes',
- [
- f'sky launch -y -c {name} {file_path}',
- # Verify with kubectl that the labels are set.
- 'kubectl get pods '
- '--selector inlinelabel1=inlinevalue1 '
- '--selector inlinelabel2=inlinevalue2 '
- '-o jsonpath=\'{.items[*].metadata.name}\' | '
- f'grep \'^{name}\''
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Pod Annotations on Kubernetes ----------
-@pytest.mark.kubernetes
-def test_add_pod_annotations_for_autodown_with_launch():
- name = _get_cluster_name()
- test = Test(
- 'add_pod_annotations_for_autodown_with_launch',
- [
- # Launch Kubernetes cluster with two nodes, each being head node and worker node.
- # Autodown is set.
- f'sky launch -y -c {name} -i 10 --down --num-nodes 2 --cpus=1 --cloud kubernetes',
- # Get names of the pods containing cluster name.
- f'pod_1=$(kubectl get pods -o name | grep {name} | sed -n 1p)',
- f'pod_2=$(kubectl get pods -o name | grep {name} | sed -n 2p)',
- # Describe the first pod and check for annotations.
- 'kubectl describe pod $pod_1 | grep -q skypilot.co/autodown',
- 'kubectl describe pod $pod_1 | grep -q skypilot.co/idle_minutes_to_autostop',
- # Describe the second pod and check for annotations.
- 'kubectl describe pod $pod_2 | grep -q skypilot.co/autodown',
- 'kubectl describe pod $pod_2 | grep -q skypilot.co/idle_minutes_to_autostop'
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-def test_add_and_remove_pod_annotations_with_autostop():
- name = _get_cluster_name()
- test = Test(
- 'add_and_remove_pod_annotations_with_autostop',
- [
- # Launch Kubernetes cluster with two nodes, each being head node and worker node.
- f'sky launch -y -c {name} --num-nodes 2 --cpus=1 --cloud kubernetes',
- # Set autodown on the cluster with 'autostop' command.
- f'sky autostop -y {name} -i 20 --down',
- # Get names of the pods containing cluster name.
- f'pod_1=$(kubectl get pods -o name | grep {name} | sed -n 1p)',
- f'pod_2=$(kubectl get pods -o name | grep {name} | sed -n 2p)',
- # Describe the first pod and check for annotations.
- 'kubectl describe pod $pod_1 | grep -q skypilot.co/autodown',
- 'kubectl describe pod $pod_1 | grep -q skypilot.co/idle_minutes_to_autostop',
- # Describe the second pod and check for annotations.
- 'kubectl describe pod $pod_2 | grep -q skypilot.co/autodown',
- 'kubectl describe pod $pod_2 | grep -q skypilot.co/idle_minutes_to_autostop',
- # Cancel the set autodown to remove the annotations from the pods.
- f'sky autostop -y {name} --cancel',
- # Describe the first pod and check if annotations are removed.
- '! kubectl describe pod $pod_1 | grep -q skypilot.co/autodown',
- '! kubectl describe pod $pod_1 | grep -q skypilot.co/idle_minutes_to_autostop',
- # Describe the second pod and check if annotations are removed.
- '! kubectl describe pod $pod_2 | grep -q skypilot.co/autodown',
- '! kubectl describe pod $pod_2 | grep -q skypilot.co/idle_minutes_to_autostop',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Container logs from task on Kubernetes ----------
-@pytest.mark.kubernetes
-def test_container_logs_multinode_kubernetes():
- name = _get_cluster_name()
- task_yaml = 'tests/test_yamls/test_k8s_logs.yaml'
- head_logs = ('kubectl get pods '
- f' | grep {name} | grep head | '
- " awk '{print $1}' | xargs -I {} kubectl logs {}")
- worker_logs = ('kubectl get pods '
- f' | grep {name} | grep worker |'
- " awk '{print $1}' | xargs -I {} kubectl logs {}")
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test = Test(
- 'container_logs_multinode_kubernetes',
- [
- f'sky launch -y -c {name} {task_yaml} --num-nodes 2',
- f'{head_logs} | wc -l | grep 9',
- f'{worker_logs} | wc -l | grep 9',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-def test_container_logs_two_jobs_kubernetes():
- name = _get_cluster_name()
- task_yaml = 'tests/test_yamls/test_k8s_logs.yaml'
- pod_logs = ('kubectl get pods '
- f' | grep {name} | grep head |'
- " awk '{print $1}' | xargs -I {} kubectl logs {}")
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test = Test(
- 'test_container_logs_two_jobs_kubernetes',
- [
- f'sky launch -y -c {name} {task_yaml}',
- f'{pod_logs} | wc -l | grep 9',
- f'sky launch -y -c {name} {task_yaml}',
- f'{pod_logs} | wc -l | grep 18',
- f'{pod_logs} | grep 1 | wc -l | grep 2',
- f'{pod_logs} | grep 2 | wc -l | grep 2',
- f'{pod_logs} | grep 3 | wc -l | grep 2',
- f'{pod_logs} | grep 4 | wc -l | grep 2',
- f'{pod_logs} | grep 5 | wc -l | grep 2',
- f'{pod_logs} | grep 6 | wc -l | grep 2',
- f'{pod_logs} | grep 7 | wc -l | grep 2',
- f'{pod_logs} | grep 8 | wc -l | grep 2',
- f'{pod_logs} | grep 9 | wc -l | grep 2',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-def test_container_logs_two_simultaneous_jobs_kubernetes():
- name = _get_cluster_name()
- task_yaml = 'tests/test_yamls/test_k8s_logs.yaml '
- pod_logs = ('kubectl get pods '
- f' | grep {name} | grep head |'
- " awk '{print $1}' | xargs -I {} kubectl logs {}")
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- test = Test(
- 'test_container_logs_two_simultaneous_jobs_kubernetes',
- [
- f'sky launch -y -c {name}',
- f'sky exec -c {name} -d {task_yaml}',
- f'sky exec -c {name} -d {task_yaml}',
- 'sleep 30',
- f'{pod_logs} | wc -l | grep 18',
- f'{pod_logs} | grep 1 | wc -l | grep 2',
- f'{pod_logs} | grep 2 | wc -l | grep 2',
- f'{pod_logs} | grep 3 | wc -l | grep 2',
- f'{pod_logs} | grep 4 | wc -l | grep 2',
- f'{pod_logs} | grep 5 | wc -l | grep 2',
- f'{pod_logs} | grep 6 | wc -l | grep 2',
- f'{pod_logs} | grep 7 | wc -l | grep 2',
- f'{pod_logs} | grep 8 | wc -l | grep 2',
- f'{pod_logs} | grep 9 | wc -l | grep 2',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Task: n=2 nodes with setups. ----------
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus
-@pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-@pytest.mark.skip(
- reason=
- 'The resnet_distributed_tf_app is flaky, due to it failing to detect GPUs.')
-def test_distributed_tf(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'resnet_distributed_tf_app',
- [
- # NOTE: running it twice will hang (sometimes?) - an app-level bug.
- f'python examples/resnet_distributed_tf_app.py {name} {generic_cloud}',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- timeout=25 * 60, # 25 mins (it takes around ~19 mins)
- )
- run_one_test(test)
-
-
-# ---------- Testing GCP start and stop instances ----------
-@pytest.mark.gcp
-def test_gcp_start_stop():
- name = _get_cluster_name()
- test = Test(
- 'gcp-start-stop',
- [
- f'sky launch -y -c {name} examples/gcp_start_stop.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} examples/gcp_start_stop.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit.
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- f'sky stop -y {name}',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=40),
- f'sky start -y {name} -i 1',
- f'sky exec {name} examples/gcp_start_stop.yaml',
- f'sky logs {name} 4 --status', # Ensure the job succeeded.
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[
- sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT
- ],
- timeout=200),
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Testing Azure start and stop instances ----------
-@pytest.mark.azure
-def test_azure_start_stop():
- name = _get_cluster_name()
- test = Test(
- 'azure-start-stop',
- [
- f'sky launch -y -c {name} examples/azure_start_stop.yaml',
- f'sky exec {name} examples/azure_start_stop.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit.
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- f'sky stop -y {name}',
- f'sky start -y {name} -i 1',
- f'sky exec {name} examples/azure_start_stop.yaml',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[
- sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT
- ],
- timeout=280) +
- f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}',
- ],
- f'sky down -y {name}',
- timeout=30 * 60, # 30 mins
- )
- run_one_test(test)
-
-
-# ---------- Testing Autostopping ----------
-@pytest.mark.no_fluidstack # FluidStack does not support stopping in SkyPilot implementation
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support stopping instances
-@pytest.mark.no_ibm # FIX(IBM) sporadically fails, as restarted workers stay uninitialized indefinitely
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-@pytest.mark.no_kubernetes # Kubernetes does not autostop yet
-def test_autostop(generic_cloud: str):
- name = _get_cluster_name()
- # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure
- # the VM is stopped.
- autostop_timeout = 600 if generic_cloud == 'azure' else 250
- # Launching and starting Azure clusters can take a long time too. e.g., restart
- # a stopped Azure cluster can take 7m. So we set the total timeout to 70m.
- total_timeout_minutes = 70 if generic_cloud == 'azure' else 20
- test = Test(
- 'autostop',
- [
- f'sky launch -y -d -c {name} --num-nodes 2 --cloud {generic_cloud} tests/test_yamls/minimal.yaml',
- f'sky autostop -y {name} -i 1',
-
- # Ensure autostop is set.
- f'sky status | grep {name} | grep "1m"',
-
- # Ensure the cluster is not stopped early.
- 'sleep 40',
- f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP',
-
- # Ensure the cluster is STOPPED.
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=autostop_timeout),
-
- # Ensure the cluster is UP and the autostop setting is reset ('-').
- f'sky start -y {name}',
- f'sky status | grep {name} | grep -E "UP\s+-"',
-
- # Ensure the job succeeded.
- f'sky exec {name} tests/test_yamls/minimal.yaml',
- f'sky logs {name} 2 --status',
-
- # Test restarting the idleness timer via reset:
- f'sky autostop -y {name} -i 1', # Idleness starts counting.
- 'sleep 40', # Almost reached the threshold.
- f'sky autostop -y {name} -i 1', # Should restart the timer.
- 'sleep 40',
- f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=autostop_timeout),
-
- # Test restarting the idleness timer via exec:
- f'sky start -y {name}',
- f'sky status | grep {name} | grep -E "UP\s+-"',
- f'sky autostop -y {name} -i 1', # Idleness starts counting.
- 'sleep 45', # Almost reached the threshold.
- f'sky exec {name} echo hi', # Should restart the timer.
- 'sleep 45',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=autostop_timeout + _BUMP_UP_SECONDS),
- ],
- f'sky down -y {name}',
- timeout=total_timeout_minutes * 60,
- )
- run_one_test(test)
-
-
-# ---------- Testing Autodowning ----------
-@pytest.mark.no_fluidstack # FluidStack does not support stopping in SkyPilot implementation
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet. Run test_scp_autodown instead.
-def test_autodown(generic_cloud: str):
- name = _get_cluster_name()
- # Azure takes ~ 13m30s (810s) to autodown a VM, so here we use 900 to ensure
- # the VM is terminated.
- autodown_timeout = 900 if generic_cloud == 'azure' else 240
- total_timeout_minutes = 90 if generic_cloud == 'azure' else 20
- test = Test(
- 'autodown',
- [
- f'sky launch -y -d -c {name} --num-nodes 2 --cloud {generic_cloud} tests/test_yamls/minimal.yaml',
- f'sky autostop -y {name} --down -i 1',
- # Ensure autostop is set.
- f'sky status | grep {name} | grep "1m (down)"',
- # Ensure the cluster is not terminated early.
- 'sleep 40',
- f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP',
- # Ensure the cluster is terminated.
- f'sleep {autodown_timeout}',
- f's=$(SKYPILOT_DEBUG=0 sky status {name} --refresh) && echo "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
- f'sky launch -y -d -c {name} --cloud {generic_cloud} --num-nodes 2 --down tests/test_yamls/minimal.yaml',
- f'sky status | grep {name} | grep UP', # Ensure the cluster is UP.
- f'sky exec {name} --cloud {generic_cloud} tests/test_yamls/minimal.yaml',
- f'sky status | grep {name} | grep "1m (down)"',
- f'sleep {autodown_timeout}',
- # Ensure the cluster is terminated.
- f's=$(SKYPILOT_DEBUG=0 sky status {name} --refresh) && echo "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
- f'sky launch -y -d -c {name} --cloud {generic_cloud} --num-nodes 2 --down tests/test_yamls/minimal.yaml',
- f'sky autostop -y {name} --cancel',
- f'sleep {autodown_timeout}',
- # Ensure the cluster is still UP.
- f's=$(SKYPILOT_DEBUG=0 sky status {name} --refresh) && echo "$s" && echo "$s" | grep {name} | grep UP',
- ],
- f'sky down -y {name}',
- timeout=total_timeout_minutes * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.scp
-def test_scp_autodown():
- name = _get_cluster_name()
- test = Test(
- 'SCP_autodown',
- [
- f'sky launch -y -d -c {name} {SCP_TYPE} tests/test_yamls/minimal.yaml',
- f'sky autostop -y {name} --down -i 1',
- # Ensure autostop is set.
- f'sky status | grep {name} | grep "1m (down)"',
- # Ensure the cluster is not terminated early.
- 'sleep 45',
- f'sky status --refresh | grep {name} | grep UP',
- # Ensure the cluster is terminated.
- 'sleep 200',
- f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
- f'sky launch -y -d -c {name} {SCP_TYPE} --down tests/test_yamls/minimal.yaml',
- f'sky status | grep {name} | grep UP', # Ensure the cluster is UP.
- f'sky exec {name} {SCP_TYPE} tests/test_yamls/minimal.yaml',
- f'sky status | grep {name} | grep "1m (down)"',
- 'sleep 200',
- # Ensure the cluster is terminated.
- f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',
- f'sky launch -y -d -c {name} {SCP_TYPE} --down tests/test_yamls/minimal.yaml',
- f'sky autostop -y {name} --cancel',
- 'sleep 200',
- # Ensure the cluster is still UP.
- f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && echo "$s" | grep {name} | grep UP',
- ],
- f'sky down -y {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-def _get_cancel_task_with_cloud(name, cloud, timeout=15 * 60):
- test = Test(
- f'{cloud}-cancel-task',
- [
- f'sky launch -c {name} examples/resnet_app.yaml --cloud {cloud} -y -d',
- # Wait the job to be scheduled and finished setup.
- f'until sky queue {name} | grep "RUNNING"; do sleep 10; done',
- # Wait the setup and initialize before the GPU process starts.
- 'sleep 120',
- f'sky exec {name} "nvidia-smi | grep python"',
- f'sky logs {name} 2 --status || {{ sky logs {name} --no-follow 1 && exit 1; }}', # Ensure the job succeeded.
- f'sky cancel -y {name} 1',
- 'sleep 60',
- # check if the python job is gone.
- f'sky exec {name} "! nvidia-smi | grep python"',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- timeout=timeout,
- )
- return test
-
-
-# ---------- Testing `sky cancel` ----------
-@pytest.mark.aws
-def test_cancel_aws():
- name = _get_cluster_name()
- test = _get_cancel_task_with_cloud(name, 'aws')
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_cancel_gcp():
- name = _get_cluster_name()
- test = _get_cancel_task_with_cloud(name, 'gcp')
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_cancel_azure():
- name = _get_cluster_name()
- test = _get_cancel_task_with_cloud(name, 'azure', timeout=30 * 60)
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack does not support V100 gpus for now
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus
-@pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA
-@pytest.mark.no_paperspace # Paperspace has `gnome-shell` on nvidia-smi
-@pytest.mark.no_scp # SCP does not support num_nodes > 1 yet
-def test_cancel_pytorch(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'cancel-pytorch',
- [
- f'sky launch -c {name} --cloud {generic_cloud} examples/resnet_distributed_torch.yaml -y -d',
- # Wait the GPU process to start.
- 'sleep 90',
- f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep python) || '
- # When run inside container/k8s, nvidia-smi cannot show process ids.
- # See https://github.com/NVIDIA/nvidia-docker/issues/179
- # To work around, we check if GPU utilization is greater than 0.
- f'[ \$(nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits) -gt 0 ]"',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- f'sky cancel -y {name} 1',
- 'sleep 60',
- f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep \'No running process\') || '
- # Ensure Xorg is the only process running.
- '[ \$(nvidia-smi | grep -A 10 Processes | grep -A 10 === | grep -v Xorg) -eq 2 ]"',
- f'sky logs {name} 3 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# can't use `_get_cancel_task_with_cloud()`, as command `nvidia-smi`
-# requires a CUDA public image, which IBM doesn't offer
-@pytest.mark.ibm
-def test_cancel_ibm():
- name = _get_cluster_name()
- test = Test(
- 'ibm-cancel-task',
- [
- f'sky launch -y -c {name} --cloud ibm examples/minimal.yaml',
- f'sky exec {name} -n {name}-1 -d "while true; do echo \'Hello SkyPilot\'; sleep 2; done"',
- 'sleep 20',
- f'sky queue {name} | grep {name}-1 | grep RUNNING',
- f'sky cancel -y {name} 2',
- f'sleep 5',
- f'sky queue {name} | grep {name}-1 | grep CANCELLED',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Testing use-spot option ----------
-@pytest.mark.no_fluidstack # FluidStack does not support spot instances
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
-@pytest.mark.no_paperspace # Paperspace does not support spot instances
-@pytest.mark.no_ibm # IBM Cloud does not support spot instances
-@pytest.mark.no_scp # SCP does not support spot instances
-@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
-def test_use_spot(generic_cloud: str):
- """Test use-spot and sky exec."""
- name = _get_cluster_name()
- test = Test(
- 'use-spot',
- [
- f'sky launch -c {name} --cloud {generic_cloud} tests/test_yamls/minimal.yaml --use-spot -y',
- f'sky logs {name} 1 --status',
- f'sky exec {name} echo hi',
- f'sky logs {name} 2 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_stop_gcp_spot():
- """Test GCP spot can be stopped, autostopped, restarted."""
- name = _get_cluster_name()
- test = Test(
- 'stop_gcp_spot',
- [
- f'sky launch -c {name} --cloud gcp --use-spot --cpus 2+ -y -- touch myfile',
- # stop should go through:
- f'sky stop {name} -y',
- f'sky start {name} -y',
- f'sky exec {name} -- ls myfile',
- f'sky logs {name} 2 --status',
- f'sky autostop {name} -i0 -y',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=90),
- f'sky start {name} -y',
- f'sky exec {name} -- ls myfile',
- f'sky logs {name} 3 --status',
- # -i option at launch should go through:
- f'sky launch -c {name} -i0 -y',
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=120),
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Testing managed job ----------
-# TODO(zhwu): make the jobs controller on GCP, to avoid parallel test issues
-# when the controller being on Azure, which takes a long time for launching
-# step.
-@pytest.mark.managed_jobs
-def test_managed_jobs(generic_cloud: str):
- """Test the managed jobs yaml."""
- name = _get_cluster_name()
- test = Test(
- 'managed-jobs',
- [
- f'sky jobs launch -n {name}-1 --cloud {generic_cloud} examples/managed_job.yaml -y -d',
- f'sky jobs launch -n {name}-2 --cloud {generic_cloud} examples/managed_job.yaml -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-1',
- job_status=[
- sky.ManagedJobStatus.PENDING,
- sky.ManagedJobStatus.SUBMITTED,
- sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING
- ],
- timeout=60),
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-2',
- job_status=[
- sky.ManagedJobStatus.PENDING,
- sky.ManagedJobStatus.SUBMITTED,
- sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING
- ],
- timeout=60),
- f'sky jobs cancel -y -n {name}-1',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-1',
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=230),
- # Test the functionality for logging.
- f's=$(sky jobs logs -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "start counting"',
- f's=$(sky jobs logs --controller -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "Cluster launched:"',
- f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "RUNNING\|SUCCEEDED"',
- ],
- # TODO(zhwu): Change to f'sky jobs cancel -y -n {name}-1 -n {name}-2' when
- # canceling multiple job names is supported.
- f'sky jobs cancel -y -n {name}-1; sky jobs cancel -y -n {name}-2',
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack #fluidstack does not support spot instances
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
-@pytest.mark.no_ibm # IBM Cloud does not support spot instances
-@pytest.mark.no_scp # SCP does not support spot instances
-@pytest.mark.no_paperspace # Paperspace does not support spot instances
-@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
-@pytest.mark.managed_jobs
-def test_job_pipeline(generic_cloud: str):
- """Test a job pipeline."""
- name = _get_cluster_name()
- test = Test(
- 'job_pipeline',
- [
- f'sky jobs launch -n {name} tests/test_yamls/pipeline.yaml -y -d',
- # Need to wait for setup and job initialization.
- 'sleep 30',
- f'{_GET_JOB_QUEUE}| grep {name} | head -n1 | grep "STARTING\|RUNNING"',
- # `grep -A 4 {name}` finds the job with {name} and the 4 lines
- # after it, i.e. the 4 tasks within the job.
- # `sed -n 2p` gets the second line of the 4 lines, i.e. the first
- # task within the job.
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "STARTING\|RUNNING"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "PENDING"',
- f'sky jobs cancel -y -n {name}',
- 'sleep 5',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "CANCELLING\|CANCELLED"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "CANCELLING\|CANCELLED"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLING\|CANCELLED"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLING\|CANCELLED"',
- 'sleep 200',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "CANCELLED"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "CANCELLED"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"',
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"',
- ],
- f'sky jobs cancel -y -n {name}',
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack #fluidstack does not support spot instances
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
-@pytest.mark.no_ibm # IBM Cloud does not support spot instances
-@pytest.mark.no_scp # SCP does not support spot instances
-@pytest.mark.no_paperspace # Paperspace does not support spot instances
-@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
-@pytest.mark.managed_jobs
-def test_managed_jobs_failed_setup(generic_cloud: str):
- """Test managed job with failed setup."""
- name = _get_cluster_name()
- test = Test(
- 'managed_jobs_failed_setup',
- [
- f'sky jobs launch -n {name} --cloud {generic_cloud} -y -d tests/test_yamls/failed_setup.yaml',
- # Make sure the job failed quickly.
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.FAILED_SETUP],
- timeout=330 + _BUMP_UP_SECONDS),
- ],
- f'sky jobs cancel -y -n {name}',
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack #fluidstack does not support spot instances
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
-@pytest.mark.no_ibm # IBM Cloud does not support spot instances
-@pytest.mark.no_scp # SCP does not support spot instances
-@pytest.mark.no_paperspace # Paperspace does not support spot instances
-@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
-@pytest.mark.managed_jobs
-def test_managed_jobs_pipeline_failed_setup(generic_cloud: str):
- """Test managed job with failed setup for a pipeline."""
- name = _get_cluster_name()
- test = Test(
- 'managed_jobs_pipeline_failed_setup',
- [
- f'sky jobs launch -n {name} -y -d tests/test_yamls/failed_setup_pipeline.yaml',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.FAILED_SETUP],
- timeout=600),
- # Make sure the job failed quickly.
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"',
- # Task 0 should be SUCCEEDED.
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "SUCCEEDED"',
- # Task 1 should be FAILED_SETUP.
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "FAILED_SETUP"',
- # Task 2 should be CANCELLED.
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"',
- # Task 3 should be CANCELLED.
- f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"',
- ],
- f'sky jobs cancel -y -n {name}',
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-# ---------- Testing managed job recovery ----------
-
-
-@pytest.mark.aws
-@pytest.mark.managed_jobs
-def test_managed_jobs_recovery_aws(aws_config_region):
- """Test managed job recovery."""
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- region = aws_config_region
- test = Test(
- 'managed_jobs_recovery_aws',
- [
- f'sky jobs launch --cloud aws --region {region} --use-spot -n {name} "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=600),
- f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
- # Terminate the cluster manually.
- (f'aws ec2 terminate-instances --region {region} --instance-ids $('
- f'aws ec2 describe-instances --region {region} '
- f'--filters "Name=tag:ray-cluster-name,Values={name_on_cloud}*" '
- '--query "Reservations[].Instances[].InstanceId" '
- '--output text)'),
- _JOB_WAIT_NOT_RUNNING.format(job_name=name),
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=200),
- f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | grep "$RUN_ID"',
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.managed_jobs
-def test_managed_jobs_recovery_gcp():
- """Test managed job recovery."""
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- zone = 'us-east4-b'
- query_cmd = (
- f'gcloud compute instances list --filter='
- # `:` means prefix match.
- f'"(labels.ray-cluster-name:{name_on_cloud})" '
- f'--zones={zone} --format="value(name)"')
- terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
- f' --quiet $({query_cmd})')
- test = Test(
- 'managed_jobs_recovery_gcp',
- [
- f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --cpus 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=300),
- f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
- # Terminate the cluster manually.
- terminate_cmd,
- _JOB_WAIT_NOT_RUNNING.format(job_name=name),
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=200),
- f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"',
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-@pytest.mark.managed_jobs
-def test_managed_jobs_pipeline_recovery_aws(aws_config_region):
- """Test managed job recovery for a pipeline."""
- name = _get_cluster_name()
- user_hash = common_utils.get_user_hash()
- user_hash = user_hash[:common_utils.USER_HASH_LENGTH_IN_CLUSTER_NAME]
- region = aws_config_region
- if region != 'us-east-2':
- pytest.skip('Only run spot pipeline recovery test in us-east-2')
- test = Test(
- 'managed_jobs_pipeline_recovery_aws',
- [
- f'sky jobs launch -n {name} tests/test_yamls/pipeline_aws.yaml -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=400),
- f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
- f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids',
- # Terminate the cluster manually.
- # The `cat ...| rev` is to retrieve the job_id from the
- # SKYPILOT_TASK_ID, which gets the second to last field
- # separated by `-`.
- (
- f'MANAGED_JOB_ID=`cat /tmp/{name}-run-id | rev | '
- 'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`;'
- f'aws ec2 terminate-instances --region {region} --instance-ids $('
- f'aws ec2 describe-instances --region {region} '
- # TODO(zhwu): fix the name for spot cluster.
- '--filters "Name=tag:ray-cluster-name,Values=*-${MANAGED_JOB_ID}"'
- f'-{user_hash} '
- '--query "Reservations[].Instances[].InstanceId" '
- '--output text)'),
- _JOB_WAIT_NOT_RUNNING.format(job_name=name),
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=200),
- f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"',
- f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new',
- f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new',
- f'cat /tmp/{name}-run-ids | sed -n 2p | grep `cat /tmp/{name}-run-id`',
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.managed_jobs
-def test_managed_jobs_pipeline_recovery_gcp():
- """Test managed job recovery for a pipeline."""
- name = _get_cluster_name()
- zone = 'us-east4-b'
- user_hash = common_utils.get_user_hash()
- user_hash = user_hash[:common_utils.USER_HASH_LENGTH_IN_CLUSTER_NAME]
- query_cmd = (
- 'gcloud compute instances list --filter='
- f'"(labels.ray-cluster-name:*-${{MANAGED_JOB_ID}}-{user_hash})" '
- f'--zones={zone} --format="value(name)"')
- terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
- f' --quiet $({query_cmd})')
- test = Test(
- 'managed_jobs_pipeline_recovery_gcp',
- [
- f'sky jobs launch -n {name} tests/test_yamls/pipeline_gcp.yaml -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=400),
- f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
- f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids',
- # Terminate the cluster manually.
- # The `cat ...| rev` is to retrieve the job_id from the
- # SKYPILOT_TASK_ID, which gets the second to last field
- # separated by `-`.
- (f'MANAGED_JOB_ID=`cat /tmp/{name}-run-id | rev | '
- f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'),
- _JOB_WAIT_NOT_RUNNING.format(job_name=name),
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=200),
- f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"',
- f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new',
- f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new',
- f'cat /tmp/{name}-run-ids | sed -n 2p | grep `cat /tmp/{name}-run-id`',
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
-@pytest.mark.no_ibm # IBM Cloud does not support spot instances
-@pytest.mark.no_scp # SCP does not support spot instances
-@pytest.mark.no_paperspace # Paperspace does not support spot instances
-@pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances
-@pytest.mark.managed_jobs
-def test_managed_jobs_recovery_default_resources(generic_cloud: str):
- """Test managed job recovery for default resources."""
- name = _get_cluster_name()
- test = Test(
- 'managed-spot-recovery-default-resources',
- [
- f'sky jobs launch -n {name} --cloud {generic_cloud} --use-spot "sleep 30 && sudo shutdown now && sleep 1000" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[
- sky.ManagedJobStatus.RUNNING,
- sky.ManagedJobStatus.RECOVERING
- ],
- timeout=360),
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-@pytest.mark.managed_jobs
-def test_managed_jobs_recovery_multi_node_aws(aws_config_region):
- """Test managed job recovery."""
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- region = aws_config_region
- test = Test(
- 'managed_jobs_recovery_multi_node_aws',
- [
- f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=450),
- f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
- # Terminate the worker manually.
- (f'aws ec2 terminate-instances --region {region} --instance-ids $('
- f'aws ec2 describe-instances --region {region} '
- f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}* '
- 'Name=tag:ray-node-type,Values=worker '
- f'--query Reservations[].Instances[].InstanceId '
- '--output text)'),
- _JOB_WAIT_NOT_RUNNING.format(job_name=name),
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=560),
- f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"',
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.managed_jobs
-def test_managed_jobs_recovery_multi_node_gcp():
- """Test managed job recovery."""
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- zone = 'us-west2-a'
- # Use ':' to match as the cluster name will contain the suffix with job id
- query_cmd = (
- f'gcloud compute instances list --filter='
- f'"(labels.ray-cluster-name:{name_on_cloud} AND '
- f'labels.ray-node-type=worker)" --zones={zone} --format="value(name)"')
- terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
- f' --quiet $({query_cmd})')
- test = Test(
- 'managed_jobs_recovery_multi_node_gcp',
- [
- f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=400),
- f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
- # Terminate the worker manually.
- terminate_cmd,
- _JOB_WAIT_NOT_RUNNING.format(job_name=name),
- f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=560),
- f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"',
- ],
- f'sky jobs cancel -y -n {name}',
- timeout=25 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.aws
-@pytest.mark.managed_jobs
-def test_managed_jobs_cancellation_aws(aws_config_region):
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- name_2_on_cloud = common_utils.make_cluster_name_on_cloud(
- f'{name}-2', jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- name_3_on_cloud = common_utils.make_cluster_name_on_cloud(
- f'{name}-3', jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- region = aws_config_region
- test = Test(
- 'managed_jobs_cancellation_aws',
- [
- # Test cancellation during spot cluster being launched.
- f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot "sleep 1000" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[
- sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING
- ],
- timeout=60 + _BUMP_UP_SECONDS),
- f'sky jobs cancel -y -n {name}',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=120 + _BUMP_UP_SECONDS),
- (f's=$(aws ec2 describe-instances --region {region} '
- f'--filters "Name=tag:ray-cluster-name,Values={name_on_cloud}-*" '
- '--query "Reservations[].Instances[].State[].Name" '
- '--output text) && echo "$s" && echo; [[ -z "$s" ]] || [[ "$s" = "terminated" ]] || [[ "$s" = "shutting-down" ]]'
- ),
- # Test cancelling the spot cluster during spot job being setup.
- f'sky jobs launch --cloud aws --region {region} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d',
- # The job is set up in the cluster, will shown as RUNNING.
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-2',
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=300 + _BUMP_UP_SECONDS),
- f'sky jobs cancel -y -n {name}-2',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-2',
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=120 + _BUMP_UP_SECONDS),
- (f's=$(aws ec2 describe-instances --region {region} '
- f'--filters "Name=tag:ray-cluster-name,Values={name_2_on_cloud}-*" '
- '--query "Reservations[].Instances[].State[].Name" '
- '--output text) && echo "$s" && echo; [[ -z "$s" ]] || [[ "$s" = "terminated" ]] || [[ "$s" = "shutting-down" ]]'
- ),
- # Test cancellation during spot job is recovering.
- f'sky jobs launch --cloud aws --region {region} -n {name}-3 --use-spot "sleep 1000" -y -d',
- # The job is running in the cluster, will shown as RUNNING.
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-3',
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=300 + _BUMP_UP_SECONDS),
- # Terminate the cluster manually.
- (f'aws ec2 terminate-instances --region {region} --instance-ids $('
- f'aws ec2 describe-instances --region {region} '
- f'--filters "Name=tag:ray-cluster-name,Values={name_3_on_cloud}-*" '
- f'--query "Reservations[].Instances[].InstanceId" '
- '--output text)'),
- _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'),
- f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"',
- f'sky jobs cancel -y -n {name}-3',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-3',
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=120 + _BUMP_UP_SECONDS),
- # The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because
- # there can be multiple VM with the same name due to the recovery.
- (f's=$(aws ec2 describe-instances --region {region} '
- f'--filters "Name=tag:ray-cluster-name,Values={name_3_on_cloud}-*" '
- '--query "Reservations[].Instances[].State[].Name" '
- '--output text) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "pending|running|stopped|stopping"'
- ),
- ],
- timeout=25 * 60)
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.managed_jobs
-def test_managed_jobs_cancellation_gcp():
- name = _get_cluster_name()
- name_3 = f'{name}-3'
- name_3_on_cloud = common_utils.make_cluster_name_on_cloud(
- name_3, jobs.JOBS_CLUSTER_NAME_PREFIX_LENGTH, add_user_hash=False)
- zone = 'us-west3-b'
- query_state_cmd = (
- 'gcloud compute instances list '
- f'--filter="(labels.ray-cluster-name:{name_3_on_cloud})" '
- '--format="value(status)"')
- query_cmd = (f'gcloud compute instances list --filter='
- f'"(labels.ray-cluster-name:{name_3_on_cloud})" '
- f'--zones={zone} --format="value(name)"')
- terminate_cmd = (f'gcloud compute instances delete --zone={zone}'
- f' --quiet $({query_cmd})')
- test = Test(
- 'managed_jobs_cancellation_gcp',
- [
- # Test cancellation during spot cluster being launched.
- f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot "sleep 1000" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.STARTING],
- timeout=60 + _BUMP_UP_SECONDS),
- f'sky jobs cancel -y -n {name}',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=120 + _BUMP_UP_SECONDS),
- # Test cancelling the spot cluster during spot job being setup.
- f'sky jobs launch --cloud gcp --zone {zone} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d',
- # The job is set up in the cluster, will shown as RUNNING.
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-2',
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=300 + _BUMP_UP_SECONDS),
- f'sky jobs cancel -y -n {name}-2',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-2',
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=120 + _BUMP_UP_SECONDS),
- # Test cancellation during spot job is recovering.
- f'sky jobs launch --cloud gcp --zone {zone} -n {name}-3 --use-spot "sleep 1000" -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-3',
- job_status=[sky.ManagedJobStatus.RUNNING],
- timeout=300 + _BUMP_UP_SECONDS),
- # Terminate the cluster manually.
- terminate_cmd,
- _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'),
- f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"',
- f'sky jobs cancel -y -n {name}-3',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=f'{name}-3',
- job_status=[sky.ManagedJobStatus.CANCELLED],
- timeout=120 + _BUMP_UP_SECONDS),
- # The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because
- # there can be multiple VM with the same name due to the recovery.
- (f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"'
- ),
- ],
- timeout=25 * 60)
- run_one_test(test)
-
-
-# ---------- Testing storage for managed job ----------
-@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
-@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
-@pytest.mark.no_ibm # IBM Cloud does not support spot instances
-@pytest.mark.no_paperspace # Paperspace does not support spot instances
-@pytest.mark.no_scp # SCP does not support spot instances
-@pytest.mark.managed_jobs
-def test_managed_jobs_storage(generic_cloud: str):
- """Test storage with managed job"""
- name = _get_cluster_name()
- yaml_str = pathlib.Path(
- 'examples/managed_job_with_storage.yaml').read_text()
- timestamp = int(time.time())
- storage_name = f'sky-test-{timestamp}'
- output_storage_name = f'sky-test-output-{timestamp}'
-
- # Also perform region testing for bucket creation to validate if buckets are
- # created in the correct region and correctly mounted in managed jobs.
- # However, we inject this testing only for AWS and GCP since they are the
- # supported object storage providers in SkyPilot.
- region_flag = ''
- region_validation_cmd = 'true'
- use_spot = ' --use-spot'
- if generic_cloud == 'aws':
- region = 'eu-central-1'
- region_flag = f' --region {region}'
- region_cmd = TestStorageWithCredentials.cli_region_cmd(
- storage_lib.StoreType.S3, bucket_name=storage_name)
- region_validation_cmd = f'{region_cmd} | grep {region}'
- s3_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
- storage_lib.StoreType.S3, output_storage_name, 'output.txt')
- output_check_cmd = f'{s3_check_file_count} | grep 1'
- elif generic_cloud == 'gcp':
- region = 'us-west2'
- region_flag = f' --region {region}'
- region_cmd = TestStorageWithCredentials.cli_region_cmd(
- storage_lib.StoreType.GCS, bucket_name=storage_name)
- region_validation_cmd = f'{region_cmd} | grep {region}'
- gcs_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
- storage_lib.StoreType.GCS, output_storage_name, 'output.txt')
- output_check_cmd = f'{gcs_check_file_count} | grep 1'
- elif generic_cloud == 'azure':
- region = 'westus2'
- region_flag = f' --region {region}'
- storage_account_name = (
- storage_lib.AzureBlobStore.get_default_storage_account_name(region))
- region_cmd = TestStorageWithCredentials.cli_region_cmd(
- storage_lib.StoreType.AZURE,
- storage_account_name=storage_account_name)
- region_validation_cmd = f'{region_cmd} | grep {region}'
- az_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
- storage_lib.StoreType.AZURE,
- output_storage_name,
- 'output.txt',
- storage_account_name=storage_account_name)
- output_check_cmd = f'{az_check_file_count} | grep 1'
- elif generic_cloud == 'kubernetes':
- # With Kubernetes, we don't know which object storage provider is used.
- # Check both S3 and GCS if bucket exists in either.
- s3_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
- storage_lib.StoreType.S3, output_storage_name, 'output.txt')
- s3_output_check_cmd = f'{s3_check_file_count} | grep 1'
- gcs_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket(
- storage_lib.StoreType.GCS, output_storage_name, 'output.txt')
- gcs_output_check_cmd = f'{gcs_check_file_count} | grep 1'
- output_check_cmd = f'{s3_output_check_cmd} || {gcs_output_check_cmd}'
- use_spot = ' --no-use-spot'
-
- yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name)
- yaml_str = yaml_str.replace('sky-output-bucket', output_storage_name)
- with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
- f.write(yaml_str)
- f.flush()
- file_path = f.name
- test = Test(
- 'managed_jobs_storage',
- [
- *STORAGE_SETUP_COMMANDS,
- f'sky jobs launch -n {name}{use_spot} --cloud {generic_cloud}{region_flag} {file_path} -y',
- region_validation_cmd, # Check if the bucket is created in the correct region
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.SUCCEEDED],
- timeout=60 + _BUMP_UP_SECONDS),
- # Wait for the job to be cleaned up.
- 'sleep 30',
- f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]',
- # Check if file was written to the mounted output bucket
- output_check_cmd
- ],
- (f'sky jobs cancel -y -n {name}',
- f'; sky storage delete {output_storage_name} || true'),
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# ---------- Testing spot TPU ----------
-@pytest.mark.gcp
-@pytest.mark.managed_jobs
-@pytest.mark.tpu
-def test_managed_jobs_tpu():
- """Test managed job on TPU."""
- name = _get_cluster_name()
- test = Test(
- 'test-spot-tpu',
- [
- f'sky jobs launch -n {name} --use-spot examples/tpu/tpuvm_mnist.yaml -y -d',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.STARTING],
- timeout=60 + _BUMP_UP_SECONDS),
- # TPU takes a while to launch
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[
- sky.ManagedJobStatus.RUNNING, sky.ManagedJobStatus.SUCCEEDED
- ],
- timeout=900 + _BUMP_UP_SECONDS),
- ],
- f'sky jobs cancel -y -n {name}',
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# ---------- Testing env for managed jobs ----------
-@pytest.mark.managed_jobs
-def test_managed_jobs_inline_env(generic_cloud: str):
- """Test managed jobs env"""
- name = _get_cluster_name()
- test = Test(
- 'test-managed-jobs-inline-env',
- [
- f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "echo "\\$TEST_ENV"; ([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
- _get_cmd_wait_until_managed_job_status_contains_matching_job_name(
- job_name=name,
- job_status=[sky.ManagedJobStatus.SUCCEEDED],
- timeout=20 + _BUMP_UP_SECONDS),
- f'JOB_ROW=$(sky jobs queue | grep {name} | head -n1) && '
- f'echo "$JOB_ROW" && echo "$JOB_ROW" | grep "SUCCEEDED" && '
- f'JOB_ID=$(echo "$JOB_ROW" | awk \'{{print $1}}\') && '
- f'echo "JOB_ID=$JOB_ID" && '
- # Test that logs are still available after the job finishes.
- 'unset SKYPILOT_DEBUG; s=$(sky jobs logs $JOB_ID --refresh) && echo "$s" && echo "$s" | grep "hello world" && '
- # Make sure we skip the unnecessary logs.
- 'echo "$s" | head -n1 | grep "Waiting for"',
- ],
- f'sky jobs cancel -y -n {name}',
- # Increase timeout since sky jobs queue -r can be blocked by other spot tests.
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# ---------- Testing env ----------
-def test_inline_env(generic_cloud: str):
- """Test env"""
- name = _get_cluster_name()
- test = Test(
- 'test-inline-env',
- [
- f'sky launch -c {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
- 'sleep 20',
- f'sky logs {name} 1 --status',
- f'sky exec {name} --env TEST_ENV2="success" "([[ ! -z \\"\$TEST_ENV2\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
- f'sky logs {name} 2 --status',
- ],
- f'sky down -y {name}',
- _get_timeout(generic_cloud),
- )
- run_one_test(test)
-
-
-# ---------- Testing env file ----------
-def test_inline_env_file(generic_cloud: str):
- """Test env"""
- name = _get_cluster_name()
- test = Test(
- 'test-inline-env-file',
- [
- f'sky launch -c {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
- f'sky logs {name} 1 --status',
- f'sky exec {name} --env-file examples/sample_dotenv "([[ ! -z \\"\$TEST_ENV2\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"',
- f'sky logs {name} 2 --status',
- ],
- f'sky down -y {name}',
- _get_timeout(generic_cloud),
- )
- run_one_test(test)
-
-
-# ---------- Testing custom image ----------
-@pytest.mark.aws
-def test_aws_custom_image():
- """Test AWS custom image"""
- name = _get_cluster_name()
- test = Test(
- 'test-aws-custom-image',
- [
- f'sky launch -c {name} --retry-until-up -y tests/test_yamls/test_custom_image.yaml --cloud aws --region us-east-2 --image-id ami-062ddd90fb6f8267a', # Nvidia image
- f'sky logs {name} 1 --status',
- ],
- f'sky down -y {name}',
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-@pytest.mark.parametrize(
- 'image_id',
- [
- 'docker:nvidia/cuda:11.8.0-devel-ubuntu18.04',
- 'docker:ubuntu:18.04',
- # Test latest image with python 3.11 installed by default.
- 'docker:continuumio/miniconda3:24.1.2-0',
- # Test python>=3.12 where SkyPilot should automatically create a separate
- # conda env for runtime with python 3.10.
- 'docker:continuumio/miniconda3:latest',
- ])
-def test_kubernetes_custom_image(image_id):
- """Test Kubernetes custom image"""
- name = _get_cluster_name()
- test = Test(
- 'test-kubernetes-custom-image',
- [
- f'sky launch -c {name} --retry-until-up -y tests/test_yamls/test_custom_image.yaml --cloud kubernetes --image-id {image_id} --region None --gpus T4:1',
- f'sky logs {name} 1 --status',
- # Try exec to run again and check if the logs are printed
- f'sky exec {name} tests/test_yamls/test_custom_image.yaml --cloud kubernetes --image-id {image_id} --region None --gpus T4:1 | grep "Hello 100"',
- # Make sure ssh is working with custom username
- f'ssh {name} echo hi | grep hi',
- ],
- f'sky down -y {name}',
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_azure_start_stop_two_nodes():
- name = _get_cluster_name()
- test = Test(
- 'azure-start-stop-two-nodes',
- [
- f'sky launch --num-nodes=2 -y -c {name} examples/azure_start_stop.yaml',
- f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky stop -y {name}',
- f'sky start -y {name} -i 1',
- f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml',
- f'sky logs {name} 2 --status', # Ensure the job succeeded.
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[
- sky.ClusterStatus.INIT, sky.ClusterStatus.STOPPED
- ],
- timeout=200 + _BUMP_UP_SECONDS) +
- f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}'
- ],
- f'sky down -y {name}',
- timeout=30 * 60, # 30 mins (it takes around ~23 mins)
- )
- run_one_test(test)
-
-
-# ---------- Testing env for disk tier ----------
-@pytest.mark.aws
-def test_aws_disk_tier():
-
- def _get_aws_query_command(region, instance_id, field, expected):
- return (f'aws ec2 describe-volumes --region {region} '
- f'--filters Name=attachment.instance-id,Values={instance_id} '
- f'--query Volumes[*].{field} | grep {expected} ; ')
-
- for disk_tier in list(resources_utils.DiskTier):
- specs = AWS._get_disk_specs(disk_tier)
- name = _get_cluster_name() + '-' + disk_tier.value
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, sky.AWS.max_cluster_name_length())
- region = 'us-east-2'
- test = Test(
- 'aws-disk-tier-' + disk_tier.value,
- [
- f'sky launch -y -c {name} --cloud aws --region {region} '
- f'--disk-tier {disk_tier.value} echo "hello sky"',
- f'id=`aws ec2 describe-instances --region {region} --filters '
- f'Name=tag:ray-cluster-name,Values={name_on_cloud} --query '
- f'Reservations[].Instances[].InstanceId --output text`; ' +
- _get_aws_query_command(region, '$id', 'VolumeType',
- specs['disk_tier']) +
- ('' if specs['disk_tier']
- == 'standard' else _get_aws_query_command(
- region, '$id', 'Iops', specs['disk_iops'])) +
- ('' if specs['disk_tier'] != 'gp3' else _get_aws_query_command(
- region, '$id', 'Throughput', specs['disk_throughput'])),
- ],
- f'sky down -y {name}',
- timeout=10 * 60, # 10 mins (it takes around ~6 mins)
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_disk_tier():
- for disk_tier in list(resources_utils.DiskTier):
- disk_types = [GCP._get_disk_type(disk_tier)]
- name = _get_cluster_name() + '-' + disk_tier.value
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, sky.GCP.max_cluster_name_length())
- region = 'us-west2'
- instance_type_options = ['']
- if disk_tier == resources_utils.DiskTier.BEST:
- # Ultra disk tier requires n2 instance types to have more than 64 CPUs.
- # If using default instance type, it will only enable the high disk tier.
- disk_types = [
- GCP._get_disk_type(resources_utils.DiskTier.HIGH),
- GCP._get_disk_type(resources_utils.DiskTier.ULTRA),
- ]
- instance_type_options = ['', '--instance-type n2-standard-64']
- for disk_type, instance_type_option in zip(disk_types,
- instance_type_options):
- test = Test(
- 'gcp-disk-tier-' + disk_tier.value,
- [
- f'sky launch -y -c {name} --cloud gcp --region {region} '
- f'--disk-tier {disk_tier.value} {instance_type_option} ',
- f'name=`gcloud compute instances list --filter='
- f'"labels.ray-cluster-name:{name_on_cloud}" '
- '--format="value(name)"`; '
- f'gcloud compute disks list --filter="name=$name" '
- f'--format="value(type)" | grep {disk_type} '
- ],
- f'sky down -y {name}',
- timeout=6 * 60, # 6 mins (it takes around ~3 mins)
- )
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_azure_disk_tier():
- for disk_tier in list(resources_utils.DiskTier):
- if disk_tier == resources_utils.DiskTier.HIGH or disk_tier == resources_utils.DiskTier.ULTRA:
- # Azure does not support high and ultra disk tier.
- continue
- type = Azure._get_disk_type(disk_tier)
- name = _get_cluster_name() + '-' + disk_tier.value
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, sky.Azure.max_cluster_name_length())
- region = 'westus2'
- test = Test(
- 'azure-disk-tier-' + disk_tier.value,
- [
- f'sky launch -y -c {name} --cloud azure --region {region} '
- f'--disk-tier {disk_tier.value} echo "hello sky"',
- f'az resource list --tag ray-cluster-name={name_on_cloud} --query '
- f'"[?type==\'Microsoft.Compute/disks\'].sku.name" '
- f'--output tsv | grep {type}'
- ],
- f'sky down -y {name}',
- timeout=20 * 60, # 20 mins (it takes around ~12 mins)
- )
- run_one_test(test)
-
-
-@pytest.mark.azure
-def test_azure_best_tier_failover():
- type = Azure._get_disk_type(resources_utils.DiskTier.LOW)
- name = _get_cluster_name()
- name_on_cloud = common_utils.make_cluster_name_on_cloud(
- name, sky.Azure.max_cluster_name_length())
- region = 'westus2'
- test = Test(
- 'azure-best-tier-failover',
- [
- f'sky launch -y -c {name} --cloud azure --region {region} '
- f'--disk-tier best --instance-type Standard_D8_v5 echo "hello sky"',
- f'az resource list --tag ray-cluster-name={name_on_cloud} --query '
- f'"[?type==\'Microsoft.Compute/disks\'].sku.name" '
- f'--output tsv | grep {type}',
- ],
- f'sky down -y {name}',
- timeout=20 * 60, # 20 mins (it takes around ~12 mins)
- )
- run_one_test(test)
-
-
-# ------ Testing Zero Quota Failover ------
-@pytest.mark.aws
-def test_aws_zero_quota_failover():
-
- name = _get_cluster_name()
- region = get_aws_region_for_quota_failover()
-
- if not region:
- pytest.xfail(
- 'Unable to test zero quota failover optimization — quotas '
- 'for EC2 P3 instances were found on all AWS regions. Is this '
- 'expected for your account?')
- return
-
- test = Test(
- 'aws-zero-quota-failover',
- [
- f'sky launch -y -c {name} --cloud aws --region {region} --gpus V100:8 --use-spot | grep "Found no quota"',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-def test_gcp_zero_quota_failover():
-
- name = _get_cluster_name()
- region = get_gcp_region_for_quota_failover()
-
- if not region:
- pytest.xfail(
- 'Unable to test zero quota failover optimization — quotas '
- 'for A100-80GB GPUs were found on all GCP regions. Is this '
- 'expected for your account?')
- return
-
- test = Test(
- 'gcp-zero-quota-failover',
- [
- f'sky launch -y -c {name} --cloud gcp --region {region} --gpus A100-80GB:1 --use-spot | grep "Found no quota"',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-def test_long_setup_run_script(generic_cloud: str):
- name = _get_cluster_name()
- with tempfile.NamedTemporaryFile('w', prefix='sky_app_',
- suffix='.yaml') as f:
- f.write(
- textwrap.dedent(""" \
- setup: |
- echo "start long setup"
- """))
- for i in range(1024 * 200):
- f.write(f' echo {i}\n')
- f.write(' echo "end long setup"\n')
- f.write(
- textwrap.dedent(""" \
- run: |
- echo "run"
- """))
- for i in range(1024 * 200):
- f.write(f' echo {i}\n')
- f.write(' echo "end run"\n')
- f.flush()
-
- test = Test(
- 'long-setup-run-script',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} --cpus 2+ {f.name}',
- f'sky exec {name} "echo hello"',
- f'sky exec {name} {f.name}',
- f'sky logs {name} --status 1',
- f'sky logs {name} --status 2',
- f'sky logs {name} --status 3',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Testing skyserve ----------
-
-
-def _get_service_name() -> str:
- """Returns a user-unique service name for each test_skyserve_().
-
- Must be called from each test_skyserve_().
- """
- caller_func_name = inspect.stack()[1][3]
- test_name = caller_func_name.replace('_', '-').replace('test-', 't-')
- test_name = test_name.replace('skyserve-', 'ss-')
- test_name = common_utils.make_cluster_name_on_cloud(test_name, 24)
- return f'{test_name}-{test_id}'
-
-
-# We check the output of the skyserve service to see if it is ready. Output of
-# `REPLICAS` is in the form of `1/2` where the first number is the number of
-# ready replicas and the second number is the number of total replicas. We
-# grep such format to ensure that the service is ready, and early exit if any
-# failure detected. In the end we sleep for
-# serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS to make sure load balancer have
-# enough time to sync with the controller and get all ready replica IPs.
-_SERVE_WAIT_UNTIL_READY = (
- '{{ while true; do'
- ' s=$(sky serve status {name}); echo "$s";'
- ' echo "$s" | grep -q "{replica_num}/{replica_num}" && break;'
- ' echo "$s" | grep -q "FAILED" && exit 1;'
- ' sleep 10;'
- ' done; }}; echo "Got service status $s";'
- f'sleep {serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS + 2};')
-_IP_REGEX = r'([0-9]{1,3}\.){3}[0-9]{1,3}'
-_AWK_ALL_LINES_BELOW_REPLICAS = r'/Replicas/{flag=1; next} flag'
-_SERVICE_LAUNCHING_STATUS_REGEX = 'PROVISIONING\|STARTING'
-# Since we don't allow terminate the service if the controller is INIT,
-# which is common for simultaneous pytest, we need to wait until the
-# controller is UP before we can terminate the service.
-# The teardown command has a 10-mins timeout, so we don't need to do
-# the timeout here. See implementation of run_one_test() for details.
-_TEARDOWN_SERVICE = (
- '(for i in `seq 1 20`; do'
- ' s=$(sky serve down -y {name});'
- ' echo "Trying to terminate {name}";'
- ' echo "$s";'
- ' echo "$s" | grep -q "scheduled to be terminated\|No service to terminate" && break;'
- ' sleep 10;'
- ' [ $i -eq 20 ] && echo "Failed to terminate service {name}";'
- 'done)')
-
-_SERVE_ENDPOINT_WAIT = (
- 'export ORIGIN_SKYPILOT_DEBUG=$SKYPILOT_DEBUG; export SKYPILOT_DEBUG=0; '
- 'endpoint=$(sky serve status --endpoint {name}); '
- 'until ! echo "$endpoint" | grep "Controller is initializing"; '
- 'do echo "Waiting for serve endpoint to be ready..."; '
- 'sleep 5; endpoint=$(sky serve status --endpoint {name}); done; '
- 'export SKYPILOT_DEBUG=$ORIGIN_SKYPILOT_DEBUG; echo "$endpoint"')
-
-_SERVE_STATUS_WAIT = ('s=$(sky serve status {name}); '
- 'until ! echo "$s" | grep "Controller is initializing."; '
- 'do echo "Waiting for serve status to be ready..."; '
- 'sleep 5; s=$(sky serve status {name}); done; echo "$s"')
-
-
-def _get_replica_ip(name: str, replica_id: int) -> str:
- return (f'ip{replica_id}=$(echo "$s" | '
- f'awk "{_AWK_ALL_LINES_BELOW_REPLICAS}" | '
- f'grep -E "{name}\s+{replica_id}" | '
- f'grep -Eo "{_IP_REGEX}")')
-
-
-def _get_skyserve_http_test(name: str, cloud: str,
- timeout_minutes: int) -> Test:
- test = Test(
- f'test-skyserve-{cloud.replace("_", "-")}',
- [
- f'sky serve up -n {name} -y tests/skyserve/http/{cloud}.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'curl http://$endpoint | grep "Hi, SkyPilot here"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=timeout_minutes * 60,
- )
- return test
-
-
-def _check_replica_in_status(name: str, check_tuples: List[Tuple[int, bool,
- str]]) -> str:
- """Check replicas' status and count in sky serve status
-
- We will check vCPU=2, as all our tests use vCPU=2.
-
- Args:
- name: the name of the service
- check_tuples: A list of replica property to check. Each tuple is
- (count, is_spot, status)
- """
- check_cmd = ''
- for check_tuple in check_tuples:
- count, is_spot, status = check_tuple
- resource_str = ''
- if status not in ['PENDING', 'SHUTTING_DOWN'
- ] and not status.startswith('FAILED'):
- spot_str = ''
- if is_spot:
- spot_str = '\[Spot\]'
- resource_str = f'({spot_str}vCPU=2)'
- check_cmd += (f' echo "$s" | grep "{resource_str}" | '
- f'grep "{status}" | wc -l | grep {count} || exit 1;')
- return (f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s"; ' + check_cmd)
-
-
-def _check_service_version(service_name: str, version: str) -> str:
- # Grep the lines before 'Service Replicas' and check if the service version
- # is correct.
- return (f'echo "$s" | grep -B1000 "Service Replicas" | '
- f'grep -E "{service_name}\s+{version}" || exit 1; ')
-
-
-@pytest.mark.gcp
-@pytest.mark.serve
-def test_skyserve_gcp_http():
- """Test skyserve on GCP"""
- name = _get_service_name()
- test = _get_skyserve_http_test(name, 'gcp', 20)
- run_one_test(test)
-
-
-@pytest.mark.aws
-@pytest.mark.serve
-def test_skyserve_aws_http():
- """Test skyserve on AWS"""
- name = _get_service_name()
- test = _get_skyserve_http_test(name, 'aws', 20)
- run_one_test(test)
-
-
-@pytest.mark.azure
-@pytest.mark.serve
-def test_skyserve_azure_http():
- """Test skyserve on Azure"""
- name = _get_service_name()
- test = _get_skyserve_http_test(name, 'azure', 30)
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-@pytest.mark.serve
-def test_skyserve_kubernetes_http():
- """Test skyserve on Kubernetes"""
- name = _get_service_name()
- test = _get_skyserve_http_test(name, 'kubernetes', 30)
- run_one_test(test)
-
-
-@pytest.mark.oci
-@pytest.mark.serve
-def test_skyserve_oci_http():
- """Test skyserve on OCI"""
- name = _get_service_name()
- test = _get_skyserve_http_test(name, 'oci', 20)
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack does not support T4 gpus for now
-@pytest.mark.serve
-def test_skyserve_llm(generic_cloud: str):
- """Test skyserve with real LLM usecase"""
- name = _get_service_name()
-
- def generate_llm_test_command(prompt: str, expected_output: str) -> str:
- prompt = shlex.quote(prompt)
- expected_output = shlex.quote(expected_output)
- return (
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'python tests/skyserve/llm/get_response.py --endpoint $endpoint '
- f'--prompt {prompt} | grep {expected_output}')
-
- with open('tests/skyserve/llm/prompt_output.json', 'r',
- encoding='utf-8') as f:
- prompt2output = json.load(f)
-
- test = Test(
- f'test-skyserve-llm',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/llm/service.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- *[
- generate_llm_test_command(prompt, output)
- for prompt, output in prompt2output.items()
- ],
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=40 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.serve
-def test_skyserve_spot_recovery():
- name = _get_service_name()
- zone = 'us-central1-a'
-
- test = Test(
- f'test-skyserve-spot-recovery-gcp',
- [
- f'sky serve up -n {name} -y tests/skyserve/spot/recovery.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
- _terminate_gcp_replica(name, zone, 1),
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
-@pytest.mark.serve
-@pytest.mark.no_kubernetes
-def test_skyserve_base_ondemand_fallback(generic_cloud: str):
- name = _get_service_name()
- test = Test(
- f'test-skyserve-base-ondemand-fallback',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/spot/base_ondemand_fallback.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- _check_replica_in_status(name, [(1, True, 'READY'),
- (1, False, 'READY')]),
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.serve
-def test_skyserve_dynamic_ondemand_fallback():
- name = _get_service_name()
- zone = 'us-central1-a'
-
- test = Test(
- f'test-skyserve-dynamic-ondemand-fallback',
- [
- f'sky serve up -n {name} --cloud gcp -y tests/skyserve/spot/dynamic_ondemand_fallback.yaml',
- f'sleep 40',
- # 2 on-demand (provisioning) + 2 Spot (provisioning).
- f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s";'
- 'echo "$s" | grep -q "0/4" || exit 1',
- # Wait for the provisioning starts
- f'sleep 40',
- _check_replica_in_status(name, [
- (2, True, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY'),
- (2, False, _SERVICE_LAUNCHING_STATUS_REGEX + '\|SHUTTING_DOWN')
- ]),
-
- # Wait until 2 spot instances are ready.
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- _check_replica_in_status(name, [(2, True, 'READY'),
- (0, False, '')]),
- _terminate_gcp_replica(name, zone, 1),
- f'sleep 40',
- # 1 on-demand (provisioning) + 1 Spot (ready) + 1 spot (provisioning).
- f'{_SERVE_STATUS_WAIT.format(name=name)}; '
- 'echo "$s" | grep -q "1/3"',
- _check_replica_in_status(
- name, [(1, True, 'READY'),
- (1, True, _SERVICE_LAUNCHING_STATUS_REGEX),
- (1, False, _SERVICE_LAUNCHING_STATUS_REGEX)]),
-
- # Wait until 2 spot instances are ready.
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- _check_replica_in_status(name, [(2, True, 'READY'),
- (0, False, '')]),
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
-@pytest.mark.no_fluidstack
-@pytest.mark.serve
-def test_skyserve_user_bug_restart(generic_cloud: str):
- """Tests that we restart the service after user bug."""
- # TODO(zhwu): this behavior needs some rethinking.
- name = _get_service_name()
- test = Test(
- f'test-skyserve-user-bug-restart',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/restart/user_bug.yaml',
- f's=$(sky serve status {name}); echo "$s";'
- 'until echo "$s" | grep -A 100 "Service Replicas" | grep "SHUTTING_DOWN"; '
- 'do echo "Waiting for first service to be SHUTTING DOWN..."; '
- f'sleep 5; s=$(sky serve status {name}); echo "$s"; done; ',
- f's=$(sky serve status {name}); echo "$s";'
- 'until echo "$s" | grep -A 100 "Service Replicas" | grep "FAILED"; '
- 'do echo "Waiting for first service to be FAILED..."; '
- f'sleep 2; s=$(sky serve status {name}); echo "$s"; done; echo "$s"; '
- + _check_replica_in_status(name, [(1, True, 'FAILED')]) +
- # User bug failure will cause no further scaling.
- f'echo "$s" | grep -A 100 "Service Replicas" | grep "{name}" | wc -l | grep 1; '
- f'echo "$s" | grep -B 100 "NO_REPLICA" | grep "0/0"',
- f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/auto_restart.yaml',
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'until curl --connect-timeout 10 --max-time 10 http://$endpoint | grep "Hi, SkyPilot here"; do sleep 1; done; sleep 2; '
- + _check_replica_in_status(name, [(1, False, 'READY'),
- (1, False, 'FAILED')]),
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.serve
-@pytest.mark.no_kubernetes # Replicas on k8s may be running on the same node and have the same public IP
-def test_skyserve_load_balancer(generic_cloud: str):
- """Test skyserve load balancer round-robin policy"""
- name = _get_service_name()
- test = Test(
- f'test-skyserve-load-balancer',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/load_balancer/service.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=3),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- f'{_SERVE_STATUS_WAIT.format(name=name)}; '
- f'{_get_replica_ip(name, 1)}; '
- f'{_get_replica_ip(name, 2)}; {_get_replica_ip(name, 3)}; '
- 'python tests/skyserve/load_balancer/test_round_robin.py '
- '--endpoint $endpoint --replica-num 3 --replica-ips $ip1 $ip2 $ip3',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.gcp
-@pytest.mark.serve
-@pytest.mark.no_kubernetes
-def test_skyserve_auto_restart():
- """Test skyserve with auto restart"""
- name = _get_service_name()
- zone = 'us-central1-a'
- test = Test(
- f'test-skyserve-auto-restart',
- [
- # TODO(tian): we can dynamically generate YAML from template to
- # avoid maintaining too many YAML files
- f'sky serve up -n {name} -y tests/skyserve/auto_restart.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
- # sleep for 20 seconds (initial delay) to make sure it will
- # be restarted
- f'sleep 20',
- _terminate_gcp_replica(name, zone, 1),
- # Wait for consecutive failure timeout passed.
- # If the cluster is not using spot, it won't check the cluster status
- # on the cloud (since manual shutdown is not a common behavior and such
- # queries takes a lot of time). Instead, we think continuous 3 min probe
- # failure is not a temporary problem but indeed a failure.
- 'sleep 180',
- # We cannot use _SERVE_WAIT_UNTIL_READY; there will be a intermediate time
- # that the output of `sky serve status` shows FAILED and this status will
- # cause _SERVE_WAIT_UNTIL_READY to early quit.
- '(while true; do'
- f' output=$(sky serve status {name});'
- ' echo "$output" | grep -q "1/1" && break;'
- ' sleep 10;'
- f'done); sleep {serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS};',
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.serve
-def test_skyserve_cancel(generic_cloud: str):
- """Test skyserve with cancel"""
- name = _get_service_name()
-
- test = Test(
- f'test-skyserve-cancel',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/cancel/cancel.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; python3 '
- 'tests/skyserve/cancel/send_cancel_request.py '
- '--endpoint $endpoint | grep "Request was cancelled"',
- f's=$(sky serve logs {name} 1 --no-follow); '
- 'until ! echo "$s" | grep "Please wait for the controller to be"; '
- 'do echo "Waiting for serve logs"; sleep 10; '
- f's=$(sky serve logs {name} 1 --no-follow); done; '
- 'echo "$s"; echo "$s" | grep "Client disconnected, stopping computation"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.serve
-def test_skyserve_streaming(generic_cloud: str):
- """Test skyserve with streaming"""
- name = _get_service_name()
- test = Test(
- f'test-skyserve-streaming',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/streaming/streaming.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'python3 tests/skyserve/streaming/send_streaming_request.py '
- '--endpoint $endpoint | grep "Streaming test passed"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.serve
-def test_skyserve_readiness_timeout_fail(generic_cloud: str):
- """Test skyserve with large readiness probe latency, expected to fail"""
- name = _get_service_name()
- test = Test(
- f'test-skyserve-readiness-timeout-fail',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task.yaml',
- # None of the readiness probe will pass, so the service will be
- # terminated after the initial delay.
- f's=$(sky serve status {name}); '
- f'until echo "$s" | grep "FAILED_INITIAL_DELAY"; do '
- 'echo "Waiting for replica to be failed..."; sleep 5; '
- f's=$(sky serve status {name}); echo "$s"; done;',
- 'sleep 60',
- f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s" | grep "{name}" | grep "FAILED_INITIAL_DELAY" | wc -l | grep 1;'
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.serve
-def test_skyserve_large_readiness_timeout(generic_cloud: str):
- """Test skyserve with customized large readiness timeout"""
- name = _get_service_name()
- test = Test(
- f'test-skyserve-large-readiness-timeout',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task_large_timeout.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
-@pytest.mark.no_fluidstack
-@pytest.mark.serve
-def test_skyserve_update(generic_cloud: str):
- """Test skyserve with update"""
- name = _get_service_name()
- test = Test(
- f'test-skyserve-update',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
- f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/new.yaml',
- # sleep before update is registered.
- 'sleep 20',
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done;'
- # Make sure the traffic is not mixed
- 'curl http://$endpoint | grep "Hi, new SkyPilot here"',
- # The latest 2 version should be READY and the older versions should be shutting down
- (_check_replica_in_status(name, [(2, False, 'READY'),
- (2, False, 'SHUTTING_DOWN')]) +
- _check_service_version(name, "2")),
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
-@pytest.mark.no_fluidstack
-@pytest.mark.serve
-def test_skyserve_rolling_update(generic_cloud: str):
- """Test skyserve with rolling update"""
- name = _get_service_name()
- single_new_replica = _check_replica_in_status(
- name, [(2, False, 'READY'), (1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
- (1, False, 'SHUTTING_DOWN')])
- test = Test(
- f'test-skyserve-rolling-update',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
- f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/new.yaml',
- # Make sure the traffic is mixed across two versions, the replicas
- # with even id will sleep 60 seconds before being ready, so we
- # should be able to get observe the period that the traffic is mixed
- # across two versions.
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done; sleep 2; '
- # The latest version should have one READY and the one of the older versions should be shutting down
- f'{single_new_replica} {_check_service_version(name, "1,2")} '
- # Check the output from the old version, immediately after the
- # output from the new version appears. This is guaranteed by the
- # round robin load balancing policy.
- # TODO(zhwu): we should have a more generalized way for checking the
- # mixed version of replicas to avoid depending on the specific
- # round robin load balancing policy.
- 'curl http://$endpoint | grep "Hi, SkyPilot here"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack
-@pytest.mark.serve
-def test_skyserve_fast_update(generic_cloud: str):
- """Test skyserve with fast update (Increment version of old replicas)"""
- name = _get_service_name()
-
- test = Test(
- f'test-skyserve-fast-update',
- [
- f'sky serve up -n {name} -y --cloud {generic_cloud} tests/skyserve/update/bump_version_before.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
- f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/bump_version_after.yaml',
- # sleep to wait for update to be registered.
- 'sleep 40',
- # 2 on-deamnd (ready) + 1 on-demand (provisioning).
- (
- _check_replica_in_status(
- name, [(2, False, 'READY'),
- (1, False, _SERVICE_LAUNCHING_STATUS_REGEX)]) +
- # Fast update will directly have the latest version ready.
- _check_service_version(name, "2")),
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=3) +
- _check_service_version(name, "2"),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
- # Test rolling update
- f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/bump_version_before.yaml',
- # sleep to wait for update to be registered.
- 'sleep 25',
- # 2 on-deamnd (ready) + 1 on-demand (shutting down).
- _check_replica_in_status(name, [(2, False, 'READY'),
- (1, False, 'SHUTTING_DOWN')]),
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
- _check_service_version(name, "3"),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.serve
-def test_skyserve_update_autoscale(generic_cloud: str):
- """Test skyserve update with autoscale"""
- name = _get_service_name()
- test = Test(
- f'test-skyserve-update-autoscale',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/num_min_two.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
- _check_service_version(name, "1"),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'curl http://$endpoint | grep "Hi, SkyPilot here"',
- f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/num_min_one.yaml',
- # sleep before update is registered.
- 'sleep 20',
- # Timeout will be triggered when update fails.
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1) +
- _check_service_version(name, "2"),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'curl http://$endpoint | grep "Hi, SkyPilot here!"',
- # Rolling Update
- f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/num_min_two.yaml',
- # sleep before update is registered.
- 'sleep 20',
- # Timeout will be triggered when update fails.
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
- _check_service_version(name, "3"),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'curl http://$endpoint | grep "Hi, SkyPilot here!"',
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=30 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Spot instances are note supported by Fluidstack
-@pytest.mark.serve
-@pytest.mark.no_kubernetes # Spot instances are not supported in Kubernetes
-@pytest.mark.parametrize('mode', ['rolling', 'blue_green'])
-def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str):
- """Test skyserve with update that changes autoscaler"""
- name = f'{_get_service_name()}-{mode}'
-
- wait_until_no_pending = (
- f's=$(sky serve status {name}); echo "$s"; '
- 'until ! echo "$s" | grep PENDING; do '
- ' echo "Waiting for replica to be out of pending..."; '
- f' sleep 5; s=$(sky serve status {name}); '
- ' echo "$s"; '
- 'done')
- four_spot_up_cmd = _check_replica_in_status(name, [(4, True, 'READY')])
- update_check = [f'until ({four_spot_up_cmd}); do sleep 5; done; sleep 15;']
- if mode == 'rolling':
- # Check rolling update, it will terminate one of the old on-demand
- # instances, once there are 4 spot instance ready.
- update_check += [
- _check_replica_in_status(
- name, [(1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
- (1, False, 'SHUTTING_DOWN'), (1, False, 'READY')]) +
- _check_service_version(name, "1,2"),
- ]
- else:
- # Check blue green update, it will keep both old on-demand instances
- # running, once there are 4 spot instance ready.
- update_check += [
- _check_replica_in_status(
- name, [(1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
- (2, False, 'READY')]) +
- _check_service_version(name, "1"),
- ]
- test = Test(
- f'test-skyserve-new-autoscaler-update-{mode}',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/new_autoscaler_before.yaml',
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) +
- _check_service_version(name, "1"),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 's=$(curl http://$endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"',
- f'sky serve update {name} --cloud {generic_cloud} --mode {mode} -y tests/skyserve/update/new_autoscaler_after.yaml',
- # Wait for update to be registered
- f'sleep 90',
- wait_until_no_pending,
- _check_replica_in_status(
- name, [(4, True, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY'),
- (1, False, _SERVICE_LAUNCHING_STATUS_REGEX),
- (2, False, 'READY')]),
- *update_check,
- _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=5),
- f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
- 'curl http://$endpoint | grep "Hi, SkyPilot here"',
- _check_replica_in_status(name, [(4, True, 'READY'),
- (1, False, 'READY')]),
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs
-@pytest.mark.no_fluidstack
-@pytest.mark.serve
-def test_skyserve_failures(generic_cloud: str):
- """Test replica failure statuses"""
- name = _get_service_name()
-
- test = Test(
- 'test-skyserve-failures',
- [
- f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/failures/initial_delay.yaml',
- f's=$(sky serve status {name}); '
- f'until echo "$s" | grep "FAILED_INITIAL_DELAY"; do '
- 'echo "Waiting for replica to be failed..."; sleep 5; '
- f's=$(sky serve status {name}); echo "$s"; done;',
- 'sleep 60',
- f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s" | grep "{name}" | grep "FAILED_INITIAL_DELAY" | wc -l | grep 2; '
- # Make sure no new replicas are started for early failure.
- f'echo "$s" | grep -A 100 "Service Replicas" | grep "{name}" | wc -l | grep 2;',
- f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/failures/probing.yaml',
- f's=$(sky serve status {name}); '
- # Wait for replica to be ready.
- f'until echo "$s" | grep "READY"; do '
- 'echo "Waiting for replica to be failed..."; sleep 5; '
- f's=$(sky serve status {name}); echo "$s"; done;',
- # Wait for replica to change to FAILED_PROBING
- f's=$(sky serve status {name}); '
- f'until echo "$s" | grep "FAILED_PROBING"; do '
- 'echo "Waiting for replica to be failed..."; sleep 5; '
- f's=$(sky serve status {name}); echo "$s"; done',
- # Wait for the PENDING replica to appear.
- 'sleep 10',
- # Wait until the replica is out of PENDING.
- f's=$(sky serve status {name}); '
- f'until ! echo "$s" | grep "PENDING" && ! echo "$s" | grep "Please wait for the controller to be ready."; do '
- 'echo "Waiting for replica to be out of pending..."; sleep 5; '
- f's=$(sky serve status {name}); echo "$s"; done; ' +
- _check_replica_in_status(name, [
- (1, False, 'FAILED_PROBING'),
- (1, False, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY')
- ]),
- # TODO(zhwu): add test for FAILED_PROVISION
- ],
- _TEARDOWN_SERVICE.format(name=name),
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-# TODO(Ziming, Tian): Add tests for autoscaling.
-
-
-# ------- Testing user dependencies --------
-def test_user_dependencies(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'user-dependencies',
- [
- f'sky launch -y -c {name} --cloud {generic_cloud} "pip install ray>2.11; ray start --head"',
- f'sky logs {name} 1 --status',
- f'sky exec {name} "echo hi"',
- f'sky logs {name} 2 --status',
- f'sky status -r {name} | grep UP',
- f'sky exec {name} "echo bye"',
- f'sky logs {name} 3 --status',
- f'sky launch -c {name} tests/test_yamls/different_default_conda_env.yaml',
- f'sky logs {name} 4 --status',
- # Launch again to test the default env does not affect SkyPilot
- # runtime setup
- f'sky launch -c {name} "python --version 2>&1 | grep \'Python 3.6\' || exit 1"',
- f'sky logs {name} 5 --status',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ------- Testing the core API --------
-# Most of the core APIs have been tested in the CLI tests.
-# These tests are for testing the return value of the APIs not fully used in CLI.
-
-
-@pytest.mark.gcp
-def test_core_api_sky_launch_exec():
- name = _get_cluster_name()
- task = sky.Task(run="whoami")
- task.set_resources(sky.Resources(cloud=sky.GCP()))
- job_id, handle = sky.get(sky.launch(task, cluster_name=name))
- assert job_id == 1
- assert handle is not None
- assert handle.cluster_name == name
- assert handle.launched_resources.cloud.is_same_cloud(sky.GCP())
- job_id_exec, handle_exec = sky.get(sky.exec(task, cluster_name=name))
- assert job_id_exec == 2
- assert handle_exec is not None
- assert handle_exec.cluster_name == name
- assert handle_exec.launched_resources.cloud.is_same_cloud(sky.GCP())
- # For dummy task (i.e. task.run is None), the job won't be submitted.
- dummy_task = sky.Task()
- job_id_dummy, _ = sky.get(sky.exec(dummy_task, cluster_name=name))
- assert job_id_dummy is None
- sky.get(sky.down(name))
-
-
-# The sky launch CLI has some additional checks to make sure the cluster is up/
-# restarted. However, the core API doesn't have these; make sure it still works
-def test_core_api_sky_launch_fast(generic_cloud: str):
- name = _get_cluster_name()
- cloud = sky.utils.registry.CLOUD_REGISTRY.from_str(generic_cloud)
- try:
- task = sky.Task(run="whoami").set_resources(sky.Resources(cloud=cloud))
- sky.launch(task,
- cluster_name=name,
- idle_minutes_to_autostop=1,
- fast=True)
- # Sleep to let the cluster autostop
- _get_cmd_wait_until_cluster_status_contains(
- cluster_name=name,
- cluster_status=[sky.ClusterStatus.STOPPED],
- timeout=120)
- # Run it again - should work with fast=True
- sky.launch(task,
- cluster_name=name,
- idle_minutes_to_autostop=1,
- fast=True)
- finally:
- sky.down(name)
-
-
-# ---------- Testing Storage ----------
-# These tests are essentially unit tests for Storage, but they require
-# credentials and network connection. Thus, they are included with smoke tests.
-# Since these tests require cloud credentials to verify bucket operations,
-# they should not be run when the API server is remote and the user does not
-# have any credentials locally.
-# TODO(romilb): In the future, we should figure out a way to ship these tests
-# to the API server and run them there. Maybe these tests can be packaged as a
-# SkyPilot task run on a remote cluster launched via the API server.
-@pytest.mark.local
-class TestStorageWithCredentials:
- """Storage tests which require credentials and network connection"""
-
- AWS_INVALID_NAMES = [
- 'ab', # less than 3 characters
- 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
- # more than 63 characters
- 'Abcdef', # contains an uppercase letter
- 'abc def', # contains a space
- 'abc..def', # two adjacent periods
- '192.168.5.4', # formatted as an IP address
- 'xn--bucket', # starts with 'xn--' prefix
- 'bucket-s3alias', # ends with '-s3alias' suffix
- 'bucket--ol-s3', # ends with '--ol-s3' suffix
- '.abc', # starts with a dot
- 'abc.', # ends with a dot
- '-abc', # starts with a hyphen
- 'abc-', # ends with a hyphen
- ]
-
- GCS_INVALID_NAMES = [
- 'ab', # less than 3 characters
- 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
- # more than 63 characters (without dots)
- 'Abcdef', # contains an uppercase letter
- 'abc def', # contains a space
- 'abc..def', # two adjacent periods
- 'abc_.def.ghi.jklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1'
- # More than 63 characters between dots
- 'abc_.def.ghi.jklmnopqrstuvwxyzabcdefghijklmnopqfghijklmnopqrstuvw' * 5,
- # more than 222 characters (with dots)
- '192.168.5.4', # formatted as an IP address
- 'googbucket', # starts with 'goog' prefix
- 'googlebucket', # contains 'google'
- 'g00glebucket', # variant of 'google'
- 'go0glebucket', # variant of 'google'
- 'g0oglebucket', # variant of 'google'
- '.abc', # starts with a dot
- 'abc.', # ends with a dot
- '_abc', # starts with an underscore
- 'abc_', # ends with an underscore
- ]
-
- AZURE_INVALID_NAMES = [
- 'ab', # less than 3 characters
- # more than 63 characters
- 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
- 'Abcdef', # contains an uppercase letter
- '.abc', # starts with a non-letter(dot)
- 'a--bc', # contains consecutive hyphens
- ]
-
- IBM_INVALID_NAMES = [
- 'ab', # less than 3 characters
- 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz1',
- # more than 63 characters
- 'Abcdef', # contains an uppercase letter
- 'abc def', # contains a space
- 'abc..def', # two adjacent periods
- '192.168.5.4', # formatted as an IP address
- 'xn--bucket', # starts with 'xn--' prefix
- '.abc', # starts with a dot
- 'abc.', # ends with a dot
- '-abc', # starts with a hyphen
- 'abc-', # ends with a hyphen
- 'a.-bc', # contains the sequence '.-'
- 'a-.bc', # contains the sequence '-.'
- 'a&bc' # contains special characters
- 'ab^c' # contains special characters
- ]
- GITIGNORE_SYNC_TEST_DIR_STRUCTURE = {
- 'double_asterisk': {
- 'double_asterisk_excluded': None,
- 'double_asterisk_excluded_dir': {
- 'dir_excluded': None,
- },
- },
- 'double_asterisk_parent': {
- 'parent': {
- 'also_excluded.txt': None,
- 'child': {
- 'double_asterisk_parent_child_excluded.txt': None,
- },
- 'double_asterisk_parent_excluded.txt': None,
- },
- },
- 'excluded.log': None,
- 'excluded_dir': {
- 'excluded.txt': None,
- 'nested_excluded': {
- 'excluded': None,
- },
- },
- 'exp-1': {
- 'be_excluded': None,
- },
- 'exp-2': {
- 'be_excluded': None,
- },
- 'front_slash_excluded': None,
- 'included.log': None,
- 'included.txt': None,
- 'include_dir': {
- 'excluded.log': None,
- 'included.log': None,
- },
- 'nested_double_asterisk': {
- 'one': {
- 'also_exclude.txt': None,
- },
- 'two': {
- 'also_exclude.txt': None,
- },
- },
- 'nested_wildcard_dir': {
- 'monday': {
- 'also_exclude.txt': None,
- },
- 'tuesday': {
- 'also_exclude.txt': None,
- },
- },
- 'no_slash_excluded': None,
- 'no_slash_tests': {
- 'no_slash_excluded': {
- 'also_excluded.txt': None,
- },
- },
- 'question_mark': {
- 'excluded1.txt': None,
- 'excluded@.txt': None,
- },
- 'square_bracket': {
- 'excluded1.txt': None,
- },
- 'square_bracket_alpha': {
- 'excludedz.txt': None,
- },
- 'square_bracket_excla': {
- 'excluded2.txt': None,
- 'excluded@.txt': None,
- },
- 'square_bracket_single': {
- 'excluded0.txt': None,
- },
- }
-
- @staticmethod
- def create_dir_structure(base_path, structure):
- # creates a given file STRUCTURE in BASE_PATH
- for name, substructure in structure.items():
- path = os.path.join(base_path, name)
- if substructure is None:
- # Create a file
- open(path, 'a', encoding='utf-8').close()
- else:
- # Create a subdirectory
- os.mkdir(path)
- TestStorageWithCredentials.create_dir_structure(
- path, substructure)
-
- @staticmethod
- def cli_delete_cmd(store_type,
- bucket_name,
- storage_account_name: str = None):
- if store_type == storage_lib.StoreType.S3:
- url = f's3://{bucket_name}'
- return f'aws s3 rb {url} --force'
- if store_type == storage_lib.StoreType.GCS:
- url = f'gs://{bucket_name}'
- gsutil_alias, alias_gen = data_utils.get_gsutil_command()
- return f'{alias_gen}; {gsutil_alias} rm -r {url}'
- if store_type == storage_lib.StoreType.AZURE:
- default_region = 'eastus'
- storage_account_name = (
- storage_lib.AzureBlobStore.get_default_storage_account_name(
- default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- return ('az storage container delete '
- f'--account-name {storage_account_name} '
- f'--account-key {storage_account_key} '
- f'--name {bucket_name}')
- if store_type == storage_lib.StoreType.R2:
- endpoint_url = cloudflare.create_endpoint()
- url = f's3://{bucket_name}'
- return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 rb {url} --force --endpoint {endpoint_url} --profile=r2'
- if store_type == storage_lib.StoreType.IBM:
- bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name(
- bucket_name, Rclone.RcloneClouds.IBM)
- return f'rclone purge {bucket_rclone_profile}:{bucket_name} && rclone config delete {bucket_rclone_profile}'
-
- @staticmethod
- def cli_ls_cmd(store_type, bucket_name, suffix=''):
- if store_type == storage_lib.StoreType.S3:
- if suffix:
- url = f's3://{bucket_name}/{suffix}'
- else:
- url = f's3://{bucket_name}'
- return f'aws s3 ls {url}'
- if store_type == storage_lib.StoreType.GCS:
- if suffix:
- url = f'gs://{bucket_name}/{suffix}'
- else:
- url = f'gs://{bucket_name}'
- return f'gsutil ls {url}'
- if store_type == storage_lib.StoreType.AZURE:
- default_region = 'eastus'
- config_storage_account = skypilot_config.get_nested(
- ('azure', 'storage_account'), None)
- storage_account_name = config_storage_account if (
- config_storage_account is not None) else (
- storage_lib.AzureBlobStore.get_default_storage_account_name(
- default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- list_cmd = ('az storage blob list '
- f'--container-name {bucket_name} '
- f'--prefix {shlex.quote(suffix)} '
- f'--account-name {storage_account_name} '
- f'--account-key {storage_account_key}')
- return list_cmd
- if store_type == storage_lib.StoreType.R2:
- endpoint_url = cloudflare.create_endpoint()
- if suffix:
- url = f's3://{bucket_name}/{suffix}'
- else:
- url = f's3://{bucket_name}'
- return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2'
- if store_type == storage_lib.StoreType.IBM:
- bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name(
- bucket_name, Rclone.RcloneClouds.IBM)
- return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}'
-
- @staticmethod
- def cli_region_cmd(store_type, bucket_name=None, storage_account_name=None):
- if store_type == storage_lib.StoreType.S3:
- assert bucket_name is not None
- return ('aws s3api get-bucket-location '
- f'--bucket {bucket_name} --output text')
- elif store_type == storage_lib.StoreType.GCS:
- assert bucket_name is not None
- return (f'gsutil ls -L -b gs://{bucket_name}/ | '
- 'grep "Location constraint" | '
- 'awk \'{print tolower($NF)}\'')
- elif store_type == storage_lib.StoreType.AZURE:
- # For Azure Blob Storage, the location of the containers are
- # determined by the location of storage accounts.
- assert storage_account_name is not None
- return (f'az storage account show --name {storage_account_name} '
- '--query "primaryLocation" --output tsv')
- else:
- raise NotImplementedError(f'Region command not implemented for '
- f'{store_type}')
-
- @staticmethod
- def cli_count_name_in_bucket(store_type,
- bucket_name,
- file_name,
- suffix='',
- storage_account_name=None):
- if store_type == storage_lib.StoreType.S3:
- if suffix:
- return f'aws s3api list-objects --bucket "{bucket_name}" --prefix {suffix} --query "length(Contents[?contains(Key,\'{file_name}\')].Key)"'
- else:
- return f'aws s3api list-objects --bucket "{bucket_name}" --query "length(Contents[?contains(Key,\'{file_name}\')].Key)"'
- elif store_type == storage_lib.StoreType.GCS:
- if suffix:
- return f'gsutil ls -r gs://{bucket_name}/{suffix} | grep "{file_name}" | wc -l'
- else:
- return f'gsutil ls -r gs://{bucket_name} | grep "{file_name}" | wc -l'
- elif store_type == storage_lib.StoreType.AZURE:
- if storage_account_name is None:
- default_region = 'eastus'
- storage_account_name = (
- storage_lib.AzureBlobStore.get_default_storage_account_name(
- default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- return ('az storage blob list '
- f'--container-name {bucket_name} '
- f'--prefix {shlex.quote(suffix)} '
- f'--account-name {storage_account_name} '
- f'--account-key {storage_account_key} | '
- f'grep {file_name} | '
- 'wc -l')
- elif store_type == storage_lib.StoreType.R2:
- endpoint_url = cloudflare.create_endpoint()
- if suffix:
- return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3api list-objects --bucket "{bucket_name}" --prefix {suffix} --query "length(Contents[?contains(Key,\'{file_name}\')].Key)" --endpoint {endpoint_url} --profile=r2'
- else:
- return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3api list-objects --bucket "{bucket_name}" --query "length(Contents[?contains(Key,\'{file_name}\')].Key)" --endpoint {endpoint_url} --profile=r2'
-
- @staticmethod
- def cli_count_file_in_bucket(store_type, bucket_name):
- if store_type == storage_lib.StoreType.S3:
- return f'aws s3 ls s3://{bucket_name} --recursive | wc -l'
- elif store_type == storage_lib.StoreType.GCS:
- return f'gsutil ls -r gs://{bucket_name}/** | wc -l'
- elif store_type == storage_lib.StoreType.AZURE:
- default_region = 'eastus'
- storage_account_name = (
- storage_lib.AzureBlobStore.get_default_storage_account_name(
- default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- return ('az storage blob list '
- f'--container-name {bucket_name} '
- f'--account-name {storage_account_name} '
- f'--account-key {storage_account_key} | '
- 'grep \\"name\\": | '
- 'wc -l')
- elif store_type == storage_lib.StoreType.R2:
- endpoint_url = cloudflare.create_endpoint()
- return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls s3://{bucket_name} --recursive --endpoint {endpoint_url} --profile=r2 | wc -l'
-
- @pytest.fixture
- def tmp_source(self, tmp_path):
- # Creates a temporary directory with a file in it
- tmp_dir = tmp_path / 'tmp-source'
- tmp_dir.mkdir()
- tmp_file = tmp_dir / 'tmp-file'
- tmp_file.write_text('test')
- circle_link = tmp_dir / 'circle-link'
- circle_link.symlink_to(tmp_dir, target_is_directory=True)
- yield str(tmp_dir)
-
- @staticmethod
- def generate_bucket_name():
- # Creates a temporary bucket name
- # time.time() returns varying precision on different systems, so we
- # replace the decimal point and use whatever precision we can get.
- timestamp = str(time.time()).replace('.', '')
- return f'sky-test-{timestamp}'
-
- @pytest.fixture
- def tmp_bucket_name(self):
- yield self.generate_bucket_name()
-
- @staticmethod
- def yield_storage_object(
- name: Optional[str] = None,
- source: Optional[storage_lib.Path] = None,
- stores: Optional[Dict[storage_lib.StoreType,
- storage_lib.AbstractStore]] = None,
- persistent: Optional[bool] = True,
- mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT):
- # Creates a temporary storage object. Stores must be added in the test.
- storage_obj = storage_lib.Storage(name=name,
- source=source,
- stores=stores,
- persistent=persistent,
- mode=mode)
- storage_obj.construct()
- yield storage_obj
- handle = global_user_state.get_handle_from_storage_name(
- storage_obj.name)
- if handle:
- # If handle exists, delete manually
- # TODO(romilb): This is potentially risky - if the delete method has
- # bugs, this can cause resource leaks. Ideally we should manually
- # eject storage from global_user_state and delete the bucket using
- # boto3 directly.
- storage_obj.delete()
-
- @pytest.fixture
- def tmp_scratch_storage_obj(self, tmp_bucket_name):
- # Creates a storage object with no source to create a scratch storage.
- # Stores must be added in the test.
- yield from self.yield_storage_object(name=tmp_bucket_name)
-
- @pytest.fixture
- def tmp_multiple_scratch_storage_obj(self):
- # Creates a list of 5 storage objects with no source to create
- # multiple scratch storages.
- # Stores for each object in the list must be added in the test.
- storage_mult_obj = []
- for _ in range(5):
- timestamp = str(time.time()).replace('.', '')
- store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}')
- store_obj.construct()
- storage_mult_obj.append(store_obj)
- yield storage_mult_obj
- for storage_obj in storage_mult_obj:
- handle = global_user_state.get_handle_from_storage_name(
- storage_obj.name)
- if handle:
- # If handle exists, delete manually
- # TODO(romilb): This is potentially risky - if the delete method has
- # bugs, this can cause resource leaks. Ideally we should manually
- # eject storage from global_user_state and delete the bucket using
- # boto3 directly.
- storage_obj.delete()
-
- @pytest.fixture
- def tmp_multiple_custom_source_storage_obj(self):
- # Creates a list of storage objects with custom source names to
- # create multiple scratch storages.
- # Stores for each object in the list must be added in the test.
- custom_source_names = ['"path With Spaces"', 'path With Spaces']
- storage_mult_obj = []
- for name in custom_source_names:
- src_path = os.path.expanduser(f'~/{name}')
- pathlib.Path(src_path).expanduser().mkdir(exist_ok=True)
- timestamp = str(time.time()).replace('.', '')
- store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}',
- source=src_path)
- store_obj.construct()
- storage_mult_obj.append(store_obj)
- yield storage_mult_obj
- for storage_obj in storage_mult_obj:
- handle = global_user_state.get_handle_from_storage_name(
- storage_obj.name)
- if handle:
- storage_obj.delete()
-
- @pytest.fixture
- def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source):
- # Creates a temporary storage object. Stores must be added in the test.
- yield from self.yield_storage_object(name=tmp_bucket_name,
- source=tmp_source)
-
- @pytest.fixture
- def tmp_local_list_storage_obj(self, tmp_bucket_name, tmp_source):
- # Creates a temp storage object which uses a list of paths as source.
- # Stores must be added in the test. After upload, the bucket should
- # have two files - /tmp-file and /tmp-source/tmp-file
- list_source = [tmp_source, tmp_source + '/tmp-file']
- yield from self.yield_storage_object(name=tmp_bucket_name,
- source=list_source)
-
- @pytest.fixture
- def tmp_bulk_del_storage_obj(self, tmp_bucket_name):
- # Creates a temporary storage object for testing bulk deletion.
- # Stores must be added in the test.
- with tempfile.TemporaryDirectory() as tmpdir:
- subprocess.check_output(f'mkdir -p {tmpdir}/folder{{000..255}}',
- shell=True)
- subprocess.check_output(f'touch {tmpdir}/test{{000..255}}.txt',
- shell=True)
- subprocess.check_output(
- f'touch {tmpdir}/folder{{000..255}}/test.txt', shell=True)
- yield from self.yield_storage_object(name=tmp_bucket_name,
- source=tmpdir)
-
- @pytest.fixture
- def tmp_copy_mnt_existing_storage_obj(self, tmp_scratch_storage_obj):
- # Creates a copy mount storage which reuses an existing storage object.
- tmp_scratch_storage_obj.add_store(storage_lib.StoreType.S3)
- storage_name = tmp_scratch_storage_obj.name
-
- # Try to initialize another storage with the storage object created
- # above, but now in COPY mode. This should succeed.
- yield from self.yield_storage_object(name=storage_name,
- mode=storage_lib.StorageMode.COPY)
-
- @pytest.fixture
- def tmp_gitignore_storage_obj(self, tmp_bucket_name, gitignore_structure):
- # Creates a temporary storage object for testing .gitignore filter.
- # GITIGINORE_STRUCTURE is representing a file structure in a dictionary
- # format. Created storage object will contain the file structure along
- # with .gitignore and .git/info/exclude files to test exclude filter.
- # Stores must be added in the test.
- with tempfile.TemporaryDirectory() as tmpdir:
- # Creates file structure to be uploaded in the Storage
- self.create_dir_structure(tmpdir, gitignore_structure)
-
- # Create .gitignore and list files/dirs to be excluded in it
- skypilot_path = os.path.dirname(os.path.dirname(sky.__file__))
- temp_path = f'{tmpdir}/.gitignore'
- file_path = os.path.join(skypilot_path, 'tests/gitignore_test')
- shutil.copyfile(file_path, temp_path)
-
- # Create .git/info/exclude and list files/dirs to be excluded in it
- temp_path = f'{tmpdir}/.git/info/'
- os.makedirs(temp_path)
- temp_exclude_path = os.path.join(temp_path, 'exclude')
- file_path = os.path.join(skypilot_path,
- 'tests/git_info_exclude_test')
- shutil.copyfile(file_path, temp_exclude_path)
-
- # Create sky Storage with the files created
- yield from self.yield_storage_object(
- name=tmp_bucket_name,
- source=tmpdir,
- mode=storage_lib.StorageMode.COPY)
-
- @pytest.fixture
- def tmp_awscli_bucket(self, tmp_bucket_name):
- # Creates a temporary bucket using awscli
- bucket_uri = f's3://{tmp_bucket_name}'
- subprocess.check_call(['aws', 's3', 'mb', bucket_uri])
- yield tmp_bucket_name, bucket_uri
- subprocess.check_call(['aws', 's3', 'rb', bucket_uri, '--force'])
-
- @pytest.fixture
- def tmp_gsutil_bucket(self, tmp_bucket_name):
- # Creates a temporary bucket using gsutil
- bucket_uri = f'gs://{tmp_bucket_name}'
- subprocess.check_call(['gsutil', 'mb', bucket_uri])
- yield tmp_bucket_name, bucket_uri
- subprocess.check_call(['gsutil', 'rm', '-r', bucket_uri])
-
- @pytest.fixture
- def tmp_az_bucket(self, tmp_bucket_name):
- # Creates a temporary bucket using gsutil
- default_region = 'eastus'
- storage_account_name = (
- storage_lib.AzureBlobStore.get_default_storage_account_name(
- default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- bucket_uri = data_utils.AZURE_CONTAINER_URL.format(
- storage_account_name=storage_account_name,
- container_name=tmp_bucket_name)
- subprocess.check_call([
- 'az', 'storage', 'container', 'create', '--name',
- f'{tmp_bucket_name}', '--account-name', f'{storage_account_name}',
- '--account-key', f'{storage_account_key}'
- ])
- yield tmp_bucket_name, bucket_uri
- subprocess.check_call([
- 'az', 'storage', 'container', 'delete', '--name',
- f'{tmp_bucket_name}', '--account-name', f'{storage_account_name}',
- '--account-key', f'{storage_account_key}'
- ])
-
- @pytest.fixture
- def tmp_awscli_bucket_r2(self, tmp_bucket_name):
- # Creates a temporary bucket using awscli
- endpoint_url = cloudflare.create_endpoint()
- bucket_uri = f's3://{tmp_bucket_name}'
- subprocess.check_call(
- f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 mb {bucket_uri} --endpoint {endpoint_url} --profile=r2',
- shell=True)
- yield tmp_bucket_name, bucket_uri
- subprocess.check_call(
- f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 rb {bucket_uri} --force --endpoint {endpoint_url} --profile=r2',
- shell=True)
-
- @pytest.fixture
- def tmp_ibm_cos_bucket(self, tmp_bucket_name):
- # Creates a temporary bucket using IBM COS API
- storage_obj = storage_lib.IBMCosStore(source="", name=tmp_bucket_name)
- yield tmp_bucket_name
- storage_obj.delete()
-
- @pytest.fixture
- def tmp_public_storage_obj(self, request):
- # Initializes a storage object with a public bucket
- storage_obj = storage_lib.Storage(source=request.param)
- storage_obj.construct()
- yield storage_obj
- # This does not require any deletion logic because it is a public bucket
- # and should not get added to global_user_state.
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('store_type', [
- storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
- pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
- pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
- pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
- ])
- def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj,
- store_type):
- # Creates a new bucket with a local source, uploads files to it
- # and deletes it.
- tmp_local_storage_obj.add_store(store_type)
-
- # Run sky storage ls to check if storage object exists in the output
- out = subprocess.check_output(['sky', 'storage', 'ls'])
- assert tmp_local_storage_obj.name in out.decode('utf-8')
-
- # Run sky storage delete to delete the storage object
- subprocess.check_output(
- ['sky', 'storage', 'delete', tmp_local_storage_obj.name, '--yes'])
-
- # Run sky storage ls to check if storage object is deleted
- out = subprocess.check_output(['sky', 'storage', 'ls'])
- assert tmp_local_storage_obj.name not in out.decode('utf-8')
-
- @pytest.mark.no_fluidstack
- @pytest.mark.xdist_group('multiple_bucket_deletion')
- @pytest.mark.parametrize('store_type', [
- storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
- pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
- pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare),
- pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm)
- ])
- def test_multiple_buckets_creation_and_deletion(
- self, tmp_multiple_scratch_storage_obj, store_type):
- # Creates multiple new buckets(5 buckets) with a local source
- # and deletes them.
- storage_obj_name = []
- for store_obj in tmp_multiple_scratch_storage_obj:
- store_obj.add_store(store_type)
- storage_obj_name.append(store_obj.name)
-
- # Run sky storage ls to check if all storage objects exists in the
- # output filtered by store type
- out_all = subprocess.check_output(['sky', 'storage', 'ls'])
- out = [
- item.split()[0]
- for item in out_all.decode('utf-8').splitlines()
- if store_type.value in item
- ]
- assert all([item in out for item in storage_obj_name])
-
- # Run sky storage delete all to delete all storage objects
- delete_cmd = ['sky', 'storage', 'delete', '--yes']
- delete_cmd += storage_obj_name
- subprocess.check_output(delete_cmd)
-
- # Run sky storage ls to check if all storage objects filtered by store
- # type are deleted
- out_all = subprocess.check_output(['sky', 'storage', 'ls'])
- out = [
- item.split()[0]
- for item in out_all.decode('utf-8').splitlines()
- if store_type.value in item
- ]
- assert all([item not in out for item in storage_obj_name])
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('store_type', [
- storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
- pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
- pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
- pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
- ])
- def test_upload_source_with_spaces(self, store_type,
- tmp_multiple_custom_source_storage_obj):
- # Creates two buckets with specified local sources
- # with spaces in the name
- storage_obj_names = []
- for storage_obj in tmp_multiple_custom_source_storage_obj:
- storage_obj.add_store(store_type)
- storage_obj_names.append(storage_obj.name)
-
- # Run sky storage ls to check if all storage objects exists in the
- # output filtered by store type
- out_all = subprocess.check_output(['sky', 'storage', 'ls'])
- out = [
- item.split()[0]
- for item in out_all.decode('utf-8').splitlines()
- if store_type.value in item
- ]
- assert all([item in out for item in storage_obj_names])
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('store_type', [
- storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
- pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
- pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
- pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
- ])
- def test_bucket_external_deletion(self, tmp_scratch_storage_obj,
- store_type):
- # Creates a bucket, deletes it externally using cloud cli commands
- # and then tries to delete it using sky storage delete.
- tmp_scratch_storage_obj.add_store(store_type)
-
- # Run sky storage ls to check if storage object exists in the output
- out = subprocess.check_output(['sky', 'storage', 'ls'])
- assert tmp_scratch_storage_obj.name in out.decode('utf-8')
-
- # Delete bucket externally
- cmd = self.cli_delete_cmd(store_type, tmp_scratch_storage_obj.name)
- subprocess.check_output(cmd, shell=True)
-
- # Run sky storage delete to delete the storage object
- out = subprocess.check_output(
- ['sky', 'storage', 'delete', tmp_scratch_storage_obj.name, '--yes'])
- # Make sure bucket was not created during deletion (see issue #1322)
- assert 'created' not in out.decode('utf-8').lower()
-
- # Run sky storage ls to check if storage object is deleted
- out = subprocess.check_output(['sky', 'storage', 'ls'])
- assert tmp_scratch_storage_obj.name not in out.decode('utf-8')
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('store_type', [
- storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
- pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
- pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
- pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
- ])
- def test_bucket_bulk_deletion(self, store_type, tmp_bulk_del_storage_obj):
- # Creates a temp folder with over 256 files and folders, upload
- # files and folders to a new bucket, then delete bucket.
- tmp_bulk_del_storage_obj.add_store(store_type)
-
- subprocess.check_output([
- 'sky', 'storage', 'delete', tmp_bulk_del_storage_obj.name, '--yes'
- ])
-
- output = subprocess.check_output(['sky', 'storage', 'ls'])
- assert tmp_bulk_del_storage_obj.name not in output.decode('utf-8')
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize(
- 'tmp_public_storage_obj, store_type',
- [('s3://tcga-2-open', storage_lib.StoreType.S3),
- ('s3://digitalcorpora', storage_lib.StoreType.S3),
- ('gs://gcp-public-data-sentinel-2', storage_lib.StoreType.GCS),
- pytest.param(
- 'https://azureopendatastorage.blob.core.windows.net/nyctlc',
- storage_lib.StoreType.AZURE,
- marks=pytest.mark.azure)],
- indirect=['tmp_public_storage_obj'])
- def test_public_bucket(self, tmp_public_storage_obj, store_type):
- # Creates a new bucket with a public source and verifies that it is not
- # added to global_user_state.
- tmp_public_storage_obj.add_store(store_type)
-
- # Run sky storage ls to check if storage object exists in the output
- out = subprocess.check_output(['sky', 'storage', 'ls'])
- assert tmp_public_storage_obj.name not in out.decode('utf-8')
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize(
- 'nonexist_bucket_url',
- [
- 's3://{random_name}',
- 'gs://{random_name}',
- pytest.param(
- 'https://{account_name}.blob.core.windows.net/{random_name}', # pylint: disable=line-too-long
- marks=pytest.mark.azure),
- pytest.param('cos://us-east/{random_name}', marks=pytest.mark.ibm),
- pytest.param('r2://{random_name}', marks=pytest.mark.cloudflare)
- ])
- def test_nonexistent_bucket(self, nonexist_bucket_url):
- # Attempts to create fetch a stroage with a non-existent source.
- # Generate a random bucket name and verify it doesn't exist:
- retry_count = 0
- while True:
- nonexist_bucket_name = str(uuid.uuid4())
- if nonexist_bucket_url.startswith('s3'):
- command = f'aws s3api head-bucket --bucket {nonexist_bucket_name}'
- expected_output = '404'
- elif nonexist_bucket_url.startswith('gs'):
- command = f'gsutil ls {nonexist_bucket_url.format(random_name=nonexist_bucket_name)}'
- expected_output = 'BucketNotFoundException'
- elif nonexist_bucket_url.startswith('https'):
- default_region = 'eastus'
- storage_account_name = (
- storage_lib.AzureBlobStore.get_default_storage_account_name(
- default_region))
- storage_account_key = data_utils.get_az_storage_account_key(
- storage_account_name)
- command = f'az storage container exists --account-name {storage_account_name} --account-key {storage_account_key} --name {nonexist_bucket_name}'
- expected_output = '"exists": false'
- elif nonexist_bucket_url.startswith('r2'):
- endpoint_url = cloudflare.create_endpoint()
- command = f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3api head-bucket --bucket {nonexist_bucket_name} --endpoint {endpoint_url} --profile=r2'
- expected_output = '404'
- elif nonexist_bucket_url.startswith('cos'):
- # Using API calls, since using rclone requires a profile's name
- try:
- expected_output = command = "echo" # avoid unrelated exception in case of failure.
- bucket_name = urllib.parse.urlsplit(
- nonexist_bucket_url.format(
- random_name=nonexist_bucket_name)).path.strip('/')
- client = ibm.get_cos_client('us-east')
- client.head_bucket(Bucket=bucket_name)
- except ibm.ibm_botocore.exceptions.ClientError as e:
- if e.response['Error']['Code'] == '404':
- # success
- return
- else:
- raise ValueError('Unsupported bucket type '
- f'{nonexist_bucket_url}')
-
- # Check if bucket exists using the cli:
- try:
- out = subprocess.check_output(command,
- stderr=subprocess.STDOUT,
- shell=True)
- except subprocess.CalledProcessError as e:
- out = e.output
- out = out.decode('utf-8')
- if expected_output in out:
- break
- else:
- retry_count += 1
- if retry_count > 3:
- raise RuntimeError('Unable to find a nonexistent bucket '
- 'to use. This is higly unlikely - '
- 'check if the tests are correct.')
-
- with pytest.raises(sky.exceptions.StorageBucketGetError,
- match='Attempted to use a non-existent'):
- if nonexist_bucket_url.startswith('https'):
- storage_obj = storage_lib.Storage(
- source=nonexist_bucket_url.format(
- account_name=storage_account_name,
- random_name=nonexist_bucket_name))
- else:
- storage_obj = storage_lib.Storage(
- source=nonexist_bucket_url.format(
- random_name=nonexist_bucket_name))
- storage_obj.construct()
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize(
- 'private_bucket',
- [
- f's3://imagenet',
- f'gs://imagenet',
- pytest.param('https://smoketestprivate.blob.core.windows.net/test',
- marks=pytest.mark.azure), # pylint: disable=line-too-long
- pytest.param('cos://us-east/bucket1', marks=pytest.mark.ibm)
- ])
- def test_private_bucket(self, private_bucket):
- # Attempts to access private buckets not belonging to the user.
- # These buckets are known to be private, but may need to be updated if
- # they are removed by their owners.
- store_type = urllib.parse.urlsplit(private_bucket).scheme
- if store_type == 'https' or store_type == 'cos':
- private_bucket_name = urllib.parse.urlsplit(
- private_bucket).path.strip('/')
- else:
- private_bucket_name = urllib.parse.urlsplit(private_bucket).netloc
- with pytest.raises(
- sky.exceptions.StorageBucketGetError,
- match=storage_lib._BUCKET_FAIL_TO_CONNECT_MESSAGE.format(
- name=private_bucket_name)):
- storage_obj = storage_lib.Storage(source=private_bucket)
- storage_obj.construct()
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('ext_bucket_fixture, store_type',
- [('tmp_awscli_bucket', storage_lib.StoreType.S3),
- ('tmp_gsutil_bucket', storage_lib.StoreType.GCS),
- pytest.param('tmp_az_bucket',
- storage_lib.StoreType.AZURE,
- marks=pytest.mark.azure),
- pytest.param('tmp_ibm_cos_bucket',
- storage_lib.StoreType.IBM,
- marks=pytest.mark.ibm),
- pytest.param('tmp_awscli_bucket_r2',
- storage_lib.StoreType.R2,
- marks=pytest.mark.cloudflare)])
- def test_upload_to_existing_bucket(self, ext_bucket_fixture, request,
- tmp_source, store_type):
- # Tries uploading existing files to newly created bucket (outside of
- # sky) and verifies that files are written.
- bucket_name, _ = request.getfixturevalue(ext_bucket_fixture)
- storage_obj = storage_lib.Storage(name=bucket_name, source=tmp_source)
- storage_obj.construct()
- storage_obj.add_store(store_type)
-
- # Check if tmp_source/tmp-file exists in the bucket using aws cli
- out = subprocess.check_output(self.cli_ls_cmd(store_type, bucket_name),
- shell=True)
- assert 'tmp-file' in out.decode('utf-8'), \
- 'File not found in bucket - output was : {}'.format(out.decode
- ('utf-8'))
-
- # Check symlinks - symlinks don't get copied by sky storage
- assert (pathlib.Path(tmp_source) / 'circle-link').is_symlink(), (
- 'circle-link was not found in the upload source - '
- 'are the test fixtures correct?')
- assert 'circle-link' not in out.decode('utf-8'), (
- 'Symlink found in bucket - ls output was : {}'.format(
- out.decode('utf-8')))
-
- # Run sky storage ls to check if storage object exists in the output.
- # It should not exist because the bucket was created externally.
- out = subprocess.check_output(['sky', 'storage', 'ls'])
- assert storage_obj.name not in out.decode('utf-8')
-
- @pytest.mark.no_fluidstack
- def test_copy_mount_existing_storage(self,
- tmp_copy_mnt_existing_storage_obj):
- # Creates a bucket with no source in MOUNT mode (empty bucket), and
- # then tries to load the same storage in COPY mode.
- tmp_copy_mnt_existing_storage_obj.add_store(storage_lib.StoreType.S3)
- storage_name = tmp_copy_mnt_existing_storage_obj.name
-
- # Check `sky storage ls` to ensure storage object exists
- out = subprocess.check_output(['sky', 'storage', 'ls']).decode('utf-8')
- assert storage_name in out, f'Storage {storage_name} not found in sky storage ls.'
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('store_type', [
- storage_lib.StoreType.S3, storage_lib.StoreType.GCS,
- pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure),
- pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm),
- pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare)
- ])
- def test_list_source(self, tmp_local_list_storage_obj, store_type):
- # Uses a list in the source field to specify a file and a directory to
- # be uploaded to the storage object.
- tmp_local_list_storage_obj.add_store(store_type)
-
- # Check if tmp-file exists in the bucket root using cli
- out = subprocess.check_output(self.cli_ls_cmd(
- store_type, tmp_local_list_storage_obj.name),
- shell=True)
- assert 'tmp-file' in out.decode('utf-8'), \
- 'File not found in bucket - output was : {}'.format(out.decode
- ('utf-8'))
-
- # Check if tmp-file exists in the bucket/tmp-source using cli
- out = subprocess.check_output(self.cli_ls_cmd(
- store_type, tmp_local_list_storage_obj.name, 'tmp-source/'),
- shell=True)
- assert 'tmp-file' in out.decode('utf-8'), \
- 'File not found in bucket - output was : {}'.format(out.decode
- ('utf-8'))
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize('invalid_name_list, store_type',
- [(AWS_INVALID_NAMES, storage_lib.StoreType.S3),
- (GCS_INVALID_NAMES, storage_lib.StoreType.GCS),
- pytest.param(AZURE_INVALID_NAMES,
- storage_lib.StoreType.AZURE,
- marks=pytest.mark.azure),
- pytest.param(IBM_INVALID_NAMES,
- storage_lib.StoreType.IBM,
- marks=pytest.mark.ibm),
- pytest.param(AWS_INVALID_NAMES,
- storage_lib.StoreType.R2,
- marks=pytest.mark.cloudflare)])
- def test_invalid_names(self, invalid_name_list, store_type):
- # Uses a list in the source field to specify a file and a directory to
- # be uploaded to the storage object.
- for name in invalid_name_list:
- with pytest.raises(sky.exceptions.StorageNameError):
- storage_obj = storage_lib.Storage(name=name)
- storage_obj.construct()
- storage_obj.add_store(store_type)
-
- @pytest.mark.no_fluidstack
- @pytest.mark.parametrize(
- 'gitignore_structure, store_type',
- [(GITIGNORE_SYNC_TEST_DIR_STRUCTURE, storage_lib.StoreType.S3),
- (GITIGNORE_SYNC_TEST_DIR_STRUCTURE, storage_lib.StoreType.GCS),
- (GITIGNORE_SYNC_TEST_DIR_STRUCTURE, storage_lib.StoreType.AZURE),
- pytest.param(GITIGNORE_SYNC_TEST_DIR_STRUCTURE,
- storage_lib.StoreType.R2,
- marks=pytest.mark.cloudflare)])
- def test_excluded_file_cloud_storage_upload_copy(self, gitignore_structure,
- store_type,
- tmp_gitignore_storage_obj):
- # tests if files included in .gitignore and .git/info/exclude are
- # excluded from being transferred to Storage
-
- tmp_gitignore_storage_obj.add_store(store_type)
-
- upload_file_name = 'included'
- # Count the number of files with the given file name
- up_cmd = self.cli_count_name_in_bucket(store_type, \
- tmp_gitignore_storage_obj.name, file_name=upload_file_name)
- git_exclude_cmd = self.cli_count_name_in_bucket(store_type, \
- tmp_gitignore_storage_obj.name, file_name='.git')
- cnt_num_file_cmd = self.cli_count_file_in_bucket(
- store_type, tmp_gitignore_storage_obj.name)
-
- up_output = subprocess.check_output(up_cmd, shell=True)
- git_exclude_output = subprocess.check_output(git_exclude_cmd,
- shell=True)
- cnt_output = subprocess.check_output(cnt_num_file_cmd, shell=True)
-
- assert '3' in up_output.decode('utf-8'), \
- 'Files to be included are not completely uploaded.'
- # 1 is read as .gitignore is uploaded
- assert '1' in git_exclude_output.decode('utf-8'), \
- '.git directory should not be uploaded.'
- # 4 files include .gitignore, included.log, included.txt, include_dir/included.log
- assert '4' in cnt_output.decode('utf-8'), \
- 'Some items listed in .gitignore and .git/info/exclude are not excluded.'
-
- @pytest.mark.parametrize('ext_bucket_fixture, store_type',
- [('tmp_awscli_bucket', storage_lib.StoreType.S3),
- ('tmp_gsutil_bucket', storage_lib.StoreType.GCS),
- pytest.param('tmp_awscli_bucket_r2',
- storage_lib.StoreType.R2,
- marks=pytest.mark.cloudflare)])
- def test_externally_created_bucket_mount_without_source(
- self, ext_bucket_fixture, request, store_type):
- # Non-sky managed buckets(buckets created outside of Skypilot CLI)
- # are allowed to be MOUNTed by specifying the URI of the bucket to
- # source field only. When it is attempted by specifying the name of
- # the bucket only, it should error out.
- #
- # TODO(doyoung): Add test for IBM COS. Currently, this is blocked
- # as rclone used to interact with IBM COS does not support feature to
- # create a bucket, and the ibmcloud CLI is not supported in Skypilot.
- # Either of the feature is necessary to simulate an external bucket
- # creation for IBM COS.
- # https://github.com/skypilot-org/skypilot/pull/1966/files#r1253439837
-
- ext_bucket_name, ext_bucket_uri = request.getfixturevalue(
- ext_bucket_fixture)
- # invalid spec
- with pytest.raises(sky.exceptions.StorageSpecError) as e:
- storage_obj = storage_lib.Storage(
- name=ext_bucket_name, mode=storage_lib.StorageMode.MOUNT)
- storage_obj.construct()
- storage_obj.add_store(store_type)
-
- assert 'Attempted to mount a non-sky managed bucket' in str(e)
-
- # valid spec
- storage_obj = storage_lib.Storage(source=ext_bucket_uri,
- mode=storage_lib.StorageMode.MOUNT)
- storage_obj.construct()
- handle = global_user_state.get_handle_from_storage_name(
- storage_obj.name)
- if handle:
- storage_obj.delete()
-
- @pytest.mark.aws
- @pytest.mark.parametrize('region', [
- 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-south-1',
- 'ap-southeast-1', 'ap-southeast-2', 'eu-central-1', 'eu-north-1',
- 'eu-west-1', 'eu-west-2', 'eu-west-3', 'sa-east-1', 'us-east-1',
- 'us-east-2', 'us-west-1', 'us-west-2'
- ])
- def test_aws_regions(self, tmp_local_storage_obj, region):
- # This tests creation and upload to bucket in all AWS s3 regions
- # To test full functionality, use test_managed_jobs_storage above.
- store_type = storage_lib.StoreType.S3
- tmp_local_storage_obj.add_store(store_type, region=region)
- bucket_name = tmp_local_storage_obj.name
-
- # Confirm that the bucket was created in the correct region
- region_cmd = self.cli_region_cmd(store_type, bucket_name=bucket_name)
- out = subprocess.check_output(region_cmd, shell=True)
- output = out.decode('utf-8')
- expected_output_region = region
- if region == 'us-east-1':
- expected_output_region = 'None' # us-east-1 is the default region
- assert expected_output_region in out.decode('utf-8'), (
- f'Bucket was not found in region {region} - '
- f'output of {region_cmd} was: {output}')
-
- # Check if tmp_source/tmp-file exists in the bucket using cli
- ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
- out = subprocess.check_output(ls_cmd, shell=True)
- output = out.decode('utf-8')
- assert 'tmp-file' in output, (
- f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')
-
- @pytest.mark.gcp
- @pytest.mark.parametrize('region', [
- 'northamerica-northeast1', 'northamerica-northeast2', 'us-central1',
- 'us-east1', 'us-east4', 'us-east5', 'us-south1', 'us-west1', 'us-west2',
- 'us-west3', 'us-west4', 'southamerica-east1', 'southamerica-west1',
- 'europe-central2', 'europe-north1', 'europe-southwest1', 'europe-west1',
- 'europe-west2', 'europe-west3', 'europe-west4', 'europe-west6',
- 'europe-west8', 'europe-west9', 'europe-west10', 'europe-west12',
- 'asia-east1', 'asia-east2', 'asia-northeast1', 'asia-northeast2',
- 'asia-northeast3', 'asia-southeast1', 'asia-south1', 'asia-south2',
- 'asia-southeast2', 'me-central1', 'me-central2', 'me-west1',
- 'australia-southeast1', 'australia-southeast2', 'africa-south1'
- ])
- def test_gcs_regions(self, tmp_local_storage_obj, region):
- # This tests creation and upload to bucket in all GCS regions
- # To test full functionality, use test_managed_jobs_storage above.
- store_type = storage_lib.StoreType.GCS
- tmp_local_storage_obj.add_store(store_type, region=region)
- bucket_name = tmp_local_storage_obj.name
-
- # Confirm that the bucket was created in the correct region
- region_cmd = self.cli_region_cmd(store_type, bucket_name=bucket_name)
- out = subprocess.check_output(region_cmd, shell=True)
- output = out.decode('utf-8')
- assert region in out.decode('utf-8'), (
- f'Bucket was not found in region {region} - '
- f'output of {region_cmd} was: {output}')
-
- # Check if tmp_source/tmp-file exists in the bucket using cli
- ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
- out = subprocess.check_output(ls_cmd, shell=True)
- output = out.decode('utf-8')
- assert 'tmp-file' in output, (
- f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')
-
-
-# ---------- Testing YAML Specs ----------
-# Our sky storage requires credentials to check the bucket existance when
-# loading a task from the yaml file, so we cannot make it a unit test.
-class TestYamlSpecs:
- # TODO(zhwu): Add test for `to_yaml_config` for the Storage object.
- # We should not use `examples/storage_demo.yaml` here, since it requires
- # users to ensure bucket names to not exist and/or be unique.
- _TEST_YAML_PATHS = [
- 'examples/minimal.yaml', 'examples/managed_job.yaml',
- 'examples/using_file_mounts.yaml', 'examples/resnet_app.yaml',
- 'examples/multi_hostname.yaml'
- ]
-
- def _is_dict_subset(self, d1, d2):
- """Check if d1 is the subset of d2."""
- for k, v in d1.items():
- if k not in d2:
- if isinstance(v, list) or isinstance(v, dict):
- assert len(v) == 0, (k, v)
- else:
- assert False, (k, v)
- elif isinstance(v, dict):
- assert isinstance(d2[k], dict), (k, v, d2)
- self._is_dict_subset(v, d2[k])
- elif isinstance(v, str):
- if k == 'accelerators':
- resources = sky.Resources()
- resources._set_accelerators(v, None)
- assert resources.accelerators == d2[k], (k, v, d2)
- else:
- assert v.lower() == d2[k].lower(), (k, v, d2[k])
- else:
- assert v == d2[k], (k, v, d2[k])
-
- def _check_equivalent(self, yaml_path):
- """Check if the yaml is equivalent after load and dump again."""
- origin_task_config = common_utils.read_yaml(yaml_path)
-
- task = sky.Task.from_yaml(yaml_path)
- new_task_config = task.to_yaml_config()
- # d1 <= d2
- print(origin_task_config, new_task_config)
- self._is_dict_subset(origin_task_config, new_task_config)
-
- def test_load_dump_yaml_config_equivalent(self):
- """Test if the yaml config is equivalent after load and dump again."""
- pathlib.Path('~/datasets').expanduser().mkdir(exist_ok=True)
- pathlib.Path('~/tmpfile').expanduser().touch()
- pathlib.Path('~/.ssh').expanduser().mkdir(exist_ok=True)
- pathlib.Path('~/.ssh/id_rsa.pub').expanduser().touch()
- pathlib.Path('~/tmp-workdir').expanduser().mkdir(exist_ok=True)
- pathlib.Path('~/Downloads/tpu').expanduser().mkdir(parents=True,
- exist_ok=True)
- for yaml_path in self._TEST_YAML_PATHS:
- self._check_equivalent(yaml_path)
-
-
-# ---------- Testing Multiple Accelerators ----------
-@pytest.mark.no_fluidstack # Fluidstack does not support K80 gpus for now
-@pytest.mark.no_paperspace # Paperspace does not support K80 gpus
-def test_multiple_accelerators_ordered():
- name = _get_cluster_name()
- test = Test(
- 'multiple-accelerators-ordered',
- [
- f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_ordered.yaml | grep "Using user-specified accelerators list"',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- timeout=20 * 60,
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs
-@pytest.mark.no_paperspace # Paperspace does not support T4 GPUs
-def test_multiple_accelerators_ordered_with_default():
- name = _get_cluster_name()
- test = Test(
- 'multiple-accelerators-ordered',
- [
- f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_ordered_with_default.yaml | grep "Using user-specified accelerators list"',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status {name} | grep Spot',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs
-@pytest.mark.no_paperspace # Paperspace does not support T4 GPUs
-def test_multiple_accelerators_unordered():
- name = _get_cluster_name()
- test = Test(
- 'multiple-accelerators-unordered',
- [
- f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_unordered.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs
-@pytest.mark.no_paperspace # Paperspace does not support T4 GPUs
-def test_multiple_accelerators_unordered_with_default():
- name = _get_cluster_name()
- test = Test(
- 'multiple-accelerators-unordered-with-default',
- [
- f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_unordered_with_default.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- f'sky status {name} | grep Spot',
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-@pytest.mark.no_fluidstack # Requires other clouds to be enabled
-def test_multiple_resources():
- name = _get_cluster_name()
- test = Test(
- 'multiple-resources',
- [
- f'sky launch -y -c {name} tests/test_yamls/test_multiple_resources.yaml',
- f'sky logs {name} 1 --status', # Ensure the job succeeded.
- ],
- f'sky down -y {name}',
- )
- run_one_test(test)
-
-
-# ---------- Sky Benchmark ----------
-@pytest.mark.no_fluidstack # Requires other clouds to be enabled
-@pytest.mark.no_paperspace # Requires other clouds to be enabled
-@pytest.mark.no_kubernetes
-@pytest.mark.aws # SkyBenchmark requires S3 access
-def test_sky_bench(generic_cloud: str):
- name = _get_cluster_name()
- test = Test(
- 'sky-bench',
- [
- f'sky bench launch -y -b {name} --cloud {generic_cloud} -i0 tests/test_yamls/minimal.yaml',
- 'sleep 120',
- f'sky bench show {name} | grep sky-bench-{name} | grep FINISHED',
- ],
- f'sky bench down {name} -y; sky bench delete {name} -y',
- )
- run_one_test(test)
-
-
-@pytest.mark.kubernetes
-def test_kubernetes_context_failover():
- """Test if the kubernetes context failover works.
-
- This test requires two kubernetes clusters:
- - kind-skypilot: the local cluster with mock labels for 8 H100 GPUs.
- - another accessible cluster: with enough CPUs
- To start the first cluster, run:
- sky local up
- # Add mock label for accelerator
- kubectl label node --overwrite skypilot-control-plane skypilot.co/accelerator=h100 --context kind-skypilot
- # Get the token for the cluster in context kind-skypilot
- TOKEN=$(kubectl config view --minify --context kind-skypilot -o jsonpath=\'{.users[0].user.token}\')
- # Get the API URL for the cluster in context kind-skypilot
- API_URL=$(kubectl config view --minify --context kind-skypilot -o jsonpath=\'{.clusters[0].cluster.server}\')
- # Add mock capacity for GPU
- curl --header "Content-Type: application/json-patch+json" --header "Authorization: Bearer $TOKEN" --request PATCH --data \'[{"op": "add", "path": "/status/capacity/nvidia.com~1gpu", "value": "8"}]\' "$API_URL/api/v1/nodes/skypilot-control-plane/status"
- # Add a new namespace to test the handling of namespaces
- kubectl create namespace test-namespace --context kind-skypilot
- # Set the namespace to test-namespace
- kubectl config set-context kind-skypilot --namespace=test-namespace --context kind-skypilot
- """
- # Get context that is not kind-skypilot
- contexts = subprocess.check_output('kubectl config get-contexts -o name',
- shell=True).decode('utf-8').split('\n')
- context = [context for context in contexts if context != 'kind-skypilot'][0]
- config = textwrap.dedent(f"""\
- kubernetes:
- allowed_contexts:
- - kind-skypilot
- - {context}
- """)
- with tempfile.NamedTemporaryFile(delete=True) as f:
- f.write(config.encode('utf-8'))
- f.flush()
- name = _get_cluster_name()
- test = Test(
- 'kubernetes-context-failover',
- [
- # Check if kind-skypilot is provisioned with H100 annotations already
- 'NODE_INFO=$(kubectl get nodes -o yaml --context kind-skypilot) && '
- 'echo "$NODE_INFO" | grep nvidia.com/gpu | grep 8 && '
- 'echo "$NODE_INFO" | grep skypilot.co/accelerator | grep h100 || '
- '{ echo "kind-skypilot does not exist '
- 'or does not have mock labels for GPUs. Check the instructions in '
- 'tests/test_smoke.py::test_kubernetes_context_failover." && exit 1; }',
- # Check namespace for kind-skypilot is test-namespace
- 'kubectl get namespaces --context kind-skypilot | grep test-namespace || '
- '{ echo "Should set the namespace to test-namespace for kind-skypilot. Check the instructions in '
- 'tests/test_smoke.py::test_kubernetes_context_failover." && exit 1; }',
- 'sky show-gpus --cloud kubernetes --region kind-skypilot | grep H100 | grep "1, 2, 3, 4, 5, 6, 7, 8"',
- # Get contexts and set current context to the other cluster that is not kind-skypilot
- f'kubectl config use-context {context}',
- # H100 should not in the current context
- '! sky show-gpus --cloud kubernetes | grep H100',
- f'sky launch -y -c {name}-1 --cpus 1 echo hi',
- f'sky logs {name}-1 --status',
- # It should be launched not on kind-skypilot
- f'sky status -v {name}-1 | grep "{context}"',
- # Test failure for launching H100 on other cluster
- f'sky launch -y -c {name}-2 --gpus H100 --cpus 1 --cloud kubernetes --region {context} echo hi && exit 1 || true',
- # Test failover
- f'sky launch -y -c {name}-3 --gpus H100 --cpus 1 --cloud kubernetes echo hi',
- f'sky logs {name}-3 --status',
- # Test pods
- f'kubectl get pods --context kind-skypilot | grep "{name}-3"',
- # It should be launched on kind-skypilot
- f'sky status -v {name}-3 | grep "kind-skypilot"',
- # Should be 7 free GPUs
- f'sky show-gpus --cloud kubernetes --region kind-skypilot | grep H100 | grep " 7"',
- # Remove the line with "kind-skypilot"
- f'sed -i "/kind-skypilot/d" {f.name}',
- # Should still be able to exec and launch on existing cluster
- f'sky exec {name}-3 "echo hi"',
- f'sky logs {name}-3 --status',
- f'sky status -r {name}-3 | grep UP',
- f'sky launch -c {name}-3 --gpus h100 echo hi',
- f'sky logs {name}-3 --status',
- f'sky status -r {name}-3 | grep UP',
- ],
- f'sky down -y {name}-1 {name}-3',
- env={'SKYPILOT_CONFIG': f.name},
- )
- run_one_test(test)
+from smoke_tests.test_api_server import *
+from smoke_tests.test_basic import *
+from smoke_tests.test_cluster_job import *
+from smoke_tests.test_images import *
+from smoke_tests.test_managed_job import *
+from smoke_tests.test_mount_and_storage import *
+from smoke_tests.test_region_and_zone import *
+from smoke_tests.test_sky_serve import *
diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py
index 7d304b60633..a9fad1b4b83 100644
--- a/tests/test_yaml_parser.py
+++ b/tests/test_yaml_parser.py
@@ -96,8 +96,8 @@ def test_empty_fields_storage(tmp_path):
storage = task.storage_mounts['/mystorage']
assert storage.name == 'sky-dataset'
assert storage.source is None
- assert len(storage.stores) == 0
- assert storage.persistent is True
+ assert not storage.stores
+ assert storage.persistent
def test_invalid_fields_storage(tmp_path):
diff --git a/tests/test_yamls/intermediate_bucket.yaml b/tests/test_yamls/intermediate_bucket.yaml
new file mode 100644
index 00000000000..fe9aafd0675
--- /dev/null
+++ b/tests/test_yamls/intermediate_bucket.yaml
@@ -0,0 +1,21 @@
+name: intermediate-bucket
+
+file_mounts:
+ /setup.py: ./setup.py
+ /sky: .
+ /train-00001-of-01024: gs://cloud-tpu-test-datasets/fake_imagenet/train-00001-of-01024
+
+workdir: .
+
+
+setup: |
+ echo "running setup"
+
+run: |
+ echo "listing workdir"
+ ls .
+ echo "listing file_mounts"
+ ls /setup.py
+ ls /sky
+ ls /train-00001-of-01024
+ echo "task run finish"
diff --git a/tests/test_yamls/minimal_test_quick_tests_core.yaml b/tests/test_yamls/minimal_test_quick_tests_core.yaml
new file mode 100644
index 00000000000..15857e972dd
--- /dev/null
+++ b/tests/test_yamls/minimal_test_quick_tests_core.yaml
@@ -0,0 +1,13 @@
+resources:
+ cloud: aws
+ instance_type: t3.small
+
+file_mounts:
+ ~/aws: .
+
+workdir: .
+
+num_nodes: 1
+
+run: |
+ ls -l ~/aws/tests/test_yamls/minimal_test_quick_tests_core.yaml
diff --git a/tests/test_yamls/use_intermediate_bucket_config.yaml b/tests/test_yamls/use_intermediate_bucket_config.yaml
new file mode 100644
index 00000000000..cdfb5fbabc1
--- /dev/null
+++ b/tests/test_yamls/use_intermediate_bucket_config.yaml
@@ -0,0 +1,2 @@
+jobs:
+ bucket: "s3://bucket-jobs-s3"
diff --git a/tests/unit_tests/kubernetes/test_gpu_label_formatters.py b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py
new file mode 100644
index 00000000000..cd7337dc7a1
--- /dev/null
+++ b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py
@@ -0,0 +1,22 @@
+"""Tests for GPU label formatting in Kubernetes integration.
+
+Tests verify correct GPU detection from Kubernetes labels.
+"""
+import pytest
+
+from sky.provision.kubernetes.utils import GFDLabelFormatter
+
+
+def test_gfd_label_formatter():
+ """Test word boundary regex matching in GFDLabelFormatter."""
+ # Test various GPU name patterns
+ test_cases = [
+ ('NVIDIA-L4-24GB', 'L4'),
+ ('NVIDIA-L40-48GB', 'L40'),
+ ('NVIDIA-L400', 'L400'), # Should not match L4 or L40
+ ('NVIDIA-L4', 'L4'),
+ ('L40-GPU', 'L40'),
+ ]
+ for input_value, expected in test_cases:
+ result = GFDLabelFormatter.get_accelerator_from_label_value(input_value)
+ assert result == expected, f'Failed for {input_value}'
diff --git a/tests/unit_tests/sky/adaptors/test_oci.py b/tests/unit_tests/sky/adaptors/test_oci.py
new file mode 100644
index 00000000000..59c2b1f99b7
--- /dev/null
+++ b/tests/unit_tests/sky/adaptors/test_oci.py
@@ -0,0 +1,65 @@
+"""Tests for OCI adaptor."""
+import logging
+
+import pytest
+
+from sky import check as sky_check
+from sky.adaptors import oci
+from sky.utils import log_utils
+
+
+def test_oci_circuit_breaker_logging():
+ """Test that OCI circuit breaker logging is properly configured."""
+ # Get the circuit breaker logger
+ logger = logging.getLogger('oci.circuit_breaker')
+
+ # Create a handler that captures log records
+ log_records = []
+ test_handler = logging.Handler()
+ test_handler.emit = lambda record: log_records.append(record)
+ logger.addHandler(test_handler)
+
+ # Create a null handler to suppress logs during import
+ null_handler = logging.NullHandler()
+ logger.addHandler(null_handler)
+
+ try:
+ # Verify logger starts at WARNING level (set by adaptor initialization)
+ initial_level = logger.getEffectiveLevel()
+ print(
+ f'Initial logger level: {initial_level} (WARNING={logging.WARNING})'
+ )
+ assert initial_level == logging.WARNING, (
+ 'OCI circuit breaker logger should be set to WARNING before initialization'
+ )
+
+ # Force OCI module import through LazyImport by accessing a module attribute
+ print('Attempting to import OCI module...')
+ try:
+ # This will trigger LazyImport's load_module for the actual OCI module
+ _ = oci.oci.config.DEFAULT_LOCATION
+ except (ImportError, AttributeError) as e:
+ # Expected when OCI SDK is not installed
+ print(f'Import/Attribute error as expected: {e}')
+ pass
+
+ # Verify logger level after import attempt
+ after_level = logger.getEffectiveLevel()
+ print(
+ f'Logger level after import: {after_level} (WARNING={logging.WARNING})'
+ )
+ assert after_level == logging.WARNING, (
+ 'OCI circuit breaker logger should remain at WARNING after initialization'
+ )
+
+ # Verify no circuit breaker logs were emitted
+ circuit_breaker_logs = [
+ record for record in log_records
+ if 'Circuit breaker' in record.getMessage()
+ ]
+ assert not circuit_breaker_logs, (
+ 'No circuit breaker logs should be emitted during initialization')
+ finally:
+ # Clean up the handlers
+ logger.removeHandler(test_handler)
+ logger.removeHandler(null_handler)
diff --git a/tests/unit_tests/test_storage_utils.py b/tests/unit_tests/test_storage_utils.py
index acb9488feaf..cf0a913fc58 100644
--- a/tests/unit_tests/test_storage_utils.py
+++ b/tests/unit_tests/test_storage_utils.py
@@ -46,7 +46,7 @@ def skyignore_dir():
def test_get_excluded_files_from_skyignore_no_file():
excluded_files = storage_utils.get_excluded_files_from_skyignore('.')
- assert len(excluded_files) == 0
+ assert not excluded_files
def test_get_excluded_files_from_skyignore(skyignore_dir):