From 45bc9d993b3427aba91afd0bc33ee07d596e1f5d Mon Sep 17 00:00:00 2001 From: nvnyale <100892976+nvnyale@users.noreply.github.com> Date: Mon, 25 Apr 2022 13:00:35 -0400 Subject: [PATCH] Add support for updating image tags (#23) * Add support for updating image tags * Fixed comments * Added Unit test * Resolved Comments --- api/handlers_images.go | 52 ++++++++++++++++++++++++++ api/policy.go | 22 +++++++++++ api/routes.go | 2 +- api/types.go | 3 ++ ec2/tags.go | 34 +++++++++++++++++ ec2/tags_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 ec2/tags.go create mode 100644 ec2/tags_test.go diff --git a/api/handlers_images.go b/api/handlers_images.go index e4ca22e..5d1791a 100644 --- a/api/handlers_images.go +++ b/api/handlers_images.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "fmt" "net/http" "strconv" @@ -92,3 +93,54 @@ func (s *server) ImageGetHandler(w http.ResponseWriter, r *http.Request) { handleResponseOk(w, toEc2ImageResponse(out[0])) } + +func (s *server) ImageUpdateHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := s.mapAccountNumber(vars["account"]) + id := vars["id"] + + req := &Ec2ImageUpdateRequest{} + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + msg := fmt.Sprintf("cannot decode body into update image input: %s", err) + handleError(w, apierror.New(apierror.ErrBadRequest, msg, err)) + return + } + + if len(req.Tags) == 0 { + handleError(w, apierror.New(apierror.ErrBadRequest, "missing required field: tags", nil)) + return + } + + role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName) + policy, err := tagCreatePolicy() + if err != nil { + handleError(w, err) + return + } + + session, err := s.assumeRole( + r.Context(), + s.session.ExternalID, + role, + policy, + "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", + ) + if err != nil { + msg := fmt.Sprintf("failed to assume role in account: %s", account) + handleError(w, apierror.New(apierror.ErrForbidden, msg, err)) + return + } + + service := ec2.New( + ec2.WithSession(session.Session), + ec2.WithOrg(s.org), + ) + + if err := service.UpdateTags(r.Context(), req.Tags, id); err != nil { + handleError(w, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/api/policy.go b/api/policy.go index ef94ee9..173c196 100644 --- a/api/policy.go +++ b/api/policy.go @@ -170,3 +170,25 @@ func sgUpdatePolicy(id string) (string, error) { return string(j), nil } + +func tagCreatePolicy() (string, error) { + policy := iam.PolicyDocument{ + Version: "2012-10-17", + Statement: []iam.StatementEntry{ + { + Effect: "Allow", + Action: []string{ + "ec2:CreateTags", + }, + Resource: []string{"*"}, + }, + }, + } + + j, err := json.Marshal(policy) + if err != nil { + return "", err + } + + return string(j), nil +} diff --git a/api/routes.go b/api/routes.go index 2851dea..04ecd99 100644 --- a/api/routes.go +++ b/api/routes.go @@ -64,7 +64,7 @@ func (s *server) routes() { api.HandleFunc("/{account}/snapshots", s.ProxyRequestHandler).Methods(http.MethodPost) api.HandleFunc("/{account}/images", s.ProxyRequestHandler).Methods(http.MethodPost) - api.HandleFunc("/{account}/images/{id}/tags", s.ProxyRequestHandler).Methods(http.MethodPut) + api.HandleFunc("/{account}/images/{id}/tags", s.ImageUpdateHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/power", s.ProxyRequestHandler).Methods(http.MethodPut) api.HandleFunc("/{account}/instances/{id}/ssm/command", s.ProxyRequestHandler).Methods(http.MethodPut) diff --git a/api/types.go b/api/types.go index 62b8d7e..742d414 100644 --- a/api/types.go +++ b/api/types.go @@ -643,6 +643,9 @@ func toSSMGetCommandInvocationOutput(rawOut *ssm.GetCommandInvocationOutput) *SS } } +type Ec2ImageUpdateRequest struct { + Tags map[string]string +} type AssociationDescription struct { Name string `json:"name"` InstanceId string `json:"instance_id"` diff --git a/ec2/tags.go b/ec2/tags.go new file mode 100644 index 0000000..7d93d67 --- /dev/null +++ b/ec2/tags.go @@ -0,0 +1,34 @@ +package ec2 + +import ( + "context" + + "github.com/YaleSpinup/apierror" + "github.com/YaleSpinup/ec2-api/common" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + log "github.com/sirupsen/logrus" +) + +func (e *Ec2) UpdateTags(ctx context.Context, rawTags map[string]string, ids ...string) error { + if len(ids) == 0 || len(rawTags) == 0 { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + var tags []*ec2.Tag + for key, val := range rawTags { + tags = append(tags, &ec2.Tag{Key: aws.String(key), Value: aws.String(val)}) + } + + log.Infof("updating resources: %v with tags %+v", ids, tags) + + input := ec2.CreateTagsInput{ + Resources: aws.StringSlice(ids), + Tags: tags, + } + + if _, err := e.Service.CreateTagsWithContext(ctx, &input); err != nil { + return common.ErrCode("creating tags", err) + } + + return nil +} diff --git a/ec2/tags_test.go b/ec2/tags_test.go new file mode 100644 index 0000000..3998008 --- /dev/null +++ b/ec2/tags_test.go @@ -0,0 +1,84 @@ +package ec2 + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" +) + +var ( + inpIds = []string{"id-234"} + inpTags = map[string]string{"foo": "bar"} + expTags = []*ec2.Tag{{Key: aws.String("foo"), Value: aws.String("bar")}} +) + +func (m *mockEC2Client) CreateTagsWithContext(ctx context.Context, input *ec2.CreateTagsInput, opts ...request.Option) (*ec2.CreateTagsOutput, error) { + if m.err != nil { + return nil, m.err + } + if !reflect.DeepEqual(input.Resources, aws.StringSlice(inpIds)) || !reflect.DeepEqual(input.Tags, expTags) { + return nil, errors.New("input does not match") + } + return &ec2.CreateTagsOutput{}, nil +} + +func TestEc2_UpdateTags(t *testing.T) { + type fields struct { + Service ec2iface.EC2API + } + type args struct { + ctx context.Context + tags map[string]string + ids []string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "success case", + args: args{ctx: context.TODO(), tags: inpTags, ids: inpIds}, + fields: fields{Service: newmockEC2Client(t, nil)}, + wantErr: false, + }, + { + name: "aws error", + args: args{ctx: context.TODO(), tags: inpTags, ids: inpIds}, + fields: fields{Service: newmockEC2Client(t, awserr.New("Bad Request", "boom.", nil))}, + wantErr: true, + }, + { + name: "no tags", + fields: fields{Service: newmockEC2Client(t, nil)}, + args: args{ctx: context.TODO(), tags: nil, ids: inpIds}, + wantErr: true, + }, + { + name: "no ids", + fields: fields{Service: newmockEC2Client(t, nil)}, + args: args{ctx: context.TODO(), tags: inpTags, ids: nil}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + Service: tt.fields.Service, + } + err := e.UpdateTags(tt.args.ctx, tt.args.tags, tt.args.ids...) + if (err != nil) != tt.wantErr { + t.Errorf("Ec2.UpdateTags() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +}