Skip to content

Commit

Permalink
Migrate SSM PUT Command (#26)
Browse files Browse the repository at this point in the history
* PUT SSM Command

* Fixed comments

* Added Unit Test
  • Loading branch information
nvnyale authored May 9, 2022
1 parent 718f11f commit a7668e9
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 2 deletions.
51 changes: 51 additions & 0 deletions api/handlers_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,54 @@ func (s *server) InstanceStateHandler(w http.ResponseWriter, r *http.Request) {

w.WriteHeader(http.StatusNoContent)
}

func (s *server) InstanceSendCommandHandler(w http.ResponseWriter, r *http.Request) {
w = LogWriter{w}
vars := mux.Vars(r)
account := s.mapAccountNumber(vars["account"])
id := vars["id"]

req := SsmCommandRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
msg := fmt.Sprintf("cannot decode body into ssm send command input: %s", err)
handleError(w, apierror.New(apierror.ErrBadRequest, msg, err))
return
}

if req.DocumentName == "" {
handleError(w, apierror.New(apierror.ErrBadRequest, "DocumentName is required", nil))
return

}

if len(req.Parameters) == 0 {
handleError(w, apierror.New(apierror.ErrBadRequest, "Parameters are required", nil))
return
}
policy, err := sendCommandPolicy()
if err != nil {
handleError(w, err)
return
}

orch, err := s.newSSMOrchestrator(r.Context(), &sessionParams{
role: fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName),
inlinePolicy: policy,
policyArns: []string{
"arn:aws:iam::aws:policy/AmazonSSMReadOnlyAccess",
},
})
if err != nil {
handleError(w, err)
return
}

out, err := orch.sendInstancesCommand(r.Context(), &req, id)
if err != nil {
handleError(w, err)
return
}

handleResponseOk(w, out)

}
20 changes: 20 additions & 0 deletions api/orchestration_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ssm"
log "github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -111,3 +112,22 @@ func (o *ec2Orchestrator) instancesState(ctx context.Context, state string, ids
return apierror.New(apierror.ErrBadRequest, msg, nil)
}
}

func (o *ssmOrchestrator) sendInstancesCommand(ctx context.Context, req *SsmCommandRequest, id ...string) (string, error) {
if req == nil {
return "", apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Debugf("got request to send command: %s", awsutil.Prettify(req))
input := &ssm.SendCommandInput{
DocumentName: aws.String(req.DocumentName),
Parameters: req.Parameters,
TimeoutSeconds: req.TimeoutSeconds,
InstanceIds: aws.StringSlice(id),
}
cmd, err := o.ssmClient.SendCommand(ctx, input)
if err != nil {
return "", err
}
return aws.StringValue(cmd.CommandId), nil
}
26 changes: 26 additions & 0 deletions api/orchestrators.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/YaleSpinup/ec2-api/ec2"
"github.com/YaleSpinup/ec2-api/ssm"
log "github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -37,3 +38,28 @@ func (s *server) newEc2Orchestrator(ctx context.Context, sp *sessionParams) (*ec
server: s,
}, nil
}

type ssmOrchestrator struct {
ssmClient *ssm.SSM
server *server
}

func (s *server) newSSMOrchestrator(ctx context.Context, sp *sessionParams) (*ssmOrchestrator, error) {
log.Debugf("initializing ssmOrchestrator")

session, err := s.assumeRole(
ctx,
s.session.ExternalID,
sp.role,
sp.inlinePolicy,
sp.policyArns...,
)
if err != nil {
return nil, err
}

return &ssmOrchestrator{
ssmClient: ssm.New(ssm.WithSession(session.Session)),
server: s,
}, nil
}
26 changes: 25 additions & 1 deletion api/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func tagCreatePolicy() (string, error) {
}

