From e75a27ca043369bb1ab8d485cf8eb1a33e13afe9 Mon Sep 17 00:00:00 2001 From: Dave New Date: Mon, 13 Nov 2023 16:58:04 +0200 Subject: [PATCH] feat: keelconfig auth configuration (#1296) --- cmd/program/model.go | 1 + config/auth.go | 205 +++++++++++++ config/config.go | 103 ++++++- config/config_test.go | 103 +++++++ config/fixtures/test_auth.yaml | 28 ++ .../fixtures/test_auth_duplicate_names.yaml | 17 ++ .../fixtures/test_auth_invalid_auth_url.yaml | 22 ++ config/fixtures/test_auth_invalid_issuer.yaml | 30 ++ .../fixtures/test_auth_invalid_token_url.yaml | 22 ++ config/fixtures/test_auth_invalid_types.yaml | 20 ++ .../test_auth_missing_client_ids.yaml | 13 + config/fixtures/test_auth_missing_names.yaml | 11 + .../test_auth_negative_token_lifespan.yaml | 4 + runtime/apis/authapi/token_endpoint_test.go | 71 +++-- runtime/oauth/access_token.go | 22 +- runtime/oauth/access_token_test.go | 278 ++++++++++++++++++ runtime/oauth/id_token.go | 14 +- runtime/oauth/id_token_test.go | 75 +++-- runtime/oauth/refresh_token.go | 20 +- runtime/oauth/refresh_token_test.go | 51 ++-- runtime/runtimectx/oauth.go | 31 ++ 21 files changed, 1035 insertions(+), 106 deletions(-) create mode 100644 config/auth.go create mode 100644 config/fixtures/test_auth.yaml create mode 100644 config/fixtures/test_auth_duplicate_names.yaml create mode 100644 config/fixtures/test_auth_invalid_auth_url.yaml create mode 100644 config/fixtures/test_auth_invalid_issuer.yaml create mode 100644 config/fixtures/test_auth_invalid_token_url.yaml create mode 100644 config/fixtures/test_auth_invalid_types.yaml create mode 100644 config/fixtures/test_auth_missing_client_ids.yaml create mode 100644 config/fixtures/test_auth_missing_names.yaml create mode 100644 config/fixtures/test_auth_negative_token_lifespan.yaml create mode 100644 runtime/oauth/access_token_test.go create mode 100644 runtime/runtimectx/oauth.go diff --git a/cmd/program/model.go b/cmd/program/model.go index 53fe64bee..f491dd567 100644 --- a/cmd/program/model.go +++ b/cmd/program/model.go @@ -400,6 +400,7 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { ctx = db.WithDatabase(ctx, m.Database) ctx = runtimectx.WithSecrets(ctx, m.Secrets) + ctx = runtimectx.WithOAuthConfig(ctx, &m.Config.Auth) mailClient := mail.NewSMTPClientFromEnv() if mailClient != nil { diff --git a/config/auth.go b/config/auth.go new file mode 100644 index 000000000..368cf72d5 --- /dev/null +++ b/config/auth.go @@ -0,0 +1,205 @@ +package config + +import ( + "fmt" + "net/url" + + "github.com/samber/lo" +) + +const ( + GoogleProvider = "google" + OpenIdConnectProvider = "oidc" + OAuthProvider = "oauth" +) + +var ( + SupportedProviderTypes = []string{ + GoogleProvider, + OpenIdConnectProvider, + OAuthProvider, + } +) + +type AuthConfig struct { + Tokens *TokensConfig `yaml:"tokens"` + Providers []Provider `yaml:"providers"` +} + +type TokensConfig struct { + AccessTokenExpiry int `yaml:"accessTokenExpiry"` + RefreshTokenExpiry int `yaml:"refreshTokenExpiry"` +} + +type Provider struct { + Type string `yaml:"type"` + Name string `yaml:"name"` + ClientId string `yaml:"clientId"` + IssuerUrl string `yaml:"issuerUrl"` + TokenUrl string `yaml:"tokenUrl"` + AuthorizationUrl string `yaml:"authorizationUrl"` +} + +func (c *AuthConfig) GetOidcProviders() []Provider { + oidcProviders := []Provider{} + for _, p := range c.Providers { + if p.Type == OpenIdConnectProvider { + oidcProviders = append(oidcProviders, p) + } + } + return oidcProviders +} + +func (c *AuthConfig) HasOidcIssuer(issuer string) (bool, error) { + for _, p := range c.Providers { + if p.Type == OAuthProvider { + continue + } + + issuerUrl, err := p.GetIssuer() + if err != nil { + return false, err + } + if issuerUrl == issuer { + return true, nil + } + } + return false, nil +} + +func (c *Provider) GetIssuer() (string, error) { + switch c.Type { + case GoogleProvider: + return "https://accounts.google.com/", nil + case OpenIdConnectProvider: + return c.IssuerUrl, nil + default: + return "", fmt.Errorf("the provider type '%s' should not have an issuer url configured", c.Type) + } +} + +func (c *AuthConfig) GetOAuthProviders() []Provider { + oidcProviders := []Provider{} + for _, p := range c.Providers { + if p.Type == OAuthProvider { + oidcProviders = append(oidcProviders, p) + } + } + return oidcProviders +} + +func (c *Provider) GetTokenUrl() (string, error) { + switch c.Type { + case GoogleProvider: + return "https://accounts.google.com/o/oauth2/token", nil + case OAuthProvider: + return c.TokenUrl, nil + default: + return "", fmt.Errorf("the provider type '%s' should not have a token url configured", c.Type) + } +} + +func (c *Provider) GetAuthorizationUrl() (string, error) { + switch c.Type { + case GoogleProvider: + return "https://accounts.google.com/o/oauth2/auth", nil + case OAuthProvider: + return c.AuthorizationUrl, nil + default: + return "", fmt.Errorf("the provider type '%s' should not have a token url configured", c.Type) + } +} + +// findAuthProviderMissingName checks for missing provider names +func findAuthProviderMissingName(providers []Provider) []Provider { + invalid := []Provider{} + for _, p := range providers { + if p.Name == "" { + invalid = append(invalid, p) + } + } + + return invalid +} + +// findAuthProviderDuplicateName checks for duplicate auth provider names +func findAuthProviderDuplicateName(providers []Provider) []Provider { + keys := make(map[string]bool) + + duplicates := []Provider{} + for _, p := range providers { + if _, value := keys[p.Name]; !value { + keys[p.Name] = true + } else { + duplicates = append(duplicates, p) + } + } + + return duplicates +} + +// findAuthProviderInvalidType checks for invalid provider types +func findAuthProviderInvalidType(providers []Provider) []Provider { + invalid := []Provider{} + for _, p := range providers { + if !lo.Contains(SupportedProviderTypes, p.Type) { + invalid = append(invalid, p) + } + } + + return invalid +} + +// findAuthProviderMissingClientId checks for missing client IDs +func findAuthProviderMissingClientId(providers []Provider) []Provider { + invalid := []Provider{} + for _, p := range providers { + if p.ClientId == "" { + invalid = append(invalid, p) + } + } + + return invalid +} + +// findAuthProviderMissingIssuerUrl checks for missing or invalid issuer URLs +func findAuthProviderMissingOrInvalidIssuerUrl(providers []Provider) []Provider { + invalid := []Provider{} + for _, p := range providers { + u, err := url.Parse(p.IssuerUrl) + if err != nil || u.Scheme != "https" { + invalid = append(invalid, p) + continue + } + } + + return invalid +} + +// findAuthProviderMissingOrInvalidTokenUrl checks for missing or invalid token URLs +func findAuthProviderMissingOrInvalidTokenUrl(providers []Provider) []Provider { + invalid := []Provider{} + for _, p := range providers { + u, err := url.Parse(p.TokenUrl) + if err != nil || u.Scheme != "https" { + invalid = append(invalid, p) + continue + } + } + + return invalid +} + +// findAuthProviderMissingOrInvalidAuthorizationUrl checks for missing or invalid authorization URLs +func findAuthProviderMissingOrInvalidAuthorizationUrl(providers []Provider) []Provider { + invalid := []Provider{} + for _, p := range providers { + u, err := url.Parse(p.AuthorizationUrl) + if err != nil || u.Scheme != "https" { + invalid = append(invalid, p) + continue + } + } + + return invalid +} diff --git a/config/config.go b/config/config.go index 6b106ebd9..537e2f64f 100644 --- a/config/config.go +++ b/config/config.go @@ -19,6 +19,7 @@ type Config struct{} type ProjectConfig struct { Environment EnvironmentConfig `yaml:"environment"` Secrets []Input `yaml:"secrets"` + Auth AuthConfig `yaml:"auth"` DisableAuth bool `yaml:"disableKeelAuth"` } @@ -117,10 +118,16 @@ type ConfigError struct { } const ( - ConfigDuplicateErrorString = "environment variable %s has a duplicate set in environment: %s" - ConfigRequiredErrorString = "environment variable %s is required but not defined in the following environments: %s" - ConfigIncorrectNamingErrorString = "%s must be written in upper snakecase" - ConfigReservedNameErrorString = "environment variable %s cannot start with %s as it is reserved" + ConfigDuplicateErrorString = "environment variable %s has a duplicate set in environment: %s" + ConfigRequiredErrorString = "environment variable %s is required but not defined in the following environments: %s" + ConfigIncorrectNamingErrorString = "%s must be written in upper snakecase" + ConfigReservedNameErrorString = "environment variable %s cannot start with %s as it is reserved" + ConfigAuthTokenExpiryMustBePositive = "%s token lifespan cannot be negative or zero for field: %s" + ConfigAuthProviderMissingFieldAtIndexErrorString = "auth provider at index %v is missing field: %s" + ConfigAuthProviderMissingFieldErrorString = "auth provider '%s' is missing field: %s" + ConfigAuthProviderInvalidTypeErrorString = "auth provider '%s' has invalid type '%s' which must be one of: %s" + ConfigAuthProviderDuplicateErrorString = "auth provider name '%s' has been defined more than once, but must be unique" + ConfigAuthProviderInvalidHttpUrlErrorString = "auth provider '%s' has missing or invalid https url for field: %s" ) type ConfigErrors struct { @@ -232,6 +239,94 @@ func Validate(config *ProjectConfig) *ConfigErrors { } } + if config.Auth.Tokens != nil && config.Auth.Tokens.AccessTokenExpiry <= 0 { + errors = append(errors, &ConfigError{ + Type: "invalid", + Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "access", "accessTokenExpiry"), + }) + } + + if config.Auth.Tokens != nil && config.Auth.Tokens.RefreshTokenExpiry <= 0 { + errors = append(errors, &ConfigError{ + Type: "invalid", + Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "refresh", "refreshTokenExpiry"), + }) + } + + missingProviderNames := findAuthProviderMissingName(config.Auth.Providers) + for i := range missingProviderNames { + errors = append(errors, &ConfigError{ + Type: "missing", + Message: fmt.Sprintf(ConfigAuthProviderMissingFieldAtIndexErrorString, i, "name"), + }) + } + + invalidProviderTypes := findAuthProviderInvalidType(config.Auth.Providers) + for _, p := range invalidProviderTypes { + if p.Name == "" { + continue + } + errors = append(errors, &ConfigError{ + Type: "missing", + Message: fmt.Sprintf(ConfigAuthProviderInvalidTypeErrorString, p.Name, p.Type, strings.Join(SupportedProviderTypes, ", ")), + }) + } + + duplicateProviders := findAuthProviderDuplicateName(config.Auth.Providers) + for _, p := range duplicateProviders { + if p.Name == "" { + continue + } + errors = append(errors, &ConfigError{ + Type: "duplicate", + Message: fmt.Sprintf(ConfigAuthProviderDuplicateErrorString, p.Name), + }) + } + + missingClientIds := findAuthProviderMissingClientId(config.Auth.Providers) + for _, p := range missingClientIds { + if p.Name == "" { + continue + } + errors = append(errors, &ConfigError{ + Type: "missing", + Message: fmt.Sprintf(ConfigAuthProviderMissingFieldErrorString, p.Name, "clientId"), + }) + } + + missingOrInvalidIssuerUrls := findAuthProviderMissingOrInvalidIssuerUrl(config.Auth.GetOidcProviders()) + for _, p := range missingOrInvalidIssuerUrls { + if p.Name == "" { + continue + } + errors = append(errors, &ConfigError{ + Type: "invalid", + Message: fmt.Sprintf(ConfigAuthProviderInvalidHttpUrlErrorString, p.Name, "issuerUrl"), + }) + } + + missingOrInvalidTokenUrls := findAuthProviderMissingOrInvalidTokenUrl(config.Auth.GetOAuthProviders()) + for _, p := range missingOrInvalidTokenUrls { + if p.Name == "" { + continue + } + errors = append(errors, &ConfigError{ + Type: "invalid", + Message: fmt.Sprintf(ConfigAuthProviderInvalidHttpUrlErrorString, p.Name, "tokenUrl"), + }) + } + + missingOrInvalidAuthUrls := findAuthProviderMissingOrInvalidAuthorizationUrl(config.Auth.GetOAuthProviders()) + for _, p := range missingOrInvalidAuthUrls { + if p.Name == "" { + continue + } + errors = append(errors, &ConfigError{ + Type: "invalid", + Message: fmt.Sprintf(ConfigAuthProviderInvalidHttpUrlErrorString, p.Name, "authorizationUrl"), + }) + } + if len(errors) == 0 { return nil } diff --git a/config/config_test.go b/config/config_test.go index 467fc9407..5055e9ef9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -114,3 +114,106 @@ func TestReservedNameValidateFormat(t *testing.T) { assert.Contains(t, err.Error(), "environment variable OPENCOLLECTOR_CONFIG_NOT_ALLOWED4 cannot start with OPENCOLLECTOR_CONFIG as it is reserved\n") assert.Contains(t, err.Error(), "environment variable _NOT_ALLOWED5 cannot start with _ as it is reserved\n") } + +func TestAuthTokens(t *testing.T) { + config, err := Load("fixtures/test_auth.yaml") + assert.NoError(t, err) + + assert.Equal(t, 3600, config.Auth.Tokens.AccessTokenExpiry) + assert.Equal(t, 604800, config.Auth.Tokens.RefreshTokenExpiry) +} + +func TestAuthNegativeTokenLifespan(t *testing.T) { + _, err := Load("fixtures/test_auth_negative_token_lifespan.yaml") + + assert.Contains(t, err.Error(), "access token lifespan cannot be negative or zero for field: accessTokenExpiry\n") + assert.Contains(t, err.Error(), "refresh token lifespan cannot be negative or zero for field: refreshTokenExpiry\n") +} + +func TestAuthProviders(t *testing.T) { + config, err := Load("fixtures/test_auth.yaml") + assert.NoError(t, err) + + assert.Equal(t, "google", config.Auth.Providers[0].Type) + assert.Equal(t, "google-1", config.Auth.Providers[0].Name) + assert.Equal(t, "foo_1", config.Auth.Providers[0].ClientId) + + assert.Equal(t, "google", config.Auth.Providers[1].Type) + assert.Equal(t, "google_2", config.Auth.Providers[1].Name) + assert.Equal(t, "foo_2", config.Auth.Providers[1].ClientId) + + assert.Equal(t, "oidc", config.Auth.Providers[2].Type) + assert.Equal(t, "Baidu", config.Auth.Providers[2].Name) + assert.Equal(t, "https://dev-skhlutl45lbqkvhv.us.auth0.com", config.Auth.Providers[2].IssuerUrl) + assert.Equal(t, "kasj28fnq09ak", config.Auth.Providers[2].ClientId) + + assert.Equal(t, "oauth", config.Auth.Providers[3].Type) + assert.Equal(t, "Github", config.Auth.Providers[3].Name) + assert.Equal(t, "hfjuw983h1hfsdf", config.Auth.Providers[3].ClientId) + assert.Equal(t, "https://github.com/auth", config.Auth.Providers[3].AuthorizationUrl) + assert.Equal(t, "https://github.com/token", config.Auth.Providers[3].TokenUrl) +} + +func TestMissingProviderName(t *testing.T) { + _, err := Load("fixtures/test_auth_missing_names.yaml") + + assert.Contains(t, err.Error(), "auth provider at index 0 is missing field: name\n") + assert.Contains(t, err.Error(), "auth provider at index 1 is missing field: name\n") + assert.Contains(t, err.Error(), "auth provider at index 2 is missing field: name\n") +} + +func TestDuplicateProviderName(t *testing.T) { + _, err := Load("fixtures/test_auth_duplicate_names.yaml") + + assert.Equal(t, "auth provider name 'my_google' has been defined more than once, but must be unique\n", err.Error()) +} + +func TestInvalidProviderTypes(t *testing.T) { + _, err := Load("fixtures/test_auth_invalid_types.yaml") + + assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, oidc, oauth\n") + assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, oidc, oauth\n") + assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, oidc, oauth\n") +} + +func TestMissingClientId(t *testing.T) { + _, err := Load("fixtures/test_auth_missing_client_ids.yaml") + + assert.Contains(t, err.Error(), "auth provider 'google_1' is missing field: clientId\n") + assert.Contains(t, err.Error(), "auth provider 'Baidu' is missing field: clientId\n") + assert.Contains(t, err.Error(), "auth provider 'Github' is missing field: clientId\n") +} + +func TestMissingOrInvalidIssuerUrl(t *testing.T) { + _, err := Load("fixtures/test_auth_invalid_issuer.yaml") + + assert.Contains(t, err.Error(), "auth provider 'not-https' has missing or invalid https url for field: issuerUrl\n") + assert.Contains(t, err.Error(), "auth provider 'missing-issuer' has missing or invalid https url for field: issuerUrl\n") + assert.Contains(t, err.Error(), "auth provider 'no-schema' has missing or invalid https url for field: issuerUrl\n") + assert.Contains(t, err.Error(), "auth provider 'invalid-url' has missing or invalid https url for field: issuerUrl\n") +} + +func TestMissingOrInvalidTokenEndpoint(t *testing.T) { + _, err := Load("fixtures/test_auth_invalid_token_url.yaml") + + assert.Contains(t, err.Error(), "auth provider 'not-https' has missing or invalid https url for field: tokenUrl\n") + assert.Contains(t, err.Error(), "auth provider 'missing-schema' has missing or invalid https url for field: tokenUrl\n") + assert.Contains(t, err.Error(), "auth provider 'missing-endpoint' has missing or invalid https url for field: tokenUrl\n") +} + +func TestHasIssuer(t *testing.T) { + config, err := Load("fixtures/test_auth.yaml") + assert.NoError(t, err) + + hasGoogleIssuer, err := config.Auth.HasOidcIssuer("https://accounts.google.com/") + assert.NoError(t, err) + assert.True(t, hasGoogleIssuer) + + hasCustomIssuer, err := config.Auth.HasOidcIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com") + assert.NoError(t, err) + assert.True(t, hasCustomIssuer) + + hasUnknownIssuer, err := config.Auth.HasOidcIssuer("https://nope.com") + assert.NoError(t, err) + assert.False(t, hasUnknownIssuer) +} diff --git a/config/fixtures/test_auth.yaml b/config/fixtures/test_auth.yaml new file mode 100644 index 000000000..7ec4ef8df --- /dev/null +++ b/config/fixtures/test_auth.yaml @@ -0,0 +1,28 @@ +auth: + tokens: + accessTokenExpiry: 3600 + refreshTokenExpiry: 604800 + + providers: + # Built-in Google provider + - type: google + name: google-1 + clientId: foo_1 + + # Built-in Google provider + - type: google + name: google_2 + clientId: foo_2 + + # Custom OIDC + - type: oidc + name: Baidu + issuerUrl: 'https://dev-skhlutl45lbqkvhv.us.auth0.com' + clientId: 'kasj28fnq09ak' + + # Custom OAuth + - type: oauth + name: Github + clientId: hfjuw983h1hfsdf + authorizationUrl: https://github.com/auth + tokenUrl: https://github.com/token \ No newline at end of file diff --git a/config/fixtures/test_auth_duplicate_names.yaml b/config/fixtures/test_auth_duplicate_names.yaml new file mode 100644 index 000000000..e9568d500 --- /dev/null +++ b/config/fixtures/test_auth_duplicate_names.yaml @@ -0,0 +1,17 @@ +auth: + providers: + - type: google + name: my_google + clientId: foo_1 + + - type: oidc + name: Baidu + issuerUrl: 'https://dev-skhlutl45lbqkvhv.us.auth0.com' + clientId: 'kasj28fnq09ak' + + # Duplicate name + - type: oauth + name: my_google + clientId: hfjuw983h1hfsdf + authorizationUrl: https://github.com/auth + tokenUrl: https://github.com/token \ No newline at end of file diff --git a/config/fixtures/test_auth_invalid_auth_url.yaml b/config/fixtures/test_auth_invalid_auth_url.yaml new file mode 100644 index 000000000..8ed0e3535 --- /dev/null +++ b/config/fixtures/test_auth_invalid_auth_url.yaml @@ -0,0 +1,22 @@ +auth: + tokens: + accessTokenExpiry: 3600 + refreshTokenExpiry: 604800 + + providers: + - type: oauth + name: not-https + clientId: hfjuw983h1hfsdf + authorizationUrl: http://github.com/auth + tokenUrl: http://github.com/token + + - type: oauth + name: missing-schema + clientId: hfjuw983h1hfsdf + authorizationUrl: github.com/auth + tokenUrl: https://github.com/token + + - type: oauth + name: missing-endpoint + clientId: hfjuw983h1hfsdf + tokenUrl: https://github.com/token \ No newline at end of file diff --git a/config/fixtures/test_auth_invalid_issuer.yaml b/config/fixtures/test_auth_invalid_issuer.yaml new file mode 100644 index 000000000..9989b2bf6 --- /dev/null +++ b/config/fixtures/test_auth_invalid_issuer.yaml @@ -0,0 +1,30 @@ +auth: + tokens: + accessTokenExpiry: 3600 + refreshTokenExpiry: 604800 + + providers: + - type: oidc + name: not-https + issuerUrl: 'http://not-https.com' + clientId: 'kasj28fnq09ak' + + - type: oidc + name: missing-issuer + clientId: 'kasj28fnq09ak' + + - type: oidc + name: no-schema + issuerUrl: 'not-https.com' + clientId: 'kasj28fnq09ak' + + - type: oidc + name: invalid-url + issuerUrl: 'whoops' + clientId: 'kasj28fnq09ak' + + - type: oauth + name: myOAuthProvider + clientId: 'kasj28fnq09ak' + issuerUrl: 'not an error' + diff --git a/config/fixtures/test_auth_invalid_token_url.yaml b/config/fixtures/test_auth_invalid_token_url.yaml new file mode 100644 index 000000000..94dbd6601 --- /dev/null +++ b/config/fixtures/test_auth_invalid_token_url.yaml @@ -0,0 +1,22 @@ +auth: + tokens: + accessTokenExpiry: 3600 + refreshTokenExpiry: 604800 + + providers: + - type: oauth + name: not-https + clientId: hfjuw983h1hfsdf + authorizationUrl: https://github.com/auth + tokenUrl: http://github.com/token + + - type: oauth + name: missing-schema + clientId: hfjuw983h1hfsdf + authorizationUrl: https://github.com/auth + tokenUrl: github.com/token + + - type: oauth + name: missing-endpoint + clientId: hfjuw983h1hfsdf + authorizationUrl: https://github.com/auth \ No newline at end of file diff --git a/config/fixtures/test_auth_invalid_types.yaml b/config/fixtures/test_auth_invalid_types.yaml new file mode 100644 index 000000000..4d10d4eb7 --- /dev/null +++ b/config/fixtures/test_auth_invalid_types.yaml @@ -0,0 +1,20 @@ +auth: + providers: + - type: google_1 + name: google_1 + clientId: foo_1 + + - type: Google + name: google_2 + clientId: foo_2 + + - type: oauth + name: Github + clientId: hfjuw983h1hfsdf + authorizationUrl: https://github.com/auth + tokenUrl: https://github.com/token + + - type: whoops + name: Baidu + issuerUrl: 'https://dev-skhlutl45lbqkvhv.us.auth0.com' + clientId: 'kasj28fnq09ak' \ No newline at end of file diff --git a/config/fixtures/test_auth_missing_client_ids.yaml b/config/fixtures/test_auth_missing_client_ids.yaml new file mode 100644 index 000000000..5f0749644 --- /dev/null +++ b/config/fixtures/test_auth_missing_client_ids.yaml @@ -0,0 +1,13 @@ +auth: + providers: + - type: google + name: google_1 + + - type: oidc + name: Baidu + issuerUrl: 'https://dev-skhlutl45lbqkvhv.us.auth0.com' + + - type: oauth + name: Github + authorizationUrl: https://github.com/auth + tokenUrl: https://github.com/token \ No newline at end of file diff --git a/config/fixtures/test_auth_missing_names.yaml b/config/fixtures/test_auth_missing_names.yaml new file mode 100644 index 000000000..96b132d81 --- /dev/null +++ b/config/fixtures/test_auth_missing_names.yaml @@ -0,0 +1,11 @@ +auth: + tokens: + accessTokenExpiry: 3600 + refreshTokenExpiry: 604800 + + providers: + - type: google + + - type: oidc + + - type: oauth diff --git a/config/fixtures/test_auth_negative_token_lifespan.yaml b/config/fixtures/test_auth_negative_token_lifespan.yaml new file mode 100644 index 000000000..559c396d9 --- /dev/null +++ b/config/fixtures/test_auth_negative_token_lifespan.yaml @@ -0,0 +1,4 @@ +auth: + tokens: + accessTokenExpiry: -1 + refreshTokenExpiry: -1 \ No newline at end of file diff --git a/runtime/apis/authapi/token_endpoint_test.go b/runtime/apis/authapi/token_endpoint_test.go index 8449e8c5f..a696aa770 100644 --- a/runtime/apis/authapi/token_endpoint_test.go +++ b/runtime/apis/authapi/token_endpoint_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/config" "github.com/teamkeel/keel/proto" "github.com/teamkeel/keel/runtime" "github.com/teamkeel/keel/runtime/apis/authapi" @@ -28,15 +29,21 @@ func TestTokenExchange_ValidNewIdentity(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + server.SetUser("id|285620", &oauth.UserClaims{ Email: "keelson@keel.so", }) @@ -59,7 +66,7 @@ func TestTokenExchange_ValidNewIdentity(t *testing.T) { require.NotEmpty(t, validResponse.RefreshToken) require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) - sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken, "") + sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken) require.NoError(t, err) require.Equal(t, "https://keel.so", iss) @@ -88,15 +95,21 @@ func TestTokenExchange_ValidNewIdentityAllUserInfo(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + server.SetUser("id|285620", &oauth.UserClaims{ Email: "keelson@keel.so", EmailVerified: true, @@ -132,7 +145,7 @@ func TestTokenExchange_ValidNewIdentityAllUserInfo(t *testing.T) { require.NotEmpty(t, validResponse.ExpiresIn) require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) - sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken, "") + sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken) require.NoError(t, err) require.Equal(t, "https://keel.so", iss) @@ -163,15 +176,21 @@ func TestTokenExchange_ValidUpdatedIdentity(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + var inserted []map[string]any database.GetDB().Raw(fmt.Sprintf("INSERT INTO identity (external_id, issuer, email) VALUES ('id|285620','%s','weaveton@keel.so') RETURNING *", server.Issuer)).Scan(&inserted) require.Len(t, inserted, 1) @@ -196,7 +215,7 @@ func TestTokenExchange_ValidUpdatedIdentity(t *testing.T) { require.NotEmpty(t, validResponse.ExpiresIn) require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) - sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken, "") + sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken) require.NoError(t, err) require.Equal(t, "https://keel.so", iss) @@ -426,15 +445,21 @@ func TestRefreshToken_Valid(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + server.SetUser("id|285620", &oauth.UserClaims{ Email: "keelson@keel.so", }) diff --git a/runtime/oauth/access_token.go b/runtime/oauth/access_token.go index b712e3ded..889ee9e5f 100644 --- a/runtime/oauth/access_token.go +++ b/runtime/oauth/access_token.go @@ -7,7 +7,6 @@ import ( "time" "github.com/golang-jwt/jwt/v4" - "github.com/samber/lo" "github.com/segmentio/ksuid" "github.com/teamkeel/keel/runtime/common" "github.com/teamkeel/keel/runtime/runtimectx" @@ -31,11 +30,14 @@ type AccessTokenClaims struct { func GenerateAccessToken(ctx context.Context, identityId string) (string, time.Duration, error) { expiry := DefaultAccessTokenExpiry - config, err := runtimectx.GetAuthConfig(ctx) - if err == nil { - if config != nil && config.Keel != nil { - expiry = time.Duration(config.Keel.TokenDuration) * time.Second - } + + config, err := runtimectx.GetOAuthConfig(ctx) + if err != nil { + return "", 0, err + } + + if config != nil && config.Tokens != nil && config.Tokens.AccessTokenExpiry != 0 { + expiry = time.Duration(config.Tokens.AccessTokenExpiry) * time.Second } token, err := generateToken(ctx, identityId, []string{}, expiry) @@ -75,7 +77,7 @@ func generateToken(ctx context.Context, sub string, aud []string, expiresIn time return tokenString, nil } -func ValidateAccessToken(ctx context.Context, tokenString string, audienceClaim string) (string, string, error) { +func ValidateAccessToken(ctx context.Context, tokenString string) (string, string, error) { ctx, span := tracer.Start(ctx, "Validate access token") defer span.End() @@ -108,12 +110,6 @@ func ValidateAccessToken(ctx context.Context, tokenString string, audienceClaim return "", "", ErrTokenExpired } - if audienceClaim != "" { - if !lo.Contains(claims.Audience, audienceClaim) { - return "", "", ErrInvalidToken - } - } - if err != nil || !token.Valid { return "", "", ErrInvalidToken } diff --git a/runtime/oauth/access_token_test.go b/runtime/oauth/access_token_test.go new file mode 100644 index 000000000..ac2351bbe --- /dev/null +++ b/runtime/oauth/access_token_test.go @@ -0,0 +1,278 @@ +package oauth_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/segmentio/ksuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/config" + "github.com/teamkeel/keel/runtime/oauth" + "github.com/teamkeel/keel/runtime/runtimectx" + "github.com/teamkeel/keel/testhelpers" +) + +func newContextWithPK() context.Context { + ctx := context.Background() + + pk, _ := testhelpers.GetEmbeddedPrivateKey() + ctx = runtimectx.WithPrivateKey(ctx, pk) + + return ctx +} + +func TestAccessTokenGenerationAndParsingWithoutPrivateKey(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) + require.NoError(t, err) + require.Equal(t, identityId.String(), parsedId) + require.Equal(t, oauth.KeelIssuer, iss) +} + +func TestAccessTokenGenerationAndParsingWithSamePrivateKey(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) + require.NoError(t, err) + require.Equal(t, identityId.String(), parsedId) + require.Equal(t, oauth.KeelIssuer, iss) +} + +func TestAccessTokenGenerationWithPrivateKeyAndParsingWithoutPrivateKey(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + ctx = newContextWithPK() + parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) + require.ErrorIs(t, oauth.ErrInvalidToken, err) + require.Empty(t, parsedId) + require.Empty(t, iss) +} + +func TestAccessTokenGenerationWithoutPrivateKeyAndParsingWithPrivateKey(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) + require.ErrorIs(t, oauth.ErrInvalidToken, err) + require.Empty(t, parsedId) + require.Empty(t, iss) +} + +func TestAccessTokenGenerationAndParsingWithDifferentPrivateKeys(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey1, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey1) + require.NoError(t, err) + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + privateKey2, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey2) + require.NoError(t, err) + + parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) + require.ErrorIs(t, oauth.ErrInvalidToken, err) + require.Empty(t, parsedId) + require.Empty(t, iss) +} + +func TestAccessTokenIsRSAMethodWithPrivateKey(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + jwtToken, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, jwtToken) + + _, err = jwt.ParseWithClaims(jwtToken, &oauth.AccessTokenClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + assert.Fail(t, "Invalid signing method. Expected RSA.") + } + return &privateKey.PublicKey, nil + }) + require.NoError(t, err) +} + +func TestAccessTokenHasDefaultExpiryClaims(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + jwtToken, lifespan, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, jwtToken) + + claims := &oauth.AccessTokenClaims{} + _, err = jwt.ParseWithClaims(jwtToken, claims, func(token *jwt.Token) (interface{}, error) { + return &privateKey.PublicKey, nil + }) + require.NoError(t, err) + + issuedAt := claims.IssuedAt.Time + expiry := claims.ExpiresAt.Time + + require.Greater(t, expiry, time.Now()) + require.Equal(t, issuedAt.Add(oauth.DefaultAccessTokenExpiry), expiry) + require.Equal(t, oauth.DefaultAccessTokenExpiry, lifespan) +} + +func TestAccessTokenHasCustomClaims(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Tokens: &config.TokensConfig{ + AccessTokenExpiry: 360, + }, + }) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + jwtToken, lifespan, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, jwtToken) + + claims := &oauth.IdTokenClaims{} + _, err = jwt.ParseWithClaims(jwtToken, claims, func(token *jwt.Token) (interface{}, error) { + return &privateKey.PublicKey, nil + }) + require.NoError(t, err) + + issuedAt := claims.IssuedAt.Time + expiry := claims.ExpiresAt.Time + + require.Greater(t, expiry, time.Now()) + require.Equal(t, issuedAt.Add(time.Second*360), expiry) + require.Equal(t, time.Second*360, lifespan) +} + +func TestShortExpiredAccessTokenIsInvalid(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Tokens: &config.TokensConfig{ + AccessTokenExpiry: 1, + }, + }) + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + time.Sleep(1100 * time.Millisecond) + + parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) + require.ErrorIs(t, oauth.ErrTokenExpired, err) + require.Empty(t, parsedId) + require.Empty(t, iss) +} + +func TestExpiredAccessTokenIsInvalid(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + require.NoError(t, err) + + // Create the jwt 1 second expired. + now := time.Now().UTC().Add(-oauth.DefaultAccessTokenExpiry).Add(time.Second * -1) + claims := oauth.AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: identityId.String(), + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour * 24)), + IssuedAt: jwt.NewNumericDate(now), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + parsedId, iss, err := oauth.ValidateAccessToken(ctx, tokenString) + require.ErrorIs(t, oauth.ErrTokenExpired, err) + require.Empty(t, parsedId) + require.Empty(t, iss) +} + +func TestAccessTokenIssueClaimIsKeel(t *testing.T) { + ctx := newContextWithPK() + identityId := ksuid.New() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) + + bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) + require.NoError(t, err) + require.NotEmpty(t, bearerJwt) + + claims := &oauth.AccessTokenClaims{} + _, err = jwt.ParseWithClaims(bearerJwt, claims, func(token *jwt.Token) (interface{}, error) { + return &privateKey.PublicKey, nil + }) + require.NoError(t, err) + + issuedAt := claims.Issuer + require.Equal(t, oauth.KeelIssuer, issuedAt) +} diff --git a/runtime/oauth/id_token.go b/runtime/oauth/id_token.go index ff441abc1..307dcad89 100644 --- a/runtime/oauth/id_token.go +++ b/runtime/oauth/id_token.go @@ -61,21 +61,17 @@ func VerifyIdToken(ctx context.Context, idTokenRaw string) (*oidc.IDToken, error span.SetAttributes(attribute.String("issuer", issuer)) - authConfig, err := runtimectx.GetAuthConfig(ctx) + authConfig, err := runtimectx.GetOAuthConfig(ctx) if err != nil { return nil, err } - issuerPermitted := authConfig.AllowAnyIssuers - if !issuerPermitted { - for _, e := range authConfig.Issuers { - if e.Iss == issuer { - issuerPermitted = true - } - } + hasIssuer, err := authConfig.HasOidcIssuer(issuer) + if err != nil { + return nil, err } - if !issuerPermitted { + if !hasIssuer { return nil, fmt.Errorf("issuer %s not registered to authenticate on this server", issuer) } diff --git a/runtime/oauth/id_token_test.go b/runtime/oauth/id_token_test.go index 591e57262..8db68fafb 100644 --- a/runtime/oauth/id_token_test.go +++ b/runtime/oauth/id_token_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/config" "github.com/teamkeel/keel/runtime/oauth" "github.com/teamkeel/keel/runtime/oauth/oauthtest" "github.com/teamkeel/keel/runtime/runtimectx" @@ -15,15 +16,21 @@ import ( func TestIdTokenAuth_Valid(t *testing.T) { ctx := context.Background() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + server.SetUser("id|285620", &oauth.UserClaims{ Email: "keelson@keel.so", }) @@ -41,15 +48,21 @@ func TestIdTokenAuth_Valid(t *testing.T) { func TestIdTokenAuth_IncorrectlySigned(t *testing.T) { ctx := context.Background() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + server.SetUser("id|285620", &oauth.UserClaims{ Email: "keelson@keel.so", }) @@ -72,15 +85,21 @@ func TestIdTokenAuth_IncorrectlySigned(t *testing.T) { func TestIdTokenAuth_IssuerMismatch(t *testing.T) { ctx := context.Background() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + issuer := server.Config["issuer"] server.SetUser("id|285620", &oauth.UserClaims{ @@ -100,18 +119,16 @@ func TestIdTokenAuth_IssuerMismatch(t *testing.T) { require.Nil(t, idToken) } -func TestIdTokenAuth_IssuerNotRegistered(t *testing.T) { +func TestIdTokenAuth_IssuerNotConfigured(t *testing.T) { ctx := context.Background() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: false, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config with no issuer + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{}) + server.SetUser("id|285620", &oauth.UserClaims{ Email: "keelson@keel.so", }) @@ -130,15 +147,21 @@ func TestIdTokenAuth_IssuerNotRegistered(t *testing.T) { func TestIdTokenAuth_ExpiredIdToken(t *testing.T) { ctx := context.Background() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - // OIDC test server server, err := oauthtest.NewOIDCServer() require.NoError(t, err) + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + IssuerUrl: server.Issuer, + }, + }, + }) + server.IdTokenLifespan = 0 * time.Second server.SetUser("id|285620", &oauth.UserClaims{ diff --git a/runtime/oauth/refresh_token.go b/runtime/oauth/refresh_token.go index eaee505b1..c1c911313 100644 --- a/runtime/oauth/refresh_token.go +++ b/runtime/oauth/refresh_token.go @@ -8,12 +8,13 @@ import ( "github.com/dchest/uniuri" "github.com/teamkeel/keel/db" + "github.com/teamkeel/keel/runtime/runtimectx" "golang.org/x/crypto/sha3" ) const ( refreshTokenLength = 64 - defaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90 // 3 months is the default + DefaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90 // 3 months is the default ) // NewRefreshToken generates a new refresh token for the identity using the @@ -38,7 +39,18 @@ func NewRefreshToken(ctx context.Context, identityId string) (string, error) { } now := time.Now().UTC() - expiry := now.Add(defaultRefreshTokenExpiry) + var expiresAt time.Time + + authConfig, err := runtimectx.GetOAuthConfig(ctx) + if err != nil { + return "", err + } + + if authConfig != nil && authConfig.Tokens != nil && authConfig.Tokens.RefreshTokenExpiry != 0 { + expiresAt = now.Add(time.Duration(authConfig.Tokens.RefreshTokenExpiry) * time.Second) + } else { + expiresAt = now.Add(DefaultRefreshTokenExpiry) + } sql := ` INSERT INTO @@ -46,7 +58,7 @@ func NewRefreshToken(ctx context.Context, identityId string) (string, error) { VALUES (?, ?, ?, ?)` - db := database.GetDB().Exec(sql, hash, identityId, expiry, now) + db := database.GetDB().Exec(sql, hash, identityId, expiresAt, now) if db.Error != nil { return "", db.Error } @@ -98,6 +110,8 @@ func RotateRefreshToken(ctx context.Context, refreshTokenRaw string) (isValid bo ?, identity_id, expires_at, now() FROM revoked_token + WHERE + expires_at >= now() RETURNING *;` rows := []map[string]any{} diff --git a/runtime/oauth/refresh_token_test.go b/runtime/oauth/refresh_token_test.go index dbea44493..1355d76db 100644 --- a/runtime/oauth/refresh_token_test.go +++ b/runtime/oauth/refresh_token_test.go @@ -3,8 +3,10 @@ package oauth_test import ( "context" "testing" + "time" "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/config" "github.com/teamkeel/keel/runtime/oauth" "github.com/teamkeel/keel/runtime/runtimectx" keeltesting "github.com/teamkeel/keel/testing" @@ -16,11 +18,6 @@ func TestNewRefreshToken_NotEmpty(t *testing.T) { ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") require.NoError(t, err) require.NotEmpty(t, refreshToken) @@ -29,11 +26,6 @@ func TestNewRefreshToken_NotEmpty(t *testing.T) { func TestNewRefreshToken_ErrorOnEmptyIdentityId(t *testing.T) { ctx := context.Background() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - _, err := oauth.NewRefreshToken(ctx, "") require.Error(t, err) } @@ -42,11 +34,6 @@ func TestRotateRefreshToken_Valid(t *testing.T) { ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") require.NoError(t, err) @@ -64,18 +51,36 @@ func TestRotateRefreshToken_Valid(t *testing.T) { require.NotEqual(t, newRefreshToken2, newRefreshToken1) } -func TestRotateRefreshToken_ReuseRefreshTokenNotValid(t *testing.T) { +func TestRotateRefreshToken_Expired(t *testing.T) { ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Tokens: &config.TokensConfig{ + RefreshTokenExpiry: 1, + }, }) refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") require.NoError(t, err) + time.Sleep(1100 * time.Millisecond) + + isValid, newRefreshToken, identityId, err := oauth.RotateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.False(t, isValid) + require.Empty(t, identityId) + require.Empty(t, newRefreshToken) +} + +func TestRotateRefreshToken_ReuseRefreshTokenNotValid(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + isValid, newRefreshToken, identityId, err := oauth.RotateRefreshToken(ctx, refreshToken) require.NoError(t, err) require.True(t, isValid) @@ -93,11 +98,6 @@ func TestRevokeRefreshToken_Unauthorised(t *testing.T) { ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") require.NoError(t, err) @@ -113,11 +113,6 @@ func TestRevokeRefreshToken_MultipleForIdentity(t *testing.T) { ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // Set up auth config - ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ - AllowAnyIssuers: true, - }) - refreshToken1, err := oauth.NewRefreshToken(ctx, "identity_id") require.NoError(t, err) diff --git a/runtime/runtimectx/oauth.go b/runtime/runtimectx/oauth.go new file mode 100644 index 000000000..f2630624f --- /dev/null +++ b/runtime/runtimectx/oauth.go @@ -0,0 +1,31 @@ +package runtimectx + +import ( + "context" + "errors" + + "github.com/teamkeel/keel/config" +) + +const ( + oauthContextKey contextKey = "oauthConfig" +) + +func WithOAuthConfig(ctx context.Context, config *config.AuthConfig) context.Context { + ctx = context.WithValue(ctx, oauthContextKey, config) + return ctx +} + +func GetOAuthConfig(ctx context.Context) (*config.AuthConfig, error) { + v := ctx.Value(oauthContextKey) + if v == nil { + return &config.AuthConfig{}, nil + } + + config, ok := v.(*config.AuthConfig) + + if !ok { + return nil, errors.New("auth config in the context has wrong value type") + } + return config, nil +}