Skip to content

Commit

Permalink
update security group rules (#16)
Browse files Browse the repository at this point in the history
* update security group rules
  • Loading branch information
fishnix authored Dec 20, 2021
1 parent 2e43852 commit 211b7fd
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 65 deletions.
39 changes: 39 additions & 0 deletions api/handlers_sgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,45 @@ func (s *server) SecurityGroupCreateHandler(w http.ResponseWriter, r *http.Reque
handleResponseOk(w, out)
}

func (s *server) SecurityGroupUpdateHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
id := vars["id"]

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := sgUpdatePolicy(id)
if err != nil {
handleError(w, err)
return
}

req := &Ec2SecurityGroupRuleRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
handleError(w, err)
return
}

orch, err := s.newEc2Orchestrator(r.Context(), &sessionParams{
inlinePolicy: policy,
role: role,
policyArns: []string{
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
},
})
if err != nil {
handleError(w, err)
return
}

if err := orch.updateSecurityGroup(r.Context(), id, req); err != nil {
handleError(w, err)
return
}

handleResponseOk(w, nil)
}

func (s *server) SecurityGroupListHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
Expand Down
138 changes: 89 additions & 49 deletions api/orchestration_sgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/YaleSpinup/apierror"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/ec2"
log "github.com/sirupsen/logrus"
)
Expand All @@ -14,6 +15,8 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur
return "", apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Debugf("got request to create security group: %s", awsutil.Prettify(req))

var err error
var rollBackTasks []rollbackFunc
defer func() {
Expand All @@ -23,39 +26,34 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur
}
}()

tags := []*ec2.Tag{}
for _, tag := range req.Tags {
for k, v := range tag {
tags = append(tags, &ec2.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
}
}

out, err := o.client.CreateSecurityGroup(ctx, &ec2.CreateSecurityGroupInput{
input := &ec2.CreateSecurityGroupInput{
Description: aws.String(req.Description),
GroupName: aws.String(req.GroupName),
VpcId: aws.String(req.VpcId),
TagSpecifications: []*ec2.TagSpecification{
}

if len(req.Tags) > 0 {
input.SetTagSpecifications([]*ec2.TagSpecification{
{
ResourceType: aws.String("security-group"),
Tags: tags,
Tags: normalizeTags(req.Tags),
},
},
})
})
}

out, err := o.ec2Client.CreateSecurityGroup(ctx, input)
if err != nil {
return "", err
}

err = o.client.WaitUntilSecurityGroupExists(ctx, aws.StringValue(out.GroupId))
if err != nil {
// err is used to trigger rollback, don't shadow it here
if err = o.ec2Client.WaitUntilSecurityGroupExists(ctx, aws.StringValue(out.GroupId)); err != nil {
return "", err
}

rollBackTasks = append(rollBackTasks, func(ctx context.Context) error {
log.Errorf("rollback: deleting security group: %s", aws.StringValue(out.GroupId))
return o.client.DeleteSecurityGroup(ctx, aws.StringValue(out.GroupId))
return o.ec2Client.DeleteSecurityGroup(ctx, aws.StringValue(out.GroupId))
})

if len(req.InitRules) > 0 {
Expand All @@ -66,42 +64,84 @@ func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2Secur
return "", apierror.New(apierror.ErrBadRequest, "cidr_ip or sg_id is required", nil)
}

ipPermissions := []*ec2.IpPermission{}

if r.CidrIp != nil {
ipPermissions = append(ipPermissions, &ec2.IpPermission{
IpProtocol: r.IpProtocol,
FromPort: r.FromPort,
ToPort: r.ToPort,
IpRanges: []*ec2.IpRange{
{
CidrIp: r.CidrIp,
Description: r.Description,
},
},
})
}

if r.SgId != nil {
ipPermissions = append(ipPermissions, &ec2.IpPermission{
IpProtocol: r.IpProtocol,
FromPort: r.FromPort,
ToPort: r.ToPort,
UserIdGroupPairs: []*ec2.UserIdGroupPair{
{
GroupId: r.SgId,
Description: r.Description,
},
},
})
}
ipPermissions := ipPermissionsFromRequest(r)

err = o.client.AuthorizeSecurityGroup(ctx, *r.RuleType, aws.StringValue(out.GroupId), ipPermissions)
if err != nil {
// err is used to trigger rollback, don't shadow it here
if err = o.ec2Client.AuthorizeSecurityGroup(ctx, *r.RuleType, aws.StringValue(out.GroupId), ipPermissions); err != nil {
return "", err
}
}
}

return aws.StringValue(out.GroupId), nil
}

func (o *ec2Orchestrator) updateSecurityGroup(ctx context.Context, id string, req *Ec2SecurityGroupRuleRequest) error {
if id == "" || req == nil {
return apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Debugf("got request to update security group %s: %s", id, awsutil.Prettify(req))

switch *req.Action {
case "add":
if err := o.ec2Client.AuthorizeSecurityGroup(ctx, *req.RuleType, id, ipPermissionsFromRequest(req)); err != nil {
return err
}
case "remove":
if err := o.ec2Client.RevokeSecurityGroup(ctx, *req.RuleType, id, ipPermissionsFromRequest(req)); err != nil {
return err
}
default:
return apierror.New(apierror.ErrBadRequest, "action should be [add|remove]", nil)
}

return nil
}

func ipPermissionsFromRequest(r *Ec2SecurityGroupRuleRequest) []*ec2.IpPermission {
ipPermissions := []*ec2.IpPermission{}

if r.CidrIp != nil {
ipPermissions = append(ipPermissions, &ec2.IpPermission{
IpProtocol: r.IpProtocol,
FromPort: r.FromPort,
ToPort: r.ToPort,
IpRanges: []*ec2.IpRange{
{
CidrIp: r.CidrIp,
Description: r.Description,
},
},
})
}

if r.SgId != nil {
ipPermissions = append(ipPermissions, &ec2.IpPermission{
IpProtocol: r.IpProtocol,
FromPort: r.FromPort,
ToPort: r.ToPort,
UserIdGroupPairs: []*ec2.UserIdGroupPair{
{
GroupId: r.SgId,
Description: r.Description,
},
},
})
}

return ipPermissions
}

func normalizeTags(tags []map[string]string) []*ec2.Tag {
t := []*ec2.Tag{}
for _, tag := range tags {
for k, v := range tag {
t = append(t, &ec2.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
}
}
return t
}
8 changes: 4 additions & 4 deletions api/orchestrators.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type sessionParams struct {
}

type ec2Orchestrator struct {
client *ec2.Ec2
server *server
ec2Client *ec2.Ec2
server *server
}

func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec2Orchestrator, error) {
Expand All @@ -33,7 +33,7 @@ func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec
}

return &ec2Orchestrator{
client: ec2.New(ec2.WithSession(session.Session)),
server: s,
ec2Client: ec2.New(ec2.WithSession(session.Session)),
server: s,
}, nil
}
30 changes: 30 additions & 0 deletions api/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,33 @@ func sgCreatePolicy() (string, error) {

return string(j), nil
}

