From 718f11ffd7c440b66789b11d9dc3416174644545 Mon Sep 17 00:00:00 2001 From: nvnyale <100892976+nvnyale@users.noreply.github.com> Date: Thu, 28 Apr 2022 16:27:53 -0400 Subject: [PATCH] Support for Modifying instances state (#25) * Migrate Modify Instances State * Added comments * InstanceState Unit test --- api/handlers_instances.go | 44 ++++++ api/orchestration_instances.go | 22 +++ api/policy.go | 28 +++- api/routes.go | 2 +- api/types.go | 4 + ec2/instances.go | 43 ++++++ ec2/instances_test.go | 236 +++++++++++++++++++++++++++++++++ 7 files changed, 376 insertions(+), 3 deletions(-) diff --git a/api/handlers_instances.go b/api/handlers_instances.go index 7bfadc5..ad5d77a 100644 --- a/api/handlers_instances.go +++ b/api/handlers_instances.go @@ -348,3 +348,47 @@ func (s *server) DescribeAssociationHandler(w http.ResponseWriter, r *http.Reque } handleResponseOk(w, toSSMAssociationDescription(out)) } + +func (s *server) InstanceStateHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := s.mapAccountNumber(vars["account"]) + id := vars["id"] + + req := &Ec2InstanceStateChangeRequest{} + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + msg := fmt.Sprintf("cannot decode body into change power input: %s", err) + handleError(w, apierror.New(apierror.ErrBadRequest, msg, err)) + return + } + + if req.State == "" { + handleError(w, apierror.New(apierror.ErrBadRequest, "missing required field: state", nil)) + return + } + + policy, err := changeInstanceStatePolicy() + if err != nil { + handleError(w, err) + return + } + + orch, err := s.newEc2Orchestrator(r.Context(), &sessionParams{ + role: fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName), + inlinePolicy: policy, + policyArns: []string{ + "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", + }, + }) + if err != nil { + handleError(w, err) + return + } + + if err := orch.instancesState(r.Context(), req.State, id); err != nil { + handleError(w, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/api/orchestration_instances.go b/api/orchestration_instances.go index f7c6fe8..ad8c19b 100644 --- a/api/orchestration_instances.go +++ b/api/orchestration_instances.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "strings" "github.com/YaleSpinup/apierror" @@ -89,3 +90,24 @@ func blockDeviceMappingsFromRequest(r []Ec2BlockDevice) []*ec2.BlockDeviceMappin return blockDeviceMappings } + +// instancesState is used to start, stop and reboot a given instance +func (o *ec2Orchestrator) instancesState(ctx context.Context, state string, ids ...string) error { + if len(ids) == 0 || state == "" { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + state = strings.ToLower(state) + switch state { + case "start": + return o.ec2Client.StartInstance(ctx, ids...) + case "stop", "poweroff": + isForce := state == "poweroff" + return o.ec2Client.StopInstance(ctx, isForce, ids...) + case "reboot": + return o.ec2Client.RebootInstance(ctx, ids...) + default: + msg := fmt.Sprintf("unknown power state %q", state) + return apierror.New(apierror.ErrBadRequest, msg, nil) + } +} diff --git a/api/policy.go b/api/policy.go index 0210930..5d5ed7f 100644 --- a/api/policy.go +++ b/api/policy.go @@ -172,8 +172,7 @@ func sgUpdatePolicy(id string) (string, error) { } func tagCreatePolicy() (string, error) { - log.Debugf("generating tag crete policy document") - + log.Debugf("generating tag create policy document") policy := iam.PolicyDocument{ Version: "2012-10-17", Statement: []iam.StatementEntry{ @@ -242,3 +241,28 @@ func volumeDeletePolicy(id string) (string, error) { return string(j), nil } + +func changeInstanceStatePolicy() (string, error) { + log.Debugf("generating power update policy document") + policy := iam.PolicyDocument{ + Version: "2012-10-17", + Statement: []iam.StatementEntry{ + { + Effect: "Allow", + Action: []string{ + "ec2:StartInstances", + "ec2:StopInstances", + "ec2:RebootInstances", + }, + Resource: []string{"*"}, + }, + }, + } + + j, err := json.Marshal(policy) + if err != nil { + return "", err + } + + return string(j), nil +} diff --git a/api/routes.go b/api/routes.go index 74abe5b..078ebe8 100644 --- a/api/routes.go +++ b/api/routes.go @@ -66,7 +66,7 @@ func (s *server) routes() { api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/instances/{id}/power", s.ProxyRequestHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/ssm/command", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) diff --git a/api/types.go b/api/types.go index 3eabf4c..3640d7e 100644 --- a/api/types.go +++ b/api/types.go @@ -723,3 +723,7 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) { } return tgts } + +type Ec2InstanceStateChangeRequest struct { + State string +} diff --git a/ec2/instances.go b/ec2/instances.go index fa9dd89..0be5a5e 100644 --- a/ec2/instances.go +++ b/ec2/instances.go @@ -240,3 +240,46 @@ func (e *Ec2) GetInstanceVolume(ctx context.Context, id, volid string) (*ec2.Vol return out.Volumes[0], nil } + +func (e *Ec2) StartInstance(ctx context.Context, ids ...string) error { + if len(ids) == 0 { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + log.Infof("starting instance %s/%v", e.org, ids) + inp := &ec2.StartInstancesInput{ + InstanceIds: aws.StringSlice(ids), + } + if _, err := e.Service.StartInstancesWithContext(ctx, inp); err != nil { + return common.ErrCode("starting instance", err) + } + return nil +} + +func (e *Ec2) StopInstance(ctx context.Context, force bool, ids ...string) error { + if len(ids) == 0 { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + log.Infof("stopping instance %s/%v", e.org, ids) + inp := &ec2.StopInstancesInput{ + Force: aws.Bool(force), + InstanceIds: aws.StringSlice(ids), + } + if _, err := e.Service.StopInstancesWithContext(ctx, inp); err != nil { + return common.ErrCode("stopping instance", err) + } + return nil +} + +func (e *Ec2) RebootInstance(ctx context.Context, ids ...string) error { + if len(ids) == 0 { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + log.Infof("rebooting instance %s/%v", e.org, ids) + inp := &ec2.StartInstancesInput{ + InstanceIds: aws.StringSlice(ids), + } + if _, err := e.Service.StartInstancesWithContext(ctx, inp); err != nil { + return common.ErrCode("rebooting instance", err) + } + return nil +} diff --git a/ec2/instances_test.go b/ec2/instances_test.go index 12300b3..83107e4 100644 --- a/ec2/instances_test.go +++ b/ec2/instances_test.go @@ -60,6 +60,30 @@ func (m mockEC2Client) TerminateInstancesWithContext(ctx context.Context, input return &ec2.TerminateInstancesOutput{}, nil } +func (m mockEC2Client) StartInstancesWithContext(ctx context.Context, input *ec2.StartInstancesInput, opts ...request.Option) (*ec2.StartInstancesOutput, error) { + if m.err != nil { + return nil, m.err + } + + return &ec2.StartInstancesOutput{}, nil +} + +func (m mockEC2Client) StopInstancesWithContext(ctx context.Context, input *ec2.StopInstancesInput, opts ...request.Option) (*ec2.StopInstancesOutput, error) { + if m.err != nil { + return nil, m.err + } + + return &ec2.StopInstancesOutput{}, nil +} + +func (m mockEC2Client) RebootInstancesWithContext(ctx context.Context, input *ec2.RebootInstancesInput, opts ...request.Option) (*ec2.RebootInstancesOutput, error) { + if m.err != nil { + return nil, m.err + } + + return &ec2.RebootInstancesOutput{}, nil +} + func TestEc2_CreateInstance(t *testing.T) { type fields struct { session *session.Session @@ -267,3 +291,215 @@ func TestEc2_GetInstance(t *testing.T) { }) } } + +func TestEc2_StartInstance(t *testing.T) { + type fields struct { + session *session.Session + Service ec2iface.EC2API + DefaultKMSKeyId string + DefaultSgs []string + DefaultSubnets []string + org string + } + type args struct { + ctx context.Context + ids []string + } + + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "nil input", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ctx: context.TODO()}, + wantErr: true, + }, + { + name: "good input", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + ids: []string{"i-0123456789abcdef0"}, + }, + }, + { + name: "aws err", + fields: fields{ + Service: newmockEC2Client(t, awserr.New("BadRequest", "boom", nil)), + }, + args: args{ + ctx: context.TODO(), + ids: []string{"i-0123456789abcdef0"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + session: tt.fields.session, + Service: tt.fields.Service, + DefaultKMSKeyId: tt.fields.DefaultKMSKeyId, + DefaultSgs: tt.fields.DefaultSgs, + DefaultSubnets: tt.fields.DefaultSubnets, + org: tt.fields.org, + } + err := e.StartInstance(tt.args.ctx, tt.args.ids...) + if (err != nil) != tt.wantErr { + t.Errorf("Ec2.StartInstance() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestEc2_StopInstance(t *testing.T) { + type fields struct { + session *session.Session + Service ec2iface.EC2API + DefaultKMSKeyId string + DefaultSgs []string + DefaultSubnets []string + org string + } + type args struct { + ctx context.Context + force bool + ids []string + } + + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "nil input", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ctx: context.TODO()}, + wantErr: true, + }, + { + name: "good input", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + force: true, + ids: []string{"i-0123456789abcdef0"}, + }, + }, + { + name: "aws err", + fields: fields{ + Service: newmockEC2Client(t, awserr.New("BadRequest", "boom", nil)), + }, + args: args{ + ctx: context.TODO(), + ids: []string{"i-0123456789abcdef0"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + session: tt.fields.session, + Service: tt.fields.Service, + DefaultKMSKeyId: tt.fields.DefaultKMSKeyId, + DefaultSgs: tt.fields.DefaultSgs, + DefaultSubnets: tt.fields.DefaultSubnets, + org: tt.fields.org, + } + err := e.StopInstance(tt.args.ctx, tt.args.force, tt.args.ids...) + if (err != nil) != tt.wantErr { + t.Errorf("Ec2.StopInstance() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestEc2_RebootInstance(t *testing.T) { + type fields struct { + session *session.Session + Service ec2iface.EC2API + DefaultKMSKeyId string + DefaultSgs []string + DefaultSubnets []string + org string + } + type args struct { + ctx context.Context + ids []string + } + + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "nil input", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ctx: context.TODO()}, + wantErr: true, + }, + { + name: "good input", + fields: fields{ + Service: newmockEC2Client(t, nil), + }, + args: args{ + ctx: context.TODO(), + ids: []string{"i-0123456789abcdef0"}, + }, + }, + { + name: "aws err", + fields: fields{ + Service: newmockEC2Client(t, awserr.New("BadRequest", "boom", nil)), + }, + args: args{ + ctx: context.TODO(), + ids: []string{"i-0123456789abcdef0"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + session: tt.fields.session, + Service: tt.fields.Service, + DefaultKMSKeyId: tt.fields.DefaultKMSKeyId, + DefaultSgs: tt.fields.DefaultSgs, + DefaultSubnets: tt.fields.DefaultSubnets, + org: tt.fields.org, + } + err := e.RebootInstance(tt.args.ctx, tt.args.ids...) + if (err != nil) != tt.wantErr { + t.Errorf("Ec2.RebootInstance() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +}