Skip to content

Commit

Permalink
Merge master + simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
nchammas committed Dec 13, 2024
2 parents d8f996c + a11946d commit f9d3c30
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 71 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/flintrock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
matrix:
os:
- ubuntu-20.04
- macos-11
- macos-14
python-version:
# Update the artifact upload steps below if modifying
# this list of Python versions.
Expand All @@ -31,7 +31,6 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64
- run: "pip install -r requirements/maintainer.pip"
- run: "pytest"
- run: python -m build
Expand Down
154 changes: 85 additions & 69 deletions flintrock/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def destroy(self):
super().destroy()
ec2 = boto3.resource(service_name='ec2', region_name=self.region)

# TODO: Centralize logic to get Flintrock base security group. (?)
flintrock_base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock']},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
flintrock_base_group = get_base_security_group(vpc_id=self.vpc_id, region=self.region)

# We "unassign" the cluster security group here (i.e. the
# 'flintrock-clustername' group) so that we can immediately delete it once
Expand All @@ -196,16 +190,15 @@ def destroy(self):
Groups=[flintrock_base_group.id])
time.sleep(1)

# TODO: Centralize logic to get cluster security group name from cluster name.
cluster_group_list = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock-' + self.name]},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))
# Cluster group might already have been killed if a destroy was ungracefully stopped during a previous execution
if len(cluster_group_list) > 0:
cluster_group_list[0].delete()
cluster_group = get_cluster_security_group(
vpc_id=self.vpc_id,
region=self.region,
cluster_name=self.name,
)
# Cluster group might already have been killed if a destroy was ungracefully stopped during
# a previous execution.
if cluster_group:
cluster_group.delete()

(ec2.instances
.filter(
Expand Down Expand Up @@ -382,13 +375,10 @@ def remove_slaves(self, *, user: str, identity_file: str, num_slaves: int):
if self.state == 'running':
super().remove_slaves(user=user, identity_file=identity_file)

# TODO: Centralize logic to get Flintrock base security group.
flintrock_base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock']},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
flintrock_base_group = get_base_security_group(
vpc_id=self.vpc_id,
region=self.region,
)

# TODO: Is there a way to do this in one call for all instances?
for instance in removed_slave_instances:
Expand Down Expand Up @@ -492,37 +482,79 @@ def check_network_config(*, region_name: str, vpc_id: str, subnet_id: str):
)


def get_security_groups(
*,
vpc_id,
region,
security_group_names) -> "List[boto3.resource('ec2').SecurityGroup]":
BASE_SECURITY_GROUP_NAME = "flintrock"


def get_base_security_group(*, vpc_id, region):
"""
The base Flintrock group is common to all Flintrock clusters and authorizes client traffic
to them.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)
base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [BASE_SECURITY_GROUP_NAME]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]
)
)
return base_group[0] if base_group else None

groups = list(

def get_cluster_security_group_name(cluster_name):
return f"flintrock-{cluster_name}"


def get_cluster_security_group(*, vpc_id, region, cluster_name):
"""
The cluster group is specific to one Flintrock cluster and authorizes intra-cluster
communication.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)
cluster_group_name = get_cluster_security_group_name(cluster_name)
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': security_group_names},
{'Name': 'group-name', 'Values': [cluster_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
return cluster_group[0] if cluster_group else None


def get_security_groups(
*,
vpc_id,
region,
security_group_names,
):
ec2 = boto3.resource(service_name='ec2', region_name=region)
groups = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': security_group_names},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]
)
)
found_group_names = [group.group_name for group in groups]
missing_group_names = set(security_group_names) - set(found_group_names)
if missing_group_names:
raise Error(
"Could not find the following security group{s}: {groups}"
.format(
s='' if len(missing_group_names) == 1 else 's',
groups=', '.join(list(missing_group_names))))

groups=', '.join(list(missing_group_names)),
)
)
return groups


def get_ssh_security_group_rules(
*,
flintrock_client_cidr=None,
flintrock_client_group=None,
) -> "boto3.resource('ec2').SecurityGroup":
):
return SecurityGroupRule(
ip_protocol='tcp',
from_port=22,
Expand All @@ -533,49 +565,26 @@ def get_ssh_security_group_rules(


def get_or_create_flintrock_security_groups(
*,
cluster_name,
vpc_id,
region,
services,
ec2_authorize_access_from,
) -> "List[boto3.resource('ec2').SecurityGroup]":
*,
cluster_name,
vpc_id,
region,
services,
ec2_authorize_access_from,
):
"""
If they do not already exist, create all the security groups needed for a
Flintrock cluster.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)

# TODO: Make these into methods, since we need this logic (though simple)
# in multiple places. (?)
flintrock_group_name = 'flintrock'
cluster_group_name = 'flintrock-' + cluster_name

# The Flintrock group is common to all Flintrock clusters and authorizes client traffic
# to them.
flintrock_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [flintrock_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
flintrock_group = flintrock_group[0] if flintrock_group else None

# The cluster group is specific to one Flintrock cluster and authorizes intra-cluster
# communication.
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [cluster_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
cluster_group = cluster_group[0] if cluster_group else None

flintrock_group = get_base_security_group(vpc_id=vpc_id, region=region)
if not flintrock_group:
flintrock_group = ec2.create_security_group(
GroupName=flintrock_group_name,
GroupName=BASE_SECURITY_GROUP_NAME,
Description="Flintrock base group",
VpcId=vpc_id)
VpcId=vpc_id,
)

# Rules for the client interacting with the cluster.
if ec2_authorize_access_from:
Expand Down Expand Up @@ -609,12 +618,19 @@ def get_or_create_flintrock_security_groups(
flintrock_client_cidr=str(IPv4Network(client_source)),
)

cluster_group_name = get_cluster_security_group_name(cluster_name)
cluster_group = get_cluster_security_group(
vpc_id=vpc_id,
region=region,
cluster_name=cluster_name,
)
# Rules for internal cluster communication.
if not cluster_group:
cluster_group = ec2.create_security_group(
GroupName=cluster_group_name,
Description="Flintrock cluster group",
VpcId=vpc_id)
VpcId=vpc_id,
)

# TODO: Don't try adding rules that already exist.
# TODO: Add rules in one shot.
Expand Down

0 comments on commit f9d3c30

Please sign in to comment.