diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go index 7b1cabe5ffe75..245fe8a9a6b23 100644 --- a/lib/cloud/awsconfig/awsconfig.go +++ b/lib/cloud/awsconfig/awsconfig.go @@ -280,11 +280,11 @@ func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Confi } func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) { - for _, r := range roles { - cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r) - } if len(roles) > 0 { - // no point caching every assumed role in the chain, we can just cache + for _, r := range roles { + cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r) + } + // No point caching every assumed role in the chain, we can just cache // the last one. cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, awsCredentialsCacheOptions) if _, err := cfg.Credentials.Retrieve(ctx); err != nil { diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 99c2deb4001f0..28e8ebabac598 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -39,8 +39,6 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" awssession "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/iam" @@ -127,8 +125,6 @@ type AWSClients interface { GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error) // GetAWSSTSClient returns AWS STS client for the specified region. GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error) - // GetAWSEKSClient returns AWS EKS client for the specified region. - GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) // GetAWSKMSClient returns AWS KMS client for the specified region. GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) // GetAWSS3Client returns AWS S3 client. @@ -585,15 +581,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts return sts.New(session), nil } -// GetAWSEKSClient returns AWS EKS client for the specified region. -func (c *cloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return eks.New(session), nil -} - // GetAWSKMSClient returns AWS KMS client for the specified region. func (c *cloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -1032,7 +1019,6 @@ type TestCloudClients struct { GCPProjects gcp.ProjectsClient GCPInstances gcp.InstancesClient InstanceMetadata imds.Client - EKS eksiface.EKSAPI KMS kmsiface.KMSAPI S3 s3iface.S3API AzureMySQL azure.DBServersClient @@ -1173,15 +1159,6 @@ func (c *TestCloudClients) GetAWSSTSClient(ctx context.Context, region string, o return c.STS, nil } -// GetAWSEKSClient returns AWS EKS client for the specified region. -func (c *TestCloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return c.EKS, nil -} - // GetAWSKMSClient returns AWS KMS client for the specified region. func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) { _, err := c.GetAWSSession(ctx, region, opts...) diff --git a/lib/cloud/mocks/aws.go b/lib/cloud/mocks/aws.go index ceb50bd822cc2..9ba40628e3a92 100644 --- a/lib/cloud/mocks/aws.go +++ b/lib/cloud/mocks/aws.go @@ -28,8 +28,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" "github.com/aws/aws-sdk-go/service/sts" @@ -288,86 +286,3 @@ func (m *IAMErrorMock) PutUserPolicyWithContext(ctx aws.Context, input *iam.PutU } return nil, trace.AccessDenied("unauthorized") } - -// EKSMock is a mock EKS client. -type EKSMock struct { - eksiface.EKSAPI - Clusters []*eks.Cluster - AccessEntries []*eks.AccessEntry - AssociatedPolicies []*eks.AssociatedAccessPolicy - Notify chan struct{} -} - -func (e *EKSMock) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) { - defer func() { - if e.Notify != nil { - e.Notify <- struct{}{} - } - }() - for _, cluster := range e.Clusters { - if aws.StringValue(req.Name) == aws.StringValue(cluster.Name) { - return &eks.DescribeClusterOutput{Cluster: cluster}, nil - } - } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(req.Name)) -} - -func (e *EKSMock) ListClustersPagesWithContext(_ aws.Context, _ *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error { - defer func() { - if e.Notify != nil { - e.Notify <- struct{}{} - } - }() - clusters := make([]*string, 0, len(e.Clusters)) - for _, cluster := range e.Clusters { - clusters = append(clusters, cluster.Name) - } - f(&eks.ListClustersOutput{ - Clusters: clusters, - }, true) - return nil -} - -func (e *EKSMock) ListAccessEntriesPagesWithContext(_ aws.Context, _ *eks.ListAccessEntriesInput, f func(*eks.ListAccessEntriesOutput, bool) bool, _ ...request.Option) error { - defer func() { - if e.Notify != nil { - e.Notify <- struct{}{} - } - }() - accessEntries := make([]*string, 0, len(e.Clusters)) - for _, a := range e.AccessEntries { - accessEntries = append(accessEntries, a.PrincipalArn) - } - f(&eks.ListAccessEntriesOutput{ - AccessEntries: accessEntries, - }, true) - return nil -} - -func (e *EKSMock) DescribeAccessEntryWithContext(_ aws.Context, req *eks.DescribeAccessEntryInput, _ ...request.Option) (*eks.DescribeAccessEntryOutput, error) { - defer func() { - if e.Notify != nil { - e.Notify <- struct{}{} - } - }() - for _, a := range e.AccessEntries { - if aws.StringValue(req.PrincipalArn) == aws.StringValue(a.PrincipalArn) && aws.StringValue(a.ClusterName) == aws.StringValue(req.ClusterName) { - return &eks.DescribeAccessEntryOutput{AccessEntry: a}, nil - } - } - return nil, trace.NotFound("access entry %v not found", aws.StringValue(req.PrincipalArn)) -} - -func (e *EKSMock) ListAssociatedAccessPoliciesPagesWithContext(_ aws.Context, _ *eks.ListAssociatedAccessPoliciesInput, f func(*eks.ListAssociatedAccessPoliciesOutput, bool) bool, _ ...request.Option) error { - defer func() { - if e.Notify != nil { - e.Notify <- struct{}{} - } - }() - - f(&eks.ListAssociatedAccessPoliciesOutput{ - AssociatedAccessPolicies: e.AssociatedPolicies, - }, true) - return nil - -} diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go index b52dfbd36d74a..819d6ca8f535e 100644 --- a/lib/cloud/mocks/aws_config.go +++ b/lib/cloud/mocks/aws_config.go @@ -38,12 +38,12 @@ func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns if stsClt == nil { stsClt = &STSClient{} } - optFns = append(optFns, + optFns = append([]awsconfig.OptionsFn{ awsconfig.WithOIDCIntegrationClient(f.OIDCIntegrationClient), awsconfig.WithSTSClientProvider( newAssumeRoleClientProviderFunc(stsClt), ), - ) + }, optFns...) return awsconfig.GetConfig(ctx, region, optFns...) } diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go index 178a1259669a4..cf117788e696f 100644 --- a/lib/cloud/mocks/aws_sts.go +++ b/lib/cloud/mocks/aws_sts.go @@ -54,6 +54,12 @@ type STSClient struct { recordFn func(roleARN, externalID string) } +func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return &sts.GetCallerIdentityOutput{ + Arn: aws.String(m.ARN), + }, nil +} + func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { m.record(aws.ToString(in.RoleArn), "") expiry := time.Now().Add(60 * time.Minute) diff --git a/lib/integrations/awsoidc/eks_enroll_clusters.go b/lib/integrations/awsoidc/eks_enroll_clusters.go index dbeb6f2385484..d61b062cccfdb 100644 --- a/lib/integrations/awsoidc/eks_enroll_clusters.go +++ b/lib/integrations/awsoidc/eks_enroll_clusters.go @@ -74,8 +74,10 @@ const ( concurrentEKSEnrollingLimit = 5 ) -var agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"} -var agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"} +var ( + agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"} + agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"} +) // EnrollEKSClusterResult contains result for a single EKS cluster enrollment, if it was successful 'Error' will be nil // otherwise it will contain an error happened during enrollment. @@ -462,7 +464,6 @@ func enrollEKSCluster(ctx context.Context, log *slog.Logger, clock clockwork.Clo return "", issueTypeFromCheckAgentInstalledError(err), trace.Wrap(err, "could not check if teleport-kube-agent is already installed.") - } else if alreadyInstalled { return "", // When using EKS Auto Discovery, after the Kube Agent connects to the Teleport cluster, it is ignored in next discovery iterations. @@ -708,7 +709,8 @@ func installKubeAgent(ctx context.Context, cfg installKubeAgentParams) error { if cfg.req.IsCloud && cfg.req.EnableAutoUpgrades { vals["updater"] = map[string]any{"enabled": true, "releaseChannel": "stable/cloud"} - vals["highAvailability"] = map[string]any{"replicaCount": 2, + vals["highAvailability"] = map[string]any{ + "replicaCount": 2, "podDisruptionBudget": map[string]any{"enabled": true, "minAvailable": 1}, } } @@ -716,11 +718,10 @@ func installKubeAgent(ctx context.Context, cfg installKubeAgentParams) error { vals["enterprise"] = true } - eksTags := make(map[string]*string, len(cfg.eksCluster.Tags)) - for k, v := range cfg.eksCluster.Tags { - eksTags[k] = aws.String(v) - } - eksTags[types.OriginLabel] = aws.String(types.OriginCloud) + eksTags := make(map[string]string, len(cfg.eksCluster.Tags)) + maps.Copy(eksTags, cfg.eksCluster.Tags) + eksTags[types.OriginLabel] = types.OriginCloud + kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(cfg.eksCluster.Name), aws.ToString(cfg.eksCluster.Arn), eksTags) if err != nil { return trace.Wrap(err) diff --git a/lib/kube/proxy/cluster_details.go b/lib/kube/proxy/cluster_details.go index 1a66ce0562978..e1dbc45fca281 100644 --- a/lib/kube/proxy/cluster_details.go +++ b/lib/kube/proxy/cluster_details.go @@ -26,8 +26,8 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "k8s.io/apimachinery/pkg/runtime/schema" @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" kubeutils "github.com/gravitational/teleport/lib/kube/utils" @@ -50,6 +51,7 @@ import ( // kubeDetails contain the cluster-related details including authentication. type kubeDetails struct { kubeCreds + // dynamicLabels is the dynamic labels executor for this cluster. dynamicLabels *labels.Dynamic // kubeCluster is the dynamic kube_cluster or a static generated from kubeconfig and that only has the name populated. @@ -86,6 +88,8 @@ type kubeDetails struct { type clusterDetailsConfig struct { // cloudClients is the cloud clients to use for dynamic clusters. cloudClients cloud.Clients + // awsCloudClients provides AWS SDK clients. + awsCloudClients AWSClientGetter // kubeCreds is the credentials to use for the cluster. kubeCreds kubeCreds // cluster is the cluster to create a proxied cluster for. @@ -103,8 +107,10 @@ type clusterDetailsConfig struct { component KubeServiceType } -const defaultRefreshPeriod = 5 * time.Minute -const backoffRefreshStep = 10 * time.Second +const ( + defaultRefreshPeriod = 5 * time.Minute + backoffRefreshStep = 10 * time.Second +) // newClusterDetails creates a proxied kubeDetails structure given a dynamic cluster. func newClusterDetails(ctx context.Context, cfg clusterDetailsConfig) (_ *kubeDetails, err error) { @@ -263,14 +269,20 @@ func (k *kubeDetails) getObjectGVK(resource apiResource) *schema.GroupVersionKin // getKubeClusterCredentials generates kube credentials for dynamic clusters. func getKubeClusterCredentials(ctx context.Context, cfg clusterDetailsConfig) (kubeCreds, error) { - dynCredsCfg := dynamicCredsConfig{kubeCluster: cfg.cluster, log: cfg.log, checker: cfg.checker, resourceMatchers: cfg.resourceMatchers, clock: cfg.clock, component: cfg.component} - switch { + switch dynCredsCfg := (dynamicCredsConfig{ + kubeCluster: cfg.cluster, + log: cfg.log, + checker: cfg.checker, + resourceMatchers: cfg.resourceMatchers, + clock: cfg.clock, + component: cfg.component, + }); { case cfg.cluster.IsKubeconfig(): return getStaticCredentialsFromKubeconfig(ctx, cfg.component, cfg.cluster, cfg.log, cfg.checker) case cfg.cluster.IsAzure(): return getAzureCredentials(ctx, cfg.cloudClients, dynCredsCfg) case cfg.cluster.IsAWS(): - return getAWSCredentials(ctx, cfg.cloudClients, dynCredsCfg) + return getAWSCredentials(ctx, cfg.awsCloudClients, dynCredsCfg) case cfg.cluster.IsGCP(): return getGCPCredentials(ctx, cfg.cloudClients, dynCredsCfg) default: @@ -308,7 +320,7 @@ func azureRestConfigClient(cloudClients cloud.Clients) dynamicCredsClient { } // getAWSCredentials creates a dynamicKubeCreds that generates and updates the access credentials to a EKS kubernetes cluster. -func getAWSCredentials(ctx context.Context, cloudClients cloud.Clients, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { +func getAWSCredentials(ctx context.Context, cloudClients AWSClientGetter, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { // create a client that returns the credentials for kubeCluster cfg.client = getAWSClientRestConfig(cloudClients, cfg.clock, cfg.resourceMatchers) creds, err := newDynamicKubeCreds(ctx, cfg) @@ -328,51 +340,66 @@ func getAWSResourceMatcherToCluster(kubeCluster types.KubeCluster, resourceMatch if match, _, _ := services.MatchLabels(matcher.Labels, kubeCluster.GetAllLabels()); !match { continue } - - return &(matcher.AWS) + return &matcher.AWS } return nil } +// STSPresignClient is the subset of the STS presign interface we use in fetchers. +type STSPresignClient = kubeutils.STSPresignClient + +// EKSClient is the subset of the EKS Client interface we use. +type EKSClient interface { + eks.DescribeClusterAPIClient +} + +// AWSClientGetter is an interface for getting an EKS client and an STS client. +type AWSClientGetter interface { + awsconfig.Provider + // GetAWSEKSClient returns AWS EKS client for the specified config. + GetAWSEKSClient(aws.Config) EKSClient + // GetAWSSTSPresignClient returns AWS STS presign client for the specified config. + GetAWSSTSPresignClient(aws.Config) STSPresignClient +} + // getAWSClientRestConfig creates a dynamicCredsClient that generates returns credentials to EKS clusters. -func getAWSClientRestConfig(cloudClients cloud.Clients, clock clockwork.Clock, resourceMatchers []services.ResourceMatcher) dynamicCredsClient { +func getAWSClientRestConfig(cloudClients AWSClientGetter, clock clockwork.Clock, resourceMatchers []services.ResourceMatcher) dynamicCredsClient { return func(ctx context.Context, cluster types.KubeCluster) (*rest.Config, time.Time, error) { region := cluster.GetAWSConfig().Region - opts := []cloud.AWSOptionsFn{ - cloud.WithAmbientCredentials(), - cloud.WithoutSessionCache(), + opts := []awsconfig.OptionsFn{ + awsconfig.WithAmbientCredentials(), } if awsAssume := getAWSResourceMatcherToCluster(cluster, resourceMatchers); awsAssume != nil { - opts = append(opts, cloud.WithAssumeRole(awsAssume.AssumeRoleARN, awsAssume.ExternalID)) + opts = append(opts, awsconfig.WithAssumeRole(awsAssume.AssumeRoleARN, awsAssume.ExternalID)) } - regionalClient, err := cloudClients.GetAWSEKSClient(ctx, region, opts...) + + cfg, err := cloudClients.GetConfig(ctx, region, opts...) if err != nil { return nil, time.Time{}, trace.Wrap(err) } - eksCfg, err := regionalClient.DescribeClusterWithContext(ctx, &eks.DescribeClusterInput{ + regionalClient := cloudClients.GetAWSEKSClient(cfg) + + eksCfg, err := regionalClient.DescribeCluster(ctx, &eks.DescribeClusterInput{ Name: aws.String(cluster.GetAWSConfig().Name), }) if err != nil { return nil, time.Time{}, trace.Wrap(err) } - ca, err := base64.StdEncoding.DecodeString(aws.StringValue(eksCfg.Cluster.CertificateAuthority.Data)) + ca, err := base64.StdEncoding.DecodeString(aws.ToString(eksCfg.Cluster.CertificateAuthority.Data)) if err != nil { return nil, time.Time{}, trace.Wrap(err) } - apiEndpoint := aws.StringValue(eksCfg.Cluster.Endpoint) + apiEndpoint := aws.ToString(eksCfg.Cluster.Endpoint) if len(apiEndpoint) == 0 { return nil, time.Time{}, trace.BadParameter("invalid api endpoint for cluster %q", cluster.GetAWSConfig().Name) } - stsClient, err := cloudClients.GetAWSSTSClient(ctx, region, opts...) - if err != nil { - return nil, time.Time{}, trace.Wrap(err) - } + stsPresignClient := cloudClients.GetAWSSTSPresignClient(cfg) - token, exp, err := kubeutils.GenAWSEKSToken(stsClient, cluster.GetAWSConfig().Name, clock) + token, exp, err := kubeutils.GenAWSEKSToken(ctx, stsPresignClient, cluster.GetAWSConfig().Name, clock) if err != nil { return nil, time.Time{}, trace.Wrap(err) } diff --git a/lib/kube/proxy/kube_creds_test.go b/lib/kube/proxy/kube_creds_test.go index ca4f1bd4b58e0..ca2f537e6de05 100644 --- a/lib/kube/proxy/kube_creds_test.go +++ b/lib/kube/proxy/kube_creds_test.go @@ -26,8 +26,11 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/service/eks" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -41,10 +44,65 @@ import ( "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/fixtures" + kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) +type mockEKSClientGetter struct { + mocks.AWSConfigProvider + stsPresignClient *mockSTSPresignAPI + eksClient *mockEKSAPI +} + +func (e *mockEKSClientGetter) GetAWSEKSClient(aws.Config) EKSClient { + return e.eksClient +} + +func (e *mockEKSClientGetter) GetAWSSTSPresignClient(aws.Config) kubeutils.STSPresignClient { + return e.stsPresignClient +} + +type mockSTSPresignAPI struct { + url *url.URL +} + +func (a *mockSTSPresignAPI) PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) { + return &v4.PresignedHTTPRequest{URL: a.url.String()}, nil +} + +type mockEKSAPI struct { + EKSClient + + notify chan struct{} + clusters []*ekstypes.Cluster +} + +func (m *mockEKSAPI) ListClusters(ctx context.Context, req *eks.ListClustersInput, _ ...func(*eks.Options)) (*eks.ListClustersOutput, error) { + defer func() { m.notify <- struct{}{} }() + + var names []string + for _, cluster := range m.clusters { + names = append(names, aws.ToString(cluster.Name)) + } + return &eks.ListClustersOutput{ + Clusters: names, + }, nil +} + +func (m *mockEKSAPI) DescribeCluster(_ context.Context, req *eks.DescribeClusterInput, _ ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) { + defer func() { m.notify <- struct{}{} }() + + for _, cluster := range m.clusters { + if aws.ToString(cluster.Name) == aws.ToString(req.Name) { + return &eks.DescribeClusterOutput{ + Cluster: cluster, + }, nil + } + } + return nil, trace.NotFound("cluster %q not found", aws.ToString(req.Name)) +} + // Test_DynamicKubeCreds tests the dynamic kube credrentials generator for // AWS, GCP, and Azure clusters accessed using their respective IAM credentials. // This test mocks the cloud provider clients and the STS client to generate @@ -99,32 +157,37 @@ func Test_DynamicKubeCreds(t *testing.T) { ) require.NoError(t, err) - // mock sts client + // Mock sts client. u := &url.URL{ Scheme: "https", Host: "sts.amazonaws.com", Path: "/?Action=GetCallerIdentity&Version=2011-06-15", } - sts := &mocks.STSClientV1{ - // u is used to presign the request - // here we just verify the pre-signed request includes this url. - URL: u, - } - // mock clients - cloudclients := &cloud.TestCloudClients{ - STS: sts, - EKS: &mocks.EKSMock{ - Notify: notify, - Clusters: []*eks.Cluster{ + // EKS clients. + eksClients := &mockEKSClientGetter{ + AWSConfigProvider: mocks.AWSConfigProvider{ + STSClient: &mocks.STSClient{}, + }, + stsPresignClient: &mockSTSPresignAPI{ + // u is used to presign the request + // here we just verify the pre-signed request includes this url. + url: u, + }, + eksClient: &mockEKSAPI{ + notify: notify, + clusters: []*ekstypes.Cluster{ { Endpoint: aws.String("https://api.eks.us-west-2.amazonaws.com"), Name: aws.String(awsKube.GetAWSConfig().Name), - CertificateAuthority: &eks.Certificate{ + CertificateAuthority: &ekstypes.Certificate{ Data: aws.String(base64.RawStdEncoding.EncodeToString([]byte(fixtures.TLSCACertPEM))), }, }, }, }, + } + // Mock clients. + cloudclients := &cloud.TestCloudClients{ GCPGKE: &mocks.GKEMock{ Notify: notify, Clock: fakeClock, @@ -204,7 +267,7 @@ func Test_DynamicKubeCreds(t *testing.T) { name: "aws eks cluster without assume role", args: args{ cluster: awsKube, - client: getAWSClientRestConfig(cloudclients, fakeClock, nil), + client: getAWSClientRestConfig(eksClients, fakeClock, nil), validateBearerToken: validateEKSToken, }, wantAddr: "api.eks.us-west-2.amazonaws.com:443", @@ -213,7 +276,7 @@ func Test_DynamicKubeCreds(t *testing.T) { name: "aws eks cluster with unmatched assume role", args: args{ cluster: awsKube, - client: getAWSClientRestConfig(cloudclients, fakeClock, []services.ResourceMatcher{ + client: getAWSClientRestConfig(eksClients, fakeClock, []services.ResourceMatcher{ { Labels: types.Labels{ "rand": []string{"value"}, @@ -233,7 +296,7 @@ func Test_DynamicKubeCreds(t *testing.T) { args: args{ cluster: awsKube, client: getAWSClientRestConfig( - cloudclients, + eksClients, fakeClock, []services.ResourceMatcher{ { @@ -331,6 +394,7 @@ func Test_DynamicKubeCreds(t *testing.T) { } require.NoError(t, got.close()) + sts := eksClients.AWSConfigProvider.STSClient require.Equal(t, tt.wantAssumedRole, apiutils.Deduplicate(sts.GetAssumedRoleARNs())) require.Equal(t, tt.wantExternalIds, apiutils.Deduplicate(sts.GetAssumedRoleExternalIDs())) sts.ResetAssumeRoleHistory() diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index 6ac466746b51f..f153039d60749 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -28,6 +28,9 @@ import ( "sync" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/eks" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "golang.org/x/net/http2" @@ -38,6 +41,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/inventory" "github.com/gravitational/teleport/lib/labels" @@ -74,6 +78,7 @@ type TLSServerConfig struct { OnReconcile func(types.KubeClusters) // CloudClients is a set of cloud clients that Teleport supports. CloudClients cloud.Clients + awsClients *awsClientsGetter // StaticLabels is a map of static labels associated with this service. // Each cluster advertised by this kubernetes_service will include these static labels. // If the service and a cluster define labels with the same key, @@ -106,6 +111,21 @@ type TLSServerConfig struct { InventoryHandle inventory.DownstreamHandle } +type awsClientsGetter struct{} + +func (f *awsClientsGetter) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) { + return awsconfig.GetConfig(ctx, region, optFns...) +} + +func (f *awsClientsGetter) GetAWSEKSClient(cfg aws.Config) EKSClient { + return eks.NewFromConfig(cfg) +} + +func (f *awsClientsGetter) GetAWSSTSPresignClient(cfg aws.Config) STSPresignClient { + stsClient := sts.NewFromConfig(cfg) + return sts.NewPresignClient(stsClient) +} + // CheckAndSetDefaults checks and sets default values func (c *TLSServerConfig) CheckAndSetDefaults() error { if err := c.ForwarderConfig.CheckAndSetDefaults(); err != nil { @@ -142,6 +162,9 @@ func (c *TLSServerConfig) CheckAndSetDefaults() error { } c.CloudClients = cloudClients } + if c.awsClients == nil { + c.awsClients = &awsClientsGetter{} + } if c.ConnectedProxyGetter == nil { c.ConnectedProxyGetter = reversetunnel.NewConnectedProxyGetter() } diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go index 56bea639d5260..fd83ddfd1ad60 100644 --- a/lib/kube/proxy/watcher.go +++ b/lib/kube/proxy/watcher.go @@ -174,6 +174,7 @@ func (m *monitoredKubeClusters) get() map[string]types.KubeCluster { func (s *TLSServer) buildClusterDetailsConfigForCluster(cluster types.KubeCluster) clusterDetailsConfig { return clusterDetailsConfig{ cloudClients: s.CloudClients, + awsCloudClients: s.awsClients, cluster: cluster, log: s.log, checker: s.CheckImpersonationPermissions, diff --git a/lib/kube/utils/eks_token_signed.go b/lib/kube/utils/eks_token_signed.go index 4431cf93dad79..1a1840af888ef 100644 --- a/lib/kube/utils/eks_token_signed.go +++ b/lib/kube/utils/eks_token_signed.go @@ -19,44 +19,64 @@ package utils import ( + "context" "encoding/base64" "time" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" ) +// STSPresignClient is the subset of the STS presign client we need to generate EKS tokens. +type STSPresignClient interface { + PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) +} + // GenAWSEKSToken creates an AWS token to access EKS clusters. // Logic from https://github.com/aws/aws-cli/blob/6c0d168f0b44136fc6175c57c090d4b115437ad1/awscli/customizations/eks/get_token.py#L211-L229 -func GenAWSEKSToken(stsClient stsiface.STSAPI, clusterID string, clock clockwork.Clock) (string, time.Time, error) { +// TODO(@creack): Consolidate with https://github.com/gravitational/teleport/blob/d37da511c944825a47155421bf278777238eecc0/lib/integrations/awsoidc/eks_enroll_clusters.go#L341-L372 +func GenAWSEKSToken(ctx context.Context, stsClient STSPresignClient, clusterID string, clock clockwork.Clock) (string, time.Time, error) { const ( - // The sts GetCallerIdentity request is valid for 15 minutes regardless of this parameters value after it has been - // signed. - requestPresignParam = 60 // The actual token expiration (presigned STS urls are valid for 15 minutes after timestamp in x-amz-date). + expireHeader = "X-Amz-Expires" + expireValue = "60" presignedURLExpiration = 15 * time.Minute v1Prefix = "k8s-aws-v1." clusterIDHeader = "x-k8s-aws-id" ) - // generate an sts:GetCallerIdentity request and add our custom cluster ID header - request, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) - // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date // timestamp regardless. We set it to 60 seconds for backwards compatibility (the // parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between // 0 and 60 on the server side). // https://github.com/aws/aws-sdk-go/issues/2167 - presignedURLString, err := request.Presign(requestPresignParam) + presignedReq, err := stsClient.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) { + po.ClientOptions = append(po.ClientOptions, sts.WithAPIOptions(func(stack *middleware.Stack) error { + return stack.Build.Add(middleware.BuildMiddlewareFunc("AddEKSId", func( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, + ) (middleware.BuildOutput, middleware.Metadata, error) { + switch req := in.Request.(type) { + case *smithyhttp.Request: + query := req.URL.Query() + query.Add(expireHeader, expireValue) + req.URL.RawQuery = query.Encode() + + req.Header.Add(clusterIDHeader, clusterID) + } + return next.HandleBuild(ctx, in) + }), middleware.Before) + })) + }) if err != nil { return "", time.Time{}, trace.Wrap(err) } - // Set token expiration to 1 minute before the presigned URL expires for some cushion + // Set token expiration to 1 minute before the presigned URL expires for some cushion. tokenExpiration := clock.Now().Add(presignedURLExpiration - 1*time.Minute) - return v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration, nil + return v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedReq.URL)), tokenExpiration, nil } diff --git a/lib/srv/db/cloud/iam_test.go b/lib/srv/db/cloud/iam_test.go index d13d1fc74b86c..c3b9ecf3dd716 100644 --- a/lib/srv/db/cloud/iam_test.go +++ b/lib/srv/db/cloud/iam_test.go @@ -416,6 +416,7 @@ func (m *mockAccessPoint) GetClusterName(opts ...services.MarshalOption) (types. ClusterID: "cluster-id", }) } + func (m *mockAccessPoint) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { return &types.SemaphoreLease{ SemaphoreKind: params.SemaphoreKind, @@ -424,6 +425,7 @@ func (m *mockAccessPoint) AcquireSemaphore(ctx context.Context, params types.Acq Expires: params.Expires, }, nil } + func (m *mockAccessPoint) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error { return nil } diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index ae136b4d53c46..63d79af27e500 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -957,8 +957,7 @@ func generateAzureVM(t *testing.T, identities []string) armcompute.VirtualMachin } // authClientMock is a mock that implements AuthClient interface. -type authClientMock struct { -} +type authClientMock struct{} // GenerateDatabaseCert generates a cert using fixtures TLS CA. func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { @@ -996,8 +995,7 @@ func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.Da }, nil } -type accessPointMock struct { -} +type accessPointMock struct{} // GetAuthPreference always returns types.DefaultAuthPreference(). func (m accessPointMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph.go index 4bc207b21df01..9d6d344ac9fda 100644 --- a/lib/srv/discovery/access_graph.go +++ b/lib/srv/discovery/access_graph.go @@ -502,6 +502,7 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M ctx, aws_sync.Config{ CloudClients: s.CloudClients, + GetEKSClient: s.GetAWSSyncEKSClient, GetEC2Client: s.GetEC2Client, AssumeRole: assumeRole, Regions: awsFetcher.Regions, diff --git a/lib/srv/discovery/common/kubernetes.go b/lib/srv/discovery/common/kubernetes.go index 9c383a6213fda..1bddd210493da 100644 --- a/lib/srv/discovery/common/kubernetes.go +++ b/lib/srv/discovery/common/kubernetes.go @@ -24,7 +24,6 @@ import ( "strings" "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go/aws" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" @@ -40,7 +39,7 @@ func setAWSKubeName(meta types.Metadata, firstNamePart string, extraNameParts .. } // NewKubeClusterFromAWSEKS creates a kube_cluster resource from an EKS cluster. -func NewKubeClusterFromAWSEKS(clusterName, clusterArn string, tags map[string]*string) (types.KubeCluster, error) { +func NewKubeClusterFromAWSEKS(clusterName, clusterArn string, tags map[string]string) (types.KubeCluster, error) { parsedARN, err := arn.Parse(clusterArn) if err != nil { return nil, trace.Wrap(err) @@ -64,7 +63,7 @@ func NewKubeClusterFromAWSEKS(clusterName, clusterArn string, tags map[string]*s } // labelsFromAWSKubeClusterTags creates kube cluster labels. -func labelsFromAWSKubeClusterTags(tags map[string]*string, parsedARN arn.ARN) map[string]string { +func labelsFromAWSKubeClusterTags(tags map[string]string, parsedARN arn.ARN) map[string]string { labels := awsEKSTagsToLabels(tags) labels[types.CloudLabel] = types.CloudAWS labels[types.DiscoveryLabelRegion] = parsedARN.Region @@ -74,11 +73,11 @@ func labelsFromAWSKubeClusterTags(tags map[string]*string, parsedARN arn.ARN) ma } // awsEKSTagsToLabels converts AWS tags to a labels map. -func awsEKSTagsToLabels(tags map[string]*string) map[string]string { +func awsEKSTagsToLabels(tags map[string]string) map[string]string { labels := make(map[string]string) for key, val := range tags { if types.IsValidLabelKey(key) { - labels[key] = aws.StringValue(val) + labels[key] = val } else { slog.DebugContext(context.Background(), "Skipping EKS tag that is not a valid label key", "tag", key) } diff --git a/lib/srv/discovery/common/kubernetes_test.go b/lib/srv/discovery/common/kubernetes_test.go index b121c624a1e76..868f9dfac9370 100644 --- a/lib/srv/discovery/common/kubernetes_test.go +++ b/lib/srv/discovery/common/kubernetes_test.go @@ -20,8 +20,8 @@ import ( "testing" "cloud.google.com/go/container/apiv1/containerpb" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go-v2/aws" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -97,16 +97,16 @@ func TestNewKubeClusterFromAWSEKS(t *testing.T) { }) require.NoError(t, err) - cluster := &eks.Cluster{ + cluster := &ekstypes.Cluster{ Name: aws.String("cluster1"), Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - overrideLabel: aws.String("override-1"), - "env": aws.String("prod"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + overrideLabel: "override-1", + "env": "prod", }, } - actual, err := NewKubeClusterFromAWSEKS(aws.StringValue(cluster.Name), aws.StringValue(cluster.Arn), cluster.Tags) + actual, err := NewKubeClusterFromAWSEKS(aws.ToString(cluster.Name), aws.ToString(cluster.Arn), cluster.Tags) require.NoError(t, err) require.Empty(t, cmp.Diff(expected, actual)) require.NoError(t, err) diff --git a/lib/srv/discovery/common/renaming_test.go b/lib/srv/discovery/common/renaming_test.go index b01825725f672..5be2c13f3b3c4 100644 --- a/lib/srv/discovery/common/renaming_test.go +++ b/lib/srv/discovery/common/renaming_test.go @@ -27,8 +27,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/rds" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -498,12 +498,12 @@ func labelsToAzureTags(labels map[string]string) map[string]*string { func makeEKSKubeCluster(t *testing.T, name, region, accountID, overrideLabel string) types.KubeCluster { t.Helper() - eksCluster := &eks.Cluster{ + eksCluster := &ekstypes.Cluster{ Name: aws.String(name), Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - overrideLabel: aws.String(name), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + overrideLabel: name, }, } kubeCluster, err := NewKubeClusterFromAWSEKS(aws.StringValue(eksCluster.Name), aws.StringValue(eksCluster.Arn), eksCluster.Tags) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index f37ba025d2450..047553edeabde 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -32,8 +32,10 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/session" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -115,10 +117,18 @@ type gcpInstaller interface { type Config struct { // CloudClients is an interface for retrieving cloud clients. CloudClients cloud.Clients + + // AWSFetchersClients gets the AWS clients for the given region for the fetchers. + AWSFetchersClients fetchers.AWSClientGetter + + // GetAWSSyncEKSClient gets an AWS EKS client for the given region for fetchers/aws-sync. + GetAWSSyncEKSClient aws_sync.EKSClientGetter + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider // AWSDatabaseFetcherFactory provides AWS database fetchers AWSDatabaseFetcherFactory *db.AWSFetcherFactory + // GetEC2Client gets an AWS EC2 client for the given region. GetEC2Client server.EC2ClientGetter // GetSSMClient gets an AWS SSM client for the given region. @@ -196,6 +206,23 @@ type AccessGraphConfig struct { Insecure bool } +type awsFetchersClientsGetter struct { + awsconfig.Provider +} + +func (f *awsFetchersClientsGetter) GetAWSEKSClient(cfg aws.Config) fetchers.EKSClient { + return eks.NewFromConfig(cfg) +} + +func (f *awsFetchersClientsGetter) GetAWSSTSClient(cfg aws.Config) fetchers.STSClient { + return sts.NewFromConfig(cfg) +} + +func (f *awsFetchersClientsGetter) GetAWSSTSPresignClient(cfg aws.Config) fetchers.STSPresignClient { + stsClient := sts.NewFromConfig(cfg) + return sts.NewPresignClient(stsClient) +} + func (c *Config) CheckAndSetDefaults() error { if c.Matchers.IsEmpty() && c.DiscoveryGroup == "" { return trace.BadParameter("no matchers or discovery group configured for discovery") @@ -253,6 +280,20 @@ kubernetes matchers are present.`) return ec2.NewFromConfig(cfg), nil } } + if c.AWSFetchersClients == nil { + c.AWSFetchersClients = &awsFetchersClientsGetter{ + Provider: awsconfig.ProviderFunc(c.getAWSConfig), + } + } + if c.GetAWSSyncEKSClient == nil { + c.GetAWSSyncEKSClient = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws_sync.EKSClient, error) { + cfg, err := c.getAWSConfig(ctx, region, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + return eks.NewFromConfig(cfg), nil + } + } if c.GetSSMClient == nil { c.GetSSMClient = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) { cfg, err := c.getAWSConfig(ctx, region, opts...) @@ -561,7 +602,7 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error { _, otherMatchers = splitMatchers(otherMatchers, db.IsAWSMatcherType) // Add non-integration kube fetchers. - kubeFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.CloudClients, otherMatchers, noDiscoveryConfig) + kubeFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.AWSFetchersClients, otherMatchers, noDiscoveryConfig) if err != nil { return trace.Wrap(err) } @@ -714,12 +755,12 @@ func (s *Server) databaseFetchersFromMatchers(matchers Matchers, discoveryConfig func (s *Server) kubeFetchersFromMatchers(matchers Matchers, discoveryConfigName string) ([]common.Fetcher, error) { var result []common.Fetcher - // AWS + // AWS. awsKubeMatchers, _ := splitMatchers(matchers.AWS, func(matcherType string) bool { return matcherType == types.AWSMatcherEKS }) if len(awsKubeMatchers) > 0 { - eksFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.CloudClients, awsKubeMatchers, discoveryConfigName) + eksFetchers, err := fetchers.MakeEKSFetchersFromAWSMatchers(s.Log, s.AWSFetchersClients, awsKubeMatchers, discoveryConfigName) if err != nil { return nil, trace.Wrap(err) } @@ -1264,7 +1305,6 @@ func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) erro _, vmOK := labels[types.VMIDLabel] return subscriptionOK && vmOK }) - if err != nil { return trace.Wrap(err) } @@ -1357,7 +1397,6 @@ func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) error { _, nameOK := labels[types.NameLabelDiscovery] return projectIDOK && zoneOK && nameOK }) - if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 865517ba4c33c..3eea560f67174 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -36,17 +36,15 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/eks" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/rds" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -86,6 +84,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" + "github.com/gravitational/teleport/lib/srv/discovery/fetchers" "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/srv/server" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" @@ -175,10 +174,10 @@ func genEC2Instances(n int) []ec2types.Instance { var ec2Instances []ec2types.Instance for _, id := range genEC2InstanceIDs(n) { ec2Instances = append(ec2Instances, ec2types.Instance{ - InstanceId: awsv2.String(id), + InstanceId: aws.String(id), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -324,11 +323,12 @@ func TestDiscoveryServer(t *testing.T) { tcs := []struct { name string - // presentInstances is a list of servers already present in teleport + // presentInstances is a list of servers already present in teleport. presentInstances []types.Server foundEC2Instances []ec2types.Instance ssm *mockSSMClient emitter *mockEmitter + eksClusters []*ekstypes.Cluster eksEnroller eksClustersEnroller discoveryConfig *discoveryconfig.DiscoveryConfig staticMatchers Matchers @@ -339,14 +339,14 @@ func TestDiscoveryServer(t *testing.T) { ssmRunError error }{ { - name: "no nodes present, 1 found ", + name: "no nodes present, 1 found", presentInstances: []types.Server{}, foundEC2Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("instance-id-1"), + InstanceId: aws.String("instance-id-1"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -356,7 +356,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -401,10 +401,10 @@ func TestDiscoveryServer(t *testing.T) { }, foundEC2Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("instance-id-1"), + InstanceId: aws.String("instance-id-1"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -414,7 +414,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -442,10 +442,10 @@ func TestDiscoveryServer(t *testing.T) { }, foundEC2Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("instance-id-1"), + InstanceId: aws.String("instance-id-1"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -455,7 +455,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -474,7 +474,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -491,10 +491,10 @@ func TestDiscoveryServer(t *testing.T) { presentInstances: []types.Server{}, foundEC2Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("instance-id-1"), + InstanceId: aws.String("instance-id-1"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -504,7 +504,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -538,10 +538,10 @@ func TestDiscoveryServer(t *testing.T) { presentInstances: []types.Server{}, foundEC2Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("instance-id-1"), + InstanceId: aws.String("instance-id-1"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -551,7 +551,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -625,10 +625,10 @@ func TestDiscoveryServer(t *testing.T) { presentInstances: []types.Server{}, foundEC2Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("instance-id-1"), + InstanceId: aws.String("instance-id-1"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -638,7 +638,7 @@ func TestDiscoveryServer(t *testing.T) { ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ Command: &ssmtypes.Command{ - CommandId: awsv2.String("command-id-1"), + CommandId: aws.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ @@ -667,7 +667,7 @@ func TestDiscoveryServer(t *testing.T) { staticMatchers: Matchers{}, discoveryConfig: discoveryConfigForUserTaskEC2Test, wantInstalledInstances: []string{}, - userTasksDiscoverCheck: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + userTasksDiscoverCheck: func(t require.TestingT, i1 interface{}, i2 ...interface{}) { existingTasks, ok := i1.([]*usertasksv1.UserTask) require.True(t, ok, "failed to get existing tasks: %T", i1) require.Len(t, existingTasks, 1) @@ -693,26 +693,21 @@ func TestDiscoveryServer(t *testing.T) { presentInstances: []types.Server{}, foundEC2Instances: []ec2types.Instance{}, ssm: &mockSSMClient{}, - cloudClients: &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - EKS: &mocks.EKSMock{ - Clusters: []*eks.Cluster{ - { - Name: aws.String("cluster01"), - Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster01"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "RunDiscover": aws.String("Please"), - }, - }, - { - Name: aws.String("cluster02"), - Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster02"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "RunDiscover": aws.String("Please"), - }, - }, + eksClusters: []*ekstypes.Cluster{ + { + Name: aws.String("cluster01"), + Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster01"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "RunDiscover": "Please", + }, + }, + { + Name: aws.String("cluster02"), + Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster02"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "RunDiscover": "Please", }, }, }, @@ -737,7 +732,7 @@ func TestDiscoveryServer(t *testing.T) { staticMatchers: Matchers{}, discoveryConfig: discoveryConfigForUserTaskEKSTest, wantInstalledInstances: []string{}, - userTasksDiscoverCheck: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + userTasksDiscoverCheck: func(t require.TestingT, i1 interface{}, i2 ...interface{}) { existingTasks, ok := i1.([]*usertasksv1.UserTask) require.True(t, ok, "failed to get existing tasks: %T", i1) require.Len(t, existingTasks, 1) @@ -761,20 +756,21 @@ func TestDiscoveryServer(t *testing.T) { } for _, tc := range tcs { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + ctx := context.Background() - ec2Client := &mockEC2Client{output: &ec2.DescribeInstancesOutput{ - Reservations: []ec2types.Reservation{ - { - OwnerId: awsv2.String("owner"), - Instances: tc.foundEC2Instances, + ec2Client := &mockEC2Client{ + output: &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{ + { + OwnerId: aws.String("owner"), + Instances: tc.foundEC2Instances, + }, }, }, - }} + } - ctx := context.Background() // Create and start test auth server. testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ Dir: t.TempDir(), @@ -782,9 +778,24 @@ func TestDiscoveryServer(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) }) + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{ + Name: "my-integration", + }, &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:iam::123456789012:role/teleport", + }) + require.NoError(t, err) + testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{ + proxies: nil, + integrations: map[string]types.Integration{ + awsOIDCIntegration.GetName(): awsOIDCIntegration, + }, + } + tlsServer, err := testAuthServer.NewTestTLSServer() require.NoError(t, err) t.Cleanup(func() { require.NoError(t, tlsServer.Close()) }) + _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration) + require.NoError(t, err) // Auth client for discovery service. identity := auth.TestServerID(types.RoleDiscovery, "hostID") @@ -816,6 +827,9 @@ func TestDiscoveryServer(t *testing.T) { eksEnroller = tc.eksEnroller } + fakeConfigProvider := mocks.AWSConfigProvider{ + OIDCIntegrationClient: tlsServer.Auth(), + } server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ GetEC2Client: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) { return ec2Client, nil @@ -823,6 +837,11 @@ func TestDiscoveryServer(t *testing.T) { GetSSMClient: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) { return tc.ssm, nil }, + AWSConfigProvider: &fakeConfigProvider, + AWSFetchersClients: &mockFetchersClients{ + AWSConfigProvider: fakeConfigProvider, + eksClusters: tc.eksClusters, + }, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPointWithEKSEnroller(tlsServer.Auth(), authClient, eksEnroller), @@ -916,20 +935,20 @@ func TestDiscoveryServerConcurrency(t *testing.T) { output: &ec2.DescribeInstancesOutput{ Reservations: []ec2types.Reservation{ { - OwnerId: awsv2.String("123456789012"), + OwnerId: aws.String("123456789012"), Instances: []ec2types.Instance{ { - InstanceId: awsv2.String("i-123456789012"), + InstanceId: aws.String("i-123456789012"), Tags: []ec2types.Tag{ { - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }, }, - PrivateIpAddress: awsv2.String("172.0.1.2"), - VpcId: awsv2.String("vpcId"), - SubnetId: awsv2.String("subnetId"), - PrivateDnsName: awsv2.String("privateDnsName"), + PrivateIpAddress: aws.String("172.0.1.2"), + VpcId: aws.String("vpcId"), + SubnetId: aws.String("subnetId"), + PrivateDnsName: aws.String("privateDnsName"), State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, }, @@ -1212,11 +1231,12 @@ func TestDiscoveryKubeServices(t *testing.T) { } func TestDiscoveryInCloudKube(t *testing.T) { + t.Parallel() + const ( mainDiscoveryGroup = "main" otherDiscoveryGroup = "other" ) - t.Parallel() tcs := []struct { name string existingKubeClusters []types.KubeCluster @@ -1440,15 +1460,11 @@ func TestDiscoveryInCloudKube(t *testing.T) { } for _, tc := range tcs { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - sts := &mocks.STSClientV1{} testCloudClients := &cloud.TestCloudClients{ - STS: sts, AzureAKSClient: newPopulatedAKSMock(), - EKS: newPopulatedEKSMock(), GCPGKE: newPopulatedGCPMock(), GCPProjects: newPopulatedGCPProjectsMock(), } @@ -1475,7 +1491,7 @@ func TestDiscoveryInCloudKube(t *testing.T) { err := tlsServer.Auth().CreateKubernetesCluster(ctx, kubeCluster) require.NoError(t, err) } - // we analyze the logs emitted by discovery service to detect clusters that were not updated + // We analyze the logs emitted by discovery service to detect clusters that were not updated // because their state didn't change. r, w := io.Pipe() t.Cleanup(func() { @@ -1506,15 +1522,26 @@ func TestDiscoveryInCloudKube(t *testing.T) { } } }() + reporter := &mockUsageReporter{} tlsServer.Auth().SetUsageReporter(reporter) + + mockedClients := &mockFetchersClients{ + AWSConfigProvider: mocks.AWSConfigProvider{ + STSClient: &mocks.STSClient{}, + OIDCIntegrationClient: newFakeAccessPoint(), + }, + eksClusters: newPopulatedEKSMock().clusters, + } + discServer, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ - CloudClients: testCloudClients, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + CloudClients: testCloudClients, + AWSFetchersClients: mockedClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), Matchers: Matchers{ AWS: tc.awsMatchers, Azure: tc.azureMatchers, @@ -1524,12 +1551,9 @@ func TestDiscoveryInCloudKube(t *testing.T) { Log: logger, DiscoveryGroup: mainDiscoveryGroup, }) - require.NoError(t, err) - t.Cleanup(func() { - discServer.Stop() - }) + t.Cleanup(discServer.Stop) go discServer.Start() clustersNotUpdatedMap := sliceToSet(tc.clustersNotUpdated) @@ -1562,8 +1586,8 @@ func TestDiscoveryInCloudKube(t *testing.T) { return len(clustersNotUpdated) == 0 && clustersFoundInAuth }, 5*time.Second, 200*time.Millisecond) - require.ElementsMatch(t, tc.expectedAssumedRoles, sts.GetAssumedRoleARNs(), "roles incorrectly assumed") - require.ElementsMatch(t, tc.expectedExternalIDs, sts.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed") + require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.STSClient.GetAssumedRoleARNs(), "roles incorrectly assumed") + require.ElementsMatch(t, tc.expectedExternalIDs, mockedClients.STSClient.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed") if tc.wantEvents > 0 { require.Eventually(t, func() bool { @@ -1582,14 +1606,15 @@ func TestDiscoveryServer_New(t *testing.T) { t.Parallel() testCases := []struct { desc string - cloudClients cloud.Clients + cloudClients fetchers.AWSClientGetter matchers Matchers errAssertion require.ErrorAssertionFunc discServerAssertion require.ValueAssertionFunc }{ { - desc: "no matchers error", - cloudClients: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, + desc: "no matchers error", + + cloudClients: &mockFetchersClients{}, matchers: Matchers{}, errAssertion: func(t require.TestingT, err error, i ...interface{}) { require.ErrorIs(t, err, &trace.BadParameterError{Message: "no matchers or discovery group configured for discovery"}) @@ -1597,8 +1622,10 @@ func TestDiscoveryServer_New(t *testing.T) { discServerAssertion: require.Nil, }, { - desc: "success with EKS matcher", - cloudClients: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}, EKS: &mocks.EKSMock{}}, + desc: "success with EKS matcher", + + cloudClients: &mockFetchersClients{}, + matchers: Matchers{ AWS: []types.AWSMatcher{ { @@ -1621,11 +1648,8 @@ func TestDiscoveryServer_New(t *testing.T) { }, }, { - desc: "EKS fetcher is skipped on initialization error (missing region)", - cloudClients: &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - EKS: &mocks.EKSMock{}, - }, + desc: "EKS fetcher is skipped on initialization error (missing region)", + cloudClients: &mockFetchersClients{}, matchers: Matchers{ AWS: []types.AWSMatcher{ { @@ -1666,12 +1690,12 @@ func TestDiscoveryServer_New(t *testing.T) { discServer, err := New( ctx, &Config{ - CloudClients: tt.cloudClients, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - AccessPoint: newFakeAccessPoint(), - Matchers: tt.matchers, - Emitter: &mockEmitter{}, - protocolChecker: &noopProtocolChecker{}, + AWSFetchersClients: tt.cloudClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + AccessPoint: newFakeAccessPoint(), + Matchers: tt.matchers, + Emitter: &mockEmitter{}, + protocolChecker: &noopProtocolChecker{}, }) tt.errAssertion(t, err) @@ -1759,28 +1783,33 @@ var aksMockClusters = map[string][]*azure.AKSCluster{ } type mockEKSAPI struct { - eksiface.EKSAPI - clusters []*eks.Cluster + fetchers.EKSClient + clusters []*ekstypes.Cluster } -func (m *mockEKSAPI) ListClustersPagesWithContext(ctx aws.Context, req *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error { - var names []*string +func (m *mockEKSAPI) ListClusters(ctx context.Context, req *eks.ListClustersInput, _ ...func(*eks.Options)) (*eks.ListClustersOutput, error) { + var names []string for _, cluster := range m.clusters { - names = append(names, cluster.Name) + names = append(names, aws.ToString(cluster.Name)) } - f(&eks.ListClustersOutput{ - Clusters: names[:len(names)/2], - }, false) - f(&eks.ListClustersOutput{ + // First call, no NextToken. Return first half and a NextToken value. + if req.NextToken == nil { + return &eks.ListClustersOutput{ + Clusters: names[:len(names)/2], + NextToken: aws.String("next"), + }, nil + } + + // Second call, we have a NextToken, return the second half. + return &eks.ListClustersOutput{ Clusters: names[len(names)/2:], - }, true) - return nil + }, nil } -func (m *mockEKSAPI) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) { +func (m *mockEKSAPI) DescribeCluster(_ context.Context, req *eks.DescribeClusterInput, _ ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) { for _, cluster := range m.clusters { - if aws.StringValue(cluster.Name) == aws.StringValue(req.Name) { + if aws.ToString(cluster.Name) == aws.ToString(req.Name) { return &eks.DescribeClusterOutput{ Cluster: cluster, }, nil @@ -1795,48 +1824,70 @@ func newPopulatedEKSMock() *mockEKSAPI { } } -var eksMockClusters = []*eks.Cluster{ +type mockFetchersClients struct { + mocks.AWSConfigProvider + eksClusters []*ekstypes.Cluster +} + +func (m *mockFetchersClients) GetAWSEKSClient(aws.Config) fetchers.EKSClient { + return &mockEKSAPI{ + clusters: m.eksClusters, + } +} + +func (m *mockFetchersClients) GetAWSSTSClient(aws.Config) fetchers.STSClient { + if m.AWSConfigProvider.STSClient != nil { + return m.AWSConfigProvider.STSClient + } + return &mocks.STSClient{} +} + +func (m *mockFetchersClients) GetAWSSTSPresignClient(aws.Config) fetchers.STSPresignClient { + return nil +} + +var eksMockClusters = []*ekstypes.Cluster{ { Name: aws.String("eks-cluster1"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("prod"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "prod", + "location": "eu-west-1", }, }, { Name: aws.String("eks-cluster2"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster2"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("prod"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "prod", + "location": "eu-west-1", }, }, { Name: aws.String("eks-cluster3"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster3"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("stg"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "stg", + "location": "eu-west-1", }, }, { Name: aws.String("eks-cluster4"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("stg"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "stg", + "location": "eu-west-1", }, }, } -func mustConvertEKSToKubeCluster(t *testing.T, eksCluster *eks.Cluster, discoveryParams rewriteDiscoveryLabelsParams) types.KubeCluster { - cluster, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(eksCluster.Name), aws.StringValue(eksCluster.Arn), eksCluster.Tags) +func mustConvertEKSToKubeCluster(t *testing.T, eksCluster *ekstypes.Cluster, discoveryParams rewriteDiscoveryLabelsParams) types.KubeCluster { + cluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksCluster.Name), aws.ToString(eksCluster.Arn), eksCluster.Tags) require.NoError(t, err) discoveryParams.matcherType = types.AWSMatcherEKS rewriteCloudResource(t, cluster, discoveryParams) @@ -2027,9 +2078,6 @@ func TestDiscoveryDatabase(t *testing.T) { &azure.ARMRedisEnterpriseClusterMock{}, &azure.ARMRedisEnterpriseDatabaseMock{}, ), - EKS: &mocks.EKSMock{ - Clusters: []*eks.Cluster{eksAWSResource}, - }, } tcs := []struct { @@ -2303,7 +2351,6 @@ func TestDiscoveryDatabase(t *testing.T) { } for _, tc := range tcs { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -2370,12 +2417,16 @@ func TestDiscoveryDatabase(t *testing.T) { authz.ContextWithUser(ctx, identity.I), &Config{ IntegrationOnlyCredentials: integrationOnlyCredential, - CloudClients: testCloudClients, - AWSDatabaseFetcherFactory: dbFetcherFactory, - AWSConfigProvider: fakeConfigProvider, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: accessPoint, + AWSFetchersClients: &mockFetchersClients{ + AWSConfigProvider: *fakeConfigProvider, + eksClusters: []*ekstypes.Cluster{eksAWSResource}, + }, + CloudClients: testCloudClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + AWSDatabaseFetcherFactory: dbFetcherFactory, + AWSConfigProvider: fakeConfigProvider, Matchers: Matchers{ AWS: tc.awsMatchers, Azure: tc.azureMatchers, @@ -2420,7 +2471,7 @@ func TestDiscoveryDatabase(t *testing.T) { cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"), )) case <-time.After(time.Second): - t.Fatal("Didn't receive reconcile event after 1s") + require.FailNow(t, "Didn't receive reconcile event after 1s") } if tc.wantEvents > 0 { @@ -2601,17 +2652,17 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { }) } -func makeEKSCluster(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*eks.Cluster, types.KubeCluster) { +func makeEKSCluster(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*ekstypes.Cluster, types.KubeCluster) { t.Helper() - eksAWSCluster := &eks.Cluster{ + eksAWSCluster := &ekstypes.Cluster{ Name: aws.String(name), Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:123456789012:cluster/%s", region, name)), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("prod"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "prod", }, } - actual, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(eksAWSCluster.Name), aws.StringValue(eksAWSCluster.Arn), eksAWSCluster.Tags) + actual, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksAWSCluster.Name), aws.ToString(eksAWSCluster.Arn), eksAWSCluster.Tags) require.NoError(t, err) discoveryParams.matcherType = types.AWSMatcherEKS rewriteCloudResource(t, actual, discoveryParams) @@ -2986,6 +3037,7 @@ func (m *mockGCPClient) getVMSForProject(projectID string) []*gcpimds.Instance { } return vms } + func (m *mockGCPClient) ListInstances(_ context.Context, projectID, _ string) ([]*gcpimds.Instance, error) { return m.getVMSForProject(projectID), nil } @@ -3697,7 +3749,7 @@ func newPopulatedGCPProjectsMock() *mockProjectsAPI { } func newFakeRedshiftClientProvider(c redshift.DescribeClustersAPIClient) db.RedshiftClientProviderFunc { - return func(cfg awsv2.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { + return func(cfg aws.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { return c } } diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go index 2a7e928370091..adc450ece9fbc 100644 --- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go +++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go @@ -47,6 +47,8 @@ const pageSize int64 = 500 type Config struct { // CloudClients is the cloud clients to use when fetching AWS resources. CloudClients cloud.Clients + // GetEKSClient gets an AWS EKS client for the given region. + GetEKSClient EKSClientGetter // GetEC2Client gets an AWS EC2 client for the given region. GetEC2Client server.EC2ClientGetter // AccountID is the AWS account ID to use when fetching resources. diff --git a/lib/srv/discovery/fetchers/aws-sync/eks.go b/lib/srv/discovery/fetchers/aws-sync/eks.go index e4a7cc768ecd2..fc1791b4cb13a 100644 --- a/lib/srv/discovery/fetchers/aws-sync/eks.go +++ b/lib/srv/discovery/fetchers/aws-sync/eks.go @@ -22,16 +22,32 @@ import ( "context" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/eks" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/gravitational/trace" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/cloud/awsconfig" ) +// EKSClientGetter returns an EKS client for aws-sync. +type EKSClientGetter func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (EKSClient, error) + +// EKSClient is the subset of the EKS interface we use in aws-sync. +type EKSClient interface { + eks.ListClustersAPIClient + eks.DescribeClusterAPIClient + + eks.ListAccessEntriesAPIClient + DescribeAccessEntry(ctx context.Context, params *eks.DescribeAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DescribeAccessEntryOutput, error) + + eks.ListAssociatedAccessPoliciesAPIClient +} + // pollAWSEKSClusters is a function that returns a function that fetches // eks clusters and their access scope levels. func (a *awsFetcher) pollAWSEKSClusters(ctx context.Context, result *Resources, collectErr func(error)) func() error { @@ -70,7 +86,8 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust collectClusters := func(cluster *accessgraphv1alpha.AWSEKSClusterV1, clusterAssociatedPolicies []*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1, clusterAccessEntries []*accessgraphv1alpha.AWSEKSClusterAccessEntryV1, - err error) { + err error, + ) { hostsMu.Lock() defer hostsMu.Unlock() if err != nil { @@ -86,41 +103,34 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust for _, region := range a.Regions { region := region eG.Go(func() error { - eksClient, err := a.CloudClients.GetAWSEKSClient(ctx, region, a.getAWSOptions()...) + eksClient, err := a.GetEKSClient(ctx, region, a.getAWSV2Options()...) if err != nil { collectClusters(nil, nil, nil, trace.Wrap(err)) return nil } var eksClusterNames []string - // ListClustersPagesWithContext returns a list of EKS cluster names existing in the region. - err = eksClient.ListClustersPagesWithContext( - ctx, - &eks.ListClustersInput{}, - func(output *eks.ListClustersOutput, lastPage bool) bool { - for _, cluster := range output.Clusters { - eksClusterNames = append(eksClusterNames, aws.StringValue(cluster)) - } - return !lastPage - - }, - ) - if err != nil { - oldEKSClusters := sliceFilter(existing.EKSClusters, func(cluster *accessgraphv1alpha.AWSEKSClusterV1) bool { - return cluster.Region == region && cluster.AccountId == a.AccountID - }) - oldAccessEntries := sliceFilter(existing.AccessEntries, func(ae *accessgraphv1alpha.AWSEKSClusterAccessEntryV1) bool { - return ae.Cluster.Region == region && ae.AccountId == a.AccountID - }) - oldAssociatedPolicies := sliceFilter(existing.AssociatedAccessPolicies, func(ap *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1) bool { - return ap.Cluster.Region == region && ap.AccountId == a.AccountID - }) - hostsMu.Lock() - output.clusters = append(output.clusters, oldEKSClusters...) - output.associatedPolicies = append(output.associatedPolicies, oldAssociatedPolicies...) - output.accessEntry = append(output.accessEntry, oldAccessEntries...) - hostsMu.Unlock() + for p := eks.NewListClustersPaginator(eksClient, nil); p.HasMorePages(); { + out, err := p.NextPage(ctx) + if err != nil { + oldEKSClusters := sliceFilter(existing.EKSClusters, func(cluster *accessgraphv1alpha.AWSEKSClusterV1) bool { + return cluster.Region == region && cluster.AccountId == a.AccountID + }) + oldAccessEntries := sliceFilter(existing.AccessEntries, func(ae *accessgraphv1alpha.AWSEKSClusterAccessEntryV1) bool { + return ae.Cluster.Region == region && ae.AccountId == a.AccountID + }) + oldAssociatedPolicies := sliceFilter(existing.AssociatedAccessPolicies, func(ap *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1) bool { + return ap.Cluster.Region == region && ap.AccountId == a.AccountID + }) + hostsMu.Lock() + output.clusters = append(output.clusters, oldEKSClusters...) + output.associatedPolicies = append(output.associatedPolicies, oldAssociatedPolicies...) + output.accessEntry = append(output.accessEntry, oldAccessEntries...) + hostsMu.Unlock() + break + } + eksClusterNames = append(eksClusterNames, out.Clusters...) } for _, cluster := range eksClusterNames { @@ -134,7 +144,7 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust return ap.Cluster.Name == cluster && ap.AccountId == a.AccountID && ap.Cluster.Region == region }) // DescribeClusterWithContext retrieves the cluster details. - cluster, err := eksClient.DescribeClusterWithContext(ctx, &eks.DescribeClusterInput{ + cluster, err := eksClient.DescribeCluster(ctx, &eks.DescribeClusterInput{ Name: aws.String(cluster), }, ) @@ -147,7 +157,7 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust // if eks cluster only allows CONFIGMAP auth, skip polling of access entries and // associated policies. if cluster.Cluster != nil && cluster.Cluster.AccessConfig != nil && - aws.StringValue(cluster.Cluster.AccessConfig.AuthenticationMode) == eks.AuthenticationModeConfigMap { + cluster.Cluster.AccessConfig.AuthenticationMode == ekstypes.AuthenticationModeConfigMap { collectClusters(protoCluster, nil, nil, nil) continue } @@ -181,20 +191,20 @@ func (a *awsFetcher) fetchAWSSEKSClusters(ctx context.Context) (fetchAWSEKSClust // awsEKSClusterToProtoCluster converts an eks.Cluster to accessgraphv1alpha.AWSEKSClusterV1 // representation. -func awsEKSClusterToProtoCluster(cluster *eks.Cluster, region, accountID string) *accessgraphv1alpha.AWSEKSClusterV1 { +func awsEKSClusterToProtoCluster(cluster *ekstypes.Cluster, region, accountID string) *accessgraphv1alpha.AWSEKSClusterV1 { var tags []*accessgraphv1alpha.AWSTag for k, v := range cluster.Tags { tags = append(tags, &accessgraphv1alpha.AWSTag{ Key: k, - Value: strPtrToWrapper(v), + Value: wrapperspb.String(v), }) } return &accessgraphv1alpha.AWSEKSClusterV1{ - Name: aws.StringValue(cluster.Name), - Arn: aws.StringValue(cluster.Arn), + Name: aws.ToString(cluster.Name), + Arn: aws.ToString(cluster.Arn), CreatedAt: awsTimeToProtoTime(cluster.CreatedAt), - Status: aws.StringValue(cluster.Status), + Status: string(cluster.Status), Region: region, AccountId: accountID, Tags: tags, @@ -203,33 +213,23 @@ func awsEKSClusterToProtoCluster(cluster *eks.Cluster, region, accountID string) } // fetchAccessEntries fetches the access entries for the given cluster. -func (a *awsFetcher) fetchAccessEntries(ctx context.Context, eksClient eksiface.EKSAPI, cluster *accessgraphv1alpha.AWSEKSClusterV1) ([]*accessgraphv1alpha.AWSEKSClusterAccessEntryV1, error) { +func (a *awsFetcher) fetchAccessEntries(ctx context.Context, eksClient EKSClient, cluster *accessgraphv1alpha.AWSEKSClusterV1) ([]*accessgraphv1alpha.AWSEKSClusterAccessEntryV1, error) { var accessEntries []string - var errs []error - err := eksClient.ListAccessEntriesPagesWithContext( - ctx, - &eks.ListAccessEntriesInput{ - ClusterName: aws.String(cluster.Name), - }, - func(output *eks.ListAccessEntriesOutput, lastPage bool) bool { - for _, accessEntry := range output.AccessEntries { - if aws.StringValue(accessEntry) == "" { - continue - } - accessEntries = append(accessEntries, aws.StringValue(accessEntry)) - } - return !lastPage - }, - ) - if err != nil { - errs = append(errs, trace.Wrap(err)) - return nil, trace.NewAggregate(errs...) + for p := eks.NewListAccessEntriesPaginator(eksClient, + &eks.ListAccessEntriesInput{ClusterName: aws.String(cluster.Name)}, + ); p.HasMorePages(); { + out, err := p.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + accessEntries = append(accessEntries, out.AccessEntries...) } + var errs []error var protoAccessEntries []*accessgraphv1alpha.AWSEKSClusterAccessEntryV1 for _, accessEntry := range accessEntries { - rsp, err := eksClient.DescribeAccessEntryWithContext( + rsp, err := eksClient.DescribeAccessEntry( ctx, &eks.DescribeAccessEntryInput{ PrincipalArn: aws.String(accessEntry), @@ -247,84 +247,81 @@ func (a *awsFetcher) fetchAccessEntries(ctx context.Context, eksClient eksiface. ) protoAccessEntries = append(protoAccessEntries, protoAccessEntry) } + return protoAccessEntries, trace.NewAggregate(errs...) } // awsAccessEntryToProtoAccessEntry converts an eks.AccessEntry to accessgraphv1alpha.AWSEKSClusterV1 -func awsAccessEntryToProtoAccessEntry(accessEntry *eks.AccessEntry, cluster *accessgraphv1alpha.AWSEKSClusterV1, accountID string) *accessgraphv1alpha.AWSEKSClusterAccessEntryV1 { - var tags []*accessgraphv1alpha.AWSTag +func awsAccessEntryToProtoAccessEntry(accessEntry *ekstypes.AccessEntry, cluster *accessgraphv1alpha.AWSEKSClusterV1, accountID string) *accessgraphv1alpha.AWSEKSClusterAccessEntryV1 { + tags := make([]*accessgraphv1alpha.AWSTag, 0, len(accessEntry.Tags)) for k, v := range accessEntry.Tags { tags = append(tags, &accessgraphv1alpha.AWSTag{ Key: k, - Value: strPtrToWrapper(v), + Value: wrapperspb.String(v), }) } - out := &accessgraphv1alpha.AWSEKSClusterAccessEntryV1{ + + return &accessgraphv1alpha.AWSEKSClusterAccessEntryV1{ Cluster: cluster, - AccessEntryArn: aws.StringValue(accessEntry.AccessEntryArn), + AccessEntryArn: aws.ToString(accessEntry.AccessEntryArn), CreatedAt: awsTimeToProtoTime(accessEntry.CreatedAt), - KubernetesGroups: aws.StringValueSlice(accessEntry.KubernetesGroups), - Username: aws.StringValue(accessEntry.Username), + KubernetesGroups: accessEntry.KubernetesGroups, + Username: aws.ToString(accessEntry.Username), ModifiedAt: awsTimeToProtoTime(accessEntry.ModifiedAt), - PrincipalArn: aws.StringValue(accessEntry.PrincipalArn), - Type: aws.StringValue(accessEntry.Type), + PrincipalArn: aws.ToString(accessEntry.PrincipalArn), + Type: aws.ToString(accessEntry.Type), Tags: tags, AccountId: accountID, LastSyncTime: timestamppb.Now(), } - - return out } // fetchAccessEntries fetches the access entries for the given cluster. -func (a *awsFetcher) fetchAssociatedPolicies(ctx context.Context, eksClient eksiface.EKSAPI, cluster *accessgraphv1alpha.AWSEKSClusterV1, arns []string) ([]*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1, error) { +func (a *awsFetcher) fetchAssociatedPolicies(ctx context.Context, eksClient EKSClient, cluster *accessgraphv1alpha.AWSEKSClusterV1, arns []string) ([]*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1, error) { var associatedPolicies []*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1 var errs []error + for _, arn := range arns { - err := eksClient.ListAssociatedAccessPoliciesPagesWithContext( - ctx, + for p := eks.NewListAssociatedAccessPoliciesPaginator(eksClient, &eks.ListAssociatedAccessPoliciesInput{ ClusterName: aws.String(cluster.Name), PrincipalArn: aws.String(arn), }, - func(output *eks.ListAssociatedAccessPoliciesOutput, lastPage bool) bool { - for _, policy := range output.AssociatedAccessPolicies { - associatedPolicies = append(associatedPolicies, - awsAssociatedAccessPolicy(policy, cluster, arn, a.AccountID), - ) - } - return !lastPage - }, - ) - if err != nil { - errs = append(errs, trace.Wrap(err)) - + ); p.HasMorePages(); { + out, err := p.NextPage(ctx) + if err != nil { + errs = append(errs, err) + break + } + for _, policy := range out.AssociatedAccessPolicies { + associatedPolicies = append(associatedPolicies, + awsAssociatedAccessPolicy(policy, cluster, arn, a.AccountID), + ) + } } - } return associatedPolicies, trace.NewAggregate(errs...) } // awsAssociatedAccessPolicy converts an eks.AssociatedAccessPolicy to accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1 -func awsAssociatedAccessPolicy(policy *eks.AssociatedAccessPolicy, cluster *accessgraphv1alpha.AWSEKSClusterV1, principalARN, accountID string) *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1 { +func awsAssociatedAccessPolicy(policy ekstypes.AssociatedAccessPolicy, cluster *accessgraphv1alpha.AWSEKSClusterV1, principalARN, accountID string) *accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1 { var accessScope *accessgraphv1alpha.AWSEKSAccessScopeV1 if policy.AccessScope != nil { accessScope = &accessgraphv1alpha.AWSEKSAccessScopeV1{ - Namespaces: aws.StringValueSlice(policy.AccessScope.Namespaces), - Type: aws.StringValue(policy.AccessScope.Type), + Namespaces: policy.AccessScope.Namespaces, + Type: string(policy.AccessScope.Type), } } - out := &accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1{ + + return &accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1{ Cluster: cluster, AssociatedAt: awsTimeToProtoTime(policy.AssociatedAt), ModifiedAt: awsTimeToProtoTime(policy.ModifiedAt), PrincipalArn: principalARN, - PolicyArn: aws.StringValue(policy.PolicyArn), + PolicyArn: aws.ToString(policy.PolicyArn), Scope: accessScope, AccountId: accountID, LastSyncTime: timestamppb.Now(), } - - return out } diff --git a/lib/srv/discovery/fetchers/aws-sync/eks_test.go b/lib/srv/discovery/fetchers/aws-sync/eks_test.go index 9c6c395018d95..b38f1ff851a92 100644 --- a/lib/srv/discovery/fetchers/aws-sync/eks_test.go +++ b/lib/srv/discovery/fetchers/aws-sync/eks_test.go @@ -24,8 +24,9 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/eks" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" @@ -33,23 +34,82 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" - "github.com/gravitational/teleport/lib/cloud" - "github.com/gravitational/teleport/lib/cloud/mocks" + "github.com/gravitational/teleport/lib/cloud/awsconfig" ) -var ( - date = time.Date(2024, 03, 12, 0, 0, 0, 0, time.UTC) +var date = time.Date(2024, 0o3, 12, 0, 0, 0, 0, time.UTC) + +const ( principalARN = "arn:iam:teleport" accessEntryARN = "arn:iam:access_entry" ) +type mockedEKSClient struct { + clusters []*ekstypes.Cluster + accessEntries []*ekstypes.AccessEntry + associatedAccessPolicies []ekstypes.AssociatedAccessPolicy +} + +func (m *mockedEKSClient) DescribeCluster(ctx context.Context, input *eks.DescribeClusterInput, optFns ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) { + for _, cluster := range m.clusters { + if aws.ToString(cluster.Name) == aws.ToString(input.Name) { + return &eks.DescribeClusterOutput{ + Cluster: cluster, + }, nil + } + } + return nil, nil +} + +func (m *mockedEKSClient) ListClusters(ctx context.Context, input *eks.ListClustersInput, optFns ...func(*eks.Options)) (*eks.ListClustersOutput, error) { + clusterNames := make([]string, 0, len(m.clusters)) + for _, cluster := range m.clusters { + clusterNames = append(clusterNames, aws.ToString(cluster.Name)) + } + return &eks.ListClustersOutput{ + Clusters: clusterNames, + }, nil +} + +func (m *mockedEKSClient) ListAccessEntries(ctx context.Context, input *eks.ListAccessEntriesInput, optFns ...func(*eks.Options)) (*eks.ListAccessEntriesOutput, error) { + accessEntries := make([]string, 0, len(m.accessEntries)) + for _, accessEntry := range m.accessEntries { + accessEntries = append(accessEntries, aws.ToString(accessEntry.AccessEntryArn)) + } + return &eks.ListAccessEntriesOutput{ + AccessEntries: accessEntries, + }, nil +} + +func (m *mockedEKSClient) ListAssociatedAccessPolicies(ctx context.Context, input *eks.ListAssociatedAccessPoliciesInput, optFns ...func(*eks.Options)) (*eks.ListAssociatedAccessPoliciesOutput, error) { + return &eks.ListAssociatedAccessPoliciesOutput{ + AssociatedAccessPolicies: m.associatedAccessPolicies, + }, nil +} + +func (m *mockedEKSClient) DescribeAccessEntry(ctx context.Context, input *eks.DescribeAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DescribeAccessEntryOutput, error) { + return &eks.DescribeAccessEntryOutput{ + AccessEntry: &ekstypes.AccessEntry{ + PrincipalArn: aws.String(principalARN), + AccessEntryArn: aws.String(accessEntryARN), + CreatedAt: aws.Time(date), + ModifiedAt: aws.Time(date), + ClusterName: aws.String("cluster1"), + Tags: map[string]string{ + "t1": "t2", + }, + Type: aws.String(string(ekstypes.AccessScopeTypeCluster)), + Username: aws.String("teleport"), + KubernetesGroups: []string{"teleport"}, + }, + }, nil +} + func TestPollAWSEKSClusters(t *testing.T) { const ( accountID = "12345678" ) - var ( - regions = []string{"eu-west-1"} - ) + regions := []string{"eu-west-1"} cluster := &accessgraphv1alpha.AWSEKSClusterV1{ Name: "cluster1", Arn: "arn:us-west1:eks:cluster1", @@ -58,7 +118,7 @@ func TestPollAWSEKSClusters(t *testing.T) { Tags: []*accessgraphv1alpha.AWSTag{ { Key: "tag1", - Value: nil, + Value: wrapperspb.String(""), }, { Key: "tag2", @@ -102,7 +162,7 @@ func TestPollAWSEKSClusters(t *testing.T) { Cluster: cluster, PrincipalArn: principalARN, Scope: &accessgraphv1alpha.AWSEKSAccessScopeV1{ - Type: eks.AccessScopeTypeCluster, + Type: string(ekstypes.AccessScopeTypeCluster), Namespaces: []string{"ns1"}, }, AssociatedAt: timestamppb.New(date), @@ -116,12 +176,14 @@ func TestPollAWSEKSClusters(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockedClients := &cloud.TestCloudClients{ - EKS: &mocks.EKSMock{ - Clusters: eksClusters(), - AccessEntries: accessEntries(), - AssociatedPolicies: associatedPolicies(), - }, + t.Parallel() + + getEKSClient := func(_ context.Context, _ string, _ ...awsconfig.OptionsFn) (EKSClient, error) { + return &mockedEKSClient{ + clusters: eksClusters(), + accessEntries: accessEntries(), + associatedAccessPolicies: associatedPolicies(), + }, nil } var ( @@ -137,20 +199,21 @@ func TestPollAWSEKSClusters(t *testing.T) { a := &awsFetcher{ Config: Config{ AccountID: accountID, - CloudClients: mockedClients, Regions: regions, Integration: accountID, + GetEKSClient: getEKSClient, }, lastResult: &Resources{}, } - result := &Resources{} - execFunc := a.pollAWSEKSClusters(context.Background(), result, collectErr) + + var result Resources + execFunc := a.pollAWSEKSClusters(context.Background(), &result, collectErr) require.NoError(t, execFunc()) require.Empty(t, cmp.Diff( tt.want, - result, + &result, protocmp.Transform(), - // tags originate from a map so we must sort them before comparing. + // Tags originate from a map so we must sort them before comparing. protocmp.SortRepeated( func(a, b *accessgraphv1alpha.AWSTag) bool { return a.Key < b.Key @@ -159,52 +222,50 @@ func TestPollAWSEKSClusters(t *testing.T) { protocmp.IgnoreFields(&accessgraphv1alpha.AWSEKSClusterV1{}, "last_sync_time"), protocmp.IgnoreFields(&accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1{}, "last_sync_time"), protocmp.IgnoreFields(&accessgraphv1alpha.AWSEKSClusterAccessEntryV1{}, "last_sync_time"), - ), - ) - + )) }) } } -func eksClusters() []*eks.Cluster { - return []*eks.Cluster{ +func eksClusters() []*ekstypes.Cluster { + return []*ekstypes.Cluster{ { Name: aws.String("cluster1"), Arn: aws.String("arn:us-west1:eks:cluster1"), CreatedAt: aws.Time(date), - Status: aws.String(eks.AddonStatusActive), - Tags: map[string]*string{ - "tag1": nil, - "tag2": aws.String("val2"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "tag1": "", + "tag2": "val2", }, }, } } -func accessEntries() []*eks.AccessEntry { - return []*eks.AccessEntry{ +func accessEntries() []*ekstypes.AccessEntry { + return []*ekstypes.AccessEntry{ { PrincipalArn: aws.String(principalARN), AccessEntryArn: aws.String(accessEntryARN), CreatedAt: aws.Time(date), ModifiedAt: aws.Time(date), ClusterName: aws.String("cluster1"), - Tags: map[string]*string{ - "t1": aws.String("t2"), + Tags: map[string]string{ + "t1": "t2", }, - Type: aws.String(eks.AccessScopeTypeCluster), + Type: aws.String(string(ekstypes.AccessScopeTypeCluster)), Username: aws.String("teleport"), - KubernetesGroups: []*string{aws.String("teleport")}, + KubernetesGroups: []string{"teleport"}, }, } } -func associatedPolicies() []*eks.AssociatedAccessPolicy { - return []*eks.AssociatedAccessPolicy{ +func associatedPolicies() []ekstypes.AssociatedAccessPolicy { + return []ekstypes.AssociatedAccessPolicy{ { - AccessScope: &eks.AccessScope{ - Namespaces: []*string{aws.String("ns1")}, - Type: aws.String(eks.AccessScopeTypeCluster), + AccessScope: &ekstypes.AccessScope{ + Namespaces: []string{"ns1"}, + Type: ekstypes.AccessScopeTypeCluster, }, ModifiedAt: aws.Time(date), AssociatedAt: aws.Time(date), diff --git a/lib/srv/discovery/fetchers/eks.go b/lib/srv/discovery/fetchers/eks.go index 193244bba75e3..27dcbdd2d83fd 100644 --- a/lib/srv/discovery/fetchers/eks.go +++ b/lib/srv/discovery/fetchers/eks.go @@ -29,13 +29,12 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/eks" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "golang.org/x/sync/errgroup" @@ -48,8 +47,8 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/fixtures" kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/services" @@ -63,24 +62,48 @@ const ( type eksFetcher struct { EKSFetcherConfig - mu sync.Mutex - client eksiface.EKSAPI - stsClient stsiface.STSAPI - callerIdentity string + mu sync.Mutex + client EKSClient + stsPresignClient STSPresignClient + callerIdentity string } -// ClientGetter is an interface for getting an EKS client and an STS client. -type ClientGetter interface { - // GetAWSEKSClient returns AWS EKS client for the specified region. - GetAWSEKSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (eksiface.EKSAPI, error) - // GetAWSSTSClient returns AWS STS client for the specified region. - GetAWSSTSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (stsiface.STSAPI, error) +// EKSClient is the subset of the EKS interface we use in fetchers. +type EKSClient interface { + eks.DescribeClusterAPIClient + eks.ListClustersAPIClient + + AssociateAccessPolicy(ctx context.Context, params *eks.AssociateAccessPolicyInput, optFns ...func(*eks.Options)) (*eks.AssociateAccessPolicyOutput, error) + CreateAccessEntry(ctx context.Context, params *eks.CreateAccessEntryInput, optFns ...func(*eks.Options)) (*eks.CreateAccessEntryOutput, error) + DeleteAccessEntry(ctx context.Context, params *eks.DeleteAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DeleteAccessEntryOutput, error) + DescribeAccessEntry(ctx context.Context, params *eks.DescribeAccessEntryInput, optFns ...func(*eks.Options)) (*eks.DescribeAccessEntryOutput, error) + UpdateAccessEntry(ctx context.Context, params *eks.UpdateAccessEntryInput, optFns ...func(*eks.Options)) (*eks.UpdateAccessEntryOutput, error) +} + +// STSClient is the subset of the STS interface we use in fetchers. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + stscreds.AssumeRoleAPIClient +} + +// STSPresignClient is the subset of the STS presign interface we use in fetchers. +type STSPresignClient = kubeutils.STSPresignClient + +// AWSClientGetter is an interface for getting an EKS client and an STS client. +type AWSClientGetter interface { + awsconfig.Provider + // GetAWSEKSClient returns AWS EKS client for the specified config. + GetAWSEKSClient(aws.Config) EKSClient + // GetAWSSTSClient returns AWS STS client for the specified config. + GetAWSSTSClient(aws.Config) STSClient + // GetAWSSTSPresignClient returns AWS STS presign client for the specified config. + GetAWSSTSPresignClient(aws.Config) STSPresignClient } // EKSFetcherConfig configures the EKS fetcher. type EKSFetcherConfig struct { // ClientGetter retrieves an EKS client and an STS client. - ClientGetter ClientGetter + ClientGetter AWSClientGetter // AssumeRole provides a role ARN and ExternalID to assume an AWS role // when fetching clusters. AssumeRole types.AssumeRole @@ -133,7 +156,7 @@ func (c *EKSFetcherConfig) CheckAndSetDefaults() error { // MakeEKSFetchersFromAWSMatchers creates fetchers from the provided matchers. Returned fetchers are separated // by their reliance on the integration. -func MakeEKSFetchersFromAWSMatchers(logger *slog.Logger, clients cloud.AWSClients, matchers []types.AWSMatcher, discoveryConfigName string) (kubeFetchers []common.Fetcher, _ error) { +func MakeEKSFetchersFromAWSMatchers(logger *slog.Logger, clients AWSClientGetter, matchers []types.AWSMatcher, discoveryConfigName string) (kubeFetchers []common.Fetcher, _ error) { for _, matcher := range matchers { var matcherAssumeRole types.AssumeRole if matcher.AssumeRole != nil { @@ -162,7 +185,8 @@ func MakeEKSFetchersFromAWSMatchers(logger *slog.Logger, clients cloud.AWSClient "error", err, "region", region, "labels", matcher.Tags, - "assume_role", matcherAssumeRole.RoleARN) + "assume_role", matcherAssumeRole.RoleARN, + ) continue } kubeFetchers = append(kubeFetchers, fetcher) @@ -197,7 +221,7 @@ func NewEKSFetcher(cfg EKSFetcherConfig) (common.Fetcher, error) { return fetcher, nil } -func (a *eksFetcher) getClient(ctx context.Context) (eksiface.EKSAPI, error) { +func (a *eksFetcher) getClient(ctx context.Context) (EKSClient, error) { a.mu.Lock() defer a.mu.Unlock() @@ -205,16 +229,12 @@ func (a *eksFetcher) getClient(ctx context.Context) (eksiface.EKSAPI, error) { return a.client, nil } - client, err := a.ClientGetter.GetAWSEKSClient( - ctx, - a.Region, - a.getAWSOpts()..., - ) + cfg, err := a.ClientGetter.GetConfig(ctx, a.Region, a.getAWSOpts()...) if err != nil { return nil, trace.Wrap(err) } - a.client = client + a.client = a.ClientGetter.GetAWSEKSClient(cfg) return a.client, nil } @@ -280,39 +300,38 @@ func (a *eksFetcher) getEKSClusters(ctx context.Context) (types.KubeClusters, er return nil, trace.Wrap(err, "failed getting AWS EKS client") } - err = client.ListClustersPagesWithContext(ctx, - &eks.ListClustersInput{ - Include: nil, // For now we should only list EKS clusters - }, - func(clustersList *eks.ListClustersOutput, _ bool) bool { - for i := 0; i < len(clustersList.Clusters); i++ { - clusterName := aws.StringValue(clustersList.Clusters[i]) - // group.Go will block if the concurrency limit is reached. - // It will resume once any running function finishes. - group.Go(func() error { - cluster, err := a.getMatchingKubeCluster(groupCtx, clusterName) - // trace.CompareFailed is returned if the cluster did not match the matcher filtering labels - // or if the cluster is not yet active. - if trace.IsCompareFailed(err) { - a.Logger.DebugContext(groupCtx, "Cluster did not match the filtering criteria", "error", err, "cluster", clusterName) - // never return an error otherwise we will impact discovery process - return nil - } else if err != nil { - a.Logger.WarnContext(groupCtx, "Failed to discover EKS cluster", "error", err, "cluster", clusterName) - // never return an error otherwise we will impact discovery process - return nil - } - - mu.Lock() - defer mu.Unlock() - clusters = append(clusters, cluster) + // For now we should only list EKS clusters so we use nil (default) input param. + for p := eks.NewListClustersPaginator(client, nil); p.HasMorePages(); { + out, err := p.NextPage(ctx) + if err != nil { + return clusters, trace.Wrap(err) + } + for _, clusterName := range out.Clusters { + // group.Go will block if the concurrency limit is reached. + // It will resume once any running function finishes. + group.Go(func() error { + cluster, err := a.getMatchingKubeCluster(groupCtx, clusterName) + // trace.CompareFailed is returned if the cluster did not match the matcher filtering labels + // or if the cluster is not yet active. + if trace.IsCompareFailed(err) { + a.Logger.DebugContext(groupCtx, "Cluster did not match the filtering criteria", "error", err, "cluster", clusterName) + // never return an error otherwise we will impact discovery process return nil - }) - } - return true - }, - ) - // error can be discarded since we do not return any error from group.Go closure. + } else if err != nil { + a.Logger.WarnContext(groupCtx, "Failed to discover EKS cluster", "error", err, "cluster", clusterName) + // never return an error otherwise we will impact discovery process + return nil + } + + mu.Lock() + defer mu.Unlock() + clusters = append(clusters, cluster) + return nil + }) + } + } + + // The error can be discarded since we do not return any error from group.Go closure. _ = group.Wait() return clusters, trace.Wrap(err) } @@ -352,7 +371,7 @@ func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName str return nil, trace.Wrap(err, "failed getting AWS EKS client") } - rsp, err := client.DescribeClusterWithContext( + rsp, err := client.DescribeCluster( ctx, &eks.DescribeClusterInput{ Name: aws.String(clusterName), @@ -362,14 +381,14 @@ func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName str return nil, trace.WrapWithMessage(err, "Unable to describe EKS cluster %q", clusterName) } - switch st := aws.StringValue(rsp.Cluster.Status); st { - case eks.ClusterStatusUpdating, eks.ClusterStatusActive: + switch st := rsp.Cluster.Status; st { + case ekstypes.ClusterStatusUpdating, ekstypes.ClusterStatusActive: a.Logger.DebugContext(ctx, "EKS cluster status is valid", "status", st, "cluster", clusterName) default: return nil, trace.CompareFailed("EKS cluster %q not enrolled due to its current status: %s", clusterName, st) } - cluster, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(rsp.Cluster.Name), aws.StringValue(rsp.Cluster.Arn), rsp.Cluster.Tags) + cluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(rsp.Cluster.Name), aws.ToString(rsp.Cluster.Arn), rsp.Cluster.Tags) if err != nil { return nil, trace.WrapWithMessage(err, "Unable to convert eks.Cluster cluster into types.KubernetesClusterV3.") } @@ -388,8 +407,8 @@ func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName str // If the fetcher should setup access for the specified ARN, first check if the cluster authentication mode // is set to either [eks.AuthenticationModeApi] or [eks.AuthenticationModeApiAndConfigMap]. // If the authentication mode is set to [eks.AuthenticationModeConfigMap], the fetcher will ignore the cluster. - switch st := aws.StringValue(rsp.Cluster.AccessConfig.AuthenticationMode); st { - case eks.AuthenticationModeApiAndConfigMap, eks.AuthenticationModeApi: + switch st := rsp.Cluster.AccessConfig.AuthenticationMode; st { + case ekstypes.AuthenticationModeApiAndConfigMap, ekstypes.AuthenticationModeApi: if err := a.checkOrSetupAccessForARN(ctx, client, rsp.Cluster); err != nil { return nil, trace.Wrap(err, "unable to setup access for EKS cluster %q", clusterName) } @@ -427,9 +446,9 @@ var eksDiscoveryPermissions = []string{ // The check involves checking if the access entry exists and if the "teleport:kube-agent:eks" is part of the Kubernetes group. // If the access entry doesn't exist or is misconfigured, the fetcher will temporarily gain admin access and create the role and binding. // The fetcher will then upsert the access entry with the correct Kubernetes group. -func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksiface.EKSAPI, cluster *eks.Cluster) error { +func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client EKSClient, cluster *ekstypes.Cluster) error { entry, err := convertAWSError( - client.DescribeAccessEntryWithContext(ctx, + client.DescribeAccessEntry(ctx, &eks.DescribeAccessEntryInput{ ClusterName: cluster.Name, PrincipalArn: aws.String(a.SetupAccessForARN), @@ -442,13 +461,13 @@ func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksifa // Access denied means that the principal does not have access to setup access entries for the cluster. a.Logger.WarnContext(ctx, "Access denied to setup access for EKS cluster, ensure the required permissions are set", "error", err, - "cluster", aws.StringValue(cluster.Name), + "cluster", aws.ToString(cluster.Name), "required_permissions", eksDiscoveryPermissions, ) return nil case err == nil: // If the access entry exists and the principal has access to the cluster, check if the teleportKubernetesGroup is part of the Kubernetes group. - if entry.AccessEntry != nil && slices.Contains(aws.StringValueSlice(entry.AccessEntry.KubernetesGroups), teleportKubernetesGroup) { + if entry.AccessEntry != nil && slices.Contains(entry.AccessEntry.KubernetesGroups, teleportKubernetesGroup) { return nil } fallthrough @@ -459,12 +478,12 @@ func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksifa // Access denied means that the principal does not have access to setup access entries for the cluster. a.Logger.WarnContext(ctx, "Access denied to setup access for EKS cluster, ensure the required permissions are set", "error", err, - "cluster", aws.StringValue(cluster.Name), + "cluster", aws.ToString(cluster.Name), "required_permissions", eksDiscoveryPermissions, ) return nil } else if err != nil { - return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.StringValue(cluster.Name)) + return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.ToString(cluster.Name)) } // upsert the access entry with the correct Kubernetes group for the final @@ -473,29 +492,29 @@ func (a *eksFetcher) checkOrSetupAccessForARN(ctx context.Context, client eksifa // Access denied means that the principal does not have access to setup access entries for the cluster. a.Logger.WarnContext(ctx, "Access denied to setup access for EKS cluster, ensure the required permissions are set", "error", err, - "cluster", aws.StringValue(cluster.Name), + "cluster", aws.ToString(cluster.Name), "required_permissions", eksDiscoveryPermissions, ) return nil } - return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.StringValue(cluster.Name)) + return trace.Wrap(err, "unable to setup access for EKS cluster %q", aws.ToString(cluster.Name)) default: return trace.Wrap(err) } - } // temporarilyGainAdminAccessAndCreateRole temporarily gains admin access to the EKS cluster by associating the EKS Cluster Admin Policy // to the callerIdentity. The fetcher will then create the role and binding for the teleportKubernetesGroup in the EKS cluster. -func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context, client eksiface.EKSAPI, cluster *eks.Cluster) error { +func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context, client EKSClient, cluster *ekstypes.Cluster) error { const ( // https://docs.aws.amazon.com/eks/latest/userguide/access-policies.html // We use cluster admin policy to create namespace and cluster role. eksClusterAdminPolicy = "arn:aws:eks::aws:cluster-access-policy/AmazonEKSClusterAdminPolicy" ) + // Setup access for the ARN rsp, err := convertAWSError( - client.CreateAccessEntryWithContext(ctx, + client.CreateAccessEntry(ctx, &eks.CreateAccessEntryInput{ ClusterName: cluster.Name, PrincipalArn: aws.String(a.callerIdentity), @@ -510,7 +529,7 @@ func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context if rsp != nil { defer func() { _, err := convertAWSError( - client.DeleteAccessEntryWithContext( + client.DeleteAccessEntry( ctx, &eks.DeleteAccessEntryInput{ ClusterName: cluster.Name, @@ -520,18 +539,17 @@ func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context if err != nil { a.Logger.WarnContext(ctx, "Failed to delete access entry for EKS cluster", "error", err, - "cluster", aws.StringValue(cluster.Name), + "cluster", aws.ToString(cluster.Name), ) } }() - } _, err = convertAWSError( - client.AssociateAccessPolicyWithContext(ctx, &eks.AssociateAccessPolicyInput{ - AccessScope: &eks.AccessScope{ + client.AssociateAccessPolicy(ctx, &eks.AssociateAccessPolicyInput{ + AccessScope: &ekstypes.AccessScope{ Namespaces: nil, - Type: aws.String(eks.AccessScopeTypeCluster), + Type: ekstypes.AccessScopeTypeCluster, }, ClusterName: cluster.Name, PolicyArn: aws.String(eksClusterAdminPolicy), @@ -539,7 +557,7 @@ func (a *eksFetcher) temporarilyGainAdminAccessAndCreateRole(ctx context.Context }), ) if err != nil && !trace.IsAlreadyExists(err) { - return trace.Wrap(err, "unable to associate EKS Access Policy to cluster %q", aws.StringValue(cluster.Name)) + return trace.Wrap(err, "unable to associate EKS Access Policy to cluster %q", aws.ToString(cluster.Name)) } timeout := a.Clock.NewTimer(60 * time.Second) @@ -561,17 +579,19 @@ forLoop: } } - return trace.Wrap(err, "unable to upsert role and binding for cluster %q", aws.StringValue(cluster.Name)) + return trace.Wrap(err, "unable to upsert role and binding for cluster %q", aws.ToString(cluster.Name)) } // upsertRoleAndBinding upserts the ClusterRole and ClusterRoleBinding for the teleportKubernetesGroup in the EKS cluster. -func (a *eksFetcher) upsertRoleAndBinding(ctx context.Context, cluster *eks.Cluster) error { - client, err := a.createKubeClient(cluster) +func (a *eksFetcher) upsertRoleAndBinding(ctx context.Context, cluster *ekstypes.Cluster) error { + client, err := a.createKubeClient(ctx, cluster) if err != nil { - return trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.StringValue(cluster.Name)) + return trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.ToString(cluster.Name)) } + ctx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() + if err := a.upsertClusterRoleWithAdminCredentials(ctx, client); err != nil { return trace.Wrap(err, "unable to upsert ClusterRole for group %q", teleportKubernetesGroup) } @@ -583,23 +603,23 @@ func (a *eksFetcher) upsertRoleAndBinding(ctx context.Context, cluster *eks.Clus return nil } -func (a *eksFetcher) createKubeClient(cluster *eks.Cluster) (*kubernetes.Clientset, error) { - if a.stsClient == nil { - return nil, trace.BadParameter("STS client is not set") +func (a *eksFetcher) createKubeClient(ctx context.Context, cluster *ekstypes.Cluster) (*kubernetes.Clientset, error) { + if a.stsPresignClient == nil { + return nil, trace.BadParameter("STS presign client is not set") } - token, _, err := kubeutils.GenAWSEKSToken(a.stsClient, aws.StringValue(cluster.Name), a.Clock) + token, _, err := kubeutils.GenAWSEKSToken(ctx, a.stsPresignClient, aws.ToString(cluster.Name), a.Clock) if err != nil { - return nil, trace.Wrap(err, "unable to generate EKS token for cluster %q", aws.StringValue(cluster.Name)) + return nil, trace.Wrap(err, "unable to generate EKS token for cluster %q", aws.ToString(cluster.Name)) } - ca, err := base64.StdEncoding.DecodeString(aws.StringValue(cluster.CertificateAuthority.Data)) + ca, err := base64.StdEncoding.DecodeString(aws.ToString(cluster.CertificateAuthority.Data)) if err != nil { - return nil, trace.Wrap(err, "unable to decode EKS cluster %q certificate authority", aws.StringValue(cluster.Name)) + return nil, trace.Wrap(err, "unable to decode EKS cluster %q certificate authority", aws.ToString(cluster.Name)) } - apiEndpoint := aws.StringValue(cluster.Endpoint) + apiEndpoint := aws.ToString(cluster.Endpoint) if len(apiEndpoint) == 0 { - return nil, trace.BadParameter("invalid api endpoint for cluster %q", aws.StringValue(cluster.Name)) + return nil, trace.BadParameter("invalid api endpoint for cluster %q", aws.ToString(cluster.Name)) } client, err := kubernetes.NewForConfig( @@ -611,7 +631,7 @@ func (a *eksFetcher) createKubeClient(cluster *eks.Cluster) (*kubernetes.Clients }, }, ) - return client, trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.StringValue(cluster.Name)) + return client, trace.Wrap(err, "unable to create Kubernetes client for cluster %q", aws.ToString(cluster.Name)) } // upsertClusterRoleWithAdminCredentials tries to upsert the ClusterRole using admin credentials. @@ -664,13 +684,13 @@ func (a *eksFetcher) upsertClusterRoleBindingWithAdminCredentials(ctx context.Co } // upsertAccessEntry upserts the access entry for the specified ARN with the teleportKubernetesGroup. -func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client eksiface.EKSAPI, cluster *eks.Cluster) error { +func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client EKSClient, cluster *ekstypes.Cluster) error { _, err := convertAWSError( - client.CreateAccessEntryWithContext(ctx, + client.CreateAccessEntry(ctx, &eks.CreateAccessEntryInput{ ClusterName: cluster.Name, PrincipalArn: aws.String(a.SetupAccessForARN), - KubernetesGroups: aws.StringSlice([]string{teleportKubernetesGroup}), + KubernetesGroups: []string{teleportKubernetesGroup}, }, )) if err == nil || !trace.IsAlreadyExists(err) { @@ -678,11 +698,11 @@ func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client eksiface.EKSA } _, err = convertAWSError( - client.UpdateAccessEntryWithContext(ctx, + client.UpdateAccessEntry(ctx, &eks.UpdateAccessEntryInput{ ClusterName: cluster.Name, PrincipalArn: aws.String(a.SetupAccessForARN), - KubernetesGroups: aws.StringSlice([]string{teleportKubernetesGroup}), + KubernetesGroups: []string{teleportKubernetesGroup}, }, )) @@ -690,35 +710,35 @@ func (a *eksFetcher) upsertAccessEntry(ctx context.Context, client eksiface.EKSA } func (a *eksFetcher) setCallerIdentity(ctx context.Context) error { - var err error - a.stsClient, err = a.ClientGetter.GetAWSSTSClient( - ctx, + cfg, err := a.ClientGetter.GetConfig(ctx, a.Region, a.getAWSOpts()..., ) if err != nil { return trace.Wrap(err) } - + a.stsPresignClient = a.ClientGetter.GetAWSSTSPresignClient(cfg) if a.AssumeRole.RoleARN != "" { a.callerIdentity = a.AssumeRole.RoleARN return nil } - identity, err := a.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + + stsClient := a.ClientGetter.GetAWSSTSClient(cfg) + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { return trace.Wrap(err) } - a.callerIdentity = convertAssumedRoleToIAMRole(aws.StringValue(identity.Arn)) + a.callerIdentity = convertAssumedRoleToIAMRole(aws.ToString(identity.Arn)) return nil } -func (a *eksFetcher) getAWSOpts() []cloud.AWSOptionsFn { - return []cloud.AWSOptionsFn{ - cloud.WithAssumeRole( +func (a *eksFetcher) getAWSOpts() []awsconfig.OptionsFn { + return []awsconfig.OptionsFn{ + awsconfig.WithAssumeRole( a.AssumeRole.RoleARN, a.AssumeRole.ExternalID, ), - cloud.WithCredentialsMaybeIntegration(a.Integration), + awsconfig.WithCredentialsMaybeIntegration(a.Integration), } } @@ -734,6 +754,7 @@ func convertAssumedRoleToIAMRole(callerIdentity string) string { const ( assumeRolePrefix = "assumed-role/" roleResource = "role" + serviceName = "iam" ) a, err := arn.Parse(callerIdentity) if err != nil { @@ -742,7 +763,7 @@ func convertAssumedRoleToIAMRole(callerIdentity string) string { if !strings.HasPrefix(a.Resource, assumeRolePrefix) { return callerIdentity } - a.Service = iam.ServiceName + a.Service = serviceName split := strings.Split(a.Resource, "/") if len(split) <= 2 { return callerIdentity diff --git a/lib/srv/discovery/fetchers/eks_test.go b/lib/srv/discovery/fetchers/eks_test.go index d7b9c6b4cac47..ad8c8667d2862 100644 --- a/lib/srv/discovery/fetchers/eks_test.go +++ b/lib/srv/discovery/fetchers/eks_test.go @@ -23,16 +23,16 @@ import ( "errors" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/service/eks" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/mocks" + kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/srv/discovery/common" "github.com/gravitational/teleport/lib/utils" ) @@ -43,9 +43,10 @@ func TestEKSFetcher(t *testing.T) { filterLabels types.Labels } tests := []struct { - name string - args args - want types.ResourcesWithLabels + name string + args args + assumeRole types.AssumeRole + want types.ResourcesWithLabels }{ { name: "list everything", @@ -57,6 +58,17 @@ func TestEKSFetcher(t *testing.T) { }, want: eksClustersToResources(t, eksMockClusters...), }, + { + name: "list everything with assumed role", + args: args{ + region: types.Wildcard, + filterLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + assumeRole: types.AssumeRole{RoleARN: "arn:aws:iam::123456789012:role/test-role", ExternalID: "extID123"}, + want: eksClustersToResources(t, eksMockClusters...), + }, { name: "list prod clusters", args: args{ @@ -88,7 +100,6 @@ func TestEKSFetcher(t *testing.T) { }, want: eksClustersToResources(t), }, - { name: "list everything with specified values", args: args{ @@ -102,14 +113,24 @@ func TestEKSFetcher(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + stsClt := &mocks.STSClient{} cfg := EKSFetcherConfig{ - ClientGetter: &mockEKSClientGetter{}, + ClientGetter: &mockEKSClientGetter{ + AWSConfigProvider: mocks.AWSConfigProvider{ + STSClient: stsClt, + }, + }, + AssumeRole: tt.assumeRole, FilterLabels: tt.args.filterLabels, Region: tt.args.region, Logger: utils.NewSlogLoggerForTests(), } fetcher, err := NewEKSFetcher(cfg) require.NoError(t, err) + if tt.assumeRole.RoleARN != "" { + require.Contains(t, stsClt.GetAssumedRoleARNs(), tt.assumeRole.RoleARN) + stsClt.ResetAssumeRoleHistory() + } resources, err := fetcher.Get(context.Background()) require.NoError(t, err) @@ -123,54 +144,68 @@ func TestEKSFetcher(t *testing.T) { } require.Equal(t, tt.want.ToMap(), clusters.ToMap()) + if tt.assumeRole.RoleARN != "" { + require.Contains(t, stsClt.GetAssumedRoleARNs(), tt.assumeRole.RoleARN) + } }) } } -type mockEKSClientGetter struct{} +type mockEKSClientGetter struct { + mocks.AWSConfigProvider +} + +func (e *mockEKSClientGetter) GetAWSEKSClient(cfg aws.Config) EKSClient { + return newPopulatedEKSMock() +} -func (e *mockEKSClientGetter) GetAWSEKSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (eksiface.EKSAPI, error) { - return newPopulatedEKSMock(), nil +func (e *mockEKSClientGetter) GetAWSSTSClient(aws.Config) STSClient { + return &mockSTSAPI{} } -func (e *mockEKSClientGetter) GetAWSSTSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (stsiface.STSAPI, error) { - return &mockSTSAPI{}, nil +func (e *mockEKSClientGetter) GetAWSSTSPresignClient(aws.Config) kubeutils.STSPresignClient { + return &mockSTSPresignAPI{} +} + +type mockSTSPresignAPI struct{} + +func (a *mockSTSPresignAPI) PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) { + panic("not implemented") } type mockSTSAPI struct { - stsiface.STSAPI arn string } -func (a *mockSTSAPI) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) { +func (a *mockSTSAPI) GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return &sts.GetCallerIdentityOutput{ Arn: aws.String(a.arn), }, nil } +func (a *mockSTSAPI) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + panic("not implemented") +} + type mockEKSAPI struct { - eksiface.EKSAPI - clusters []*eks.Cluster + EKSClient + + clusters []*ekstypes.Cluster } -func (m *mockEKSAPI) ListClustersPagesWithContext(ctx aws.Context, req *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error { - var names []*string +func (m *mockEKSAPI) ListClusters(ctx context.Context, req *eks.ListClustersInput, _ ...func(*eks.Options)) (*eks.ListClustersOutput, error) { + var names []string for _, cluster := range m.clusters { - names = append(names, cluster.Name) + names = append(names, aws.ToString(cluster.Name)) } - f(&eks.ListClustersOutput{ - Clusters: names[:len(names)/2], - }, false) - - f(&eks.ListClustersOutput{ - Clusters: names[len(names)/2:], - }, true) - return nil + return &eks.ListClustersOutput{ + Clusters: names, + }, nil } -func (m *mockEKSAPI) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) { +func (m *mockEKSAPI) DescribeCluster(_ context.Context, req *eks.DescribeClusterInput, _ ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) { for _, cluster := range m.clusters { - if aws.StringValue(cluster.Name) == aws.StringValue(req.Name) { + if aws.ToString(cluster.Name) == aws.ToString(req.Name) { return &eks.DescribeClusterOutput{ Cluster: cluster, }, nil @@ -185,51 +220,50 @@ func newPopulatedEKSMock() *mockEKSAPI { } } -var eksMockClusters = []*eks.Cluster{ - +var eksMockClusters = []*ekstypes.Cluster{ { Name: aws.String("cluster1"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("prod"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "prod", + "location": "eu-west-1", }, }, { Name: aws.String("cluster2"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster2"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("prod"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "prod", + "location": "eu-west-1", }, }, { Name: aws.String("cluster3"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster3"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("stg"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "stg", + "location": "eu-west-1", }, }, { Name: aws.String("cluster4"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"), - Status: aws.String(eks.ClusterStatusActive), - Tags: map[string]*string{ - "env": aws.String("stg"), - "location": aws.String("eu-west-1"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "env": "stg", + "location": "eu-west-1", }, }, } -func eksClustersToResources(t *testing.T, clusters ...*eks.Cluster) types.ResourcesWithLabels { +func eksClustersToResources(t *testing.T, clusters ...*ekstypes.Cluster) types.ResourcesWithLabels { var kubeClusters types.KubeClusters for _, cluster := range clusters { - kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.StringValue(cluster.Name), aws.StringValue(cluster.Arn), cluster.Tags) + kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(cluster.Name), aws.ToString(cluster.Arn), cluster.Tags) require.NoError(t, err) require.True(t, kubeCluster.IsAWS()) common.ApplyEKSNameSuffix(kubeCluster) diff --git a/lib/srv/discovery/kube_integration_watcher_test.go b/lib/srv/discovery/kube_integration_watcher_test.go index 423339678ae8d..3c7cbd57731fd 100644 --- a/lib/srv/discovery/kube_integration_watcher_test.go +++ b/lib/srv/discovery/kube_integration_watcher_test.go @@ -26,9 +26,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/eks" - eksTypes "github.com/aws/aws-sdk-go-v2/service/eks/types" + ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/aws/aws-sdk-go-v2/service/sts" - eksV1 "github.com/aws/aws-sdk-go/service/eks" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/assert" @@ -45,7 +44,6 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/services" @@ -56,22 +54,24 @@ import ( func TestServer_getKubeFetchers(t *testing.T) { eks1, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{ - ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, + ClientGetter: &mockFetchersClients{}, FilterLabels: types.Labels{"l1": []string{"v1"}}, Region: "region1", }) require.NoError(t, err) eks2, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{ - ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, + ClientGetter: &mockFetchersClients{}, FilterLabels: types.Labels{"l1": []string{"v1"}}, Region: "region1", - Integration: "aws1"}) + Integration: "aws1", + }) require.NoError(t, err) eks3, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{ - ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, + ClientGetter: &mockFetchersClients{}, FilterLabels: types.Labels{"l1": []string{"v1"}}, Region: "region1", - Integration: "aws1"}) + Integration: "aws1", + }) require.NoError(t, err) aks1, err := fetchers.NewAKSFetcher(fetchers.AKSFetcherConfig{ @@ -139,20 +139,51 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { testCAData = "VGVzdENBREFUQQ==" ) - testEKSClusters := []eksTypes.Cluster{ + // Create and start test auth server. + testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) }) + + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{ + Name: "integration1", + }, &types.AWSOIDCIntegrationSpecV1{ + RoleARN: roleArn, + }) + require.NoError(t, err) + testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{ + proxies: nil, + integrations: map[string]types.Integration{ + awsOIDCIntegration.GetName(): awsOIDCIntegration, + }, + } + + ctx := context.Background() + tlsServer, err := testAuthServer.NewTestTLSServer() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, tlsServer.Close()) }) + _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration) + require.NoError(t, err) + + fakeConfigProvider := mocks.AWSConfigProvider{ + OIDCIntegrationClient: tlsServer.Auth(), + } + + testEKSClusters := []ekstypes.Cluster{ { Name: aws.String("eks-cluster1"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster1"), Tags: map[string]string{"env": "prod", "location": "eu-west-1"}, - CertificateAuthority: &eksTypes.Certificate{Data: aws.String(testCAData)}, - Status: eksTypes.ClusterStatusActive, + CertificateAuthority: &ekstypes.Certificate{Data: aws.String(testCAData)}, + Status: ekstypes.ClusterStatusActive, }, { Name: aws.String("eks-cluster2"), Arn: aws.String("arn:aws:eks:eu-west-1:accountID:cluster/cluster2"), Tags: map[string]string{"env": "prod", "location": "eu-west-1"}, - CertificateAuthority: &eksTypes.Certificate{Data: aws.String(testCAData)}, - Status: eksTypes.ClusterStatusActive, + CertificateAuthority: &ekstypes.Certificate{Data: aws.String(testCAData)}, + Status: ekstypes.ClusterStatusActive, }, } @@ -173,7 +204,7 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { return dc } - clusterFinder := func(clusterName string) *eksTypes.Cluster { + clusterFinder := func(clusterName string) *ekstypes.Cluster { for _, c := range testEKSClusters { if aws.ToString(c.Name) == clusterName { return &c @@ -309,17 +340,9 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - EKS: &mockEKSAPI{ - clusters: eksMockClusters[:2], - }, - } - ctx := context.Background() // Create and start test auth server. testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ @@ -372,7 +395,10 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { discServer, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ - CloudClients: testCloudClients, + AWSFetchersClients: &mockFetchersClients{ + AWSConfigProvider: fakeConfigProvider, + eksClusters: eksMockClusters[:2], + }, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: tc.accessPoint(t, tlsServer.Auth(), authClient), @@ -391,7 +417,7 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { _, err := tlsServer.Auth().DiscoveryConfigs.CreateDiscoveryConfig(ctx, dc) require.NoError(t, err) - // Wait for the DiscoveryConfig to be added to the dynamic fetchers + // Wait for the DiscoveryConfig to be added to the dynamic fetchers. require.Eventually(t, func() bool { discServer.muDynamicKubeFetchers.RLock() defer discServer.muDynamicKubeFetchers.RUnlock() @@ -425,9 +451,9 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { } } -func mustConvertEKSToKubeServerV1(t *testing.T, eksCluster *eksV1.Cluster, resourceID, discoveryGroup string) types.KubeServer { - eksCluster.Tags[types.OriginLabel] = aws.String(types.OriginCloud) - eksCluster.Tags[types.InternalResourceIDLabel] = aws.String(resourceID) +func mustConvertEKSToKubeServerV1(t *testing.T, eksCluster *ekstypes.Cluster, resourceID, _ string) types.KubeServer { + eksCluster.Tags[types.OriginLabel] = types.OriginCloud + eksCluster.Tags[types.InternalResourceIDLabel] = resourceID kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksCluster.Name), aws.ToString(eksCluster.Arn), eksCluster.Tags) assert.NoError(t, err) @@ -440,13 +466,13 @@ func mustConvertEKSToKubeServerV1(t *testing.T, eksCluster *eksV1.Cluster, resou return kubeServer } -func mustConvertEKSToKubeServerV2(t *testing.T, eksCluster *eksTypes.Cluster, resourceID, discoveryGroup string) types.KubeServer { - eksTags := make(map[string]*string, len(eksCluster.Tags)) +func mustConvertEKSToKubeServerV2(t *testing.T, eksCluster *ekstypes.Cluster, resourceID, _ string) types.KubeServer { + eksTags := make(map[string]string, len(eksCluster.Tags)) for k, v := range eksCluster.Tags { - eksTags[k] = aws.String(v) + eksTags[k] = v } - eksTags[types.OriginLabel] = aws.String(types.OriginCloud) - eksTags[types.InternalResourceIDLabel] = aws.String(resourceID) + eksTags[types.OriginLabel] = types.OriginCloud + eksTags[types.InternalResourceIDLabel] = resourceID kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(eksCluster.Name), aws.ToString(eksCluster.Arn), eksTags) assert.NoError(t, err) @@ -476,9 +502,8 @@ func (a *accessPointWrapper) EnrollEKSClusters(ctx context.Context, req *integra } type mockIntegrationsTokenGenerator struct { - proxies []types.Server - integrations map[string]types.Integration - tokenCallsCount int + proxies []types.Server + integrations map[string]types.Integration } // GetIntegration returns the specified integration resources. @@ -497,7 +522,6 @@ func (m *mockIntegrationsTokenGenerator) GetProxies() ([]types.Server, error) { // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action. func (m *mockIntegrationsTokenGenerator) GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) { - m.tokenCallsCount++ return uuid.NewString(), nil } @@ -509,7 +533,7 @@ type mockEnrollEKSClusterClient struct { describeCluster func(context.Context, *eks.DescribeClusterInput, ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) getCallerIdentity func(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) checkAgentAlreadyInstalled func(context.Context, genericclioptions.RESTClientGetter, *slog.Logger) (bool, error) - installKubeAgent func(context.Context, *eksTypes.Cluster, string, string, string, genericclioptions.RESTClientGetter, *slog.Logger, awsoidc.EnrollEKSClustersRequest) error + installKubeAgent func(context.Context, *ekstypes.Cluster, string, string, string, genericclioptions.RESTClientGetter, *slog.Logger, awsoidc.EnrollEKSClustersRequest) error createToken func(context.Context, types.ProvisionToken) error presignGetCallerIdentityURL func(ctx context.Context, clusterName string) (string, error) } @@ -563,7 +587,7 @@ func (m *mockEnrollEKSClusterClient) CheckAgentAlreadyInstalled(ctx context.Cont return false, nil } -func (m *mockEnrollEKSClusterClient) InstallKubeAgent(ctx context.Context, eksCluster *eksTypes.Cluster, proxyAddr, joinToken, resourceId string, kubeconfig genericclioptions.RESTClientGetter, log *slog.Logger, req awsoidc.EnrollEKSClustersRequest) error { +func (m *mockEnrollEKSClusterClient) InstallKubeAgent(ctx context.Context, eksCluster *ekstypes.Cluster, proxyAddr, joinToken, resourceId string, kubeconfig genericclioptions.RESTClientGetter, log *slog.Logger, req awsoidc.EnrollEKSClustersRequest) error { if m.installKubeAgent != nil { return m.installKubeAgent(ctx, eksCluster, proxyAddr, joinToken, resourceId, kubeconfig, log, req) }