Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for updating instances #27

Merged
merged 12 commits into from
May 19, 2022
109 changes: 109 additions & 0 deletions api/handlers_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,112 @@ func (s *server) InstanceSendCommandHandler(w http.ResponseWriter, r *http.Reque
handleResponseOk(w, out)

}

func (s *server) InstanceIDHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
w.WriteHeader(http.StatusNotImplemented)
}

func (s *server) InstanceSSMAssociationHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
instanceId := vars["id"]

req := &SSMAssociationRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
msg := fmt.Sprintf("cannot decode body into ssm create input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if req.Document == "" {
handleError(w, apierror.New(apierror.ErrBadRequest, "Document is mandatory", nil))
return
}

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)

session, err := s.assumeRole(
r.Context(),
s.session.ExternalID,
role,
"",
"arn:aws:iam::aws:policy/AmazonSSMReadOnlyAccess",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, does this work? you're not passing any IAM policy besides the read-only access

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I missed it. I am not familiar with the AWS policies and can you please help me to choose the correct policy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case you're working with SSM (Systems Manager) so you can take a look here first: https://docs.aws.amazon.com/systems-manager/latest/userguide/security_iam_service-with-iam.html
Try to find which IAM permissions are required to do the work (in this case CreateAssociation).
For example, here you can see all the different IAM permissions: https://aws.permissions.cloud/iam/ssm
In some cases it may be a bit of trial-and-error until you get all the permissions right, but in this case should be pretty straightforward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)
if err != nil {
msg := fmt.Sprintf("failed to assume role in account: %s", account)
handleError(w, apierror.New(apierror.ErrForbidden, msg, err))
return
}

service := ssm.New(
ssm.WithSession(session.Session),
)

out, err := service.CreateAssociation(r.Context(), instanceId, req.Document)
if err != nil {
handleError(w, err)
return
}

handleResponseOk(w, struct{ AssociationId string }{AssociationId: *out.AssociationDescription.AssociationId})
}

func (s *server) InstanceUpdateHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
instanceId := vars["id"]

req := &Ec2InstanceUpdateRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
msg := fmt.Sprintf("cannot decode body into update image input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if len(req.Tags) == 0 && len(req.InstanceType) == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should also check that both are not > 0 and return an error only one of these is expected: tags or instance_type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

handleError(w, apierror.New(apierror.ErrBadRequest, "missing required fields", nil))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing required fields: tags or instance_type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should create and use a separate instanceUpdatePolicy here (since you need more permissions than just tags)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated code.

role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName)
policy, err := tagCreatePolicy()
if err != nil {
handleError(w, err)
return
}

session, err := s.assumeRole(
r.Context(),
s.session.ExternalID,
role,
policy,
"arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess",
)
if err != nil {
msg := fmt.Sprintf("failed to assume role in account: %s", account)
handleError(w, apierror.New(apierror.ErrForbidden, msg, err))
return
}

service := ec2.New(
ec2.WithSession(session.Session),
ec2.WithOrg(s.org),
)

if len(req.Tags) > 0 {
if err := service.UpdateTags(r.Context(), req.Tags, instanceId); err != nil {
handleError(w, err)
return
}
} else if len(req.InstanceType) > 0 {
if err := service.UpdateAttributes(r.Context(), req.InstanceType["value"], instanceId); err != nil {
handleError(w, err)
return
}
}

w.WriteHeader(http.StatusNoContent)
}
8 changes: 4 additions & 4 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ func (s *server) routes() {
api.HandleFunc("/{account}/images", s.ProxyRequestHandler).Methods(http.MethodPost)

api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.InstanceIDHandler).Methods(http.MethodPut)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InstanceUpdateHandler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ID's route, handler is not implemented, so I didn't generalize this handler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, then let's just call it NotImplementedHandler, so it's clearer and it can be used by other routes that are not implemented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/command", s.InstanceSendCommandHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.InstanceSSMAssociationHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.InstanceUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.InstanceUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
Expand Down
13 changes: 11 additions & 2 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,12 @@ func toSSMGetCommandInvocationOutput(rawOut *ssm.GetCommandInvocationOutput) *SS
}

type Ec2ImageUpdateRequest struct {
Tags map[string]string
Tags map[string]string `json:"tags"`
}

type Ec2InstanceUpdateRequest struct {
Tags map[string]string `json:"tags"`
InstanceType map[string]string `json:"instance_type"`
}
type AssociationDescription struct {
Name string `json:"name"`
Expand Down Expand Up @@ -725,7 +730,11 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) {
}

type Ec2InstanceStateChangeRequest struct {
State string
State string `json:"state"`
}

type SSMAssociationRequest struct {
Document string `json:"document"`
}

