From 211b7fd117b7592bb697356bbc245c52fdd3f98c Mon Sep 17 00:00:00 2001 From: E Camden Fisher Date: Mon, 20 Dec 2021 11:10:42 -0500 Subject: [PATCH] update security group rules (#16) * update security group rules --- api/handlers_sgs.go | 39 ++++++ api/orchestration_sgs.go | 138 +++++++++++++-------- api/orchestrators.go | 8 +- api/policy.go | 30 +++++ api/routes.go | 2 +- api/types.go | 12 +- ec2/sgs.go | 53 ++++++++- ec2/sgs_test.go | 250 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 467 insertions(+), 65 deletions(-) diff --git a/api/handlers_sgs.go b/api/handlers_sgs.go index 824bbee..966c18f 100644 --- a/api/handlers_sgs.go +++ b/api/handlers_sgs.go @@ -50,6 +50,45 @@ func (s *server) SecurityGroupCreateHandler(w http.ResponseWriter, r *http.Reque handleResponseOk(w, out) } +func (s *server) SecurityGroupUpdateHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := s.mapAccountNumber(vars["account"]) + id := vars["id"] + + role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName) + policy, err := sgUpdatePolicy(id) + if err != nil { + handleError(w, err) + return + } + + req := &Ec2SecurityGroupRuleRequest{} + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + handleError(w, err) + return + } + + orch, err := s.newEc2Orchestrator(r.Context(), &sessionParams{ + inlinePolicy: policy, + role: role, + policyArns: []string{ + "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", + }, + }) + if err != nil { + handleError(w, err) + return + } + + if err := orch.updateSecurityGroup(r.Context(), id, req); err != nil { + handleError(w, err) + return + } + + handleResponseOk(w, nil) +} + func (s *server) SecurityGroupListHandler(w http.ResponseWriter, r *http.Request) { w = LogWriter{w} vars := mux.Vars(r) diff --git a/api/orchestration_sgs.go b/api/orchestration_sgs.go index 3b82132..fc7cf40 100644 --- a/api/orchestration_sgs.go +++ b/api/orchestration_sgs.go @@ -5,6 +5,7 @@ import ( "github.com/YaleSpinup/apierror" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/aws/aws-sdk-go/service/ec2" log "github.com/sirupsen/logrus" ) @@ -14,6 +15,8 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur return "", apierror.New(apierror.ErrBadRequest, "invalid input", nil) } + log.Debugf("got request to create security group: %s", awsutil.Prettify(req)) + var err error var rollBackTasks []rollbackFunc defer func() { @@ -23,39 +26,34 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur } }() - tags := []*ec2.Tag{} - for _, tag := range req.Tags { - for k, v := range tag { - tags = append(tags, &ec2.Tag{ - Key: aws.String(k), - Value: aws.String(v), - }) - } - } - - out, err := o.client.CreateSecurityGroup(ctx, &ec2.CreateSecurityGroupInput{ + input := &ec2.CreateSecurityGroupInput{ Description: aws.String(req.Description), GroupName: aws.String(req.GroupName), VpcId: aws.String(req.VpcId), - TagSpecifications: []*ec2.TagSpecification{ + } + + if len(req.Tags) > 0 { + input.SetTagSpecifications([]*ec2.TagSpecification{ { ResourceType: aws.String("security-group"), - Tags: tags, + Tags: normalizeTags(req.Tags), }, - }, - }) + }) + } + + out, err := o.ec2Client.CreateSecurityGroup(ctx, input) if err != nil { return "", err } - err = o.client.WaitUntilSecurityGroupExists(ctx, aws.StringValue(out.GroupId)) - if err != nil { + // err is used to trigger rollback, don't shadow it here + if err = o.ec2Client.WaitUntilSecurityGroupExists(ctx, aws.StringValue(out.GroupId)); err != nil { return "", err } rollBackTasks = append(rollBackTasks, func(ctx context.Context) error { log.Errorf("rollback: deleting security group: %s", aws.StringValue(out.GroupId)) - return o.client.DeleteSecurityGroup(ctx, aws.StringValue(out.GroupId)) + return o.ec2Client.DeleteSecurityGroup(ctx, aws.StringValue(out.GroupId)) }) if len(req.InitRules) > 0 { @@ -66,38 +64,10 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur return "", apierror.New(apierror.ErrBadRequest, "cidr_ip or sg_id is required", nil) } - ipPermissions := []*ec2.IpPermission{} - - if r.CidrIp != nil { - ipPermissions = append(ipPermissions, &ec2.IpPermission{ - IpProtocol: r.IpProtocol, - FromPort: r.FromPort, - ToPort: r.ToPort, - IpRanges: []*ec2.IpRange{ - { - CidrIp: r.CidrIp, - Description: r.Description, - }, - }, - }) - } - - if r.SgId != nil { - ipPermissions = append(ipPermissions, &ec2.IpPermission{ - IpProtocol: r.IpProtocol, - FromPort: r.FromPort, - ToPort: r.ToPort, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ - { - GroupId: r.SgId, - Description: r.Description, - }, - }, - }) - } + ipPermissions := ipPermissionsFromRequest(r) - err = o.client.AuthorizeSecurityGroup(ctx, *r.RuleType, aws.StringValue(out.GroupId), ipPermissions) - if err != nil { + // err is used to trigger rollback, don't shadow it here + if err = o.ec2Client.AuthorizeSecurityGroup(ctx, *r.RuleType, aws.StringValue(out.GroupId), ipPermissions); err != nil { return "", err } } @@ -105,3 +75,73 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur return aws.StringValue(out.GroupId), nil } + +func (o *ec2Orchestrator) updateSecurityGroup(ctx context.Context, id string, req *Ec2SecurityGroupRuleRequest) error { + if id == "" || req == nil { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Debugf("got request to update security group %s: %s", id, awsutil.Prettify(req)) + + switch *req.Action { + case "add": + if err := o.ec2Client.AuthorizeSecurityGroup(ctx, *req.RuleType, id, ipPermissionsFromRequest(req)); err != nil { + return err + } + case "remove": + if err := o.ec2Client.RevokeSecurityGroup(ctx, *req.RuleType, id, ipPermissionsFromRequest(req)); err != nil { + return err + } + default: + return apierror.New(apierror.ErrBadRequest, "action should be [add|remove]", nil) + } + + return nil +} + +func ipPermissionsFromRequest(r *Ec2SecurityGroupRuleRequest) []*ec2.IpPermission { + ipPermissions := []*ec2.IpPermission{} + + if r.CidrIp != nil { + ipPermissions = append(ipPermissions, &ec2.IpPermission{ + IpProtocol: r.IpProtocol, + FromPort: r.FromPort, + ToPort: r.ToPort, + IpRanges: []*ec2.IpRange{ + { + CidrIp: r.CidrIp, + Description: r.Description, + }, + }, + }) + } + + if r.SgId != nil { + ipPermissions = append(ipPermissions, &ec2.IpPermission{ + IpProtocol: r.IpProtocol, + FromPort: r.FromPort, + ToPort: r.ToPort, + UserIdGroupPairs: []*ec2.UserIdGroupPair{ + { + GroupId: r.SgId, + Description: r.Description, + }, + }, + }) + } + + return ipPermissions +} + +func normalizeTags(tags []map[string]string) []*ec2.Tag { + t := []*ec2.Tag{} + for _, tag := range tags { + for k, v := range tag { + t = append(t, &ec2.Tag{ + Key: aws.String(k), + Value: aws.String(v), + }) + } + } + return t +} diff --git a/api/orchestrators.go b/api/orchestrators.go index a1478c3..5e1d522 100644 --- a/api/orchestrators.go +++ b/api/orchestrators.go @@ -14,8 +14,8 @@ type sessionParams struct { } type ec2Orchestrator struct { - client *ec2.Ec2 - server *server + ec2Client *ec2.Ec2 + server *server } func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec2Orchestrator, error) { @@ -33,7 +33,7 @@ func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec } return &ec2Orchestrator{ - client: ec2.New(ec2.WithSession(session.Session)), - server: s, + ec2Client: ec2.New(ec2.WithSession(session.Session)), + server: s, }, nil } diff --git a/api/policy.go b/api/policy.go index 0c220b0..e396acc 100644 --- a/api/policy.go +++ b/api/policy.go @@ -91,3 +91,33 @@ func sgCreatePolicy() (string, error) { return string(j), nil } + +func sgUpdatePolicy(id string) (string, error) { + log.Debugf("generating sg crete policy document") + + sgResource := fmt.Sprintf("arn:aws:ec2:*:*:security-group/%s", id) + + policy := iam.PolicyDocument{ + Version: "2012-10-17", + Statement: []iam.StatementEntry{ + { + Effect: "Allow", + Action: []string{ + "ec2:ModifySecurityGroupRules", + "ec2:AuthorizeSecurityGroupEgress", + "ec2:AuthorizeSecurityGroupIngress", + "ec2:RevokeSecurityGroupEgress", + "ec2:RevokeSecurityGroupIngress", + }, + Resource: []string{sgResource}, + }, + }, + } + + j, err := json.Marshal(policy) + if err != nil { + return "", err + } + + return string(j), nil +} diff --git a/api/routes.go b/api/routes.go index 2b8f126..38433b9 100644 --- a/api/routes.go +++ b/api/routes.go @@ -71,7 +71,7 @@ func (s *server) routes() { api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/sgs/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/volumes/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) diff --git a/api/types.go b/api/types.go index 4ea39b4..f47e737 100644 --- a/api/types.go +++ b/api/types.go @@ -344,14 +344,14 @@ func (e *Ec2ImageVolumeMap) MarshalJSON() ([]byte, error) { } type Ec2SecurityGroupRequest struct { - Description string `json:"description"` - GroupName string `json:"group_name"` - InitRules []*Ec2SecurityGroupInitRuleRequest `json:"init_rules"` - Tags []map[string]string `json:"tags"` - VpcId string `json:"vpc_id"` + Description string `json:"description"` + GroupName string `json:"group_name"` + InitRules []*Ec2SecurityGroupRuleRequest `json:"init_rules"` + Tags []map[string]string `json:"tags"` + VpcId string `json:"vpc_id"` } -type Ec2SecurityGroupInitRuleRequest struct { +type Ec2SecurityGroupRuleRequest struct { RuleType *string `json:"rule_type"` // Direction of traffic: [inbound|outbound] Action *string `json:"action"` // Adding or removing the rule: [add|remove] CidrIp *string `json:"cidr_ip"` // IPv4 CIDR address range to allow traffic to/from diff --git a/ec2/sgs.go b/ec2/sgs.go index e636d19..29dd91c 100644 --- a/ec2/sgs.go +++ b/ec2/sgs.go @@ -132,29 +132,72 @@ func (e *Ec2) AuthorizeSecurityGroup(ctx context.Context, direction, sg string, IpPermissions: permissions, }) if err != nil { - return err + return ErrCode("failed authorizing egress", err) } log.Debugf("got output authorizing security group egress: %+v", out) if !aws.BoolValue(out.Return) { - return apierror.New(apierror.ErrInternalError, "security group authorization failed", nil) + return apierror.New(apierror.ErrBadRequest, "security group authorization rule failed", nil) } case "inbound": - var out *ec2.AuthorizeSecurityGroupIngressOutput out, err := e.Service.AuthorizeSecurityGroupIngressWithContext(ctx, &ec2.AuthorizeSecurityGroupIngressInput{ GroupId: aws.String(sg), IpPermissions: permissions, }) if err != nil { - return err + return ErrCode("failed authorizing ingress", err) } log.Debugf("got output authorizing security group ingress: %+v", out) if !aws.BoolValue(out.Return) { - return apierror.New(apierror.ErrInternalError, "security group authorization failed", nil) + return apierror.New(apierror.ErrBadRequest, "security group authorization rule failed", nil) + } + default: + return apierror.New(apierror.ErrBadRequest, "direction is required to be [outbound|inbound]", nil) + } + + return nil +} + +func (e *Ec2) RevokeSecurityGroup(ctx context.Context, direction, sg string, permissions []*ec2.IpPermission) error { + if direction == "" || sg == "" || permissions == nil { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("Revoking security group %s for %s", direction, sg) + + switch direction { + case "outbound": + out, err := e.Service.RevokeSecurityGroupEgressWithContext(ctx, &ec2.RevokeSecurityGroupEgressInput{ + GroupId: aws.String(sg), + IpPermissions: permissions, + }) + if err != nil { + return ErrCode("failed revoking egress", err) + } + + log.Debugf("got output authorizing security group egress: %+v", out) + + if !aws.BoolValue(out.Return) { + return apierror.New(apierror.ErrBadRequest, "security group revoke rule failed", nil) + } + + case "inbound": + out, err := e.Service.RevokeSecurityGroupIngressWithContext(ctx, &ec2.RevokeSecurityGroupIngressInput{ + GroupId: aws.String(sg), + IpPermissions: permissions, + }) + if err != nil { + return ErrCode("failed revoking egress", err) + } + + log.Debugf("got output authorizing security group ingress: %+v", out) + + if !aws.BoolValue(out.Return) { + return apierror.New(apierror.ErrBadRequest, "security group revoke rule failed", nil) } default: return apierror.New(apierror.ErrBadRequest, "direction is required to be [outbound|enbound]", nil) diff --git a/ec2/sgs_test.go b/ec2/sgs_test.go index 8b135ee..76a50d4 100644 --- a/ec2/sgs_test.go +++ b/ec2/sgs_test.go @@ -408,6 +408,34 @@ func (m mockEC2Client) AuthorizeSecurityGroupEgressWithContext(ctx context.Conte return nil, awserr.New("NotFound", "Security group not found", nil) } +func (m mockEC2Client) RevokeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.RevokeSecurityGroupIngressInput, opts ...request.Option) (*ec2.RevokeSecurityGroupIngressOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, securityGroup := range securityGroups { + if aws.StringValue(input.GroupId) == aws.StringValue(securityGroup.GroupId) { + return &ec2.RevokeSecurityGroupIngressOutput{Return: aws.Bool(true)}, nil + } + } + + return nil, awserr.New("NotFound", "Security group not found", nil) +} + +func (m mockEC2Client) RevokeSecurityGroupEgressWithContext(ctx context.Context, input *ec2.RevokeSecurityGroupEgressInput, opts ...request.Option) (*ec2.RevokeSecurityGroupEgressOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, securityGroup := range securityGroups { + if aws.StringValue(input.GroupId) == aws.StringValue(securityGroup.GroupId) { + return &ec2.RevokeSecurityGroupEgressOutput{Return: aws.Bool(true)}, nil + } + } + + return nil, awserr.New("NotFound", "Security group not found", nil) +} + func TestEc2_ListSecurityGroups(t *testing.T) { type fields struct { session *session.Session @@ -1447,3 +1475,225 @@ func TestEc2_CreateSecurityGroup(t *testing.T) { }) } } + +func TestEc2_RevokeSecurityGroup(t *testing.T) { + type fields struct { + session *session.Session + Service ec2iface.EC2API + DefaultKMSKeyId string + DefaultSgs []string + DefaultSubnets []string + org string + } + type args struct { + ctx context.Context + direction string + sg string + permissions []*ec2.IpPermission + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "empty direction", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + sg: "sg-0000000001", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "empty sg", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + direction: "inbound", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "empty permissions", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + direction: "inbound", + sg: "sg-0000000001", + }, + wantErr: true, + }, + { + name: "inbound rule", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + direction: "inbound", + sg: "sg-0000000001", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + }, + { + name: "inbound rule err", + fields: fields{ + Service: newmockEC2Client(t, awserr.New("BadRequest", "boom", nil)), + }, + args: args{ + ctx: context.TODO(), + direction: "inbound", + sg: "sg-0000000001", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "outbound rule", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + direction: "outbound", + sg: "sg-0000000001", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + }, + { + name: "outbound rule err", + fields: fields{ + Service: newmockEC2Client(t, awserr.New("BadRequest", "boom", nil)), + }, + args: args{ + ctx: context.TODO(), + direction: "outbound", + sg: "sg-0000000001", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "bad direction", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + direction: "sideways", + sg: "sg-0000000001", + permissions: []*ec2.IpPermission{ + { + IpProtocol: aws.String("tcp"), + FromPort: aws.Int64(-1), + ToPort: aws.Int64(-1), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String("192.168.0.0/24"), + Description: aws.String("hax"), + }, + }, + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + session: tt.fields.session, + Service: tt.fields.Service, + DefaultKMSKeyId: tt.fields.DefaultKMSKeyId, + DefaultSgs: tt.fields.DefaultSgs, + DefaultSubnets: tt.fields.DefaultSubnets, + org: tt.fields.org, + } + if err := e.RevokeSecurityGroup(tt.args.ctx, tt.args.direction, tt.args.sg, tt.args.permissions); (err != nil) != tt.wantErr { + t.Errorf("Ec2.RevokeSecurityGroup() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}