func volumeCreatePolicy() (string, error) {
log.Debugf("generating volume crete policy document")
log.Debugf("generating volume create policy document")

policy := iam.PolicyDocument{
Version: "2012-10-17",
Expand Down Expand Up @@ -266,3 +266,27 @@ func changeInstanceStatePolicy() (string, error) {

return string(j), nil
}

func sendCommandPolicy() (string, error) {
log.Debugf("generating send command policy document")

policy := iam.PolicyDocument{
Version: "2012-10-17",
Statement: []iam.StatementEntry{
{
Effect: "Allow",
Action: []string{
"ssm:SendCommand",
},
Resource: []string{"*"},
},
},
}

j, err := json.Marshal(policy)
if err != nil {
return "", err
}

return string(j), nil
}
2 changes: 1 addition & 1 deletion api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s *server) routes() {
api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/command", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/command", s.InstanceSendCommandHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut)
api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut)
Expand Down
6 changes: 6 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,9 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) {
type Ec2InstanceStateChangeRequest struct {
State string
}

type SsmCommandRequest struct {
DocumentName string `json:"document_name"`
Parameters map[string][]*string `json:"parameters"`
TimeoutSeconds *int64 `json:"timeout"`
}
15 changes: 15 additions & 0 deletions ssm/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,18 @@ func (s *SSM) GetCommandInvocation(ctx context.Context, instanceId, commandId st
log.Debugf("got output describing SSM Command: %+v", out)
return out, nil
}

func (s *SSM) SendCommand(ctx context.Context, input *ssm.SendCommandInput) (*ssm.Command, error) {
if input == nil {
return nil, apierror.New(apierror.ErrBadRequest, "invalid input", nil)
}

log.Infof("sending command with doc name: %s, params: %+v", aws.StringValue(input.DocumentName), input.Parameters)

out, err := s.Service.SendCommandWithContext(ctx, input)
if err != nil {
return nil, common.ErrCode("failed to send command", err)
}
log.Debugf("got output sending command: %+v", out)
return out.Command, nil
}
64 changes: 64 additions & 0 deletions ssm/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ func (m *mockSSMClient) GetCommandInvocationWithContext(ctx context.Context, inp
}, nil
}

func (m *mockSSMClient) SendCommandWithContext(ctx aws.Context, inp *ssm.SendCommandInput, opt ...request.Option) (*ssm.SendCommandOutput, error) {
if m.err != nil {
return nil, m.err
}
return &ssm.SendCommandOutput{
Command: &ssm.Command{CommandId: aws.String("Command-123")},
}, nil

}

func TestSSM_GetCommandInvocation(t *testing.T) {
type fields struct {
session *session.Session
Expand Down Expand Up @@ -106,3 +116,57 @@ func TestSSM_GetCommandInvocation(t *testing.T) {
})
}
}

func TestSSM_SendCommand(t *testing.T) {
type fields struct {
session *session.Session
Service ssmiface.SSMAPI
}
type args struct {
ctx context.Context
input *ssm.SendCommandInput
}
tests := []struct {
name string
fields fields
s *SSM
args args
want *ssm.Command
wantErr bool
}{
{
name: "valid input",
fields: fields{Service: newMockSSMClient(t, nil)},
args: args{ctx: context.TODO(), input: &ssm.SendCommandInput{}},
want: &ssm.Command{CommandId: aws.String("Command-123")},
},
{
name: "valid input, aws error",
fields: fields{Service: newMockSSMClient(t, errors.New("some error"))},
args: args{ctx: context.TODO(), input: &ssm.SendCommandInput{}},
wantErr: true,
},
{
name: "invalid input",
fields: fields{Service: newMockSSMClient(t, errors.New("some error"))},
args: args{ctx: context.TODO(), input: nil},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &SSM{
session: tt.fields.session,
Service: tt.fields.Service,
}
got, err := s.SendCommand(tt.args.ctx, tt.args.input)
if (err != nil) != tt.wantErr {
t.Errorf("SSM.SendCommand() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSM.SendCommand() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit a7668e9

Please sign in to comment.