Skip to content

Commit

Permalink
Create security group, initialize rules (#15)
Browse files Browse the repository at this point in the history
* create security group, initialize rules
* update security group rules (#16)
  • Loading branch information
fishnix authored Jan 5, 2022
1 parent 85cbc9b commit 2da0c99
Show file tree
Hide file tree
Showing 13 changed files with 1,473 additions and 229 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ jobs:
strategy:
matrix:
go-version:
- "1.16.x"
- "1.17.x"
os:
- "ubuntu-latest"
Expand Down
79 changes: 79 additions & 0 deletions api/handlers_sgs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"encoding/json"
"fmt"
"net/http"
"strconv"
Expand All @@ -10,6 +11,84 @@ import (
"github.com/gorilla/mux"
)

func (s *server) SecurityGroupCreateHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := sgCreatePolicy()
if err != nil {
handleError(w, err)
return
}

req := &Ec2SecurityGroupRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
handleError(w, err)
return
}

orch, err := s.newEc2Orchestrator(r.Context(), &sessionParams{
inlinePolicy: policy,
role: role,
policyArns: []string{
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
},
})
if err != nil {
handleError(w, err)
return
}

out, err := orch.createSecurityGroup(r.Context(), req)
if err != nil {
handleError(w, err)
return
}

handleResponseOk(w, out)
}

func (s *server) SecurityGroupUpdateHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
id := vars["id"]

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := sgUpdatePolicy(id)
if err != nil {
handleError(w, err)
return
}

req := &Ec2SecurityGroupRuleRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
handleError(w, err)
return
}

orch, err := s.newEc2Orchestrator(r.Context(), &sessionParams{
inlinePolicy: policy,
role: role,
policyArns: []string{
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
},
})
if err != nil {
handleError(w, err)
return
}

if err := orch.updateSecurityGroup(r.Context(), id, req); err != nil {
handleError(w, err)
return
}

handleResponseOk(w, nil)
}

func (s *server) SecurityGroupListHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
Expand Down
147 changes: 147 additions & 0 deletions api/orchestration_sgs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package api

import (
"context"

"github.com/YaleSpinup/apierror"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/ec2"
log "github.com/sirupsen/logrus"
)

