Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for updating instances #27

Merged
merged 12 commits into from
May 19, 2022
118 changes: 118 additions & 0 deletions api/handlers_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,121 @@ func (s *server) InstanceSendCommandHandler(w http.ResponseWriter, r *http.Reque
handleResponseOk(w, out)

}

func (s *server) NotImplementedHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
w.WriteHeader(http.StatusNotImplemented)
}

func (s *server) InstanceSSMAssociationHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
instanceId := vars["id"]

req := &SSMAssociationRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
msg := fmt.Sprintf("cannot decode body into ssm create input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if req.Document == "" {
handleError(w, apierror.New(apierror.ErrBadRequest, "Document is mandatory", nil))
return
}

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := ssmAssociationPolicy()
if err != nil {
handleError(w, err)
return
}

session, err := s.assumeRole(
r.Context(),
s.session.ExternalID,
role,
policy,
"arn:aws:iam::aws:policy/AmazonSSMReadOnlyAccess",
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
)
if err != nil {
msg := fmt.Sprintf("failed to assume role in account: %s", account)
handleError(w, apierror.New(apierror.ErrForbidden, msg, err))
return
}

service := ssm.New(
ssm.WithSession(session.Session),
)

out, err := service.CreateAssociation(r.Context(), instanceId, req.Document)
if err != nil {
handleError(w, err)
return
}

handleResponseOk(w, out)
}

func (s *server) InstanceUpdateHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
instanceId := vars["id"]

req := &Ec2InstanceUpdateRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
msg := fmt.Sprintf("cannot decode body into update image input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if len(req.Tags) == 0 && len(req.InstanceType) == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should also check that both are not > 0 and return an error only one of these is expected: tags or instance_type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

handleError(w, apierror.New(apierror.ErrBadRequest, "missing required fields: tags or instance_type", nil))
return
} else if len(req.Tags) > 0 && len(req.InstanceType) > 0 {
handleError(w, apierror.New(apierror.ErrBadRequest, "only one of these fields should be provided: tags or instance_type", nil))
return
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should create and use a separate instanceUpdatePolicy here (since you need more permissions than just tags)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated code.

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := instanceUpdatePolicy()
if err != nil {
handleError(w, err)
return
}

session, err := s.assumeRole(
r.Context(),
s.session.ExternalID,
role,
policy,
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
)
if err != nil {
msg := fmt.Sprintf("failed to assume role in account: %s", account)
handleError(w, apierror.New(apierror.ErrForbidden, msg, err))
return
}

service := ec2.New(
ec2.WithSession(session.Session),
ec2.WithOrg(s.org),
)

if len(req.Tags) > 0 {
if err := service.UpdateTags(r.Context(), req.Tags, instanceId); err != nil {
handleError(w, err)
return
}
} else if len(req.InstanceType) > 0 {
if err := service.UpdateAttributes(r.Context(), req.InstanceType["value"], instanceId); err != nil {
handleError(w, err)
return
}
}

w.WriteHeader(http.StatusNoContent)
}
48 changes: 48 additions & 0 deletions api/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,51 @@ func sendCommandPolicy() (string, error) {

return string(j), nil
}

func instanceUpdatePolicy() (string, error) {
log.Debugf("generating tag create policy document")
policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ec2:CreateTags",
"ec2:ModifyInstanceAttribute",
},
Resource: []string{"*"},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}

func ssmAssociationPolicy() (string, error) {
log.Debugf("generating tag create policy document")
policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ssm:CreateAssociation",
"ssm:UpdateAssociation",
},
Resource: []string{"*"},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}
8 changes: 4 additions & 4 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ func (s *server) routes() {
api.HandleFunc("/{account}/images", s.ProxyRequestHandler).Methods(http.MethodPost)

api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.NotImplementedHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/command", s.InstanceSendCommandHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.InstanceSSMAssociationHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.InstanceUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.InstanceUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
Expand Down
2 changes: 1 addition & 1 deletion api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func NewServer(config common.Config) error {
if config.ListenAddress == "" {
config.ListenAddress = ":8080"
}
handler := handlers.RecoveryHandler()(handlers.LoggingHandler(os.Stdout, TokenMiddleware([]byte(config.Token), publicURLs, s.router)))
handler := handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))(handlers.LoggingHandler(os.Stdout, TokenMiddleware([]byte(config.Token), publicURLs, s.router)))
srv := &http.Server{
Handler: handler,
Addr: config.ListenAddress,
Expand Down
50 changes: 31 additions & 19 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,12 @@ func toSSMGetCommandInvocationOutput(rawOut *ssm.GetCommandInvocationOutput) *SS
}

type Ec2ImageUpdateRequest struct {
Tags map[string]string
Tags map[string]string `json:"tags"`
}

