Skip to content

Commit

Permalink
Merge pull request #6 from rail-berkeley/add-azure-sweeper
Browse files Browse the repository at this point in the history
Add azure sweeper, easy_launch, make azure script args work
  • Loading branch information
Jendker authored Apr 21, 2021
2 parents 4dacc3a + f51a3c9 commit 24df365
Show file tree
Hide file tree
Showing 12 changed files with 1,524 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
aws_config/
*.dar
testing/test_outputs/
doodad/wrappers/easy_launch/config_private.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
10 changes: 5 additions & 5 deletions doodad/apis/azure_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import azure
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient

from doodad.utils import hash_file, REPO_DIR, safe_import
#storage = safe_import.try_import('google.cloud.storage')
from doodad.utils import REPO_DIR, safe_import

blob = safe_import.try_import('azure.storage.blob')
azure = safe_import.try_import('azure')

AZURE_STARTUP_SCRIPT_PATH = os.path.join(REPO_DIR, "scripts/azure/azure_startup_script.sh")
AZURE_SHUTDOWN_SCRIPT_PATH = os.path.join(REPO_DIR, "scripts/azure/azure_shutdown_script.sh")
Expand All @@ -24,7 +24,7 @@ def upload_file_to_azure_storage(
remote_path = 'doodad/mount/' + remote_filename

if not dry:
blob_service_client = BlobServiceClient.from_connection_string(connection_str)
blob_service_client = blob.BlobServiceClient.from_connection_string(connection_str)
blob_client = blob_service_client.get_blob_client(container=container_name, blob=remote_path)
if check_exists:
try:
Expand Down
44 changes: 35 additions & 9 deletions doodad/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def __init__(self,
num_vcpu='default',
promo_price=True,
spot_price=-1,
tags=None,
**kwargs):
super(AzureMode, self).__init__(**kwargs)
self.subscription_id = azure_subscription_id
Expand All @@ -600,19 +601,43 @@ def __init__(self,
self.azure_client_id = azure_client_id
self.azure_authentication_key = azure_authentication_key
self.azure_tenant_id = azure_tenant_id
self.log_path = log_path
self._log_path = log_path
self.terminate_on_end = terminate_on_end
self.preemptible = preemptible
self.region = region
self.instance_type = instance_type
self.spot_max_price = spot_price
if tags is None:
from os import environ, getcwd
getUser = lambda: environ["USERNAME"] if "C:" in getcwd() else environ[
"USER"]
user = getUser()
tags = {
'user': user,
'log_path': log_path,
}
if 'user' not in tags:
raise ValueError("""
Please set `user` in tags (tags = {'user': 'NAME'}) so that we keep
track of who is running which experiment.
""")
self.tags = tags

self.connection_str = azure_storage_connection_str
self.connection_info = dict([k.split('=', 1) for k in self.connection_str.split(';')])

if self.use_gpu:
self.instance_type = azure_util.get_gpu_type_instance(gpu_model, num_gpu, num_vcpu, promo_price)

@property
def log_path(self):
return self._log_path

@log_path.setter
def log_path(self, value):
self._log_path = value
self.tags['log_path'] = value

def __str__(self):
return 'Azure-%s-%s' % (self.azure_resource_group_base, self.instance_type)

Expand Down Expand Up @@ -660,12 +685,12 @@ def run_script(self, script, dry=False, return_output=False, verbose=False):
'shell_interpreter': self.shell_interpreter,
'azure_container_path': self.log_path,
'remote_script_path': remote_script,
'remote_script_args': script_args,
'container_name': self.azure_container,
'terminate': json.dumps(self.terminate_on_end),
'use_gpu': json.dumps(self.use_gpu),
'script_args': script_args,
'startup-script': start_script,
'shutdown-script': stop_script,
'startup_script': start_script,
'shutdown_script': stop_script,
'region': region
}
success, instance_info = self.create_instance(metadata, verbose=verbose)
Expand Down Expand Up @@ -723,6 +748,7 @@ def create_instance(self, metadata, verbose=False):
)
resource_group_params = {
'location': region,
'tags': self.tags,
}
resource_group = resource_group_client.resource_groups.create_or_update(
azure_resource_group,
Expand Down Expand Up @@ -782,14 +808,16 @@ def create_instance(self, metadata, verbose=False):
)
nic = poller.result()

with open(azure_util.AZURE_STARTUP_SCRIPT_PATH, mode='r') as f:
startup_script_str = f.read()
startup_script_str = metadata['startup_script']
# TODO: how do we use this shutdown script?
shutdown_script_str = metadata['shutdown_script']
for old, new in [
('DOODAD_LOG_PATH', self.log_path),
('DOODAD_STORAGE_ACCOUNT_NAME', self.connection_info['AccountName']),
('DOODAD_STORAGE_ACCOUNT_KEY', self.connection_info['AccountKey']),
('DOODAD_CONTAINER_NAME', self.azure_container),
('DOODAD_REMOTE_SCRIPT_PATH', metadata['remote_script_path']),
('DOODAD_REMOTE_SCRIPT_ARGS', metadata['remote_script_args']),
('DOODAD_SHELL_INTERPRETER', metadata['shell_interpreter']),
('DOODAD_TERMINATE_ON_END', metadata['terminate']),
('DOODAD_USE_GPU', metadata['use_gpu'])
Expand Down Expand Up @@ -830,9 +858,7 @@ def create_instance(self, metadata, verbose=False):
'id': nic.id
}]
},
'tags': {
'log_path': self.log_path,
},
'tags': self.tags,
'identity': params_identity,
}
if self.preemptible:
Expand Down
Loading

0 comments on commit 24df365

Please sign in to comment.