func sgUpdatePolicy(id string) (string, error) {
log.Debugf("generating sg crete policy document")

sgResource := fmt.Sprintf("arn:aws:ec2:*:*:security-group/%s", id)

policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ec2:ModifySecurityGroupRules",
"ec2:AuthorizeSecurityGroupEgress",
"ec2:AuthorizeSecurityGroupIngress",
"ec2:RevokeSecurityGroupEgress",
"ec2:RevokeSecurityGroupIngress",
},
Resource: []string{sgResource},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}
2 changes: 1 addition & 1 deletion api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (s *server) routes() {
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
Expand Down
12 changes: 6 additions & 6 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,14 @@ func (e *Ec2ImageVolumeMap) MarshalJSON() ([]byte, error) {
}

type Ec2SecurityGroupRequest struct {
Description string `json:"description"`
GroupName string `json:"group_name"`
InitRules []*Ec2SecurityGroupInitRuleRequest `json:"init_rules"`
Tags []map[string]string `json:"tags"`
VpcId string `json:"vpc_id"`
Description string `json:"description"`
GroupName string `json:"group_name"`
InitRules []*Ec2SecurityGroupRuleRequest `json:"init_rules"`
Tags []map[string]string `json:"tags"`
VpcId string `json:"vpc_id"`
}

type Ec2SecurityGroupInitRuleRequest struct {
type Ec2SecurityGroupRuleRequest struct {
RuleType *string `json:"rule_type"` // Direction of traffic: [inbound|outbound]
Action *string `json:"action"` // Adding or removing the rule: [add|remove]
CidrIp *string `json:"cidr_ip"` // IPv4 CIDR address range to allow traffic to/from
Expand Down
53 changes: 48 additions & 5 deletions ec2/sgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,72 @@ func (e *Ec2) AuthorizeSecurityGroup(ctx context.Context, direction, sg string,
IpPermissions: permissions,
})
if err != nil {
return err
return ErrCode("failed authorizing egress", err)
}

log.Debugf("got output authorizing security group egress: %+v", out)

if !aws.BoolValue(out.Return) {
return apierror.New(apierror.ErrInternalError, "security group authorization failed", nil)
return apierror.New(apierror.ErrBadRequest, "security group authorization rule failed", nil)
}

case "inbound":
var out *ec2.AuthorizeSecurityGroupIngressOutput
out, err := e.Service.AuthorizeSecurityGroupIngressWithContext(ctx, &ec2.AuthorizeSecurityGroupIngressInput{
GroupId: aws.String(sg),
IpPermissions: permissions,
})
if err != nil {
return err
return ErrCode("failed authorizing ingress", err)
}

log.Debugf("got output authorizing security group ingress: %+v", out)

if !aws.BoolValue(out.Return) {
return apierror.New(apierror.ErrInternalError, "security group authorization failed", nil)
return apierror.New(apierror.ErrBadRequest, "security group authorization rule failed", nil)
}
default:
return apierror.New(apierror.ErrBadRequest, "direction is required to be [outbound|inbound]", nil)
}

return nil
}

func (e *Ec2) RevokeSecurityGroup(ctx context.Context, direction, sg string, permissions []*ec2.IpPermission) error {
if direction == "" || sg == "" || permissions == nil {
return apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Infof("Revoking security group %s for %s", direction, sg)

switch direction {
case "outbound":
out, err := e.Service.RevokeSecurityGroupEgressWithContext(ctx, &ec2.RevokeSecurityGroupEgressInput{
GroupId: aws.String(sg),
IpPermissions: permissions,
})
if err != nil {
return ErrCode("failed revoking egress", err)
}

log.Debugf("got output authorizing security group egress: %+v", out)

if !aws.BoolValue(out.Return) {
return apierror.New(apierror.ErrBadRequest, "security group revoke rule failed", nil)
}

case "inbound":
out, err := e.Service.RevokeSecurityGroupIngressWithContext(ctx, &ec2.RevokeSecurityGroupIngressInput{
GroupId: aws.String(sg),
IpPermissions: permissions,
})
if err != nil {
return ErrCode("failed revoking egress", err)
}

log.Debugf("got output authorizing security group ingress: %+v", out)

if !aws.BoolValue(out.Return) {
return apierror.New(apierror.ErrBadRequest, "security group revoke rule failed", nil)
}
default:
return apierror.New(apierror.ErrBadRequest, "direction is required to be [outbound|enbound]", nil)
Expand Down
Loading

0 comments on commit 211b7fd

Please sign in to comment.