type SsmCommandRequest struct {
Expand Down
30 changes: 30 additions & 0 deletions ec2/attributes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package ec2

import (
"context"

"github.com/YaleSpinup/apierror"
"github.com/YaleSpinup/ec2-api/common"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
log "github.com/sirupsen/logrus"
)

func (e *Ec2) UpdateAttributes(ctx context.Context, instanceType, instanceId string) error {
if len(instanceId) == 0 || len(instanceType) == 0 {
return apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Infof("updating attributes: %v with instance type %+v", instanceId, instanceType)

input := ec2.ModifyInstanceAttributeInput{
InstanceType: &ec2.AttributeValue{Value: aws.String(instanceType)},
InstanceId: aws.String(instanceId),
}

if _, err := e.Service.ModifyInstanceAttributeWithContext(ctx, &input); err != nil {
return common.ErrCode("updating attributes", err)
}

return nil
}
72 changes: 72 additions & 0 deletions ec2/attributes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package ec2

import (
"context"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

func (m *mockEC2Client) ModifyInstanceAttributeWithContext(ctx aws.Context, inp *ec2.ModifyInstanceAttributeInput, opt ...request.Option) (*ec2.ModifyInstanceAttributeOutput, error) {
if m.err != nil {
return nil, m.err
}
return &ec2.ModifyInstanceAttributeOutput{}, nil

}
func TestEc2_UpdateAttributes(t *testing.T) {
type fields struct {
Service ec2iface.EC2API
}
type args struct {
ctx context.Context
instanceType string
instanceId string
}
tests := []struct {
name string
fields fields
e *Ec2
args args
wantErr bool
}{
{
name: "success case",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: false,
},
{
name: "aws error",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, awserr.New("Bad Request", "boom.", nil))},
wantErr: true,
},
{
name: "invalid input, instance id is empty",
args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: ""},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: true,
},
{
name: "invalid input, instance type is empty",
args: args{ctx: context.TODO(), instanceType: "", instanceId: "i-123"},
fields: fields{Service: newmockEC2Client(t, nil)},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Ec2{
Service: tt.fields.Service,
}
if err := e.UpdateAttributes(tt.args.ctx, tt.args.instanceType, tt.args.instanceId); (err != nil) != tt.wantErr {
t.Errorf("Ec2.UpdateAttributes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
16 changes: 16 additions & 0 deletions ssm/association.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ func (s *SSM) DescribeAssociation(ctx context.Context, instanceId, docName strin
log.Debugf("got output describing SSM Association: %+v", out)
return out, nil
}

func (s *SSM) CreateAssociation(ctx context.Context, instanceId, docName string) (*ssm.CreateAssociationOutput, error) {
if instanceId == "" || docName == "" {
return nil, apierror.New(apierror.ErrBadRequest, "both instanceId and docName should be present", nil)
}
inp:= &ssm.CreateAssociationInput{
Name: aws.String(docName),
InstanceId: aws.String(instanceId),
}
out, err := s.Service.CreateAssociationWithContext(ctx,inp)
if err != nil {
return nil, common.ErrCode("failed to create association", err)
}
log.Debugf("got output creating SSM Association: %+v", out)
return out, nil
}
79 changes: 79 additions & 0 deletions ssm/association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ func (m *mockSSMClient) DescribeAssociationWithContext(ctx context.Context, inp
}, nil
}

func (m *mockSSMClient) CreateAssociationWithContext(ctx context.Context, inp *ssm.CreateAssociationInput, opt ...request.Option) (*ssm.CreateAssociationOutput, error) {
if m.err != nil {
return nil, m.err
}

return &ssm.CreateAssociationOutput{
AssociationDescription: &ssm.AssociationDescription{AssociationId: aws.String("id123")},
}, nil
}

func TestSSM_DescribeAssociation(t *testing.T) {
type fields struct {
session *session.Session
Expand Down Expand Up @@ -109,3 +119,72 @@ func TestSSM_DescribeAssociation(t *testing.T) {
})
}
}

func TestSSM_CreateAssociation(t *testing.T) {
type fields struct {
session *session.Session
Service ssmiface.SSMAPI
}
type args struct {
ctx context.Context
instanceId string
docName string
}
tests := []struct {
name string
fields fields
args args
want *ssm.CreateAssociationOutput
wantErr bool
}{
{
name: "valid input",
fields: fields{Service: newMockSSMClient(t, nil)},
args: args{ctx: context.TODO(), instanceId: "i-123", docName: "doc123"},
want: &ssm.CreateAssociationOutput{
AssociationDescription: &ssm.AssociationDescription{
AssociationId: aws.String("id123"),
},
},
wantErr: false,
},
{
name: "valid input, error from aws",
fields: fields{Service: newMockSSMClient(t, errors.New("some error"))},
args: args{ctx: context.TODO(), instanceId: "i-123", docName: "doc123"},
want: nil,
wantErr: true,
},
{
name: "invalid input, instance id is empty",
fields: fields{Service: newMockSSMClient(t, nil)},
args: args{ctx: context.TODO(), instanceId: "", docName: "doc123"},
want: nil,
wantErr: true,
},
{
name: "invalid input, document name is empty",
fields: fields{Service: newMockSSMClient(t, nil)},
args: args{ctx: context.TODO(), instanceId: "i-123", docName: ""},
want: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &SSM{
session: tt.fields.session,
Service: tt.fields.Service,
}
got, err := s.CreateAssociation(tt.args.ctx, tt.args.instanceId, tt.args.docName)
if (err != nil) != tt.wantErr {
t.Errorf("SSM.CreateAssociation() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSM.CreateAssociation() = %v, want %v", got, tt.want)
}
})
}
}