diff --git a/api/handlers_instances.go b/api/handlers_instances.go index ad5d77a..506706c 100644 --- a/api/handlers_instances.go +++ b/api/handlers_instances.go @@ -392,3 +392,54 @@ func (s *server) InstanceStateHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } + +func (s *server) InstanceSendCommandHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := s.mapAccountNumber(vars["account"]) + id := vars["id"] + + req := SsmCommandRequest{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + msg := fmt.Sprintf("cannot decode body into ssm send command input: %s", err) + handleError(w, apierror.New(apierror.ErrBadRequest, msg, err)) + return + } + + if req.DocumentName == "" { + handleError(w, apierror.New(apierror.ErrBadRequest, "DocumentName is required", nil)) + return + + } + + if len(req.Parameters) == 0 { + handleError(w, apierror.New(apierror.ErrBadRequest, "Parameters are required", nil)) + return + } + policy, err := sendCommandPolicy() + if err != nil { + handleError(w, err) + return + } + + orch, err := s.newSSMOrchestrator(r.Context(), &sessionParams{ + role: fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName), + inlinePolicy: policy, + policyArns: []string{ + "arn:aws:iam::aws:policy/AmazonSSMReadOnlyAccess", + }, + }) + if err != nil { + handleError(w, err) + return + } + + out, err := orch.sendInstancesCommand(r.Context(), &req, id) + if err != nil { + handleError(w, err) + return + } + + handleResponseOk(w, out) + +} diff --git a/api/orchestration_instances.go b/api/orchestration_instances.go index ad8c19b..0bb2198 100644 --- a/api/orchestration_instances.go +++ b/api/orchestration_instances.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ssm" log "github.com/sirupsen/logrus" ) @@ -111,3 +112,22 @@ func (o *ec2Orchestrator) instancesState(ctx context.Context, state string, ids return apierror.New(apierror.ErrBadRequest, msg, nil) } } + +func (o *ssmOrchestrator) sendInstancesCommand(ctx context.Context, req *SsmCommandRequest, id ...string) (string, error) { + if req == nil { + return "", apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Debugf("got request to send command: %s", awsutil.Prettify(req)) + input := &ssm.SendCommandInput{ + DocumentName: aws.String(req.DocumentName), + Parameters: req.Parameters, + TimeoutSeconds: req.TimeoutSeconds, + InstanceIds: aws.StringSlice(id), + } + cmd, err := o.ssmClient.SendCommand(ctx, input) + if err != nil { + return "", err + } + return aws.StringValue(cmd.CommandId), nil +} diff --git a/api/orchestrators.go b/api/orchestrators.go index 5e1d522..b04817d 100644 --- a/api/orchestrators.go +++ b/api/orchestrators.go @@ -4,6 +4,7 @@ import ( "context" "github.com/YaleSpinup/ec2-api/ec2" + "github.com/YaleSpinup/ec2-api/ssm" log "github.com/sirupsen/logrus" ) @@ -37,3 +38,28 @@ func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec server: s, }, nil } + +type ssmOrchestrator struct { + ssmClient *ssm.SSM + server *server +} + +func (s *server) newSSMOrchestrator(ctx context.Context, sp *sessionParams) (*ssmOrchestrator, error) { + log.Debugf("initializing ssmOrchestrator") + + session, err := s.assumeRole( + ctx, + s.session.ExternalID, + sp.role, + sp.inlinePolicy, + sp.policyArns..., + ) + if err != nil { + return nil, err + } + + return &ssmOrchestrator{ + ssmClient: ssm.New(ssm.WithSession(session.Session)), + server: s, + }, nil +} diff --git a/api/policy.go b/api/policy.go index 5d5ed7f..2055dbf 100644 --- a/api/policy.go +++ b/api/policy.go @@ -195,7 +195,7 @@ func tagCreatePolicy() (string, error) { } func volumeCreatePolicy() (string, error) { - log.Debugf("generating volume crete policy document") + log.Debugf("generating volume create policy document") policy := iam.PolicyDocument{ Version: "2012-10-17", @@ -266,3 +266,27 @@ func changeInstanceStatePolicy() (string, error) { return string(j), nil } + +func sendCommandPolicy() (string, error) { + log.Debugf("generating send command policy document") + + policy := iam.PolicyDocument{ + Version: "2012-10-17", + Statement: []iam.StatementEntry{ + { + Effect: "Allow", + Action: []string{ + "ssm:SendCommand", + }, + Resource: []string{"*"}, + }, + }, + } + + j, err := json.Marshal(policy) + if err != nil { + return "", err + } + + return string(j), nil +} diff --git a/api/routes.go b/api/routes.go index 078ebe8..dada1d4 100644 --- a/api/routes.go +++ b/api/routes.go @@ -67,7 +67,7 @@ func (s *server) routes() { api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/instances/{id}/ssm/command", s.ProxyRequestHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/instances/{id}/ssm/command", s.InstanceSendCommandHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut) diff --git a/api/types.go b/api/types.go index 3640d7e..b74e2b6 100644 --- a/api/types.go +++ b/api/types.go @@ -727,3 +727,9 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) { type Ec2InstanceStateChangeRequest struct { State string } + +type SsmCommandRequest struct { + DocumentName string `json:"document_name"` + Parameters map[string][]*string `json:"parameters"` + TimeoutSeconds *int64 `json:"timeout"` +} diff --git a/ssm/command.go b/ssm/command.go index 1004372..16b4324 100644 --- a/ssm/command.go +++ b/ssm/command.go @@ -24,3 +24,18 @@ func (s *SSM) GetCommandInvocation(ctx context.Context, instanceId, commandId st log.Debugf("got output describing SSM Command: %+v", out) return out, nil } + +func (s *SSM) SendCommand(ctx context.Context, input *ssm.SendCommandInput) (*ssm.Command, error) { + if input == nil { + return nil, apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("sending command with doc name: %s, params: %+v", aws.StringValue(input.DocumentName), input.Parameters) + + out, err := s.Service.SendCommandWithContext(ctx, input) + if err != nil { + return nil, common.ErrCode("failed to send command", err) + } + log.Debugf("got output sending command: %+v", out) + return out.Command, nil +} diff --git a/ssm/command_test.go b/ssm/command_test.go index 24187d9..d57cddd 100644 --- a/ssm/command_test.go +++ b/ssm/command_test.go @@ -32,6 +32,16 @@ func (m *mockSSMClient) GetCommandInvocationWithContext(ctx context.Context, inp }, nil } +func (m *mockSSMClient) SendCommandWithContext(ctx aws.Context, inp *ssm.SendCommandInput, opt ...request.Option) (*ssm.SendCommandOutput, error) { + if m.err != nil { + return nil, m.err + } + return &ssm.SendCommandOutput{ + Command: &ssm.Command{CommandId: aws.String("Command-123")}, + }, nil + +} + func TestSSM_GetCommandInvocation(t *testing.T) { type fields struct { session *session.Session @@ -106,3 +116,57 @@ func TestSSM_GetCommandInvocation(t *testing.T) { }) } } + +func TestSSM_SendCommand(t *testing.T) { + type fields struct { + session *session.Session + Service ssmiface.SSMAPI + } + type args struct { + ctx context.Context + input *ssm.SendCommandInput + } + tests := []struct { + name string + fields fields + s *SSM + args args + want *ssm.Command + wantErr bool + }{ + { + name: "valid input", + fields: fields{Service: newMockSSMClient(t, nil)}, + args: args{ctx: context.TODO(), input: &ssm.SendCommandInput{}}, + want: &ssm.Command{CommandId: aws.String("Command-123")}, + }, + { + name: "valid input, aws error", + fields: fields{Service: newMockSSMClient(t, errors.New("some error"))}, + args: args{ctx: context.TODO(), input: &ssm.SendCommandInput{}}, + wantErr: true, + }, + { + name: "invalid input", + fields: fields{Service: newMockSSMClient(t, errors.New("some error"))}, + args: args{ctx: context.TODO(), input: nil}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SSM{ + session: tt.fields.session, + Service: tt.fields.Service, + } + got, err := s.SendCommand(tt.args.ctx, tt.args.input) + if (err != nil) != tt.wantErr { + t.Errorf("SSM.SendCommand() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SSM.SendCommand() = %v, want %v", got, tt.want) + } + }) + } +}