diff --git a/config/auth.go b/config/auth.go index 124188dd9..7c8586d71 100644 --- a/config/auth.go +++ b/config/auth.go @@ -3,12 +3,23 @@ package config import ( "fmt" "net/url" + "strings" + "time" "github.com/samber/lo" ) +const ( + // 24 hours is the default access token expiry period + DefaultAccessTokenExpiry time.Duration = time.Hour * 24 + // 3 months is the default refresh token expiry period + DefaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90 +) + const ( GoogleProvider = "google" + FacebookProvider = "facebook" + GitLabProvider = "gitlab" OpenIdConnectProvider = "oidc" OAuthProvider = "oauth" ) @@ -16,19 +27,22 @@ const ( var ( SupportedProviderTypes = []string{ GoogleProvider, + FacebookProvider, + GitLabProvider, OpenIdConnectProvider, OAuthProvider, } ) type AuthConfig struct { - Tokens *TokensConfig `yaml:"tokens"` - Providers []Provider `yaml:"providers"` + Tokens TokensConfig `yaml:"tokens"` + Providers []Provider `yaml:"providers"` } type TokensConfig struct { - AccessTokenExpiry int `yaml:"accessTokenExpiry"` - RefreshTokenExpiry int `yaml:"refreshTokenExpiry"` + AccessTokenExpiry *int `yaml:"accessTokenExpiry,omitempty"` + RefreshTokenExpiry *int `yaml:"refreshTokenExpiry,omitempty"` + RefreshTokenRotationEnabled *bool `yaml:"refreshTokenRotationEnabled,omitempty"` } type Provider struct { @@ -40,6 +54,33 @@ type Provider struct { AuthorizationUrl string `yaml:"authorizationUrl"` } +// AccessTokenExpiry retrieves the configured or default access token expiry +func (c *AuthConfig) AccessTokenExpiry() time.Duration { + if c.Tokens.AccessTokenExpiry != nil { + return time.Duration(*c.Tokens.AccessTokenExpiry) * time.Second + } else { + return DefaultAccessTokenExpiry + } +} + +// RefreshTokenExpiry retrieves the configured or default refresh token expiry +func (c *AuthConfig) RefreshTokenExpiry() time.Duration { + if c.Tokens.RefreshTokenExpiry != nil { + return time.Duration(*c.Tokens.RefreshTokenExpiry) * time.Second + } else { + return DefaultRefreshTokenExpiry + } +} + +// RefreshTokenRotationEnabled retrieves the configured or default refresh token rotation +func (c *AuthConfig) RefreshTokenRotationEnabled() bool { + if c.Tokens.RefreshTokenRotationEnabled != nil { + return *c.Tokens.RefreshTokenRotationEnabled + } else { + return true + } +} + func (c *AuthConfig) GetOidcProviders() []Provider { oidcProviders := []Provider{} for _, p := range c.Providers { @@ -50,9 +91,9 @@ func (c *AuthConfig) GetOidcProviders() []Provider { return oidcProviders } -// GetProvidersOidcIssuer gets all providers by issuer url. -// It's possible that multiple providers from the same issuer as configured. -func (c *AuthConfig) GetProvidersOidcIssuer(issuer string) ([]Provider, error) { +// GetOidcProvidersByIssuer gets all OpenID Connect providers by issuer url. +// It's possible that multiple providers from the same issuer are configured. +func (c *AuthConfig) GetOidcProvidersByIssuer(issuer string) ([]Provider, error) { providers := []Provider{} for _, p := range c.Providers { @@ -64,7 +105,7 @@ func (c *AuthConfig) GetProvidersOidcIssuer(issuer string) ([]Provider, error) { if err != nil { return nil, err } - if issuerUrl == issuer { + if strings.TrimSuffix(issuerUrl, "/") == strings.TrimSuffix(issuer, "/") { providers = append(providers, p) } } @@ -75,7 +116,11 @@ func (c *AuthConfig) GetProvidersOidcIssuer(issuer string) ([]Provider, error) { func (c *Provider) GetIssuer() (string, error) { switch c.Type { case GoogleProvider: - return "https://accounts.google.com/", nil + return "https://accounts.google.com", nil + case FacebookProvider: + return "https://www.facebook.com", nil + case GitLabProvider: + return "https://gitlab.com", nil case OpenIdConnectProvider: return c.IssuerUrl, nil default: diff --git a/config/config.go b/config/config.go index 57169b5dd..644bf8d71 100644 --- a/config/config.go +++ b/config/config.go @@ -231,7 +231,6 @@ func Validate(config *ProjectConfig) *ConfigErrors { if hasIncorrectNames { for incorrectName := range incorrectNames { startsWith := reservedEnvVarRegex.FindString(incorrectName) - errors = append(errors, &ConfigError{ Type: "reserved", Message: fmt.Sprintf(ConfigReservedNameErrorString, incorrectName, startsWith), @@ -239,14 +238,14 @@ func Validate(config *ProjectConfig) *ConfigErrors { } } - if config.Auth.Tokens != nil && config.Auth.Tokens.AccessTokenExpiry < 0 { + if config.Auth.AccessTokenExpiry() <= 0 { errors = append(errors, &ConfigError{ Type: "invalid", Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "access", "accessTokenExpiry"), }) } - if config.Auth.Tokens != nil && config.Auth.Tokens.RefreshTokenExpiry < 0 { + if config.Auth.RefreshTokenExpiry() <= 0 { errors = append(errors, &ConfigError{ Type: "invalid", Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "refresh", "refreshTokenExpiry"), diff --git a/config/config_test.go b/config/config_test.go index c23ed9012..c5d652774 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -119,8 +120,26 @@ func TestAuthTokens(t *testing.T) { config, err := Load("fixtures/test_auth.yaml") assert.NoError(t, err) - assert.Equal(t, 3600, config.Auth.Tokens.AccessTokenExpiry) - assert.Equal(t, 604800, config.Auth.Tokens.RefreshTokenExpiry) + assert.Equal(t, 3600, *config.Auth.Tokens.AccessTokenExpiry) + assert.Equal(t, 604800, *config.Auth.Tokens.RefreshTokenExpiry) + assert.Equal(t, false, *config.Auth.Tokens.RefreshTokenRotationEnabled) + + assert.Equal(t, time.Duration(3600)*time.Second, config.Auth.AccessTokenExpiry()) + assert.Equal(t, time.Duration(604800)*time.Second, config.Auth.RefreshTokenExpiry()) + assert.Equal(t, false, config.Auth.RefreshTokenRotationEnabled()) +} + +func TestAuthDefaults(t *testing.T) { + config, err := Load("fixtures/test_auth_empty.yaml") + assert.NoError(t, err) + + assert.Nil(t, config.Auth.Tokens.AccessTokenExpiry) + assert.Nil(t, config.Auth.Tokens.RefreshTokenExpiry) + assert.Nil(t, config.Auth.Tokens.RefreshTokenRotationEnabled) + + assert.Equal(t, time.Duration(24)*time.Hour, config.Auth.AccessTokenExpiry()) + assert.Equal(t, time.Duration(24)*time.Hour*90, config.Auth.RefreshTokenExpiry()) + assert.Equal(t, true, config.Auth.RefreshTokenRotationEnabled()) } func TestAuthNegativeTokenLifespan(t *testing.T) { @@ -171,9 +190,9 @@ func TestDuplicateProviderName(t *testing.T) { func TestInvalidProviderTypes(t *testing.T) { _, err := Load("fixtures/test_auth_invalid_types.yaml") - assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, oidc, oauth\n") - assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, oidc, oauth\n") - assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, oidc, oauth\n") + assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, facebook, gitlab, oidc, oauth\n") + assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, facebook, gitlab, oidc, oauth\n") + assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, facebook, gitlab, oidc, oauth\n") } func TestMissingClientId(t *testing.T) { @@ -205,15 +224,15 @@ func TestGetOidcIssuer(t *testing.T) { config, err := Load("fixtures/test_auth.yaml") assert.NoError(t, err) - googleIssuer, err := config.Auth.GetProvidersOidcIssuer("https://accounts.google.com/") + googleIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://accounts.google.com/") assert.NoError(t, err) assert.Len(t, googleIssuer, 2) - auth0Issuer, err := config.Auth.GetProvidersOidcIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com") + auth0Issuer, err := config.Auth.GetOidcProvidersByIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com") assert.NoError(t, err) assert.Len(t, auth0Issuer, 1) - nopeIssuer, err := config.Auth.GetProvidersOidcIssuer("https://nope.com") + nopeIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://nope.com") assert.NoError(t, err) assert.Len(t, nopeIssuer, 0) } @@ -222,7 +241,7 @@ func TestGetOidcSameIssuers(t *testing.T) { config, err := Load("fixtures/test_auth_same_issuers.yaml") assert.NoError(t, err) - googleIssuer, err := config.Auth.GetProvidersOidcIssuer("https://accounts.google.com/") + googleIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://accounts.google.com/") assert.NoError(t, err) assert.Len(t, googleIssuer, 3) } diff --git a/config/fixtures/test_auth.yaml b/config/fixtures/test_auth.yaml index 7ec4ef8df..f9bf4eb10 100644 --- a/config/fixtures/test_auth.yaml +++ b/config/fixtures/test_auth.yaml @@ -2,7 +2,8 @@ auth: tokens: accessTokenExpiry: 3600 refreshTokenExpiry: 604800 - + refreshTokenRotationEnabled: false + providers: # Built-in Google provider - type: google diff --git a/config/fixtures/test_auth_empty.yaml b/config/fixtures/test_auth_empty.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/runtime/apis/authapi/revoke_endpoint.go b/runtime/apis/authapi/revoke_endpoint.go index d12a6d9a1..b273ace11 100644 --- a/runtime/apis/authapi/revoke_endpoint.go +++ b/runtime/apis/authapi/revoke_endpoint.go @@ -5,12 +5,53 @@ import ( "github.com/teamkeel/keel/proto" "github.com/teamkeel/keel/runtime/common" + "github.com/teamkeel/keel/runtime/oauth" + "go.opentelemetry.io/otel/trace" ) +type RevokeEndpointErrorResponse struct { + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + func RevokeHandler(schema *proto.Schema) common.HandlerFunc { return func(r *http.Request) common.Response { - return common.Response{ - Status: http.StatusNotImplemented, + ctx, span := tracer.Start(r.Context(), "Revoke Token") + defer span.End() + + if r.Method != http.MethodPost { + return common.NewJsonResponse(http.StatusMethodNotAllowed, &ErrorResponse{ + Error: InvalidRequest, + ErrorDescription: "the revoke endpoint only accepts POST", + }, nil) } + + if !HasContentType(r.Header, "application/x-www-form-urlencoded") { + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, + ErrorDescription: "the request must be an encoded form with Content-Type application/x-www-form-urlencoded", + }, nil) + } + + refreshTokenRaw := r.FormValue(ArgToken) + + if refreshTokenRaw == "" { + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, + ErrorDescription: "the refresh token must be provided in the token field", + }, nil) + } + + // Revoke the refresh token + err := oauth.RevokeRefreshToken(ctx, refreshTokenRaw) + if err != nil { + span.RecordError(err, trace.WithStackTrace(true)) + return common.NewJsonResponse(http.StatusUnauthorized, &ErrorResponse{ + Error: InvalidClient, + ErrorDescription: "possible causes may be that the id token is invalid, has expired, or has insufficient claims", + }, nil) + } + + return common.NewJsonResponse(http.StatusOK, nil, nil) } } diff --git a/runtime/apis/authapi/revoke_endpoint_test.go b/runtime/apis/authapi/revoke_endpoint_test.go new file mode 100644 index 000000000..892e4dde6 --- /dev/null +++ b/runtime/apis/authapi/revoke_endpoint_test.go @@ -0,0 +1,157 @@ +package authapi_test + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/config" + "github.com/teamkeel/keel/runtime/apis/authapi" + "github.com/teamkeel/keel/runtime/oauth" + "github.com/teamkeel/keel/runtime/oauth/oauthtest" + "github.com/teamkeel/keel/runtime/runtimectx" + keeltesting "github.com/teamkeel/keel/testing" +) + +func TestRevokeToken_Success(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // OIDC test server + server, err := oauthtest.NewOIDCServer() + require.NoError(t, err) + + // Set up auth config + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + ClientId: "oidc-client-id", + IssuerUrl: server.Issuer, + }, + }, + }) + + server.SetUser("id|285620", &oauth.UserClaims{ + Email: "keelson@keel.so", + }) + + // Get ID token from server + idToken, err := server.FetchIdToken("id|285620", []string{"oidc-client-id"}) + require.NoError(t, err) + + // Make a token exchange grant request + requestToken := makeTokenExchangeRequest(ctx, idToken) + + // Handle runtime request, expecting TokenResponse + validResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, requestToken) + require.NoError(t, err) + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + + // Make a token exchange grant request + requestRevoke := makeRevokeTokenRequest(ctx, validResponse.RefreshToken) + + // Handle runtime request, expecting TokenResponse + _, revokeHttpResponse, err := handleRuntimeRequest[any](schema, requestRevoke) + require.NoError(t, err) + require.Equal(t, http.StatusOK, revokeHttpResponse.StatusCode) + + // Make a token exchange grant request + refreshToken := makeRefreshTokenRequest(ctx, validResponse.RefreshToken) + + // Handle runtime request, expecting TokenResponse + refreshResponse, refreshHttpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, refreshToken) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, refreshHttpResponse.StatusCode) + require.Equal(t, "invalid_client", refreshResponse.Error) +} + +func TestRevokeEndpoint_HttpGet(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Make a token exchange grant request + request := makeRevokeTokenRequest(ctx, "mock_token") + + request.Method = http.MethodGet + + // Handle runtime request, expecting TokenErrorResponse + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, httpResponse.StatusCode) + require.Equal(t, "invalid_request", errorResponse.Error) + require.Equal(t, "the revoke endpoint only accepts POST", errorResponse.ErrorDescription) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) +} + +func TestRevokeEndpoint_EmptyToken(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Make a revoke request + request := makeRevokeTokenRequest(ctx, "mock_token") + form := url.Values{} + form.Add("mock_token", "") + request.URL.RawQuery = form.Encode() + + // Handle runtime request, expecting TokenErrorResponse + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) + require.Equal(t, "invalid_request", errorResponse.Error) + require.Equal(t, "the refresh token must be provided in the token field", errorResponse.ErrorDescription) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) +} + +func TestRevokeEndpoint_NoToken(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Make a revoke request + request := makeRevokeTokenRequest(ctx, "mock_token") + form := url.Values{} + form.Del("token") + request.URL.RawQuery = form.Encode() + + // Handle runtime request, expecting TokenErrorResponse + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) + require.Equal(t, "invalid_request", errorResponse.Error) + require.Equal(t, "the refresh token must be provided in the token field", errorResponse.ErrorDescription) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) +} + +func TestRevokeEndpoint_UnknownToken(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Make a revoke request + request := makeRevokeTokenRequest(ctx, "mock_token") + + // Handle runtime request, expecting TokenErrorResponse + _, httpResponse, err := handleRuntimeRequest[any](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) +} + +func makeRevokeTokenRequest(ctx context.Context, token string) *http.Request { + request := httptest.NewRequest(http.MethodPost, "http://mykeelapp.keel.so/auth/revoke", nil) + request.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + form := url.Values{} + form.Add("token", token) + request.URL.RawQuery = form.Encode() + request = request.WithContext(ctx) + + return request +} diff --git a/runtime/apis/authapi/token_endpoint.go b/runtime/apis/authapi/token_endpoint.go index e60ff8230..798c55f5a 100644 --- a/runtime/apis/authapi/token_endpoint.go +++ b/runtime/apis/authapi/token_endpoint.go @@ -9,19 +9,24 @@ import ( "github.com/teamkeel/keel/runtime/actions" "github.com/teamkeel/keel/runtime/common" "github.com/teamkeel/keel/runtime/oauth" + "github.com/teamkeel/keel/runtime/runtimectx" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var tracer = otel.Tracer("github.com/teamkeel/keel/runtime") // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +// https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 const ( ArgGrantType = "grant_type" ArgSubjectToken = "subject_token" ArgSubjectTokenType = "subject_token_type" ArgRequestedTokeType = "requested_token_type" ArgRefreshToken = "refresh_token" + ArgToken = "token" ) const ( @@ -37,16 +42,17 @@ type TokenResponse struct { } // https://openid.net/specs/openid-connect-standard-1_0-21_orig.html#AccessTokenErrorResponse -type TokenErrorResponse struct { +// https://datatracker.ietf.org/doc/html/rfc7009#section-2.2 +type ErrorResponse struct { Error string `json:"error,omitempty"` ErrorDescription string `json:"error_description,omitempty"` } // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 const ( - TokenEndpointUnsupportedGrantType = "unsupported_grant_type" - TokenEndpointInvalidClient = "invalid_client" - TokenEndpointInvalidRequest = "invalid_request" + UnsupportedGrantType = "unsupported_grant_type" + InvalidClient = "invalid_client" + InvalidRequest = "invalid_request" ) const ( @@ -66,16 +72,26 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { ctx, span := tracer.Start(r.Context(), "Token Endpoint") defer span.End() + var identityId string + var refreshToken string + + config, err := runtimectx.GetOAuthConfig(ctx) + if err != nil { + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } + if r.Method != http.MethodPost { - return common.NewJsonResponse(http.StatusMethodNotAllowed, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusMethodNotAllowed, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the token endpoint only accepts POST", }, nil) } if !HasContentType(r.Header, "application/x-www-form-urlencoded") { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the request must be an encoded form with Content-Type application/x-www-form-urlencoded", }, nil) } @@ -83,8 +99,8 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { grantType := r.FormValue(ArgGrantType) if grantType == "" { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the grant-type field is required with either 'refresh_token' or 'token_exchange'", }, nil) } @@ -96,69 +112,70 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { switch grantType { case GrantTypeRefreshToken: if !r.Form.Has(ArgRefreshToken) { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the refresh token must be provided in the refresh_token field", }, nil) } - refreshTokenRaw := r.Form.Get(ArgRefreshToken) + refreshTokenRaw := r.FormValue(ArgRefreshToken) if refreshTokenRaw == "" { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the refresh token in the refresh_token field cannot be an empty string", }, nil) } - isValid, newRefreshToken, identityId, err := oauth.RotateRefreshToken(ctx, refreshTokenRaw) - if err != nil { - span.RecordError(err) - return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + var isValid bool + if config.RefreshTokenRotationEnabled() { + // Rotate and revoke this refresh token, and mint a new one. + isValid, refreshToken, identityId, err = oauth.RotateRefreshToken(ctx, refreshTokenRaw) + if err != nil { + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } + } else { + // Response with the same refresh token when refresh token rotation is disabled + refreshToken = refreshTokenRaw + + // Check that the refresh token exists and has not expired. + isValid, identityId, err = oauth.ValidateRefreshToken(ctx, refreshToken) + if err != nil { + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } } if !isValid { - return common.NewJsonResponse(http.StatusUnauthorized, &TokenErrorResponse{ - Error: TokenEndpointInvalidClient, + return common.NewJsonResponse(http.StatusUnauthorized, &ErrorResponse{ + Error: InvalidClient, ErrorDescription: "possible causes may be that the refresh token has been revoked or has expired", }, nil) } - // Generate an access token for this identity. - accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identityId) - if err != nil { - span.RecordError(err) - return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) - } - - response := &TokenResponse{ - AccessToken: accessTokenRaw, - TokenType: TokenType, - ExpiresIn: int(expiresIn.Seconds()), - RefreshToken: newRefreshToken, - } - - return common.NewJsonResponse(http.StatusOK, response, nil) case GrantTypeTokenExchange: if !r.Form.Has(ArgSubjectToken) { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the ID token must be provided in the subject_token field", }, nil) } // We do not require subject_token_type, but if provided we only support 'id_token' if r.Form.Has(ArgSubjectTokenType) && r.Form.Get(ArgSubjectTokenType) != "id_token" { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the only supported subject_token_type is 'id_token'", }, nil) } // We do not require requested_token_type, but if provided we only support 'access_token' if r.Form.Has(ArgRequestedTokeType) && (r.Form.Get(ArgRequestedTokeType) != "urn:ietf:params:oauth:token-type:access_token" && r.Form.Get("requested_token_type") != "access_token") { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the only supported requested_token_type is 'access_token'", }, nil) } @@ -166,8 +183,8 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { idTokenRaw := r.Form.Get(ArgSubjectToken) if idTokenRaw == "" { - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "the ID token in the subject_token field cannot be an empty string", }, nil) } @@ -180,9 +197,9 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { // Verify the ID token with the OIDC provider idToken, err := oauth.VerifyIdToken(ctx, idTokenRaw) if err != nil { - span.RecordError(err) - return common.NewJsonResponse(http.StatusUnauthorized, &TokenErrorResponse{ - Error: TokenEndpointInvalidClient, + span.RecordError(err, trace.WithStackTrace(true)) + return common.NewJsonResponse(http.StatusUnauthorized, &ErrorResponse{ + Error: InvalidClient, ErrorDescription: "possible causes may be that the id token is invalid, has expired, or has insufficient claims", }, nil) } @@ -190,62 +207,69 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { // Extract claims var claims oauth.IdTokenClaims if err := idToken.Claims(&claims); err != nil { - span.RecordError(err) - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointInvalidRequest, + span.RecordError(err, trace.WithStackTrace(true)) + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: InvalidRequest, ErrorDescription: "insufficient claims on id_token", }, nil) } identity, err := actions.FindIdentityByExternalId(ctx, schema, idToken.Subject, idToken.Issuer) if err != nil { - span.RecordError(err) + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) } if identity == nil { identity, err = actions.CreateIdentityWithIdTokenClaims(ctx, schema, idToken.Subject, idToken.Issuer, claims) if err != nil { - span.RecordError(err) + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) } } else { identity, err = actions.UpdateIdentityWithIdTokenClaims(ctx, schema, idToken.Subject, idToken.Issuer, claims) if err != nil { - span.RecordError(err) + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) } } - // Generate an access token for this identity. - accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identity.Id) - if err != nil { - span.RecordError(err) - return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) - } - // Generate a refresh token. - refreshTokenRaw, err := oauth.NewRefreshToken(ctx, identity.Id) + refreshToken, err = oauth.NewRefreshToken(ctx, identity.Id) if err != nil { - span.RecordError(err) + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) } - response := &TokenResponse{ - AccessToken: accessTokenRaw, - TokenType: TokenType, - ExpiresIn: int(expiresIn.Seconds()), - RefreshToken: refreshTokenRaw, - } - - return common.NewJsonResponse(http.StatusOK, response, nil) + identityId = identity.Id default: - return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ - Error: TokenEndpointUnsupportedGrantType, + return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{ + Error: UnsupportedGrantType, ErrorDescription: "the only supported grants are 'refresh_token' and 'token_exchange'", }, nil) } + + // Generate a new access token for this identity. + accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identityId) + if err != nil { + span.RecordError(err, trace.WithStackTrace(true)) + span.SetStatus(codes.Error, err.Error()) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } + + response := &TokenResponse{ + AccessToken: accessTokenRaw, + TokenType: TokenType, + ExpiresIn: int(expiresIn.Seconds()), + RefreshToken: refreshToken, + } + + return common.NewJsonResponse(http.StatusOK, response, nil) } } diff --git a/runtime/apis/authapi/token_endpoint_test.go b/runtime/apis/authapi/token_endpoint_test.go index 8a1a95fca..7f238a714 100644 --- a/runtime/apis/authapi/token_endpoint_test.go +++ b/runtime/apis/authapi/token_endpoint_test.go @@ -254,7 +254,7 @@ func TestTokenEndpoint_HttpGet(t *testing.T) { request.Method = http.MethodGet // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusMethodNotAllowed, httpResponse.StatusCode) @@ -273,7 +273,7 @@ func TestTokenEndpoint_ApplicationJsonRequest(t *testing.T) { request.Header.Add("Content-Type", "application/json") // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -293,7 +293,7 @@ func TestTokenEndpoint_MissingGrantType(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -314,7 +314,7 @@ func TestTokenEndpoint_WrongGrantType(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -323,7 +323,7 @@ func TestTokenEndpoint_WrongGrantType(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestTokenExchange_NoSubjectToken(t *testing.T) { +func TestTokenExchangeGrant_NoSubjectToken(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -334,7 +334,7 @@ func TestTokenExchange_NoSubjectToken(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -343,7 +343,7 @@ func TestTokenExchange_NoSubjectToken(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestTokenExchange_EmptySubjectToken(t *testing.T) { +func TestTokenExchangeGrant_EmptySubjectToken(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -355,7 +355,7 @@ func TestTokenExchange_EmptySubjectToken(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -364,7 +364,7 @@ func TestTokenExchange_EmptySubjectToken(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestTokenExchange_WrongSubjectTokenType(t *testing.T) { +func TestTokenExchangeGrant_WrongSubjectTokenType(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -378,7 +378,7 @@ func TestTokenExchange_WrongSubjectTokenType(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -387,7 +387,7 @@ func TestTokenExchange_WrongSubjectTokenType(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestTokenExchange_WrongRequestedTokenType(t *testing.T) { +func TestTokenExchangeGrant_WrongRequestedTokenType(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -401,7 +401,7 @@ func TestTokenExchange_WrongRequestedTokenType(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -410,7 +410,7 @@ func TestTokenExchange_WrongRequestedTokenType(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestTokenExchange_BadIdToken(t *testing.T) { +func TestTokenExchangeGrant_BadIdToken(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -435,7 +435,7 @@ func TestTokenExchange_BadIdToken(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, httpResponse.StatusCode) @@ -444,7 +444,7 @@ func TestTokenExchange_BadIdToken(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestRefreshToken_Valid(t *testing.T) { +func TestRefreshTokenGrantRotationEnabled_Valid(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -453,7 +453,11 @@ func TestRefreshToken_Valid(t *testing.T) { require.NoError(t, err) // Set up auth config + refreshTokenRotation := true ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Tokens: config.TokensConfig{ + RefreshTokenRotationEnabled: &refreshTokenRotation, + }, Providers: []config.Provider{ { Type: config.OpenIdConnectProvider, @@ -475,7 +479,7 @@ func TestRefreshToken_Valid(t *testing.T) { // Make a token exchange grant request request := makeTokenExchangeRequest(ctx, idToken) - // Handle runtime request, expecting TokenErrorResponse + // Handle runtime request, expecting TokenResponse tokenExchangeResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusOK, httpResponse.StatusCode) @@ -486,7 +490,7 @@ func TestRefreshToken_Valid(t *testing.T) { // Make a refresh token grant request request = makeRefreshTokenRequest(ctx, tokenExchangeResponse.RefreshToken) - // Handle runtime request, expecting TokenErrorResponse + // Handle runtime request, expecting TokenResponse refreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) require.NoError(t, err) @@ -517,13 +521,95 @@ func TestRefreshToken_Valid(t *testing.T) { request = makeRefreshTokenRequest(ctx, tokenExchangeResponse.RefreshToken) // Handle runtime request, expecting TokenErrorResponse - secondRefreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + secondRefreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, httpResponse.StatusCode) require.Equal(t, "possible causes may be that the refresh token has been revoked or has expired", secondRefreshGrantResponse.ErrorDescription) } -func TestRefreshToken_NoRefreshToken(t *testing.T) { +func TestRefreshTokenGrantRotationDisabled_Valid(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // OIDC test server + server, err := oauthtest.NewOIDCServer() + require.NoError(t, err) + + // Set up auth config + refreshTokenRotation := false + ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ + Tokens: config.TokensConfig{ + RefreshTokenRotationEnabled: &refreshTokenRotation, + }, + Providers: []config.Provider{ + { + Type: config.OpenIdConnectProvider, + Name: "my-oidc", + ClientId: "oidc-client-id", + IssuerUrl: server.Issuer, + }, + }, + }) + + server.SetUser("id|285620", &oauth.UserClaims{ + Email: "keelson@keel.so", + }) + + // Get ID token from server + idToken, err := server.FetchIdToken("id|285620", []string{"oidc-client-id"}) + require.NoError(t, err) + + // Make a token exchange grant request + request := makeTokenExchangeRequest(ctx, idToken) + + // Handle runtime request, expecting TokenResponse + tokenExchangeResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) + require.NoError(t, err) + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + + // We need 1 second to pass in order to get a different access token + time.Sleep(1000 * time.Millisecond) + + // Make a refresh token grant request + request = makeRefreshTokenRequest(ctx, tokenExchangeResponse.RefreshToken) + + // Handle runtime request, expecting TokenResponse + refreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + require.NotEmpty(t, refreshGrantResponse.AccessToken) + require.Equal(t, "bearer", refreshGrantResponse.TokenType) + require.NotEmpty(t, refreshGrantResponse.ExpiresIn) + require.NotEmpty(t, refreshGrantResponse.RefreshToken) + require.Equal(t, refreshGrantResponse.RefreshToken, tokenExchangeResponse.RefreshToken) + require.NotEqual(t, refreshGrantResponse.AccessToken, tokenExchangeResponse.AccessToken) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) + + accessToken1Issuer, err := auth.ExtractClaimFromToken(tokenExchangeResponse.AccessToken, "iss") + require.NoError(t, err) + accessToken2Issuer, err := auth.ExtractClaimFromToken(refreshGrantResponse.AccessToken, "iss") + require.NoError(t, err) + require.NotEmpty(t, accessToken1Issuer) + require.Equal(t, accessToken1Issuer, accessToken2Issuer) + + accessToken1Sub, err := auth.ExtractClaimFromToken(tokenExchangeResponse.AccessToken, "sub") + require.NoError(t, err) + accessToken2Sub, err := auth.ExtractClaimFromToken(refreshGrantResponse.AccessToken, "sub") + require.NoError(t, err) + require.NotEmpty(t, accessToken1Sub) + require.Equal(t, accessToken1Sub, accessToken2Sub) + + // Make a refresh token grant request using the original refresh token + request = makeRefreshTokenRequest(ctx, tokenExchangeResponse.RefreshToken) + + secondRefreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) + require.NoError(t, err) + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + require.Equal(t, tokenExchangeResponse.RefreshToken, secondRefreshGrantResponse.RefreshToken) +} + +func TestRefreshTokenGrant_NoRefreshToken(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -534,7 +620,7 @@ func TestRefreshToken_NoRefreshToken(t *testing.T) { request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) @@ -543,7 +629,7 @@ func TestRefreshToken_NoRefreshToken(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func TestRefreshToken_EmptyRefreshToken(t *testing.T) { +func TestRefreshTokenGrant_EmptyRefreshToken(t *testing.T) { ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() @@ -551,7 +637,7 @@ func TestRefreshToken_EmptyRefreshToken(t *testing.T) { request := makeRefreshTokenRequest(ctx, "") // Handle runtime request, expecting TokenErrorResponse - errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.ErrorResponse](schema, request) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) diff --git a/runtime/oauth/access_token.go b/runtime/oauth/access_token.go index 889ee9e5f..7cf061271 100644 --- a/runtime/oauth/access_token.go +++ b/runtime/oauth/access_token.go @@ -13,8 +13,8 @@ import ( ) const ( - KeelIssuer = "https://keel.so" - DefaultAccessTokenExpiry time.Duration = time.Hour * 24 + // Issuer 'iss' claim for access tokens issued by Keel + KeelIssuer = "https://keel.so" ) var ( @@ -29,16 +29,16 @@ type AccessTokenClaims struct { } func GenerateAccessToken(ctx context.Context, identityId string) (string, time.Duration, error) { - expiry := DefaultAccessTokenExpiry + if identityId == "" { + return "", 0, errors.New("cannot generate access token with an empty identityId intended for the sub claim") + } config, err := runtimectx.GetOAuthConfig(ctx) if err != nil { return "", 0, err } - if config != nil && config.Tokens != nil && config.Tokens.AccessTokenExpiry != 0 { - expiry = time.Duration(config.Tokens.AccessTokenExpiry) * time.Second - } + expiry := config.AccessTokenExpiry() token, err := generateToken(ctx, identityId, []string{}, expiry) if err != nil { diff --git a/runtime/oauth/access_token_test.go b/runtime/oauth/access_token_test.go index ac2351bbe..b2b7bf030 100644 --- a/runtime/oauth/access_token_test.go +++ b/runtime/oauth/access_token_test.go @@ -26,7 +26,7 @@ func newContextWithPK() context.Context { return ctx } -func TestAccessTokenGenerationAndParsingWithoutPrivateKey(t *testing.T) { +func TestAccessTokenGeneration(t *testing.T) { ctx := newContextWithPK() identityId := ksuid.New() @@ -40,46 +40,23 @@ func TestAccessTokenGenerationAndParsingWithoutPrivateKey(t *testing.T) { require.Equal(t, oauth.KeelIssuer, iss) } -func TestAccessTokenGenerationAndParsingWithSamePrivateKey(t *testing.T) { +func TestAccessTokenValidationNoPrivateKey(t *testing.T) { ctx := newContextWithPK() identityId := ksuid.New() - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey) - require.NoError(t, err) - bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) require.NoError(t, err) require.NotEmpty(t, bearerJwt) - parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) - require.NoError(t, err) - require.Equal(t, identityId.String(), parsedId) - require.Equal(t, oauth.KeelIssuer, iss) -} - -func TestAccessTokenGenerationWithPrivateKeyAndParsingWithoutPrivateKey(t *testing.T) { - ctx := newContextWithPK() - identityId := ksuid.New() - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey) - require.NoError(t, err) - - bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) - require.NoError(t, err) - require.NotEmpty(t, bearerJwt) + ctx = runtimectx.WithPrivateKey(ctx, nil) - ctx = newContextWithPK() parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) - require.ErrorIs(t, oauth.ErrInvalidToken, err) + require.Error(t, err, "no private key set") require.Empty(t, parsedId) require.Empty(t, iss) } -func TestAccessTokenGenerationWithoutPrivateKeyAndParsingWithPrivateKey(t *testing.T) { +func TestAccessTokenGenerationAndParsingWithSamePrivateKey(t *testing.T) { ctx := newContextWithPK() identityId := ksuid.New() @@ -87,35 +64,26 @@ func TestAccessTokenGenerationWithoutPrivateKeyAndParsingWithPrivateKey(t *testi require.NoError(t, err) require.NotEmpty(t, bearerJwt) - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey) - require.NoError(t, err) - parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) - require.ErrorIs(t, oauth.ErrInvalidToken, err) - require.Empty(t, parsedId) - require.Empty(t, iss) + require.NoError(t, err) + require.Equal(t, identityId.String(), parsedId) + require.Equal(t, oauth.KeelIssuer, iss) } -func TestAccessTokenGenerationAndParsingWithDifferentPrivateKeys(t *testing.T) { +func TestAccessTokenValidationDifferentPrivateKey(t *testing.T) { ctx := newContextWithPK() identityId := ksuid.New() - privateKey1, err := rsa.GenerateKey(rand.Reader, 2048) + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey1) + ctx = runtimectx.WithPrivateKey(ctx, privateKey) require.NoError(t, err) bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) require.NoError(t, err) require.NotEmpty(t, bearerJwt) - privateKey2, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey2) - require.NoError(t, err) - + ctx = newContextWithPK() parsedId, iss, err := oauth.ValidateAccessToken(ctx, bearerJwt) require.ErrorIs(t, oauth.ErrInvalidToken, err) require.Empty(t, parsedId) @@ -123,7 +91,7 @@ func TestAccessTokenGenerationAndParsingWithDifferentPrivateKeys(t *testing.T) { } func TestAccessTokenIsRSAMethodWithPrivateKey(t *testing.T) { - ctx := newContextWithPK() + ctx := context.Background() identityId := ksuid.New() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) @@ -144,42 +112,17 @@ func TestAccessTokenIsRSAMethodWithPrivateKey(t *testing.T) { require.NoError(t, err) } -func TestAccessTokenHasDefaultExpiryClaims(t *testing.T) { - ctx := newContextWithPK() - identityId := ksuid.New() - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey) - require.NoError(t, err) - - jwtToken, lifespan, err := oauth.GenerateAccessToken(ctx, identityId.String()) - require.NoError(t, err) - require.NotEmpty(t, jwtToken) - - claims := &oauth.AccessTokenClaims{} - _, err = jwt.ParseWithClaims(jwtToken, claims, func(token *jwt.Token) (interface{}, error) { - return &privateKey.PublicKey, nil - }) - require.NoError(t, err) - - issuedAt := claims.IssuedAt.Time - expiry := claims.ExpiresAt.Time - - require.Greater(t, expiry, time.Now()) - require.Equal(t, issuedAt.Add(oauth.DefaultAccessTokenExpiry), expiry) - require.Equal(t, oauth.DefaultAccessTokenExpiry, lifespan) -} - -func TestAccessTokenHasCustomClaims(t *testing.T) { - ctx := newContextWithPK() +func TestAccessTokenClaims(t *testing.T) { + ctx := context.Background() identityId := ksuid.New() - ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ - Tokens: &config.TokensConfig{ - AccessTokenExpiry: 360, + seconds := 360 + config := config.AuthConfig{ + Tokens: config.TokensConfig{ + AccessTokenExpiry: &seconds, }, - }) + } + ctx = runtimectx.WithOAuthConfig(ctx, &config) privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) @@ -198,19 +141,25 @@ func TestAccessTokenHasCustomClaims(t *testing.T) { issuedAt := claims.IssuedAt.Time expiry := claims.ExpiresAt.Time + subject := claims.Subject + issuer := claims.Issuer require.Greater(t, expiry, time.Now()) require.Equal(t, issuedAt.Add(time.Second*360), expiry) require.Equal(t, time.Second*360, lifespan) + require.Equal(t, config.AccessTokenExpiry(), time.Second*360) + require.Equal(t, subject, identityId.String()) + require.Equal(t, issuer, "https://keel.so") } func TestShortExpiredAccessTokenIsInvalid(t *testing.T) { ctx := newContextWithPK() identityId := ksuid.New() + seconds := 1 ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ - Tokens: &config.TokensConfig{ - AccessTokenExpiry: 1, + Tokens: config.TokensConfig{ + AccessTokenExpiry: &seconds, }, }) @@ -236,7 +185,7 @@ func TestExpiredAccessTokenIsInvalid(t *testing.T) { require.NoError(t, err) // Create the jwt 1 second expired. - now := time.Now().UTC().Add(-oauth.DefaultAccessTokenExpiry).Add(time.Second * -1) + now := time.Now().UTC().Add(-config.DefaultAccessTokenExpiry).Add(time.Second * -1) claims := oauth.AccessTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: identityId.String(), @@ -254,25 +203,3 @@ func TestExpiredAccessTokenIsInvalid(t *testing.T) { require.Empty(t, parsedId) require.Empty(t, iss) } - -func TestAccessTokenIssueClaimIsKeel(t *testing.T) { - ctx := newContextWithPK() - identityId := ksuid.New() - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, privateKey) - - bearerJwt, _, err := oauth.GenerateAccessToken(ctx, identityId.String()) - require.NoError(t, err) - require.NotEmpty(t, bearerJwt) - - claims := &oauth.AccessTokenClaims{} - _, err = jwt.ParseWithClaims(bearerJwt, claims, func(token *jwt.Token) (interface{}, error) { - return &privateKey.PublicKey, nil - }) - require.NoError(t, err) - - issuedAt := claims.Issuer - require.Equal(t, oauth.KeelIssuer, issuedAt) -} diff --git a/runtime/oauth/id_token.go b/runtime/oauth/id_token.go index 1650c91ba..158e0e9eb 100644 --- a/runtime/oauth/id_token.go +++ b/runtime/oauth/id_token.go @@ -66,7 +66,7 @@ func VerifyIdToken(ctx context.Context, idTokenRaw string) (*oidc.IDToken, error return nil, err } - providers, err := authConfig.GetProvidersOidcIssuer(issuer) + providers, err := authConfig.GetOidcProvidersByIssuer(issuer) if err != nil { return nil, err } @@ -84,7 +84,7 @@ func VerifyIdToken(ctx context.Context, idTokenRaw string) (*oidc.IDToken, error var verificationErrs error - // Verify against each configuired client ID for this issuer + // Verify against each configured provider with this issuer for _, p := range providers { // Checking the clientId during verification ensures that the ID token was intended for this client, // because it could have been stolen from any other application with an ID token from this same issuer. diff --git a/runtime/oauth/refresh_token.go b/runtime/oauth/refresh_token.go index c1c911313..f622331a3 100644 --- a/runtime/oauth/refresh_token.go +++ b/runtime/oauth/refresh_token.go @@ -13,8 +13,8 @@ import ( ) const ( - refreshTokenLength = 64 - DefaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90 // 3 months is the default + // Character length of crypo-generated refresh token + refreshTokenLength = 64 ) // NewRefreshToken generates a new refresh token for the identity using the @@ -38,19 +38,13 @@ func NewRefreshToken(ctx context.Context, identityId string) (string, error) { return "", err } - now := time.Now().UTC() - var expiresAt time.Time - - authConfig, err := runtimectx.GetOAuthConfig(ctx) + config, err := runtimectx.GetOAuthConfig(ctx) if err != nil { return "", err } - if authConfig != nil && authConfig.Tokens != nil && authConfig.Tokens.RefreshTokenExpiry != 0 { - expiresAt = now.Add(time.Duration(authConfig.Tokens.RefreshTokenExpiry) * time.Second) - } else { - expiresAt = now.Add(DefaultRefreshTokenExpiry) - } + now := time.Now().UTC() + expiresAt := now.Add(config.RefreshTokenExpiry()) sql := ` INSERT INTO @@ -94,7 +88,7 @@ func RotateRefreshToken(ctx context.Context, refreshTokenRaw string) (isValid bo } // This query has the following (important) characteristics: - // - find and delete the refresh token if it has not expired (the latter is for performance) + // - find and delete the refresh token // - create a new refresh token with the identity_id and expire_at of the original token // - only creates the new token if the original token had not expired sql := ` @@ -112,7 +106,7 @@ func RotateRefreshToken(ctx context.Context, refreshTokenRaw string) (isValid bo revoked_token WHERE expires_at >= now() - RETURNING *;` + RETURNING *` rows := []map[string]any{} err = database.GetDB().Raw(sql, tokenHash, newTokenHash).Scan(&rows).Error @@ -133,6 +127,50 @@ func RotateRefreshToken(ctx context.Context, refreshTokenRaw string) (isValid bo return true, newRefreshToken, identityId, nil } +// ValidateRefreshToken validates that the provided refresh token has no expired, +// and also returns the identity it is associated with. The refresh token is not revoked. +func ValidateRefreshToken(ctx context.Context, refreshTokenRaw string) (isValid bool, identityId string, err error) { + ctx, span := tracer.Start(ctx, "Validate Refresh Token") + defer span.End() + + tokenHash, err := hashToken(refreshTokenRaw) + if err != nil { + return false, "", err + } + + database, err := db.GetDatabase(ctx) + if err != nil { + return false, "", err + } + + sql := ` + SELECT + token, identity_id, expires_at, now() + FROM + keel_refresh_token + WHERE + token = ? AND + expires_at >= now()` + + rows := []map[string]any{} + err = database.GetDB().Raw(sql, tokenHash).Scan(&rows).Error + if err != nil { + return false, "", err + } + + // There was no refresh token found, and thus it is not valid + if len(rows) != 1 { + return false, "", nil + } + + identityId, ok := rows[0]["identity_id"].(string) + if !ok { + return false, "", errors.New("could not parse identity_id from database result") + } + + return true, identityId, nil +} + // RevokeRefreshToken will delete (revoke) the provided refresh token, // which will prevent it from being used again. func RevokeRefreshToken(ctx context.Context, refreshTokenRaw string) error { @@ -153,20 +191,13 @@ func RevokeRefreshToken(ctx context.Context, refreshTokenRaw string) error { DELETE FROM keel_refresh_token WHERE - token = ? - RETURNING *` + token = ?` - rows := []map[string]any{} - err = database.GetDB().Raw(sql, tokenHash).Scan(&rows).Error + err = database.GetDB().Exec(sql, tokenHash).Error if err != nil { return err } - // There was no refresh token found, and thus none to revoke. - if len(rows) == 0 { - return nil - } - return nil } diff --git a/runtime/oauth/refresh_token_test.go b/runtime/oauth/refresh_token_test.go index 1355d76db..77436097f 100644 --- a/runtime/oauth/refresh_token_test.go +++ b/runtime/oauth/refresh_token_test.go @@ -56,11 +56,13 @@ func TestRotateRefreshToken_Expired(t *testing.T) { defer database.Close() // Set up auth config - ctx = runtimectx.WithOAuthConfig(ctx, &config.AuthConfig{ - Tokens: &config.TokensConfig{ - RefreshTokenExpiry: 1, + seconds := 1 + config := config.AuthConfig{ + Tokens: config.TokensConfig{ + RefreshTokenExpiry: &seconds, }, - }) + } + ctx = runtimectx.WithOAuthConfig(ctx, &config) refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") require.NoError(t, err) @@ -94,6 +96,43 @@ func TestRotateRefreshToken_ReuseRefreshTokenNotValid(t *testing.T) { require.Empty(t, newRefreshToken2) } +func TestValidateRefreshToken_Valid(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + isValid, identityId, err := oauth.ValidateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.True(t, isValid) + require.Equal(t, "identity_id", identityId) +} + +func TestValidateRefreshToken_Expired(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + seconds := 1 + config := config.AuthConfig{ + Tokens: config.TokensConfig{ + RefreshTokenExpiry: &seconds, + }, + } + ctx = runtimectx.WithOAuthConfig(ctx, &config) + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + time.Sleep(1100 * time.Millisecond) + + isValid, identityId, err := oauth.ValidateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.False(t, isValid) + require.Empty(t, identityId) +} + func TestRevokeRefreshToken_Unauthorised(t *testing.T) { ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close()