Skip to content

Commit

Permalink
Refactor some code related to getting common security groups (#378)
Browse files Browse the repository at this point in the history
Fixes #377.

---------

Co-authored-by: Nicholas Chammas <[email protected]>
  • Loading branch information
tomahawk360 and nchammas authored Dec 13, 2024
1 parent 5a1b681 commit a11946d
Showing 1 changed file with 81 additions and 66 deletions.
147 changes: 81 additions & 66 deletions flintrock/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def destroy(self):
super().destroy()
ec2 = boto3.resource(service_name='ec2', region_name=self.region)

# TODO: Centralize logic to get Flintrock base security group. (?)
flintrock_base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock']},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
flintrock_base_group = get_base_security_group(vpc_id=self.vpc_id, region=self.region)

# We "unassign" the cluster security group here (i.e. the
# 'flintrock-clustername' group) so that we can immediately delete it once
Expand All @@ -196,13 +190,11 @@ def destroy(self):
Groups=[flintrock_base_group.id])
time.sleep(1)

# TODO: Centralize logic to get cluster security group name from cluster name.
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock-' + self.name]},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
cluster_group = get_cluster_security_group(
vpc_id=self.vpc_id,
region=self.region,
cluster_name=self.name,
)
cluster_group.delete()

(ec2.instances
Expand Down Expand Up @@ -380,13 +372,10 @@ def remove_slaves(self, *, user: str, identity_file: str, num_slaves: int):
if self.state == 'running':
super().remove_slaves(user=user, identity_file=identity_file)

# TODO: Centralize logic to get Flintrock base security group.
flintrock_base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock']},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
flintrock_base_group = get_base_security_group(
vpc_id=self.vpc_id,
region=self.region,
)

# TODO: Is there a way to do this in one call for all instances?
for instance in removed_slave_instances:
Expand Down Expand Up @@ -490,37 +479,79 @@ def check_network_config(*, region_name: str, vpc_id: str, subnet_id: str):
)


def get_security_groups(
*,
vpc_id,
region,
security_group_names) -> "List[boto3.resource('ec2').SecurityGroup]":
BASE_SECURITY_GROUP_NAME = "flintrock"


def get_base_security_group(*, vpc_id, region):
"""
The base Flintrock group is common to all Flintrock clusters and authorizes client traffic
to them.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)
base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [BASE_SECURITY_GROUP_NAME]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]
)
)
return base_group[0] if base_group else None

groups = list(

def get_cluster_security_group_name(cluster_name):
return f"flintrock-{cluster_name}"


def get_cluster_security_group(*, vpc_id, region, cluster_name):
"""
The cluster group is specific to one Flintrock cluster and authorizes intra-cluster
communication.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)
cluster_group_name = get_cluster_security_group_name(cluster_name)
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': security_group_names},
{'Name': 'group-name', 'Values': [cluster_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
return cluster_group[0] if cluster_group else None


def get_security_groups(
*,
vpc_id,
region,
security_group_names,
):
ec2 = boto3.resource(service_name='ec2', region_name=region)
groups = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': security_group_names},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]
)
)
found_group_names = [group.group_name for group in groups]
missing_group_names = set(security_group_names) - set(found_group_names)
if missing_group_names:
raise Error(
"Could not find the following security group{s}: {groups}"
.format(
s='' if len(missing_group_names) == 1 else 's',
groups=', '.join(list(missing_group_names))))

groups=', '.join(list(missing_group_names)),
)
)
return groups


def get_ssh_security_group_rules(
*,
flintrock_client_cidr=None,
flintrock_client_group=None,
) -> "boto3.resource('ec2').SecurityGroup":
):
return SecurityGroupRule(
ip_protocol='tcp',
from_port=22,
Expand All @@ -531,49 +562,26 @@ def get_ssh_security_group_rules(


def get_or_create_flintrock_security_groups(
*,
cluster_name,
vpc_id,
region,
services,
ec2_authorize_access_from,
) -> "List[boto3.resource('ec2').SecurityGroup]":
*,
cluster_name,
vpc_id,
region,
services,
ec2_authorize_access_from,
):
"""
If they do not already exist, create all the security groups needed for a
Flintrock cluster.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)

# TODO: Make these into methods, since we need this logic (though simple)
# in multiple places. (?)
flintrock_group_name = 'flintrock'
cluster_group_name = 'flintrock-' + cluster_name

# The Flintrock group is common to all Flintrock clusters and authorizes client traffic
# to them.
flintrock_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [flintrock_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
flintrock_group = flintrock_group[0] if flintrock_group else None

# The cluster group is specific to one Flintrock cluster and authorizes intra-cluster
# communication.
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [cluster_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
cluster_group = cluster_group[0] if cluster_group else None

flintrock_group = get_base_security_group(vpc_id=vpc_id, region=region)
if not flintrock_group:
flintrock_group = ec2.create_security_group(
GroupName=flintrock_group_name,
GroupName=BASE_SECURITY_GROUP_NAME,
Description="Flintrock base group",
VpcId=vpc_id)
VpcId=vpc_id,
)

# Rules for the client interacting with the cluster.
if ec2_authorize_access_from:
Expand Down Expand Up @@ -607,12 +615,19 @@ def get_or_create_flintrock_security_groups(
flintrock_client_cidr=str(IPv4Network(client_source)),
)

cluster_group_name = get_cluster_security_group_name(cluster_name)
cluster_group = get_cluster_security_group(
vpc_id=vpc_id,
region=region,
cluster_name=cluster_name,
)
# Rules for internal cluster communication.
if not cluster_group:
cluster_group = ec2.create_security_group(
GroupName=cluster_group_name,
Description="Flintrock cluster group",
VpcId=vpc_id)
VpcId=vpc_id,
)

# TODO: Don't try adding rules that already exist.
# TODO: Add rules in one shot.
Expand Down

0 comments on commit a11946d

Please sign in to comment.