type Ec2InstanceUpdateRequest struct {
Tags map[string]string `json:"tags"`
InstanceType map[string]string `json:"instance_type"`
}
type AssociationDescription struct {
Name string `json:"name"`
Expand Down Expand Up @@ -687,28 +692,31 @@ type AssociationTarget struct {
}

func toSSMAssociationDescription(rawDesc *ssm.DescribeAssociationOutput) *AssociationDescription {
const dateLayout = "2006-01-02 15:04:05 +0000"
var status AssociationStatus
if rawDesc.AssociationDescription.Status != nil {
status.Date = tzTimeFormat(rawDesc.AssociationDescription.Status.Date)
status.Name = aws.StringValue(rawDesc.AssociationDescription.Status.Name)
status.Message = aws.StringValue(rawDesc.AssociationDescription.Status.Message)
}
var overview AssociationOverview
if rawDesc.AssociationDescription.Overview != nil {
overview.Status = aws.StringValue(rawDesc.AssociationDescription.Overview.Status)
overview.DetailedStatus = aws.StringValue(rawDesc.AssociationDescription.Overview.DetailedStatus)
}

return &AssociationDescription{
Name: aws.StringValue(rawDesc.AssociationDescription.Name),
InstanceId: aws.StringValue(rawDesc.AssociationDescription.InstanceId),
AssociationVersion: aws.StringValue(rawDesc.AssociationDescription.AssociationVersion),
Date: rawDesc.AssociationDescription.Date.Format(dateLayout),
LastUpdateAssociationDate: rawDesc.AssociationDescription.LastUpdateAssociationDate.Format(dateLayout),
Status: AssociationStatus{
Date: rawDesc.AssociationDescription.Status.Date.Format(dateLayout),
Name: aws.StringValue(rawDesc.AssociationDescription.Status.Name),
Message: aws.StringValue(rawDesc.AssociationDescription.Status.Message),
},
Overview: AssociationOverview{
Status: aws.StringValue(rawDesc.AssociationDescription.Overview.Status),
DetailedStatus: aws.StringValue(rawDesc.AssociationDescription.Overview.DetailedStatus),
},
Name: aws.StringValue(rawDesc.AssociationDescription.Name),
InstanceId: aws.StringValue(rawDesc.AssociationDescription.InstanceId),
AssociationVersion: aws.StringValue(rawDesc.AssociationDescription.AssociationVersion),
Date: tzTimeFormat(rawDesc.AssociationDescription.Date),
LastUpdateAssociationDate: tzTimeFormat(rawDesc.AssociationDescription.LastUpdateAssociationDate),
Status: status,
Overview: overview,
DocumentVersion: aws.StringValue(rawDesc.AssociationDescription.DocumentVersion),
AssociationId: aws.StringValue(rawDesc.AssociationDescription.AssociationId),
Targets: parseAssociationTargets(rawDesc.AssociationDescription.Targets),
LastExecutionDate: rawDesc.AssociationDescription.LastExecutionDate.Format(dateLayout),
LastSuccessfulExecutionDate: rawDesc.AssociationDescription.LastSuccessfulExecutionDate.Format(dateLayout),
LastExecutionDate: tzTimeFormat(rawDesc.AssociationDescription.LastExecutionDate),
LastSuccessfulExecutionDate: tzTimeFormat(rawDesc.AssociationDescription.LastSuccessfulExecutionDate),
ApplyOnlyAtCronInterval: aws.BoolValue(rawDesc.AssociationDescription.ApplyOnlyAtCronInterval),
}
}
Expand All @@ -725,7 +733,11 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) {
}

type Ec2InstanceStateChangeRequest struct {
State string
State string `json:"state"`
}

type SSMAssociationRequest struct {
Document string `json:"document"`
}

type SsmCommandRequest struct {
Expand Down
30 changes: 30 additions & 0 deletions ec2/attributes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package ec2

import (
"context"

"github.com/YaleSpinup/apierror"
"github.com/YaleSpinup/ec2-api/common"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
log "github.com/sirupsen/logrus"
)

func (e *Ec2) UpdateAttributes(ctx context.Context, instanceType, instanceId string) error {
if len(instanceId) == 0 || len(instanceType) == 0 {
return apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Infof("updating attributes: %v with instance type %+v", instanceId, instanceType)

input := ec2.ModifyInstanceAttributeInput{
InstanceType: &ec2.AttributeValue{Value: aws.String(instanceType)},
InstanceId: aws.String(instanceId),
}

if _, err := e.Service.ModifyInstanceAttributeWithContext(ctx, &input); err != nil {
return common.ErrCode("updating attributes", err)
}

return nil
}
72 changes: 72 additions & 0 deletions ec2/attributes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package ec2

import (
"context"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

func (m *mockEC2Client) ModifyInstanceAttributeWithContext(ctx aws.Context, inp *ec2.ModifyInstanceAttributeInput, opt ...request.Option) (*ec2.ModifyInstanceAttributeOutput, error) {
if m.err != nil {
return nil, m.err
}
return &ec2.ModifyInstanceAttributeOutput{}, nil

}
func TestEc2_UpdateAttributes(t *testing.T) {
type fields struct {
Service ec2iface.EC2API
}
type args struct {
ctx context.Context
instanceType string
instanceId string
}
tests := []struct {
name string
fields fields
e *Ec2
args args
wantErr bool
}{
{
name: "success case",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: false,
},
{
name: "aws error",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, awserr.New("Bad Request", "boom.", nil))},
wantErr: true,
},
{
name: "invalid input, instance id is empty",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: ""},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: true,
},
{
name: "invalid input, instance type is empty",
args: args{ctx: context.TODO(), instanceType: "", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Ec2{
Service: tt.fields.Service,
}
if err := e.UpdateAttributes(tt.args.ctx, tt.args.instanceType, tt.args.instanceId); (err != nil) != tt.wantErr {
t.Errorf("Ec2.UpdateAttributes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
16 changes: 16 additions & 0 deletions ssm/association.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ func (s *SSM) DescribeAssociation(ctx context.Context, instanceId, docName strin
log.Debugf("got output describing SSM Association: %+v", out)
return out, nil
}

func (s *SSM) CreateAssociation(ctx context.Context, instanceId, docName string) (string, error) {
if instanceId == "" || docName == "" {
return "", apierror.New(apierror.ErrBadRequest, "both instanceId and docName should be present", nil)
}
inp := &ssm.CreateAssociationInput{
Name: aws.String(docName),
InstanceId: aws.String(instanceId),
}
out, err := s.Service.CreateAssociationWithContext(ctx, inp)
if err != nil {
return "", common.ErrCode("failed to create association", err)
}
log.Debugf("got output creating SSM Association: %+v", out)
return aws.StringValue(out.AssociationDescription.AssociationId), nil
}
Loading