func (o *ec2Orchestrator) createSecurityGroup(ctx context.Context, req *Ec2SecurityGroupRequest) (string, error) {
if req == nil {
return "", apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Debugf("got request to create security group: %s", awsutil.Prettify(req))

var err error
var rollBackTasks []rollbackFunc
defer func() {
if err != nil {
log.Errorf("recovering from error: %s, executing %d rollback tasks", err, len(rollBackTasks))
rollBack(&rollBackTasks)
}
}()

input := &ec2.CreateSecurityGroupInput{
Description: aws.String(req.Description),
GroupName: aws.String(req.GroupName),
VpcId: aws.String(req.VpcId),
}

if len(req.Tags) > 0 {
input.SetTagSpecifications([]*ec2.TagSpecification{
{
ResourceType: aws.String("security-group"),
Tags: normalizeTags(req.Tags),
},
})
}

out, err := o.ec2Client.CreateSecurityGroup(ctx, input)
if err != nil {
return "", err
}

// err is used to trigger rollback, don't shadow it here
if err = o.ec2Client.WaitUntilSecurityGroupExists(ctx, aws.StringValue(out.GroupId)); err != nil {
return "", err
}

rollBackTasks = append(rollBackTasks, func(ctx context.Context) error {
log.Errorf("rollback: deleting security group: %s", aws.StringValue(out.GroupId))
return o.ec2Client.DeleteSecurityGroup(ctx, aws.StringValue(out.GroupId))
})

if len(req.InitRules) > 0 {
for _, r := range req.InitRules {
log.Debugf("creating securitygrouprulerequest with %+v", r)

if r.CidrIp == nil && r.SgId == nil {
return "", apierror.New(apierror.ErrBadRequest, "cidr_ip or sg_id is required", nil)
}

ipPermissions := ipPermissionsFromRequest(r)

// err is used to trigger rollback, don't shadow it here
if err = o.ec2Client.AuthorizeSecurityGroup(ctx, *r.RuleType, aws.StringValue(out.GroupId), ipPermissions); err != nil {
return "", err
}
}
}

return aws.StringValue(out.GroupId), nil
}

func (o *ec2Orchestrator) updateSecurityGroup(ctx context.Context, id string, req *Ec2SecurityGroupRuleRequest) error {
if id == "" || req == nil {
return apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Debugf("got request to update security group %s: %s", id, awsutil.Prettify(req))

switch *req.Action {
case "add":
if err := o.ec2Client.AuthorizeSecurityGroup(ctx, *req.RuleType, id, ipPermissionsFromRequest(req)); err != nil {
return err
}
case "remove":
if err := o.ec2Client.RevokeSecurityGroup(ctx, *req.RuleType, id, ipPermissionsFromRequest(req)); err != nil {
return err
}
default:
return apierror.New(apierror.ErrBadRequest, "action should be [add|remove]", nil)
}

return nil
}

func ipPermissionsFromRequest(r *Ec2SecurityGroupRuleRequest) []*ec2.IpPermission {
ipPermissions := []*ec2.IpPermission{}

if r.CidrIp != nil {
ipPermissions = append(ipPermissions, &ec2.IpPermission{
IpProtocol: r.IpProtocol,
FromPort: r.FromPort,
ToPort: r.ToPort,
IpRanges: []*ec2.IpRange{
{
CidrIp: r.CidrIp,
Description: r.Description,
},
},
})
}

if r.SgId != nil {
ipPermissions = append(ipPermissions, &ec2.IpPermission{
IpProtocol: r.IpProtocol,
FromPort: r.FromPort,
ToPort: r.ToPort,
UserIdGroupPairs: []*ec2.UserIdGroupPair{
{
GroupId: r.SgId,
Description: r.Description,
},
},
})
}

return ipPermissions
}

func normalizeTags(tags []map[string]string) []*ec2.Tag {
t := []*ec2.Tag{}
for _, tag := range tags {
for k, v := range tag {
t = append(t, &ec2.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
}
}
return t
}
39 changes: 39 additions & 0 deletions api/orchestrators.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package api

import (
"context"

"github.com/YaleSpinup/ec2-api/ec2"
log "github.com/sirupsen/logrus"
)

type sessionParams struct {
role string
inlinePolicy string
policyArns []string
}

type ec2Orchestrator struct {
ec2Client *ec2.Ec2
server *server
}

func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec2Orchestrator, error) {
log.Debugf("initializing ec2Orchestrator")

session, err := s.assumeRole(
ctx,
s.session.ExternalID,
sp.role,
sp.inlinePolicy,
sp.policyArns...,
)
if err != nil {
return nil, err
}

return &ec2Orchestrator{
ec2Client: ec2.New(ec2.WithSession(session.Session)),
server: s,
}, nil
}
61 changes: 60 additions & 1 deletion api/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func orgTagAccessPolicy(org string) (string, error) {
}

func sgDeletePolicy(id string) (string, error) {
log.Debugf("generating org policy document")
log.Debugf("generating sg delete policy document")

sgResource := fmt.Sprintf("arn:aws:ec2:*:*:security-group/%s", id)

Expand All @@ -62,3 +62,62 @@ func sgDeletePolicy(id string) (string, error) {

return string(j), nil
}

func sgCreatePolicy() (string, error) {
log.Debugf("generating sg crete policy document")

policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ec2:CreateSecurityGroup",
"ec2:CreateTags",
"ec2:ModifySecurityGroupRules",
"ec2:DeleteSecurityGroup",
"ec2:AuthorizeSecurityGroupEgress",
"ec2:AuthorizeSecurityGroupIngress",
},
Resource: []string{"*"},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}

func sgUpdatePolicy(id string) (string, error) {
log.Debugf("generating sg crete policy document")

sgResource := fmt.Sprintf("arn:aws:ec2:*:*:security-group/%s", id)

policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ec2:ModifySecurityGroupRules",
"ec2:AuthorizeSecurityGroupEgress",
"ec2:AuthorizeSecurityGroupIngress",
"ec2:RevokeSecurityGroupEgress",
"ec2:RevokeSecurityGroupIngress",
},
Resource: []string{sgResource},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}
4 changes: 2 additions & 2 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *server) routes() {

api.HandleFunc("/{account}/instances", s.ProxyRequestHandler).Methods(http.MethodPost)
api.HandleFunc("/{account}/instances/{id}/volumes", s.ProxyRequestHandler).Methods(http.MethodPost)
api.HandleFunc("/{account}/sgs", s.ProxyRequestHandler).Methods(http.MethodPost)
api.HandleFunc("/{account}/sgs", s.SecurityGroupCreateHandler).Methods(http.MethodPost)
api.HandleFunc("/{account}/volumes", s.ProxyRequestHandler).Methods(http.MethodPost)
api.HandleFunc("/{account}/snapshots", s.ProxyRequestHandler).Methods(http.MethodPost)
api.HandleFunc("/{account}/images", s.ProxyRequestHandler).Methods(http.MethodPost)
Expand All @@ -71,7 +71,7 @@ func (s *server) routes() {
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
Expand Down
Loading

0 comments on commit 2da0c99

Please sign in to comment.