Skip to content

Commit

Permalink
Update awsconfig (#50561)
Browse files Browse the repository at this point in the history
* Add a Cache for caching credentials, similar to SDK v1 session cache.
* Add a Provider interface that provides aws.Config
* Simplified role chaining options

Unlike our SDK v1 session cache, the SDK v2 implementation in this PR
does not include region as a cache key.
There are regional AWS STS endpoints for lower latency calls, but the
lowest latency path is to just grab credentials from the cache if we
already have them - the region they were originally taken from doesn't
matter.
  • Loading branch information
GavinFrazar authored Jan 6, 2025
1 parent dbf8fcd commit 40e597e
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 74 deletions.
174 changes: 106 additions & 68 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,23 @@ const (
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)

// AssumeRoleClientProviderFunc provides an AWS STS assume role API client.
type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient

// AssumeRole is an AWS role to assume, optionally with an external ID.
type AssumeRole struct {
// RoleARN is the ARN of the role to assume.
RoleARN string `json:"role_arn"`
// ExternalID is an optional ID to include when assuming the role.
ExternalID string `json:"external_id"`
}

// options is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type options struct {
// baseConfigis a config to use instead of the default config for an
// AWS region, which is used to enable role chaining.
baseConfig *aws.Config
// assumeRoleARN is the AWS IAM Role ARN to assume.
assumeRoleARN string
// assumeRoleExternalID is used to assume an external AWS IAM Role.
assumeRoleExternalID string
// assumeRoles are AWS IAM roles that should be assumed one by one in order,
// as a chain of assumed roles.
assumeRoles []AssumeRole
// credentialsSource describes which source to use to fetch credentials.
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
Expand All @@ -67,22 +74,45 @@ type options struct {
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
// assumeRoleClientProvider sets the STS assume role client provider func.
assumeRoleClientProvider AssumeRoleClientProviderFunc
}

func (a *options) checkAndSetDefaults() error {
switch a.credentialsSource {
func buildOptions(optFns ...OptionsFn) (*options, error) {
var opts options
for _, optFn := range optFns {
optFn(&opts)
}
if err := opts.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &opts, nil
}

func (o *options) checkAndSetDefaults() error {
switch o.credentialsSource {
case credentialsSourceAmbient:
if a.integration != "" {
if o.integration != "" {
return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
}
case credentialsSourceIntegration:
if a.integration == "" {
if o.integration == "" {
return trace.BadParameter("missing integration name")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}
if len(o.assumeRoles) > 2 {
return trace.BadParameter("role chain contains more than 2 roles")
}

if o.assumeRoleClientProvider == nil {
o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
return sts.NewFromConfig(cfg, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
}
}
return nil
}

Expand All @@ -93,8 +123,14 @@ type OptionsFn func(*options)
// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) OptionsFn {
return func(options *options) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
if roleARN == "" {
// ignore empty role ARN for caller convenience.
return
}
options.assumeRoles = append(options.assumeRoles, AssumeRole{
RoleARN: roleARN,
ExternalID: externalID,
})
}
}

Expand Down Expand Up @@ -146,96 +182,98 @@ func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) O
}
}

// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to
// assume roles.
func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn {
return func(options *options) {
options.assumeRoleClientProvider = fn
}
}

// GetConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) {
var options options
for _, opt := range opts {
opt(&options)
}
if options.baseConfig == nil {
cfg, err := getConfigForRegion(ctx, region, options)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
options.baseConfig = &cfg
func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
opts, err := buildOptions(optFns...)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
if options.assumeRoleARN == "" {
return *options.baseConfig, nil

cfg, err := getBaseConfig(ctx, region, opts)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
return getConfigForRole(ctx, region, options)
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider)
}

// ambientConfigProvider loads a new config using the environment variables.
func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) {
opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
// loadDefaultConfig loads a new config.
func loadDefaultConfig(ctx context.Context, region string, cred aws.CredentialsProvider, opts *options) (aws.Config, error) {
configOpts := buildConfigOptions(region, cred, opts)
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
return cfg, trace.Wrap(err)
}

func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error {
opts := []func(*config.LoadOptions) error{
func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *options) []func(*config.LoadOptions) error {
configOpts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
}
if modules.GetModules().IsBoringBinary() {
opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
}
if options.customRetryer != nil {
opts = append(opts, config.WithRetryer(options.customRetryer))
if opts.customRetryer != nil {
configOpts = append(configOpts, config.WithRetryer(opts.customRetryer))
}
if options.maxRetries != nil {
opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
if opts.maxRetries != nil {
configOpts = append(configOpts, config.WithRetryMaxAttempts(*opts.maxRetries))
}
return opts
return configOpts
}

// getConfigForRegion returns AWS config for the specified region.
func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

// getBaseConfig returns an AWS config without assuming any roles.
func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
if options.credentialsSource == credentialsSourceIntegration {
if options.integrationCredentialsProvider == nil {
if opts.credentialsSource == credentialsSourceIntegration {
if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
cred, err = options.integrationCredentialsProvider(ctx, region, options.integration)
cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
}

cfg, err := ambientConfigProvider(region, cred, options)
cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
}

// getConfigForRole returns an AWS config for the specified region and role.
func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}

stsClient := sts.NewFromConfig(*options.baseConfig, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
if options.assumeRoleExternalID != "" {
aro.ExternalID = aws.String(options.assumeRoleExternalID)
if len(roles) > 0 {
// no point caching every assumed role in the chain, we can just cache
// the last one.
cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, awsCredentialsCacheOptions)
if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}
})
if _, err := cred.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}
return cfg, nil
}

opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(ctx, opts...)
return cfg, trace.Wrap(err)
func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient, role AssumeRole) aws.CredentialsProvider {
slog.DebugContext(ctx, "Initializing AWS session for assumed role",
"assumed_role", role.RoleARN,
)
return stscreds.NewAssumeRoleProvider(clt, role.RoleARN, func(aro *stscreds.AssumeRoleOptions) {
if role.ExternalID != "" {
aro.ExternalID = aws.String(role.ExternalID)
}
})
}
Loading

0 comments on commit 40e597e

Please sign in to comment.