Skip to content

Commit

Permalink
fix: allow CredentialCache.Retrieve to return before initialization (#…
Browse files Browse the repository at this point in the history
…50748)

* fix: allow CredentialCache.Retrieve to return before initialization

Fixes #50747

This PR adds an option to allow CredentialsCache.Retrieve to return with
an error before CredentialsCache.SetGenerateOIDCTokenFn has been called.
We have been relying on this behaviour for Auth service startup, because
the credential cache is a dependency for the audit storage, which is a
dependency for the auth server, which is a dependency for the credential
cache's GenerateOIDCTokenFn.
Some non-critical checks during auth startup must fail instead of
hanging on blocking the process from successfully starting.

* use default io timeout

* synchronize access to generateOIDCTokenFn
  • Loading branch information
nklaassen authored Jan 6, 2025
1 parent cea0e73 commit c5b4101
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 140 deletions.
9 changes: 7 additions & 2 deletions lib/events/s3sessions/s3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"go.opentelemetry.io/otel"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -423,7 +424,11 @@ func (h *Handler) fromPath(p string) session.ID {

// ensureBucket makes sure bucket exists, and if it does not, creates it
func (h *Handler) ensureBucket(ctx context.Context) error {
_, err := h.client.HeadBucket(ctx, &s3.HeadBucketInput{
// Use a short timeout for the HeadBucket call in case it takes too long, in
// #50747 this call would hang.
shortCtx, cancel := context.WithTimeout(ctx, apidefaults.DefaultIOTimeout)
defer cancel()
_, err := h.client.HeadBucket(shortCtx, &s3.HeadBucketInput{
Bucket: aws.String(h.Bucket),
})
err = awsutils.ConvertS3Error(err)
Expand All @@ -434,7 +439,7 @@ func (h *Handler) ensureBucket(ctx context.Context) error {
case trace.IsBadParameter(err):
return trace.Wrap(err)
case !trace.IsNotFound(err):
h.logger.ErrorContext(ctx, "Failed to ensure that S3 bucket exists. S3 session uploads may fail. If you've set up the bucket already and gave Teleport write-only access, feel free to ignore this error.", "bucket", h.Bucket, "error", err)
h.logger.ErrorContext(ctx, "Failed to ensure that S3 bucket exists. This is expected if External Audit Storage is enabled or if Teleport has write-only access to the bucket, otherwise S3 session uploads may fail.", "bucket", h.Bucket, "error", err)
return nil
}

Expand Down
91 changes: 66 additions & 25 deletions lib/integrations/awsoidc/credprovider/credentialscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,21 @@ type CredentialsCache struct {
roleARN arn.ARN
integration string

// generateOIDCTokenFn is dynamically set after auth is initialized.
generateOIDCTokenFn GenerateOIDCTokenFn

// initialized communicates (via closing channel) that generateOIDCTokenFn is set.
initialized chan struct{}
closeInitialized func()
// generateOIDCTokenFn can be dynamically set after creating the credential
// cache, this is a workaround for a dependency cycle where audit storage
// depends on the credential cache, the auth server depends on audit
// storage, and the credential cache depends on the auth server for a
// GenerateOIDCTokenFn.
generateOIDCTokenFn GenerateOIDCTokenFn
generateOIDCTokenFnMu sync.Mutex
// gotGenerateOIDCTokenFn communicates (via closing channel) that
// generateOIDCTokenFn is set.
gotGenerateOIDCTokenFn chan struct{}
closeGotGenerateOIDCTokenFn func()
// allowRetrieveBeforeInit allows the Retrieve method to return an error if
// [gotGenerateOIDCTokenFn] has not been closed yet, instead of waiting for it to be
// closed.
allowRetrieveBeforeInit bool

// gotFirstCredsOrErr communicates (via closing channel) that the first
// credsOrErr has been set.
Expand All @@ -92,6 +101,15 @@ type CredentialsCacheOptions struct {
// with AWS
STSClient stscreds.AssumeRoleWithWebIdentityAPIClient

// GenerateOIDCTokenFn is a function that should return a valid, signed JWT for
// authenticating to AWS via OIDC.
GenerateOIDCTokenFn GenerateOIDCTokenFn

// AllowRetrieveBeforeInit allows the Retrieve method to return with an
// error before the cache has been initialized, instead of waiting for the
// first credentials to be generated.
AllowRetrieveBeforeInit bool

// Log is the logger to use. A default will be supplied if no logger is
// explicitly set
Log *slog.Logger
Expand Down Expand Up @@ -124,36 +142,59 @@ func NewCredentialsCache(options CredentialsCacheOptions) (*CredentialsCache, er
return nil, trace.Wrap(err, "creating credentials cache")
}

initialized := make(chan struct{})
gotGenerateOIDCTokenFn := make(chan struct{})
closeGotGenerateOIDCTokenFn := sync.OnceFunc(func() { close(gotGenerateOIDCTokenFn) })
if options.GenerateOIDCTokenFn != nil {
closeGotGenerateOIDCTokenFn()
}

gotFirstCredsOrErr := make(chan struct{})
closeGotFirstCredsOrErr := sync.OnceFunc(func() { close(gotFirstCredsOrErr) })

return &CredentialsCache{
roleARN: options.RoleARN,
integration: options.Integration,
log: options.Log.With("integration", options.Integration),
initialized: initialized,
closeInitialized: sync.OnceFunc(func() { close(initialized) }),
gotFirstCredsOrErr: gotFirstCredsOrErr,
closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }),
credsOrErr: credsOrErr{err: errNotReady},
clock: options.Clock,
stsClient: options.STSClient,
roleARN: options.RoleARN,
integration: options.Integration,
generateOIDCTokenFn: options.GenerateOIDCTokenFn,
gotGenerateOIDCTokenFn: gotGenerateOIDCTokenFn,
closeGotGenerateOIDCTokenFn: closeGotGenerateOIDCTokenFn,
allowRetrieveBeforeInit: options.AllowRetrieveBeforeInit,
log: options.Log.With("integration", options.Integration),
gotFirstCredsOrErr: gotFirstCredsOrErr,
closeGotFirstCredsOrErr: closeGotFirstCredsOrErr,
credsOrErr: credsOrErr{err: errNotReady},
clock: options.Clock,
stsClient: options.STSClient,
}, nil
}

// SetGenerateOIDCTokenFn can be used to set a GenerateOIDCTokenFn after
// creating the credential cache, when dependencies require the credential cache
// to be created before a valid GenerateOIDCTokenFn can be created.
func (cc *CredentialsCache) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) {
cc.generateOIDCTokenFnMu.Lock()
defer cc.generateOIDCTokenFnMu.Unlock()
cc.generateOIDCTokenFn = fn
cc.closeInitialized()
close(cc.gotGenerateOIDCTokenFn)
}

// getGenerateOIDCTokenFn must not be called before [cc.gotGenerateOIDCTokenFn]
// has been closed, or it will return nil.
func (cc *CredentialsCache) getGenerateOIDCTokenFn() GenerateOIDCTokenFn {
cc.generateOIDCTokenFnMu.Lock()
defer cc.generateOIDCTokenFnMu.Unlock()
return cc.generateOIDCTokenFn
}

// Retrieve implements [aws.CredentialsProvider] and returns the latest cached
// credentials, or an error if no credentials have been generated yet or the
// last generated credentials have expired.
func (cc *CredentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) {
select {
case <-cc.gotFirstCredsOrErr:
case <-ctx.Done():
return aws.Credentials{}, ctx.Err()
if !cc.allowRetrieveBeforeInit {
select {
case <-cc.gotFirstCredsOrErr:
case <-ctx.Done():
return aws.Credentials{}, ctx.Err()
}
}
creds, err := cc.retrieve(ctx)
return creds, trace.Wrap(err)
Expand All @@ -169,9 +210,9 @@ func (cc *CredentialsCache) retrieve(ctx context.Context) (aws.Credentials, erro
}

func (cc *CredentialsCache) Run(ctx context.Context) {
// Wait for initialized signal before running loop.
// Wait for a generateOIDCTokenFn before running loop.
select {
case <-cc.initialized:
case <-cc.gotGenerateOIDCTokenFn:
case <-ctx.Done():
cc.log.DebugContext(ctx, "Context canceled before initialized.")
return
Expand Down Expand Up @@ -241,7 +282,7 @@ func (cc *CredentialsCache) refresh(ctx context.Context) (aws.Credentials, error
defer cc.log.InfoContext(ctx, "Exiting AWS credentials refresh")

cc.log.InfoContext(ctx, "Generating Token")
oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration)
oidcToken, err := cc.getGenerateOIDCTokenFn()(ctx, cc.integration)
if err != nil {
cc.log.ErrorContext(ctx, "Token generation failed", errorValue(err))
return aws.Credentials{}, trace.Wrap(err)
Expand Down
60 changes: 51 additions & 9 deletions lib/integrations/awsoidc/credprovider/credentialscache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -34,13 +33,13 @@ import (

"github.com/gravitational/teleport/entitlements"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/utils"
)

type fakeSTSClient struct {
clock clockwork.Clock
err error
sync.Mutex
called int32
}

func (f *fakeSTSClient) setError(err error) {
Expand All @@ -56,7 +55,6 @@ func (f *fakeSTSClient) getError() error {
}

func (f *fakeSTSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
atomic.AddInt32(&f.called, 1)
if err := f.getError(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -95,6 +93,9 @@ func TestCredentialsCache(t *testing.T) {
STSClient: stsClient,
Integration: "test",
Clock: clock,
GenerateOIDCTokenFn: func(ctx context.Context, integration string) (string, error) {
return uuid.NewString(), nil
},
})
require.NoError(t, err)
require.NotNil(t, cacheUnderTest)
Expand All @@ -107,12 +108,6 @@ func TestCredentialsCache(t *testing.T) {
clock.Advance(d)
}

// Set the GenerateOIDCTokenFn to a dumb faked function.
cacheUnderTest.SetGenerateOIDCTokenFn(
func(ctx context.Context, integration string) (string, error) {
return uuid.NewString(), nil
})

checkRetrieveCredentials := func(t require.TestingT, expectErr error) {
_, err := cacheUnderTest.Retrieve(ctx)
assert.ErrorIs(t, err, expectErr)
Expand Down Expand Up @@ -224,3 +219,50 @@ func TestCredentialsCache(t *testing.T) {
}
})
}

func TestCredentialsCacheRetrieveBeforeInit(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clock := clockwork.NewFakeClock()
stsClient := &fakeSTSClient{
clock: clock,
}
cache, err := NewCredentialsCache(CredentialsCacheOptions{
STSClient: stsClient,
Integration: "test",
Clock: clock,
AllowRetrieveBeforeInit: true,
})
require.NoError(t, err)

utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{
Name: "cache.Run",
Task: func(ctx context.Context) error {
cache.Run(ctx)
return nil
},
Terminate: func() error {
cancel()
return nil
},
})

// cache.Retrieve should return immediately with errNotReady if
// SetGenerateOIDCTokenFn has not been called yet.
_, err = cache.Retrieve(ctx)
require.ErrorIs(t, err, errNotReady)

// The GenerateOIDCTokenFn can be set after the cache has been initialized.
cache.SetGenerateOIDCTokenFn(func(ctx context.Context, integration string) (string, error) {
return uuid.NewString(), nil
})
// WaitForFirstCredsOrErr should usually be called after
// SetGenerateOIDCTokenFn to make sure credentials are ready before they
// will be relied upon.
cache.WaitForFirstCredsOrErr(ctx)
// Now cache.Retrieve should not return an error.
creds, err := cache.Retrieve(ctx)
require.NoError(t, err)
require.NotEmpty(t, creds.SecretAccessKey)
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,17 @@ import (
)

// Options represents additional options for configuring the AWS credentials provider.
type Options struct {
// WaitForFirstInit indicates whether to wait for the initial credential
// generation before returning from CreateAWSConfigForIntegration.
WaitForFirstInit bool
}
// There are currently no options but this type is still referenced from
// teleport.e.
type Options struct{}

// Option is a function that modifies the Options struct for the AWS configuration.
type Option func(*Options)

// WithWaitForFirstInit configures the provider to wait until the first set of
// credentials is generated before proceeding. This is useful in cases where
// immediate credential availability is necessary.
func WithWaitForFirstInit(wait bool) Option {
return func(o *Options) {
o.WaitForFirstInit = wait
}
}

// CreateAWSConfigForIntegration returns a new AWS credentials provider that
// uses the AWS OIDC integration to generate temporary credentials.
// The provider will periodically refresh the credentials before they expire.
func CreateAWSConfigForIntegration(ctx context.Context, config Config, option ...Option) (*aws.Config, error) {
options := Options{}
for _, opt := range option {
opt(&options)
}
if err := config.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -76,10 +61,6 @@ func CreateAWSConfigForIntegration(ctx context.Context, config Config, option ..
}
go credCache.Run(ctx)

if options.WaitForFirstInit {
credCache.WaitForFirstCredsOrErr(ctx)
}

awsCfg, err := newAWSConfig(ctx, config.Region, awsConfig.WithCredentialsProvider(credCache))
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -152,17 +133,17 @@ func newAWSCredCache(ctx context.Context, cfg Config, stsClient stscreds.AssumeR

credCache, err := NewCredentialsCache(
CredentialsCacheOptions{
Log: cfg.Logger,
Clock: cfg.Clock,
STSClient: stsClient,
RoleARN: roleARN,
Integration: cfg.IntegrationName,
Log: cfg.Logger,
Clock: cfg.Clock,
STSClient: stsClient,
RoleARN: roleARN,
Integration: cfg.IntegrationName,
GenerateOIDCTokenFn: cfg.AWSOIDCTokenGenerator.GenerateAWSOIDCToken,
},
)
if err != nil {
return nil, trace.Wrap(err, "creating OIDC credentials cache")
}
credCache.SetGenerateOIDCTokenFn(cfg.AWSOIDCTokenGenerator.GenerateAWSOIDCToken)
return credCache, nil
}

Expand Down
Loading

0 comments on commit c5b4101

Please sign in to comment.