diff --git a/api/handlers_images.go b/api/handlers_images.go index 989b0e0..d091fbe 100644 --- a/api/handlers_images.go +++ b/api/handlers_images.go @@ -138,7 +138,7 @@ func (s *server) ImageUpdateHandler(w http.ResponseWriter, r *http.Request) { ec2.WithOrg(s.org), ) - if err := service.UpdateTags(r.Context(), req.Tags, id); err != nil { + if err := service.UpdateRawTags(r.Context(), req.Tags, id); err != nil { handleError(w, err) return } diff --git a/api/handlers_instances.go b/api/handlers_instances.go index ca7a9b3..fb8b3ad 100644 --- a/api/handlers_instances.go +++ b/api/handlers_instances.go @@ -518,38 +518,31 @@ func (s *server) InstanceUpdateHandler(w http.ResponseWriter, r *http.Request) { return } - role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName) - policy, err := instanceUpdatePolicy() + policy, err := generatePolicy([]string{"ec2:CreateTags", "ec2:ModifyInstanceAttribute"}) if err != nil { handleError(w, err) return } - session, err := s.assumeRole( - r.Context(), - s.session.ExternalID, - role, - policy, - "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", - ) + orch, err := s.newEc2Orchestrator(r.Context(), &sessionParams{ + role: fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName), + inlinePolicy: policy, + policyArns: []string{ + "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess", + }, + }) if err != nil { - msg := fmt.Sprintf("failed to assume role in account: %s", account) - handleError(w, apierror.New(apierror.ErrForbidden, msg, err)) + handleError(w, err) return } - service := ec2.New( - ec2.WithSession(session.Session), - ec2.WithOrg(s.org), - ) - if len(req.Tags) > 0 { - if err := service.UpdateTags(r.Context(), req.Tags, instanceId); err != nil { + if err := orch.updateInstanceTags(r.Context(), req.Tags, instanceId); err != nil { handleError(w, err) return } } else if len(req.InstanceType) > 0 { - if err := service.UpdateAttributes(r.Context(), req.InstanceType["value"], instanceId); err != nil { + if err := orch.ec2Client.UpdateAttributes(r.Context(), req.InstanceType["value"], instanceId); err != nil { handleError(w, err) return } diff --git a/api/handlers_sgs.go b/api/handlers_sgs.go index 5e14e1c..76f26d3 100644 --- a/api/handlers_sgs.go +++ b/api/handlers_sgs.go @@ -95,7 +95,7 @@ func (s *server) SecurityGroupUpdateHandler(w http.ResponseWriter, r *http.Reque } if req.Tags != nil { - if err := orch.ec2Client.UpdateTags(r.Context(), *req.Tags, id); err != nil { + if err := orch.ec2Client.UpdateRawTags(r.Context(), *req.Tags, id); err != nil { handleError(w, err) return } diff --git a/api/handlers_volumes.go b/api/handlers_volumes.go index 807111a..8109b11 100644 --- a/api/handlers_volumes.go +++ b/api/handlers_volumes.go @@ -332,7 +332,7 @@ func (s *server) VolumeUpdateHandler(w http.ResponseWriter, r *http.Request) { } if req.Tags != nil { - if err := orch.ec2Client.UpdateTags(r.Context(), *req.Tags, id); err != nil { + if err := orch.ec2Client.UpdateRawTags(r.Context(), *req.Tags, id); err != nil { handleError(w, err) return } diff --git a/api/orchestration_instances.go b/api/orchestration_instances.go index 0bb2198..af0525f 100644 --- a/api/orchestration_instances.go +++ b/api/orchestration_instances.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/YaleSpinup/apierror" + "github.com/YaleSpinup/ec2-api/common" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/aws/aws-sdk-go/service/ec2" @@ -131,3 +132,37 @@ func (o *ssmOrchestrator) sendInstancesCommand(ctx context.Context, req *SsmComm } return aws.StringValue(cmd.CommandId), nil } +func (o *ec2Orchestrator) updateInstanceTags(ctx context.Context, rawTags map[string]string, ids ...string) error { + if len(ids) == 0 || len(rawTags) == 0 { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + var tags []*ec2.Tag + for key, val := range rawTags { + tags = append(tags, &ec2.Tag{Key: aws.String(key), Value: aws.String(val)}) + } + + volumeIds := []string{} + for _, id := range ids { + if strings.HasPrefix(id, "i-") { + vIds, err := o.ec2Client.ListInstanceVolumes(ctx, id) + if err != nil { + return common.ErrCode("describing volumes for instance", err) + } + volumeIds = append(volumeIds, vIds...) + } + } + + ids = append(ids, volumeIds...) + log.Infof("updating resources: %v with tags %+v", ids, tags) + + input := ec2.CreateTagsInput{ + Resources: aws.StringSlice(ids), + Tags: tags, + } + + if err := o.ec2Client.UpdateTags(ctx, &input); err != nil { + return err + } + + return nil +} diff --git a/api/policy.go b/api/policy.go index 7295675..963e843 100644 --- a/api/policy.go +++ b/api/policy.go @@ -289,30 +289,6 @@ func sendCommandPolicy() (string, error) { return string(j), nil } -func instanceUpdatePolicy() (string, error) { - log.Debugf("generating tag create policy document") - policy := iam.PolicyDocument{ - Version: "2012-10-17", - Statement: []iam.StatementEntry{ - { - Effect: "Allow", - Action: []string{ - "ec2:CreateTags", - "ec2:ModifyInstanceAttribute", - }, - Resource: []string{"*"}, - }, - }, - } - - j, err := json.Marshal(policy) - if err != nil { - return "", err - } - - return string(j), nil -} - func ssmAssociationPolicy() (string, error) { log.Debugf("generating tag create policy document") policy := iam.PolicyDocument{ diff --git a/ec2/tags.go b/ec2/tags.go index 7d93d67..b0962d3 100644 --- a/ec2/tags.go +++ b/ec2/tags.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" ) -func (e *Ec2) UpdateTags(ctx context.Context, rawTags map[string]string, ids ...string) error { +func (e *Ec2) UpdateRawTags(ctx context.Context, rawTags map[string]string, ids ...string) error { if len(ids) == 0 || len(rawTags) == 0 { return apierror.New(apierror.ErrBadRequest, "invalid input", nil) } @@ -32,3 +32,17 @@ func (e *Ec2) UpdateTags(ctx context.Context, rawTags map[string]string, ids ... return nil } + +func (e *Ec2) UpdateTags(ctx context.Context, input *ec2.CreateTagsInput) error { + if input == nil { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("updating tags: %v", input) + + if _, err := e.Service.CreateTagsWithContext(ctx, input); err != nil { + return common.ErrCode("creating tags", err) + } + + return nil +} diff --git a/ec2/tags_test.go b/ec2/tags_test.go index 3998008..bd36fd2 100644 --- a/ec2/tags_test.go +++ b/ec2/tags_test.go @@ -29,7 +29,14 @@ func (m *mockEC2Client) CreateTagsWithContext(ctx context.Context, input *ec2.Cr return &ec2.CreateTagsOutput{}, nil } -func TestEc2_UpdateTags(t *testing.T) { +func (m *mockEC2Client) DescribeVolumesWithContext(aws aws.Context, inp *ec2.DescribeVolumesInput, opt ...request.Option) (*ec2.DescribeVolumesOutput, error) { + if m.err != nil { + return nil, m.err + } + return &ec2.DescribeVolumesOutput{}, nil +} + +func TestEc2_UpdateRawTags(t *testing.T) { type fields struct { Service ec2iface.EC2API } @@ -65,7 +72,7 @@ func TestEc2_UpdateTags(t *testing.T) { { name: "no ids", fields: fields{Service: newmockEC2Client(t, nil)}, - args: args{ctx: context.TODO(), tags: inpTags, ids: nil}, + args: args{ctx: context.TODO(), tags: inpTags, ids: []string{}}, wantErr: true, }, } @@ -74,7 +81,7 @@ func TestEc2_UpdateTags(t *testing.T) { e := &Ec2{ Service: tt.fields.Service, } - err := e.UpdateTags(tt.args.ctx, tt.args.tags, tt.args.ids...) + err := e.UpdateRawTags(tt.args.ctx, tt.args.tags, tt.args.ids...) if (err != nil) != tt.wantErr { t.Errorf("Ec2.UpdateTags() error = %v, wantErr %v", err, tt.wantErr) return @@ -82,3 +89,68 @@ func TestEc2_UpdateTags(t *testing.T) { }) } } +func TestEc2_UpdateTags(t *testing.T) { + type fields struct { + Service ec2iface.EC2API + } + type args struct { + ctx context.Context + input *ec2.CreateTagsInput + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "success case", + args: args{ctx: context.TODO(), input: &ec2.CreateTagsInput{ + Resources: aws.StringSlice(inpIds), + Tags: expTags}}, + fields: fields{Service: newmockEC2Client(t, nil)}, + wantErr: false, + }, + { + name: "aws error", + args: args{ctx: context.TODO(), input: &ec2.CreateTagsInput{ + Resources: aws.StringSlice(inpIds), + Tags: expTags}}, + fields: fields{Service: newmockEC2Client(t, awserr.New("Bad Request", "boom.", nil))}, + wantErr: true, + }, + { + name: "no tags", + fields: fields{Service: newmockEC2Client(t, nil)}, + args: args{ctx: context.TODO(), input: &ec2.CreateTagsInput{ + Resources: aws.StringSlice(inpIds), + Tags: nil}}, + wantErr: true, + }, + { + name: "no ids", + fields: fields{Service: newmockEC2Client(t, nil)}, + args: args{ctx: context.TODO(), input: &ec2.CreateTagsInput{ + Resources: aws.StringSlice([]string{}), + Tags: expTags}}, + wantErr: true, + }, + { + name: "no input", + fields: fields{Service: newmockEC2Client(t, nil)}, + args: args{ctx: context.TODO(), input: nil}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Ec2{ + Service: tt.fields.Service, + } + if err := e.UpdateTags(tt.args.ctx, tt.args.input); (err != nil) != tt.wantErr { + t.Errorf("Ec2.UpdateTags() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +}