Skip to content

Commit

Permalink
Migrate eks discovery to aws sdk v2 (#50603)
Browse files Browse the repository at this point in the history
* Remove all references to EKS sdk v1.

* Address PR comments.
  • Loading branch information
creack authored Jan 10, 2025
1 parent f63a099 commit 84956a8
Show file tree
Hide file tree
Showing 25 changed files with 966 additions and 702 deletions.
8 changes: 4 additions & 4 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Confi
}

func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}
if len(roles) > 0 {
// no point caching every assumed role in the chain, we can just cache
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}
// No point caching every assumed role in the chain, we can just cache
// the last one.
cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, awsCredentialsCacheOptions)
if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
Expand Down
23 changes: 0 additions & 23 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface"
"github.com/aws/aws-sdk-go/service/iam"
Expand Down Expand Up @@ -127,8 +125,6 @@ type AWSClients interface {
GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error)
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error)
// GetAWSEKSClient returns AWS EKS client for the specified region.
GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error)
// GetAWSKMSClient returns AWS KMS client for the specified region.
GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error)
// GetAWSS3Client returns AWS S3 client.
Expand Down Expand Up @@ -585,15 +581,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts
return sts.New(session), nil
}

// GetAWSEKSClient returns AWS EKS client for the specified region.
func (c *cloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return eks.New(session), nil
}

// GetAWSKMSClient returns AWS KMS client for the specified region.
func (c *cloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
Expand Down Expand Up @@ -1032,7 +1019,6 @@ type TestCloudClients struct {
GCPProjects gcp.ProjectsClient
GCPInstances gcp.InstancesClient
InstanceMetadata imds.Client
EKS eksiface.EKSAPI
KMS kmsiface.KMSAPI
S3 s3iface.S3API
AzureMySQL azure.DBServersClient
Expand Down Expand Up @@ -1173,15 +1159,6 @@ func (c *TestCloudClients) GetAWSSTSClient(ctx context.Context, region string, o
return c.STS, nil
}

// GetAWSEKSClient returns AWS EKS client for the specified region.
func (c *TestCloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return c.EKS, nil
}

// GetAWSKMSClient returns AWS KMS client for the specified region.
func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
Expand Down
85 changes: 0 additions & 85 deletions lib/cloud/mocks/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
Expand Down Expand Up @@ -288,86 +286,3 @@ func (m *IAMErrorMock) PutUserPolicyWithContext(ctx aws.Context, input *iam.PutU
}
return nil, trace.AccessDenied("unauthorized")
}

// EKSMock is a mock EKS client.
type EKSMock struct {
eksiface.EKSAPI
Clusters []*eks.Cluster
AccessEntries []*eks.AccessEntry
AssociatedPolicies []*eks.AssociatedAccessPolicy
Notify chan struct{}
}

func (e *EKSMock) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) {
defer func() {
if e.Notify != nil {
e.Notify <- struct{}{}
}
}()
for _, cluster := range e.Clusters {
if aws.StringValue(req.Name) == aws.StringValue(cluster.Name) {
return &eks.DescribeClusterOutput{Cluster: cluster}, nil
}
}
return nil, trace.NotFound("cluster %v not found", aws.StringValue(req.Name))
}

