Skip to content

Commit

Permalink
Support for updating instances (#27)
Browse files Browse the repository at this point in the history
* PUT SSM Association

* SPIN-2936: Implemented below update instances endpoints
PUT "/{account}/instances/{id}"
PUT "/{account}/instances/{id}/power"
PUT "/{account}/instances/{id}/ssm/association"
PUT "/{account}/instances/{id}/tags"
PUT "/{account}/instances/{id}/attribute"

* Fixed comments

* Added policies

* Role updated

* updated PUT SSM

* Print stack trace when encountering panic

* Updated SSM Association

* Fixed AssociationID output

* Fixed unit testcase

* Updated time format
  • Loading branch information
nvnyale authored May 19, 2022
1 parent a7668e9 commit 1a3d9d3
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 24 deletions.
118 changes: 118 additions & 0 deletions api/handlers_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,121 @@ func (s *server) InstanceSendCommandHandler(w http.ResponseWriter, r *http.Reque
handleResponseOk(w, out)

}

func (s *server) NotImplementedHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
w.WriteHeader(http.StatusNotImplemented)
}

func (s *server) InstanceSSMAssociationHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
instanceId := vars["id"]

req := &SSMAssociationRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
msg := fmt.Sprintf("cannot decode body into ssm create input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if req.Document == "" {
handleError(w, apierror.New(apierror.ErrBadRequest, "Document is mandatory", nil))
return
}

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := ssmAssociationPolicy()
if err != nil {
handleError(w, err)
return
}

session, err := s.assumeRole(
r.Context(),
s.session.ExternalID,
role,
policy,
"arn:aws:iam::aws:policy/AmazonSSMReadOnlyAccess",
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
)
if err != nil {
msg := fmt.Sprintf("failed to assume role in account: %s", account)
handleError(w, apierror.New(apierror.ErrForbidden, msg, err))
return
}

service := ssm.New(
ssm.WithSession(session.Session),
)

out, err := service.CreateAssociation(r.Context(), instanceId, req.Document)
if err != nil {
handleError(w, err)
return
}

handleResponseOk(w, out)
}

func (s *server) InstanceUpdateHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
instanceId := vars["id"]

req := &Ec2InstanceUpdateRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
msg := fmt.Sprintf("cannot decode body into update image input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if len(req.Tags) == 0 && len(req.InstanceType) == 0 {
handleError(w, apierror.New(apierror.ErrBadRequest, "missing required fields: tags or instance_type", nil))
return
} else if len(req.Tags) > 0 && len(req.InstanceType) > 0 {
handleError(w, apierror.New(apierror.ErrBadRequest, "only one of these fields should be provided: tags or instance_type", nil))
return
}

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := instanceUpdatePolicy()
if err != nil {
handleError(w, err)
return
}

session, err := s.assumeRole(
r.Context(),
s.session.ExternalID,
role,
policy,
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
)
if err != nil {
msg := fmt.Sprintf("failed to assume role in account: %s", account)
handleError(w, apierror.New(apierror.ErrForbidden, msg, err))
return
}

service := ec2.New(
ec2.WithSession(session.Session),
ec2.WithOrg(s.org),
)

if len(req.Tags) > 0 {
if err := service.UpdateTags(r.Context(), req.Tags, instanceId); err != nil {
handleError(w, err)
return
}
} else if len(req.InstanceType) > 0 {
if err := service.UpdateAttributes(r.Context(), req.InstanceType["value"], instanceId); err != nil {
handleError(w, err)
return
}
}

w.WriteHeader(http.StatusNoContent)
}
48 changes: 48 additions & 0 deletions api/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,51 @@ func sendCommandPolicy() (string, error) {

return string(j), nil
}

func instanceUpdatePolicy() (string, error) {
log.Debugf("generating tag create policy document")
policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ec2:CreateTags",
"ec2:ModifyInstanceAttribute",
},
Resource: []string{"*"},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}

func ssmAssociationPolicy() (string, error) {
log.Debugf("generating tag create policy document")
policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ssm:CreateAssociation",
"ssm:UpdateAssociation",
},
Resource: []string{"*"},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}
8 changes: 4 additions & 4 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ func (s *server) routes() {
api.HandleFunc("/{account}/images", s.ProxyRequestHandler).Methods(http.MethodPost)

api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.NotImplementedHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/command", s.InstanceSendCommandHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.InstanceSSMAssociationHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.InstanceUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.InstanceUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
Expand Down
2 changes: 1 addition & 1 deletion api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func NewServer(config common.Config) error {
if config.ListenAddress == "" {
config.ListenAddress = ":8080"
}
handler := handlers.RecoveryHandler()(handlers.LoggingHandler(os.Stdout, TokenMiddleware([]byte(config.Token), publicURLs, s.router)))
handler := handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))(handlers.LoggingHandler(os.Stdout, TokenMiddleware([]byte(config.Token), publicURLs, s.router)))
srv := &http.Server{
Handler: handler,
Addr: config.ListenAddress,
Expand Down
50 changes: 31 additions & 19 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,12 @@ func toSSMGetCommandInvocationOutput(rawOut *ssm.GetCommandInvocationOutput) *SS
}

type Ec2ImageUpdateRequest struct {
Tags map[string]string
Tags map[string]string `json:"tags"`
}

type Ec2InstanceUpdateRequest struct {
Tags map[string]string `json:"tags"`
InstanceType map[string]string `json:"instance_type"`
}
type AssociationDescription struct {
Name string `json:"name"`
Expand Down Expand Up @@ -687,28 +692,31 @@ type AssociationTarget struct {
}

func toSSMAssociationDescription(rawDesc *ssm.DescribeAssociationOutput) *AssociationDescription {
const dateLayout = "2006-01-02 15:04:05 +0000"
var status AssociationStatus
if rawDesc.AssociationDescription.Status != nil {
status.Date = tzTimeFormat(rawDesc.AssociationDescription.Status.Date)
status.Name = aws.StringValue(rawDesc.AssociationDescription.Status.Name)
status.Message = aws.StringValue(rawDesc.AssociationDescription.Status.Message)
}
var overview AssociationOverview
if rawDesc.AssociationDescription.Overview != nil {
overview.Status = aws.StringValue(rawDesc.AssociationDescription.Overview.Status)
overview.DetailedStatus = aws.StringValue(rawDesc.AssociationDescription.Overview.DetailedStatus)
}

return &AssociationDescription{
Name: aws.StringValue(rawDesc.AssociationDescription.Name),
InstanceId: aws.StringValue(rawDesc.AssociationDescription.InstanceId),
AssociationVersion: aws.StringValue(rawDesc.AssociationDescription.AssociationVersion),
Date: rawDesc.AssociationDescription.Date.Format(dateLayout),
LastUpdateAssociationDate: rawDesc.AssociationDescription.LastUpdateAssociationDate.Format(dateLayout),
Status: AssociationStatus{
Date: rawDesc.AssociationDescription.Status.Date.Format(dateLayout),
Name: aws.StringValue(rawDesc.AssociationDescription.Status.Name),
Message: aws.StringValue(rawDesc.AssociationDescription.Status.Message),
},
Overview: AssociationOverview{
Status: aws.StringValue(rawDesc.AssociationDescription.Overview.Status),
DetailedStatus: aws.StringValue(rawDesc.AssociationDescription.Overview.DetailedStatus),
},
Name: aws.StringValue(rawDesc.AssociationDescription.Name),
InstanceId: aws.StringValue(rawDesc.AssociationDescription.InstanceId),
AssociationVersion: aws.StringValue(rawDesc.AssociationDescription.AssociationVersion),
Date: tzTimeFormat(rawDesc.AssociationDescription.Date),
LastUpdateAssociationDate: tzTimeFormat(rawDesc.AssociationDescription.LastUpdateAssociationDate),
Status: status,
Overview: overview,
DocumentVersion: aws.StringValue(rawDesc.AssociationDescription.DocumentVersion),
AssociationId: aws.StringValue(rawDesc.AssociationDescription.AssociationId),
Targets: parseAssociationTargets(rawDesc.AssociationDescription.Targets),
LastExecutionDate: rawDesc.AssociationDescription.LastExecutionDate.Format(dateLayout),
LastSuccessfulExecutionDate: rawDesc.AssociationDescription.LastSuccessfulExecutionDate.Format(dateLayout),
LastExecutionDate: tzTimeFormat(rawDesc.AssociationDescription.LastExecutionDate),
LastSuccessfulExecutionDate: tzTimeFormat(rawDesc.AssociationDescription.LastSuccessfulExecutionDate),
ApplyOnlyAtCronInterval: aws.BoolValue(rawDesc.AssociationDescription.ApplyOnlyAtCronInterval),
}
}
Expand All @@ -725,7 +733,11 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) {
}

type Ec2InstanceStateChangeRequest struct {
State string
State string `json:"state"`
}

type SSMAssociationRequest struct {
Document string `json:"document"`
}

type SsmCommandRequest struct {
Expand Down
30 changes: 30 additions & 0 deletions ec2/attributes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package ec2

import (
"context"

"github.com/YaleSpinup/apierror"
"github.com/YaleSpinup/ec2-api/common"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
log "github.com/sirupsen/logrus"
)

func (e *Ec2) UpdateAttributes(ctx context.Context, instanceType, instanceId string) error {
if len(instanceId) == 0 || len(instanceType) == 0 {
return apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Infof("updating attributes: %v with instance type %+v", instanceId, instanceType)

input := ec2.ModifyInstanceAttributeInput{
InstanceType: &ec2.AttributeValue{Value: aws.String(instanceType)},
InstanceId: aws.String(instanceId),
}

if _, err := e.Service.ModifyInstanceAttributeWithContext(ctx, &input); err != nil {
return common.ErrCode("updating attributes", err)
}

return nil
}
72 changes: 72 additions & 0 deletions ec2/attributes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package ec2

import (
"context"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

func (m *mockEC2Client) ModifyInstanceAttributeWithContext(ctx aws.Context, inp *ec2.ModifyInstanceAttributeInput, opt ...request.Option) (*ec2.ModifyInstanceAttributeOutput, error) {
if m.err != nil {
return nil, m.err
}
return &ec2.ModifyInstanceAttributeOutput{}, nil

}
func TestEc2_UpdateAttributes(t *testing.T) {
type fields struct {
Service ec2iface.EC2API
}
type args struct {
ctx context.Context
instanceType string
instanceId string
}
tests := []struct {
name string
fields fields
e *Ec2
args args
wantErr bool
}{
{
name: "success case",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: false,
},
{
name: "aws error",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, awserr.New("Bad Request", "boom.", nil))},
wantErr: true,
},
{
name: "invalid input, instance id is empty",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: ""},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: true,
},
{
name: "invalid input, instance type is empty",
args: args{ctx: context.TODO(), instanceType: "", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Ec2{
Service: tt.fields.Service,
}
if err := e.UpdateAttributes(tt.args.ctx, tt.args.instanceType, tt.args.instanceId); (err != nil) != tt.wantErr {
t.Errorf("Ec2.UpdateAttributes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
16 changes: 16 additions & 0 deletions ssm/association.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ func (s *SSM) DescribeAssociation(ctx context.Context, instanceId, docName strin
log.Debugf("got output describing SSM Association: %+v", out)
return out, nil
}

func (s *SSM) CreateAssociation(ctx context.Context, instanceId, docName string) (string, error) {
if instanceId == "" || docName == "" {
return "", apierror.New(apierror.ErrBadRequest, "both instanceId and docName should be present", nil)
}
inp := &ssm.CreateAssociationInput{
Name: aws.String(docName),
InstanceId: aws.String(instanceId),
}
out, err := s.Service.CreateAssociationWithContext(ctx, inp)
if err != nil {
return "", common.ErrCode("failed to create association", err)
}
log.Debugf("got output creating SSM Association: %+v", out)
return aws.StringValue(out.AssociationDescription.AssociationId), nil
}
Loading

0 comments on commit 1a3d9d3

Please sign in to comment.