Skip to content

Commit

Permalink
Refactor how we launch spot instances (#366)
Browse files Browse the repository at this point in the history
* use create_instances for spot instances too
* deprecate --ec2-spot-request-duration
* remove unused date utilities for translating durations into expirations
* add changelog
  • Loading branch information
nchammas authored Nov 21, 2023
1 parent 7dde875 commit 0a7821b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 148 deletions.
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

* [#348]: Bumped default Spark to 3.2; dropped support for Python 3.6; added CI build for Python 3.10.
* [#361]: Migrated from AdoptOpenJDK, which is deprecated, to Adoptium OpenJDK.
* [#362]: Improved Flintrock's ability to cleanup after launch failures.
* [#362][#366]: Improved Flintrock's ability to cleanup after launch failures.
* [#366]: Deprecated `--ec2-spot-request-duration`, which is not needed for one-time spot instances launched using the RunInstances API.

[#348]: https://github.com/nchammas/flintrock/pull/348
[#361]: https://github.com/nchammas/flintrock/pull/361
[#362]: https://github.com/nchammas/flintrock/pull/362
[#366]: https://github.com/nchammas/flintrock/pull/366

## [2.0.0] - 2021-06-10

Expand Down
1 change: 0 additions & 1 deletion flintrock/config.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ providers:
# ami: ami-61bbf104 # CentOS 7, us-east-1
# user: centos
# spot-price: <price>
# spot-request-duration: 7d # duration a spot request is valid, supports d/h/m/s (e.g. 4d 3h 2m 1s)
# vpc-id: <id>
# subnet-id: <id>
# placement-group: <name>
Expand Down
102 changes: 25 additions & 77 deletions flintrock/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)
from .ssh import generate_ssh_key_pair
from .services import SecurityGroupRule
from .util import duration_to_expiration

logger = logging.getLogger('flintrock.ec2')

Expand Down Expand Up @@ -275,7 +274,6 @@ def add_slaves(
identity_file: str,
num_slaves: int,
spot_price: float,
spot_request_duration: str,
min_root_ebs_size_gb: int,
tags: list,
assume_yes: bool,
Expand Down Expand Up @@ -321,7 +319,6 @@ def add_slaves(
num_instances=num_slaves,
region=self.region,
spot_price=spot_price,
spot_request_valid_until=duration_to_expiration(spot_request_duration),
ami=self.master_instance.image_id,
assume_yes=assume_yes,
key_name=self.master_instance.key_name,
Expand Down Expand Up @@ -704,7 +701,6 @@ def _create_instances(
num_instances,
region,
spot_price,
spot_request_valid_until,
ami,
assume_yes,
key_name,
Expand All @@ -724,7 +720,6 @@ def _create_instances(
ec2 = boto3.resource(service_name='ec2', region_name=region)

cluster_instances = []
spot_requests = []
common_launch_specs = {
'ImageId': ami,
'KeyName': key_name,
Expand All @@ -733,7 +728,6 @@ def _create_instances(
'Placement': {
'AvailabilityZone': availability_zone,
'Tenancy': tenancy,
'GroupName': placement_group,
},
'SecurityGroupIds': security_group_ids,
'SubnetId': subnet_id,
Expand All @@ -748,80 +742,36 @@ def _create_instances(
],
}

if spot_price:
common_launch_specs.update({
'InstanceMarketOptions': {
'MarketType': 'spot',
'SpotOptions': {
'SpotInstanceType': 'one-time',
'MaxPrice': str(spot_price),
'InstanceInterruptionBehavior': 'terminate',
},
}
})
else:
common_launch_specs.update({
'InstanceInitiatedShutdownBehavior': instance_initiated_shutdown_behavior,
})
# This can't be part of the previous update because we need a deep merge.
common_launch_specs['Placement'].update({
'GroupName': placement_group,
})

try:
if spot_price:
user_data = base64.b64encode(user_data.encode('utf-8')).decode()
logger.info("Requesting {c} spot instances at a max price of ${p}...".format(
c=num_instances, p=spot_price))
client = ec2.meta.client
spot_requests = client.request_spot_instances(
SpotPrice=str(spot_price),
InstanceCount=num_instances,
ValidUntil=spot_request_valid_until,
LaunchSpecification=common_launch_specs,
)['SpotInstanceRequests']

request_ids = [r['SpotInstanceRequestId'] for r in spot_requests]
pending_request_ids = request_ids

while pending_request_ids:
logger.info("{grant} of {req} instances granted. Waiting...".format(
grant=num_instances - len(pending_request_ids),
req=num_instances))
time.sleep(30)
spot_requests = client.describe_spot_instance_requests(
SpotInstanceRequestIds=request_ids)['SpotInstanceRequests']

failed_requests = [r for r in spot_requests if r['State'] == 'failed']
if failed_requests:
failure_reasons = {r['Status']['Code'] for r in failed_requests}
raise Error(
"The spot request failed for the following reason{s}: {reasons}"
.format(
s='' if len(failure_reasons) == 1 else 's',
reasons=', '.join(failure_reasons)))

pending_request_ids = [
r['SpotInstanceRequestId'] for r in spot_requests
if r['State'] == 'open']

logger.info("All {c} instances granted.".format(c=num_instances))

cluster_instances = list(
ec2.instances.filter(
Filters=[
{'Name': 'instance-id', 'Values': [r['InstanceId'] for r in spot_requests]}
]))
else:
cluster_instances = ec2.create_instances(
MinCount=num_instances,
MaxCount=num_instances,
# Shutdown Behavior is specific to on-demand instances.
InstanceInitiatedShutdownBehavior=instance_initiated_shutdown_behavior,
**common_launch_specs,
)
cluster_instances = ec2.create_instances(
MinCount=num_instances,
MaxCount=num_instances,
**common_launch_specs,
)
return cluster_instances
except (Exception, KeyboardInterrupt) as e:
if not isinstance(e, KeyboardInterrupt):
print(e, file=sys.stderr)
if spot_requests:
request_ids = [r['SpotInstanceRequestId'] for r in spot_requests]
if any([r['State'] != 'active' for r in spot_requests]):
print("Canceling spot instance requests...", file=sys.stderr)
client.cancel_spot_instance_requests(
SpotInstanceRequestIds=request_ids)
# Make sure we have the latest information on any launched spot instances.
spot_requests = client.describe_spot_instance_requests(
SpotInstanceRequestIds=request_ids)['SpotInstanceRequests']
instance_ids = [
r['InstanceId'] for r in spot_requests
if 'InstanceId' in r]
if instance_ids:
cluster_instances = list(
ec2.instances.filter(
Filters=[
{'Name': 'instance-id', 'Values': instance_ids}
]))
raise InterruptedEC2Operation(instances=cluster_instances) from e


Expand All @@ -842,7 +792,6 @@ def launch(
user,
security_groups,
spot_price=None,
spot_request_duration=None,
min_root_ebs_size_gb,
vpc_id,
subnet_id,
Expand Down Expand Up @@ -916,7 +865,6 @@ def launch(
common_instance_spec = {
'region': region,
'spot_price': spot_price,
'spot_request_valid_until': duration_to_expiration(spot_request_duration),
'ami': ami,
'assume_yes': assume_yes,
'key_name': key_name,
Expand Down
20 changes: 14 additions & 6 deletions flintrock/flintrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
else:
THIS_DIR = os.path.dirname(os.path.realpath(__file__))

EC2_SPOT_REQUEST_DURATION_DEPRECATION_MESSAGE = (
"Deprecation: --ec2-spot-request-duration is deprecated. One-time spot instances do "
"not support a request duration. "
"For more information see: https://github.com/nchammas/flintrock/pull/366"
)

logger = logging.getLogger('flintrock.flintrock')

Expand Down Expand Up @@ -345,8 +350,8 @@ def cli(cli_context, config, provider, debug):
help="Additional security groups names to assign to the instances. "
"You can specify this option multiple times.")
@click.option('--ec2-spot-price', type=float)
@click.option('--ec2-spot-request-duration', default='7d',
help="Duration a spot request is valid (e.g. 3d 2h 1m).")
@click.option('--ec2-spot-request-duration',
help="(DEPRECATED) Duration a spot request is valid (e.g. 3d 2h 1m).")
@click.option('--ec2-min-root-ebs-size-gb', type=int, default=30)
@click.option('--ec2-vpc-id', default='', help="Leave empty for default VPC.")
@click.option('--ec2-subnet-id', default='')
Expand Down Expand Up @@ -414,6 +419,9 @@ def launch(
"""
Launch a new cluster.
"""
if ec2_spot_request_duration:
logger.warning(EC2_SPOT_REQUEST_DURATION_DEPRECATION_MESSAGE)

provider = cli_context.obj['provider']
services = []

Expand Down Expand Up @@ -511,7 +519,6 @@ def launch(
user=ec2_user,
security_groups=ec2_security_groups,
spot_price=ec2_spot_price,
spot_request_duration=ec2_spot_request_duration,
min_root_ebs_size_gb=ec2_min_root_ebs_size_gb,
vpc_id=ec2_vpc_id,
subnet_id=ec2_subnet_id,
Expand Down Expand Up @@ -787,8 +794,8 @@ def stop(cli_context, cluster_name, ec2_region, ec2_vpc_id, assume_yes):
help="Path to SSH .pem file for accessing nodes.")
@click.option('--ec2-user')
@click.option('--ec2-spot-price', type=float)
@click.option('--ec2-spot-request-duration', default='7d',
help="Duration a spot request is valid (e.g. 3d 2h 1m).")
@click.option('--ec2-spot-request-duration',
help="(DEPRECATED) Duration a spot request is valid (e.g. 3d 2h 1m).")
@click.option('--ec2-min-root-ebs-size-gb', type=int, default=30)
@click.option('--assume-yes/--no-assume-yes', default=False)
@click.option('--ec2-tag', 'ec2_tags',
Expand Down Expand Up @@ -816,6 +823,8 @@ def add_slaves(
Flintrock will configure new slaves based on information queried
automatically from the master.
"""
if ec2_spot_request_duration:
logger.warning(EC2_SPOT_REQUEST_DURATION_DEPRECATION_MESSAGE)
provider = cli_context.obj['provider']

option_requires(
Expand All @@ -842,7 +851,6 @@ def add_slaves(
provider_options = {
'min_root_ebs_size_gb': ec2_min_root_ebs_size_gb,
'spot_price': ec2_spot_price,
'spot_request_duration': ec2_spot_request_duration,
'tags': ec2_tags
}
else:
Expand Down
43 changes: 0 additions & 43 deletions flintrock/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import sys
from datetime import datetime, timedelta, timezone
from decimal import Decimal

FROZEN = getattr(sys, 'frozen', False)

Expand All @@ -20,47 +18,6 @@ def get_subprocess_env() -> dict:
return env


def duration_to_timedelta(duration_string):
"""
Convert a time duration string (e.g. 3h 4m 10s) into a timedelta
"""

duration_string = duration_string.lower()

total_seconds = Decimal('0')

prev_num = []
for character in duration_string:
if character.isalpha():
if prev_num:
num = Decimal(''.join(prev_num))
if character == 'd':
total_seconds += num * 60 * 60 * 24
elif character == 'h':
total_seconds += num * 60 * 60
elif character == 'm':
total_seconds += num * 60
elif character == 's':
total_seconds += num
prev_num = []

elif character.isnumeric() or character == '.':
prev_num.append(character)

return timedelta(seconds=float(total_seconds))


def duration_to_expiration(duration_string):
default_duration = timedelta(days=7)

if not duration_string:
expiration = datetime.now(tz=timezone.utc) + default_duration
else:
expiration = datetime.now(tz=timezone.utc) + duration_to_timedelta(duration_string)

return expiration


def spark_hadoop_build_version(hadoop_version: str) -> str:
"""
Given a Hadoop version, determine the Hadoop build of Spark to use.
Expand Down
21 changes: 1 addition & 20 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,4 @@
from datetime import datetime, timedelta, timezone
from flintrock.util import (
duration_to_timedelta,
duration_to_expiration,
spark_hadoop_build_version,
)
from freezegun import freeze_time


def test_duration_to_timedelta():
assert duration_to_timedelta('1d') == timedelta(days=1)
assert duration_to_timedelta('3d2h1m') == timedelta(days=3, hours=2, minutes=1)
assert duration_to_timedelta('4d 2h 1m 5s') == timedelta(days=4, hours=2, minutes=1, seconds=5)
assert duration_to_timedelta('36h') == timedelta(hours=36)
assert duration_to_timedelta('7d') == timedelta(days=7)


@freeze_time("2012-01-14")
def test_duration_to_expiration():
assert duration_to_expiration('5m') == datetime.now(tz=timezone.utc) + timedelta(minutes=5)
from flintrock.util import spark_hadoop_build_version


def test_spark_hadoop_build_version():
Expand Down

0 comments on commit 0a7821b

Please sign in to comment.