diff --git a/api/handlers_instances.go b/api/handlers_instances.go index 506706c..8487490 100644 --- a/api/handlers_instances.go +++ b/api/handlers_instances.go @@ -443,3 +443,121 @@ func (s *server) InstanceSendCommandHandler(w http.ResponseWriter, r *http.Reque handleResponseOk(w, out) } + +func (s *server) NotImplementedHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + w.WriteHeader(http.StatusNotImplemented) +} + +func (s *server) InstanceSSMAssociationHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := s.mapAccountNumber(vars["account"]) + instanceId := vars["id"] + + req := &SSMAssociationRequest{} + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + msg := fmt.Sprintf("cannot decode body into ssm create input: %s", err) + handleError(w, apierror.New(apierror.ErrBadRequest, msg, err)) + return + } + + if req.Document == "" { + handleError(w, apierror.New(apierror.ErrBadRequest, "Document is mandatory", nil)) + return + } + + role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName) + policy, err := ssmAssociationPolicy() + if err != nil { + handleError(w, err) + return + } + + session, err := s.assumeRole( + r.Context(), + s.session.ExternalID, + role, + policy, + "arn:aws:iam::aws:policy/AmazonSSMReadOnlyAccess", + "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", + ) + if err != nil { + msg := fmt.Sprintf("failed to assume role in account: %s", account) + handleError(w, apierror.New(apierror.ErrForbidden, msg, err)) + return + } + + service := ssm.New( + ssm.WithSession(session.Session), + ) + + out, err := service.CreateAssociation(r.Context(), instanceId, req.Document) + if err != nil { + handleError(w, err) + return + } + + handleResponseOk(w, out) +} + +func (s *server) InstanceUpdateHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := s.mapAccountNumber(vars["account"]) + instanceId := vars["id"] + + req := &Ec2InstanceUpdateRequest{} + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + msg := fmt.Sprintf("cannot decode body into update image input: %s", err) + handleError(w, apierror.New(apierror.ErrBadRequest, msg, err)) + return + } + + if len(req.Tags) == 0 && len(req.InstanceType) == 0 { + handleError(w, apierror.New(apierror.ErrBadRequest, "missing required fields: tags or instance_type", nil)) + return + } else if len(req.Tags) > 0 && len(req.InstanceType) > 0 { + handleError(w, apierror.New(apierror.ErrBadRequest, "only one of these fields should be provided: tags or instance_type", nil)) + return + } + + role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName) + policy, err := instanceUpdatePolicy() + if err != nil { + handleError(w, err) + return + } + + session, err := s.assumeRole( + r.Context(), + s.session.ExternalID, + role, + policy, + "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", + ) + if err != nil { + msg := fmt.Sprintf("failed to assume role in account: %s", account) + handleError(w, apierror.New(apierror.ErrForbidden, msg, err)) + return + } + + service := ec2.New( + ec2.WithSession(session.Session), + ec2.WithOrg(s.org), + ) + + if len(req.Tags) > 0 { + if err := service.UpdateTags(r.Context(), req.Tags, instanceId); err != nil { + handleError(w, err) + return + } + } else if len(req.InstanceType) > 0 { + if err := service.UpdateAttributes(r.Context(), req.InstanceType["value"], instanceId); err != nil { + handleError(w, err) + return + } + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/api/policy.go b/api/policy.go index 2055dbf..8f4fa07 100644 --- a/api/policy.go +++ b/api/policy.go @@ -290,3 +290,51 @@ func sendCommandPolicy() (string, error) { return string(j), nil } + +func instanceUpdatePolicy() (string, error) { + log.Debugf("generating tag create policy document") + policy := iam.PolicyDocument{ + Version: "2012-10-17", + Statement: []iam.StatementEntry{ + { + Effect: "Allow", + Action: []string{ + "ec2:CreateTags", + "ec2:ModifyInstanceAttribute", + }, + Resource: []string{"*"}, + }, + }, + } + + j, err := json.Marshal(policy) + if err != nil { + return "", err + } + + return string(j), nil +} + +func ssmAssociationPolicy() (string, error) { + log.Debugf("generating tag create policy document") + policy := iam.PolicyDocument{ + Version: "2012-10-17", + Statement: []iam.StatementEntry{ + { + Effect: "Allow", + Action: []string{ + "ssm:CreateAssociation", + "ssm:UpdateAssociation", + }, + Resource: []string{"*"}, + }, + }, + } + + j, err := json.Marshal(policy) + if err != nil { + return "", err + } + + return string(j), nil +} diff --git a/api/routes.go b/api/routes.go index dada1d4..ebc28d5 100644 --- a/api/routes.go +++ b/api/routes.go @@ -65,12 +65,12 @@ func (s *server) routes() { api.HandleFunc("/{account}/images", s.ProxyRequestHandler).Methods(http.MethodPost) api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/instances/{id}", s.NotImplementedHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/power", s.InstanceStateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/ssm/command", s.InstanceSendCommandHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/instances/{id}/ssm/association", s.ProxyRequestHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/instances/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) - api.HandleFunc("/{account}/instances/{id}/attribute", s.ProxyRequestHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/instances/{id}/ssm/association", s.InstanceSSMAssociationHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/instances/{id}/tags", s.InstanceUpdateHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/instances/{id}/attribute", s.InstanceUpdateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/sgs/{id}", s.SecurityGroupUpdateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/sgs/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/volumes/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) diff --git a/api/server.go b/api/server.go index ca28396..0739dfa 100644 --- a/api/server.go +++ b/api/server.go @@ -125,7 +125,7 @@ func NewServer(config common.Config) error { if config.ListenAddress == "" { config.ListenAddress = ":8080" } - handler := handlers.RecoveryHandler()(handlers.LoggingHandler(os.Stdout, TokenMiddleware([]byte(config.Token), publicURLs, s.router))) + handler := handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))(handlers.LoggingHandler(os.Stdout, TokenMiddleware([]byte(config.Token), publicURLs, s.router))) srv := &http.Server{ Handler: handler, Addr: config.ListenAddress, diff --git a/api/types.go b/api/types.go index b74e2b6..220d304 100644 --- a/api/types.go +++ b/api/types.go @@ -654,7 +654,12 @@ func toSSMGetCommandInvocationOutput(rawOut *ssm.GetCommandInvocationOutput) *SS } type Ec2ImageUpdateRequest struct { - Tags map[string]string + Tags map[string]string `json:"tags"` +} + +type Ec2InstanceUpdateRequest struct { + Tags map[string]string `json:"tags"` + InstanceType map[string]string `json:"instance_type"` } type AssociationDescription struct { Name string `json:"name"` @@ -687,28 +692,31 @@ type AssociationTarget struct { } func toSSMAssociationDescription(rawDesc *ssm.DescribeAssociationOutput) *AssociationDescription { - const dateLayout = "2006-01-02 15:04:05 +0000" + var status AssociationStatus + if rawDesc.AssociationDescription.Status != nil { + status.Date = tzTimeFormat(rawDesc.AssociationDescription.Status.Date) + status.Name = aws.StringValue(rawDesc.AssociationDescription.Status.Name) + status.Message = aws.StringValue(rawDesc.AssociationDescription.Status.Message) + } + var overview AssociationOverview + if rawDesc.AssociationDescription.Overview != nil { + overview.Status = aws.StringValue(rawDesc.AssociationDescription.Overview.Status) + overview.DetailedStatus = aws.StringValue(rawDesc.AssociationDescription.Overview.DetailedStatus) + } return &AssociationDescription{ - Name: aws.StringValue(rawDesc.AssociationDescription.Name), - InstanceId: aws.StringValue(rawDesc.AssociationDescription.InstanceId), - AssociationVersion: aws.StringValue(rawDesc.AssociationDescription.AssociationVersion), - Date: rawDesc.AssociationDescription.Date.Format(dateLayout), - LastUpdateAssociationDate: rawDesc.AssociationDescription.LastUpdateAssociationDate.Format(dateLayout), - Status: AssociationStatus{ - Date: rawDesc.AssociationDescription.Status.Date.Format(dateLayout), - Name: aws.StringValue(rawDesc.AssociationDescription.Status.Name), - Message: aws.StringValue(rawDesc.AssociationDescription.Status.Message), - }, - Overview: AssociationOverview{ - Status: aws.StringValue(rawDesc.AssociationDescription.Overview.Status), - DetailedStatus: aws.StringValue(rawDesc.AssociationDescription.Overview.DetailedStatus), - }, + Name: aws.StringValue(rawDesc.AssociationDescription.Name), + InstanceId: aws.StringValue(rawDesc.AssociationDescription.InstanceId), + AssociationVersion: aws.StringValue(rawDesc.AssociationDescription.AssociationVersion), + Date: tzTimeFormat(rawDesc.AssociationDescription.Date), + LastUpdateAssociationDate: tzTimeFormat(rawDesc.AssociationDescription.LastUpdateAssociationDate), + Status: status, + Overview: overview, DocumentVersion: aws.StringValue(rawDesc.AssociationDescription.DocumentVersion), AssociationId: aws.StringValue(rawDesc.AssociationDescription.AssociationId), Targets: parseAssociationTargets(rawDesc.AssociationDescription.Targets), - LastExecutionDate: rawDesc.AssociationDescription.LastExecutionDate.Format(dateLayout), - LastSuccessfulExecutionDate: rawDesc.AssociationDescription.LastSuccessfulExecutionDate.Format(dateLayout), + LastExecutionDate: tzTimeFormat(rawDesc.AssociationDescription.LastExecutionDate), + LastSuccessfulExecutionDate: tzTimeFormat(rawDesc.AssociationDescription.LastSuccessfulExecutionDate), ApplyOnlyAtCronInterval: aws.BoolValue(rawDesc.AssociationDescription.ApplyOnlyAtCronInterval), } } @@ -725,7 +733,11 @@ func parseAssociationTargets(rawTgts []*ssm.Target) (tgts []AssociationTarget) { } type Ec2InstanceStateChangeRequest struct { - State string + State string `json:"state"` +} + +type SSMAssociationRequest struct { + Document string `json:"document"` } type SsmCommandRequest struct { diff --git a/ec2/attributes.go b/ec2/attributes.go new file mode 100644 index 0000000..6f7bdda --- /dev/null +++ b/ec2/attributes.go @@ -0,0 +1,30 @@ +package ec2 + +import ( + "context" + + "github.com/YaleSpinup/apierror" + "github.com/YaleSpinup/ec2-api/common" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + log "github.com/sirupsen/logrus" +) + +func (e *Ec2) UpdateAttributes(ctx context.Context, instanceType, instanceId string) error { + if len(instanceId) == 0 || len(instanceType) == 0 { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("updating attributes: %v with instance type %+v", instanceId, instanceType) + + input := ec2.ModifyInstanceAttributeInput{ + InstanceType: &ec2.AttributeValue{Value: aws.String(instanceType)}, + InstanceId: aws.String(instanceId), + } + + if _, err := e.Service.ModifyInstanceAttributeWithContext(ctx, &input); err != nil { + return common.ErrCode("updating attributes", err) + } + + return nil +} diff --git a/ec2/attributes_test.go b/ec2/attributes_test.go new file mode 100644 index 0000000..34557c4 --- /dev/null +++ b/ec2/attributes_test.go @@ -0,0 +1,72 @@ +package ec2 + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" +) + +func (m *mockEC2Client) ModifyInstanceAttributeWithContext(ctx aws.Context, inp *ec2.ModifyInstanceAttributeInput, opt ...request.Option) (*ec2.ModifyInstanceAttributeOutput, error) { + if m.err != nil { + return nil, m.err + } + return &ec2.ModifyInstanceAttributeOutput{}, nil + +} +func TestEc2_UpdateAttributes(t *testing.T) { + type fields struct { + Service ec2iface.EC2API + } + type args struct { + ctx context.Context + instanceType string + instanceId string + } + tests := []struct { + name string + fields fields + e *Ec2 + args args + wantErr bool + }{ + { + name: "success case", + args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"}, + fields: fields{Service: newmockEC2Client(t, nil)}, + wantErr: false, + }, + { + name: "aws error", + args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: "i-123"}, + fields: fields{Service: newmockEC2Client(t, awserr.New("Bad Request", "boom.", nil))}, + wantErr: true, + }, + { + name: "invalid input, instance id is empty", + args: args{ctx: context.TODO(), instanceType: "Type1", instanceId: ""}, + fields: fields{Service: newmockEC2Client(t, nil)}, + wantErr: true, + }, + { + name: "invalid input, instance type is empty", + args: args{ctx: context.TODO(), instanceType: "", instanceId: "i-123"}, + fields: fields{Service: newmockEC2Client(t, nil)}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + Service: tt.fields.Service, + } + if err := e.UpdateAttributes(tt.args.ctx, tt.args.instanceType, tt.args.instanceId); (err != nil) != tt.wantErr { + t.Errorf("Ec2.UpdateAttributes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/ssm/association.go b/ssm/association.go index bdf7e0f..251a563 100644 --- a/ssm/association.go +++ b/ssm/association.go @@ -24,3 +24,19 @@ func (s *SSM) DescribeAssociation(ctx context.Context, instanceId, docName strin log.Debugf("got output describing SSM Association: %+v", out) return out, nil } + +func (s *SSM) CreateAssociation(ctx context.Context, instanceId, docName string) (string, error) { + if instanceId == "" || docName == "" { + return "", apierror.New(apierror.ErrBadRequest, "both instanceId and docName should be present", nil) + } + inp := &ssm.CreateAssociationInput{ + Name: aws.String(docName), + InstanceId: aws.String(instanceId), + } + out, err := s.Service.CreateAssociationWithContext(ctx, inp) + if err != nil { + return "", common.ErrCode("failed to create association", err) + } + log.Debugf("got output creating SSM Association: %+v", out) + return aws.StringValue(out.AssociationDescription.AssociationId), nil +} diff --git a/ssm/association_test.go b/ssm/association_test.go index 688fdb8..c728da2 100644 --- a/ssm/association_test.go +++ b/ssm/association_test.go @@ -33,6 +33,16 @@ func (m *mockSSMClient) DescribeAssociationWithContext(ctx context.Context, inp }, nil } +func (m *mockSSMClient) CreateAssociationWithContext(ctx context.Context, inp *ssm.CreateAssociationInput, opt ...request.Option) (*ssm.CreateAssociationOutput, error) { + if m.err != nil { + return nil, m.err + } + + return &ssm.CreateAssociationOutput{ + AssociationDescription: &ssm.AssociationDescription{AssociationId: aws.String("id123")}, + }, nil +} + func TestSSM_DescribeAssociation(t *testing.T) { type fields struct { session *session.Session @@ -109,3 +119,68 @@ func TestSSM_DescribeAssociation(t *testing.T) { }) } } + +func TestSSM_CreateAssociation(t *testing.T) { + type fields struct { + session *session.Session + Service ssmiface.SSMAPI + } + type args struct { + ctx context.Context + instanceId string + docName string + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + name: "valid input", + fields: fields{Service: newMockSSMClient(t, nil)}, + args: args{ctx: context.TODO(), instanceId: "i-123", docName: "doc123"}, + want: "id123", + wantErr: false, + }, + { + name: "valid input, error from aws", + fields: fields{Service: newMockSSMClient(t, errors.New("some error"))}, + args: args{ctx: context.TODO(), instanceId: "i-123", docName: "doc123"}, + want: "", + wantErr: true, + }, + { + name: "invalid input, instance id is empty", + fields: fields{Service: newMockSSMClient(t, nil)}, + args: args{ctx: context.TODO(), instanceId: "", docName: "doc123"}, + want: "", + wantErr: true, + }, + { + name: "invalid input, document name is empty", + fields: fields{Service: newMockSSMClient(t, nil)}, + args: args{ctx: context.TODO(), instanceId: "i-123", docName: ""}, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SSM{ + session: tt.fields.session, + Service: tt.fields.Service, + } + got, err := s.CreateAssociation(tt.args.ctx, tt.args.instanceId, tt.args.docName) + if (err != nil) != tt.wantErr { + t.Errorf("SSM.CreateAssociation() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("SSM.CreateAssociation() = %s, want %s", got, tt.want) + } + }) + } +}