diff --git a/lib/events/s3sessions/s3handler.go b/lib/events/s3sessions/s3handler.go index 71edb9a99e6c2..1608f39ef2808 100644 --- a/lib/events/s3sessions/s3handler.go +++ b/lib/events/s3sessions/s3handler.go @@ -41,6 +41,7 @@ import ( "go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws" "github.com/gravitational/teleport" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -419,7 +420,11 @@ func (h *Handler) fromPath(p string) session.ID { // ensureBucket makes sure bucket exists, and if it does not, creates it func (h *Handler) ensureBucket(ctx context.Context) error { - _, err := h.client.HeadBucket(ctx, &s3.HeadBucketInput{ + // Use a short timeout for the HeadBucket call in case it takes too long, in + // #50747 this call would hang. + shortCtx, cancel := context.WithTimeout(ctx, apidefaults.DefaultIOTimeout) + defer cancel() + _, err := h.client.HeadBucket(shortCtx, &s3.HeadBucketInput{ Bucket: aws.String(h.Bucket), }) err = awsutils.ConvertS3Error(err) @@ -430,7 +435,7 @@ func (h *Handler) ensureBucket(ctx context.Context) error { case trace.IsBadParameter(err): return trace.Wrap(err) case !trace.IsNotFound(err): - h.logger.ErrorContext(ctx, "Failed to ensure that S3 bucket exists. S3 session uploads may fail. If you've set up the bucket already and gave Teleport write-only access, feel free to ignore this error.", "bucket", h.Bucket, "error", err) + h.logger.ErrorContext(ctx, "Failed to ensure that S3 bucket exists. This is expected if External Audit Storage is enabled or if Teleport has write-only access to the bucket, otherwise S3 session uploads may fail.", "bucket", h.Bucket, "error", err) return nil } diff --git a/lib/integrations/awsoidc/credprovider/credentialscache.go b/lib/integrations/awsoidc/credprovider/credentialscache.go index 2711d0126b2ce..ad1db35e94bea 100644 --- a/lib/integrations/awsoidc/credprovider/credentialscache.go +++ b/lib/integrations/awsoidc/credprovider/credentialscache.go @@ -62,12 +62,21 @@ type CredentialsCache struct { roleARN arn.ARN integration string - // generateOIDCTokenFn is dynamically set after auth is initialized. - generateOIDCTokenFn GenerateOIDCTokenFn - - // initialized communicates (via closing channel) that generateOIDCTokenFn is set. - initialized chan struct{} - closeInitialized func() + // generateOIDCTokenFn can be dynamically set after creating the credential + // cache, this is a workaround for a dependency cycle where audit storage + // depends on the credential cache, the auth server depends on audit + // storage, and the credential cache depends on the auth server for a + // GenerateOIDCTokenFn. + generateOIDCTokenFn GenerateOIDCTokenFn + generateOIDCTokenFnMu sync.Mutex + // gotGenerateOIDCTokenFn communicates (via closing channel) that + // generateOIDCTokenFn is set. + gotGenerateOIDCTokenFn chan struct{} + closeGotGenerateOIDCTokenFn func() + // allowRetrieveBeforeInit allows the Retrieve method to return an error if + // [gotGenerateOIDCTokenFn] has not been closed yet, instead of waiting for it to be + // closed. + allowRetrieveBeforeInit bool // gotFirstCredsOrErr communicates (via closing channel) that the first // credsOrErr has been set. @@ -92,6 +101,15 @@ type CredentialsCacheOptions struct { // with AWS STSClient stscreds.AssumeRoleWithWebIdentityAPIClient + // GenerateOIDCTokenFn is a function that should return a valid, signed JWT for + // authenticating to AWS via OIDC. + GenerateOIDCTokenFn GenerateOIDCTokenFn + + // AllowRetrieveBeforeInit allows the Retrieve method to return with an + // error before the cache has been initialized, instead of waiting for the + // first credentials to be generated. + AllowRetrieveBeforeInit bool + // Log is the logger to use. A default will be supplied if no logger is // explicitly set Log *slog.Logger @@ -124,36 +142,59 @@ func NewCredentialsCache(options CredentialsCacheOptions) (*CredentialsCache, er return nil, trace.Wrap(err, "creating credentials cache") } - initialized := make(chan struct{}) + gotGenerateOIDCTokenFn := make(chan struct{}) + closeGotGenerateOIDCTokenFn := sync.OnceFunc(func() { close(gotGenerateOIDCTokenFn) }) + if options.GenerateOIDCTokenFn != nil { + closeGotGenerateOIDCTokenFn() + } + gotFirstCredsOrErr := make(chan struct{}) + closeGotFirstCredsOrErr := sync.OnceFunc(func() { close(gotFirstCredsOrErr) }) return &CredentialsCache{ - roleARN: options.RoleARN, - integration: options.Integration, - log: options.Log.With("integration", options.Integration), - initialized: initialized, - closeInitialized: sync.OnceFunc(func() { close(initialized) }), - gotFirstCredsOrErr: gotFirstCredsOrErr, - closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), - credsOrErr: credsOrErr{err: errNotReady}, - clock: options.Clock, - stsClient: options.STSClient, + roleARN: options.RoleARN, + integration: options.Integration, + generateOIDCTokenFn: options.GenerateOIDCTokenFn, + gotGenerateOIDCTokenFn: gotGenerateOIDCTokenFn, + closeGotGenerateOIDCTokenFn: closeGotGenerateOIDCTokenFn, + allowRetrieveBeforeInit: options.AllowRetrieveBeforeInit, + log: options.Log.With("integration", options.Integration), + gotFirstCredsOrErr: gotFirstCredsOrErr, + closeGotFirstCredsOrErr: closeGotFirstCredsOrErr, + credsOrErr: credsOrErr{err: errNotReady}, + clock: options.Clock, + stsClient: options.STSClient, }, nil } +// SetGenerateOIDCTokenFn can be used to set a GenerateOIDCTokenFn after +// creating the credential cache, when dependencies require the credential cache +// to be created before a valid GenerateOIDCTokenFn can be created. func (cc *CredentialsCache) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { + cc.generateOIDCTokenFnMu.Lock() + defer cc.generateOIDCTokenFnMu.Unlock() cc.generateOIDCTokenFn = fn - cc.closeInitialized() + close(cc.gotGenerateOIDCTokenFn) +} + +// getGenerateOIDCTokenFn must not be called before [cc.gotGenerateOIDCTokenFn] +// has been closed, or it will return nil. +func (cc *CredentialsCache) getGenerateOIDCTokenFn() GenerateOIDCTokenFn { + cc.generateOIDCTokenFnMu.Lock() + defer cc.generateOIDCTokenFnMu.Unlock() + return cc.generateOIDCTokenFn } // Retrieve implements [aws.CredentialsProvider] and returns the latest cached // credentials, or an error if no credentials have been generated yet or the // last generated credentials have expired. func (cc *CredentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) { - select { - case <-cc.gotFirstCredsOrErr: - case <-ctx.Done(): - return aws.Credentials{}, ctx.Err() + if !cc.allowRetrieveBeforeInit { + select { + case <-cc.gotFirstCredsOrErr: + case <-ctx.Done(): + return aws.Credentials{}, ctx.Err() + } } creds, err := cc.retrieve(ctx) return creds, trace.Wrap(err) @@ -169,9 +210,9 @@ func (cc *CredentialsCache) retrieve(ctx context.Context) (aws.Credentials, erro } func (cc *CredentialsCache) Run(ctx context.Context) { - // Wait for initialized signal before running loop. + // Wait for a generateOIDCTokenFn before running loop. select { - case <-cc.initialized: + case <-cc.gotGenerateOIDCTokenFn: case <-ctx.Done(): cc.log.DebugContext(ctx, "Context canceled before initialized.") return @@ -241,7 +282,7 @@ func (cc *CredentialsCache) refresh(ctx context.Context) (aws.Credentials, error defer cc.log.InfoContext(ctx, "Exiting AWS credentials refresh") cc.log.InfoContext(ctx, "Generating Token") - oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration) + oidcToken, err := cc.getGenerateOIDCTokenFn()(ctx, cc.integration) if err != nil { cc.log.ErrorContext(ctx, "Token generation failed", errorValue(err)) return aws.Credentials{}, trace.Wrap(err) diff --git a/lib/integrations/awsoidc/credprovider/credentialscache_test.go b/lib/integrations/awsoidc/credprovider/credentialscache_test.go index 6384bed0b8db0..359232da55503 100644 --- a/lib/integrations/awsoidc/credprovider/credentialscache_test.go +++ b/lib/integrations/awsoidc/credprovider/credentialscache_test.go @@ -20,7 +20,6 @@ import ( "context" "errors" "sync" - "sync/atomic" "testing" "time" @@ -35,13 +34,13 @@ import ( "github.com/gravitational/teleport/entitlements" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/utils" ) type fakeSTSClient struct { clock clockwork.Clock err error sync.Mutex - called int32 } func (f *fakeSTSClient) setError(err error) { @@ -57,7 +56,6 @@ func (f *fakeSTSClient) getError() error { } func (f *fakeSTSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { - atomic.AddInt32(&f.called, 1) if err := f.getError(); err != nil { return nil, err } @@ -98,6 +96,9 @@ func TestCredentialsCache(t *testing.T) { STSClient: stsClient, Integration: "test", Clock: clock, + GenerateOIDCTokenFn: func(ctx context.Context, integration string) (string, error) { + return uuid.NewString(), nil + }, }) require.NoError(t, err) require.NotNil(t, cacheUnderTest) @@ -110,12 +111,6 @@ func TestCredentialsCache(t *testing.T) { clock.Advance(d) } - // Set the GenerateOIDCTokenFn to a dumb faked function. - cacheUnderTest.SetGenerateOIDCTokenFn( - func(ctx context.Context, integration string) (string, error) { - return uuid.NewString(), nil - }) - checkRetrieveCredentials := func(t require.TestingT, expectErr error) { _, err := cacheUnderTest.Retrieve(ctx) assert.ErrorIs(t, err, expectErr) @@ -227,3 +222,50 @@ func TestCredentialsCache(t *testing.T) { } }) } + +func TestCredentialsCacheRetrieveBeforeInit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clock := clockwork.NewFakeClock() + stsClient := &fakeSTSClient{ + clock: clock, + } + cache, err := NewCredentialsCache(CredentialsCacheOptions{ + STSClient: stsClient, + Integration: "test", + Clock: clock, + AllowRetrieveBeforeInit: true, + }) + require.NoError(t, err) + + utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ + Name: "cache.Run", + Task: func(ctx context.Context) error { + cache.Run(ctx) + return nil + }, + Terminate: func() error { + cancel() + return nil + }, + }) + + // cache.Retrieve should return immediately with errNotReady if + // SetGenerateOIDCTokenFn has not been called yet. + _, err = cache.Retrieve(ctx) + require.ErrorIs(t, err, errNotReady) + + // The GenerateOIDCTokenFn can be set after the cache has been initialized. + cache.SetGenerateOIDCTokenFn(func(ctx context.Context, integration string) (string, error) { + return uuid.NewString(), nil + }) + // WaitForFirstCredsOrErr should usually be called after + // SetGenerateOIDCTokenFn to make sure credentials are ready before they + // will be relied upon. + cache.WaitForFirstCredsOrErr(ctx) + // Now cache.Retrieve should not return an error. + creds, err := cache.Retrieve(ctx) + require.NoError(t, err) + require.NotEmpty(t, creds.SecretAccessKey) +} diff --git a/lib/integrations/awsoidc/credprovider/integration_config_provider.go b/lib/integrations/awsoidc/credprovider/integration_config_provider.go index 204ab121cc133..76ed003113588 100644 --- a/lib/integrations/awsoidc/credprovider/integration_config_provider.go +++ b/lib/integrations/awsoidc/credprovider/integration_config_provider.go @@ -34,32 +34,17 @@ import ( ) // Options represents additional options for configuring the AWS credentials provider. -type Options struct { - // WaitForFirstInit indicates whether to wait for the initial credential - // generation before returning from CreateAWSConfigForIntegration. - WaitForFirstInit bool -} +// There are currently no options but this type is still referenced from +// teleport.e. +type Options struct{} // Option is a function that modifies the Options struct for the AWS configuration. type Option func(*Options) -// WithWaitForFirstInit configures the provider to wait until the first set of -// credentials is generated before proceeding. This is useful in cases where -// immediate credential availability is necessary. -func WithWaitForFirstInit(wait bool) Option { - return func(o *Options) { - o.WaitForFirstInit = wait - } -} - // CreateAWSConfigForIntegration returns a new AWS credentials provider that // uses the AWS OIDC integration to generate temporary credentials. // The provider will periodically refresh the credentials before they expire. func CreateAWSConfigForIntegration(ctx context.Context, config Config, option ...Option) (*aws.Config, error) { - options := Options{} - for _, opt := range option { - opt(&options) - } if err := config.checkAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -76,10 +61,6 @@ func CreateAWSConfigForIntegration(ctx context.Context, config Config, option .. } go credCache.Run(ctx) - if options.WaitForFirstInit { - credCache.WaitForFirstCredsOrErr(ctx) - } - awsCfg, err := newAWSConfig(ctx, config.Region, awsConfig.WithCredentialsProvider(credCache)) if err != nil { return nil, trace.Wrap(err) @@ -152,17 +133,17 @@ func newAWSCredCache(ctx context.Context, cfg Config, stsClient stscreds.AssumeR credCache, err := NewCredentialsCache( CredentialsCacheOptions{ - Log: cfg.Logger, - Clock: cfg.Clock, - STSClient: stsClient, - RoleARN: roleARN, - Integration: cfg.IntegrationName, + Log: cfg.Logger, + Clock: cfg.Clock, + STSClient: stsClient, + RoleARN: roleARN, + Integration: cfg.IntegrationName, + GenerateOIDCTokenFn: cfg.AWSOIDCTokenGenerator.GenerateAWSOIDCToken, }, ) if err != nil { return nil, trace.Wrap(err, "creating OIDC credentials cache") } - credCache.SetGenerateOIDCTokenFn(cfg.AWSOIDCTokenGenerator.GenerateAWSOIDCToken) return credCache, nil } diff --git a/lib/integrations/awsoidc/credprovider/integration_config_provider_test.go b/lib/integrations/awsoidc/credprovider/integration_config_provider_test.go index 03af684ce602b..7a9e42adb3d5f 100644 --- a/lib/integrations/awsoidc/credprovider/integration_config_provider_test.go +++ b/lib/integrations/awsoidc/credprovider/integration_config_provider_test.go @@ -19,7 +19,6 @@ package credprovider import ( "context" "crypto" - "sync/atomic" "testing" "github.com/gravitational/trace" @@ -55,24 +54,6 @@ func TestCreateAWSConfigForIntegration(t *testing.T) { STSClient: stsClient, }) require.NoError(t, err) - require.Equal(t, int32(0), atomic.LoadInt32(&stsClient.called)) - - creds, err := config.Credentials.Retrieve(ctx) - require.NoError(t, err) - require.NotEmpty(t, creds.SecretAccessKey) - }) - - t.Run("should init creds before retrieve call", func(t *testing.T) { - stsClient := &fakeSTSClient{clock: clockwork.NewFakeClock()} - config, err := CreateAWSConfigForIntegration(ctx, Config{ - Region: awsRegion, - IntegrationName: integrationName, - IntegrationGetter: deps, - AWSOIDCTokenGenerator: deps, - STSClient: stsClient, - }, WithWaitForFirstInit(true)) - require.NoError(t, err) - require.Equal(t, int32(1), atomic.LoadInt32(&stsClient.called)) creds, err := config.Credentials.Retrieve(ctx) require.NoError(t, err) @@ -88,19 +69,10 @@ type depsMock struct { proxies []types.Server } -func (d *depsMock) GenerateOIDCTokenFn(ctx context.Context, integration string) (string, error) { - token, err := awsoidc.GenerateAWSOIDCToken(ctx, d, d, awsoidc.GenerateAWSOIDCTokenRequest{ - Integration: integrationName, - Username: testUser, - Subject: types.IntegrationAWSOIDCSubject, - }) - return token, trace.Wrap(err) -} - func (d *depsMock) GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) { token, err := awsoidc.GenerateAWSOIDCToken(ctx, d, d, awsoidc.GenerateAWSOIDCTokenRequest{ Integration: integration, - Username: "test-user", + Username: testUser, Subject: types.IntegrationAWSOIDCSubject, }) return token, trace.Wrap(err) diff --git a/lib/integrations/externalauditstorage/configurator.go b/lib/integrations/externalauditstorage/configurator.go index 96c16c9dde133..739ee9d7342a3 100644 --- a/lib/integrations/externalauditstorage/configurator.go +++ b/lib/integrations/externalauditstorage/configurator.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -218,6 +217,10 @@ func newConfigurator(ctx context.Context, spec *externalauditstorage.ExternalAud RoleARN: awsRoleARN, STSClient: options.stsClient, Clock: options.clock, + // SetGenerateOIDCTokenFn will be called later, until then we must allow + // credentialsCache.Retrieve to return errors instead of blocking auth + // startup. + AllowRetrieveBeforeInit: true, }) if err != nil { return nil, trace.Wrap(err) @@ -263,13 +266,6 @@ func (p *Configurator) CredentialsProvider() aws.CredentialsProvider { return p.credentialsCache } -// CredentialsProviderSDKV1 returns a credentials.ProviderWithContext that can be used to -// authenticate with the customer AWS account via the configured AWS OIDC -// integration with aws-sdk-go. -func (p *Configurator) CredentialsProviderSDKV1() credentials.ProviderWithContext { - return &v1Adapter{cc: p.credentialsCache} -} - // WaitForFirstCredentials waits for the internal credentials cache to finish // fetching its first credentials (or getting an error attempting to do so). // This can be called after SetGenerateOIDCTokenFn to make sure any returned @@ -278,37 +274,3 @@ func (p *Configurator) CredentialsProviderSDKV1() credentials.ProviderWithContex func (p *Configurator) WaitForFirstCredentials(ctx context.Context) { p.credentialsCache.WaitForFirstCredsOrErr(ctx) } - -// v1Adapter wraps the credentialsCache to implement -// [credentials.ProviderWithContext] used by aws-sdk-go (v1). -type v1Adapter struct { - cc *credprovider.CredentialsCache -} - -var _ credentials.ProviderWithContext = (*v1Adapter)(nil) - -// RetrieveWithContext returns cached credentials. -func (a *v1Adapter) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { - credsV2, err := a.cc.Retrieve(ctx) - if err != nil { - return credentials.Value{}, trace.Wrap(err) - } - - return credentials.Value{ - AccessKeyID: credsV2.AccessKeyID, - SecretAccessKey: credsV2.SecretAccessKey, - SessionToken: credsV2.SessionToken, - ProviderName: credsV2.Source, - }, nil -} - -// Retrieve returns cached credentials. -func (a *v1Adapter) Retrieve() (credentials.Value, error) { - return a.RetrieveWithContext(context.Background()) -} - -// IsExpired always returns true in order to opt out of AWS SDK credential -// caching. Retrieve(WithContext) already returns cached credentials. -func (a *v1Adapter) IsExpired() bool { - return true -} diff --git a/lib/integrations/externalauditstorage/configurator_test.go b/lib/integrations/externalauditstorage/configurator_test.go index ba86e5f8e0c27..ffdf4ab543d31 100644 --- a/lib/integrations/externalauditstorage/configurator_test.go +++ b/lib/integrations/externalauditstorage/configurator_test.go @@ -233,17 +233,12 @@ func TestCredentialsCache(t *testing.T) { }) provider := c.CredentialsProvider() - providerV1 := c.CredentialsProviderSDKV1() checkRetrieveCredentials := func(t require.TestingT, expectErr error) { - _, err = providerV1.RetrieveWithContext(ctx) - assert.ErrorIs(t, err, expectErr) _, err := provider.Retrieve(ctx) assert.ErrorIs(t, err, expectErr) } checkRetrieveCredentialsWithExpiry := func(t require.TestingT, expectExpiry time.Time) { - _, err = providerV1.RetrieveWithContext(ctx) - assert.NoError(t, err) creds, err := provider.Retrieve(ctx) assert.NoError(t, err) if err == nil {