diff --git a/flintrock/ec2.py b/flintrock/ec2.py index ecd1cda..195781e 100644 --- a/flintrock/ec2.py +++ b/flintrock/ec2.py @@ -177,13 +177,7 @@ def destroy(self): super().destroy() ec2 = boto3.resource(service_name='ec2', region_name=self.region) - # TODO: Centralize logic to get Flintrock base security group. (?) - flintrock_base_group = list( - ec2.security_groups.filter( - Filters=[ - {'Name': 'group-name', 'Values': ['flintrock']}, - {'Name': 'vpc-id', 'Values': [self.vpc_id]}, - ]))[0] + flintrock_base_group = get_base_security_group(vpc_id=self.vpc_id, region=self.region) # We "unassign" the cluster security group here (i.e. the # 'flintrock-clustername' group) so that we can immediately delete it once @@ -196,13 +190,11 @@ def destroy(self): Groups=[flintrock_base_group.id]) time.sleep(1) - # TODO: Centralize logic to get cluster security group name from cluster name. - cluster_group = list( - ec2.security_groups.filter( - Filters=[ - {'Name': 'group-name', 'Values': ['flintrock-' + self.name]}, - {'Name': 'vpc-id', 'Values': [self.vpc_id]}, - ]))[0] + cluster_group = get_cluster_security_group( + vpc_id=self.vpc_id, + region=self.region, + cluster_name=self.name, + ) cluster_group.delete() (ec2.instances @@ -380,13 +372,10 @@ def remove_slaves(self, *, user: str, identity_file: str, num_slaves: int): if self.state == 'running': super().remove_slaves(user=user, identity_file=identity_file) - # TODO: Centralize logic to get Flintrock base security group. - flintrock_base_group = list( - ec2.security_groups.filter( - Filters=[ - {'Name': 'group-name', 'Values': ['flintrock']}, - {'Name': 'vpc-id', 'Values': [self.vpc_id]}, - ]))[0] + flintrock_base_group = get_base_security_group( + vpc_id=self.vpc_id, + region=self.region, + ) # TODO: Is there a way to do this in one call for all instances? for instance in removed_slave_instances: @@ -490,20 +479,61 @@ def check_network_config(*, region_name: str, vpc_id: str, subnet_id: str): ) -def get_security_groups( - *, - vpc_id, - region, - security_group_names) -> "List[boto3.resource('ec2').SecurityGroup]": +BASE_SECURITY_GROUP_NAME = "flintrock" + + +def get_base_security_group(*, vpc_id, region): + """ + The base Flintrock group is common to all Flintrock clusters and authorizes client traffic + to them. + """ ec2 = boto3.resource(service_name='ec2', region_name=region) + base_group = list( + ec2.security_groups.filter( + Filters=[ + {'Name': 'group-name', 'Values': [BASE_SECURITY_GROUP_NAME]}, + {'Name': 'vpc-id', 'Values': [vpc_id]}, + ] + ) + ) + return base_group[0] if base_group else None - groups = list( + +def get_cluster_security_group_name(cluster_name): + return f"flintrock-{cluster_name}" + + +def get_cluster_security_group(*, vpc_id, region, cluster_name): + """ + The cluster group is specific to one Flintrock cluster and authorizes intra-cluster + communication. + """ + ec2 = boto3.resource(service_name='ec2', region_name=region) + cluster_group_name = get_cluster_security_group_name(cluster_name) + cluster_group = list( ec2.security_groups.filter( Filters=[ - {'Name': 'group-name', 'Values': security_group_names}, + {'Name': 'group-name', 'Values': [cluster_group_name]}, {'Name': 'vpc-id', 'Values': [vpc_id]}, ])) + return cluster_group[0] if cluster_group else None + +def get_security_groups( + *, + vpc_id, + region, + security_group_names, +): + ec2 = boto3.resource(service_name='ec2', region_name=region) + groups = list( + ec2.security_groups.filter( + Filters=[ + {'Name': 'group-name', 'Values': security_group_names}, + {'Name': 'vpc-id', 'Values': [vpc_id]}, + ] + ) + ) found_group_names = [group.group_name for group in groups] missing_group_names = set(security_group_names) - set(found_group_names) if missing_group_names: @@ -511,8 +541,9 @@ def get_security_groups( "Could not find the following security group{s}: {groups}" .format( s='' if len(missing_group_names) == 1 else 's', - groups=', '.join(list(missing_group_names)))) - + groups=', '.join(list(missing_group_names)), + ) + ) return groups @@ -520,7 +551,7 @@ def get_ssh_security_group_rules( *, flintrock_client_cidr=None, flintrock_client_group=None, -) -> "boto3.resource('ec2').SecurityGroup": +): return SecurityGroupRule( ip_protocol='tcp', from_port=22, @@ -531,49 +562,26 @@ def get_ssh_security_group_rules( def get_or_create_flintrock_security_groups( - *, - cluster_name, - vpc_id, - region, - services, - ec2_authorize_access_from, -) -> "List[boto3.resource('ec2').SecurityGroup]": + *, + cluster_name, + vpc_id, + region, + services, + ec2_authorize_access_from, +): """ If they do not already exist, create all the security groups needed for a Flintrock cluster. """ ec2 = boto3.resource(service_name='ec2', region_name=region) - # TODO: Make these into methods, since we need this logic (though simple) - # in multiple places. (?) - flintrock_group_name = 'flintrock' - cluster_group_name = 'flintrock-' + cluster_name - - # The Flintrock group is common to all Flintrock clusters and authorizes client traffic - # to them. - flintrock_group = list( - ec2.security_groups.filter( - Filters=[ - {'Name': 'group-name', 'Values': [flintrock_group_name]}, - {'Name': 'vpc-id', 'Values': [vpc_id]}, - ])) - flintrock_group = flintrock_group[0] if flintrock_group else None - - # The cluster group is specific to one Flintrock cluster and authorizes intra-cluster - # communication. - cluster_group = list( - ec2.security_groups.filter( - Filters=[ - {'Name': 'group-name', 'Values': [cluster_group_name]}, - {'Name': 'vpc-id', 'Values': [vpc_id]}, - ])) - cluster_group = cluster_group[0] if cluster_group else None - + flintrock_group = get_base_security_group(vpc_id=vpc_id, region=region) if not flintrock_group: flintrock_group = ec2.create_security_group( - GroupName=flintrock_group_name, + GroupName=BASE_SECURITY_GROUP_NAME, Description="Flintrock base group", - VpcId=vpc_id) + VpcId=vpc_id, + ) # Rules for the client interacting with the cluster. if ec2_authorize_access_from: @@ -607,12 +615,19 @@ def get_or_create_flintrock_security_groups( flintrock_client_cidr=str(IPv4Network(client_source)), ) + cluster_group_name = get_cluster_security_group_name(cluster_name) + cluster_group = get_cluster_security_group( + vpc_id=vpc_id, + region=region, + cluster_name=cluster_name, + ) # Rules for internal cluster communication. if not cluster_group: cluster_group = ec2.create_security_group( GroupName=cluster_group_name, Description="Flintrock cluster group", - VpcId=vpc_id) + VpcId=vpc_id, + ) # TODO: Don't try adding rules that already exist. # TODO: Add rules in one shot.