func (e *EKSMock) ListClustersPagesWithContext(_ aws.Context, _ *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error {
defer func() {
if e.Notify != nil {
e.Notify <- struct{}{}
}
}()
clusters := make([]*string, 0, len(e.Clusters))
for _, cluster := range e.Clusters {
clusters = append(clusters, cluster.Name)
}
f(&eks.ListClustersOutput{
Clusters: clusters,
}, true)
return nil
}

func (e *EKSMock) ListAccessEntriesPagesWithContext(_ aws.Context, _ *eks.ListAccessEntriesInput, f func(*eks.ListAccessEntriesOutput, bool) bool, _ ...request.Option) error {
defer func() {
if e.Notify != nil {
e.Notify <- struct{}{}
}
}()
accessEntries := make([]*string, 0, len(e.Clusters))
for _, a := range e.AccessEntries {
accessEntries = append(accessEntries, a.PrincipalArn)
}
f(&eks.ListAccessEntriesOutput{
AccessEntries: accessEntries,
}, true)
return nil
}

func (e *EKSMock) DescribeAccessEntryWithContext(_ aws.Context, req *eks.DescribeAccessEntryInput, _ ...request.Option) (*eks.DescribeAccessEntryOutput, error) {
defer func() {
if e.Notify != nil {
e.Notify <- struct{}{}
}
}()
for _, a := range e.AccessEntries {
if aws.StringValue(req.PrincipalArn) == aws.StringValue(a.PrincipalArn) && aws.StringValue(a.ClusterName) == aws.StringValue(req.ClusterName) {
return &eks.DescribeAccessEntryOutput{AccessEntry: a}, nil
}
}
return nil, trace.NotFound("access entry %v not found", aws.StringValue(req.PrincipalArn))
}

func (e *EKSMock) ListAssociatedAccessPoliciesPagesWithContext(_ aws.Context, _ *eks.ListAssociatedAccessPoliciesInput, f func(*eks.ListAssociatedAccessPoliciesOutput, bool) bool, _ ...request.Option) error {
defer func() {
if e.Notify != nil {
e.Notify <- struct{}{}
}
}()

f(&eks.ListAssociatedAccessPoliciesOutput{
AssociatedAccessPolicies: e.AssociatedPolicies,
}, true)
return nil

}
4 changes: 2 additions & 2 deletions lib/cloud/mocks/aws_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns
if stsClt == nil {
stsClt = &STSClient{}
}
optFns = append(optFns,
optFns = append([]awsconfig.OptionsFn{
awsconfig.WithOIDCIntegrationClient(f.OIDCIntegrationClient),
awsconfig.WithSTSClientProvider(
newAssumeRoleClientProviderFunc(stsClt),
),
)
}, optFns...)
return awsconfig.GetConfig(ctx, region, optFns...)
}

Expand Down
6 changes: 6 additions & 0 deletions lib/cloud/mocks/aws_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ type STSClient struct {
recordFn func(roleARN, externalID string)
}

func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
return &sts.GetCallerIdentityOutput{
Arn: aws.String(m.ARN),
}, nil
}

func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
m.record(aws.ToString(in.RoleArn), "")
expiry := time.Now().Add(60 * time.Minute)
Expand Down
19 changes: 10 additions & 9 deletions lib/integrations/awsoidc/eks_enroll_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ const (
concurrentEKSEnrollingLimit = 5
)

var agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"}
var agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"}
var (
agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"}
agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"}
)

// EnrollEKSClusterResult contains result for a single EKS cluster enrollment, if it was successful 'Error' will be nil
// otherwise it will contain an error happened during enrollment.
Expand Down Expand Up @@ -462,7 +464,6 @@ func enrollEKSCluster(ctx context.Context, log *slog.Logger, clock clockwork.Clo
return "",
issueTypeFromCheckAgentInstalledError(err),
trace.Wrap(err, "could not check if teleport-kube-agent is already installed.")

} else if alreadyInstalled {
return "",
// When using EKS Auto Discovery, after the Kube Agent connects to the Teleport cluster, it is ignored in next discovery iterations.
Expand Down Expand Up @@ -708,19 +709,19 @@ func installKubeAgent(ctx context.Context, cfg installKubeAgentParams) error {
if cfg.req.IsCloud && cfg.req.EnableAutoUpgrades {
vals["updater"] = map[string]any{"enabled": true, "releaseChannel": "stable/cloud"}

vals["highAvailability"] = map[string]any{"replicaCount": 2,
vals["highAvailability"] = map[string]any{
"replicaCount": 2,
"podDisruptionBudget": map[string]any{"enabled": true, "minAvailable": 1},
}
}
if modules.GetModules().BuildType() == modules.BuildEnterprise {
vals["enterprise"] = true
}

eksTags := make(map[string]*string, len(cfg.eksCluster.Tags))
for k, v := range cfg.eksCluster.Tags {
eksTags[k] = aws.String(v)
}
eksTags[types.OriginLabel] = aws.String(types.OriginCloud)
eksTags := make(map[string]string, len(cfg.eksCluster.Tags))
maps.Copy(eksTags, cfg.eksCluster.Tags)
eksTags[types.OriginLabel] = types.OriginCloud

kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(cfg.eksCluster.Name), aws.ToString(cfg.eksCluster.Arn), eksTags)
if err != nil {
return trace.Wrap(err)
Expand Down
Loading

0 comments on commit 84956a8

Please sign in to comment.