From a35e78e364a26c4f87f37d9f545ef10b3ffa468a Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:52:48 +0100 Subject: [PATCH] feat: handle concurrent refreshes and improve graceful refreshing This patch improves Ory Hydra's ability to deal with refresh flows which, for example, concurrently refresh the same token. Furthermore, graceful token refresh has been improved to handle a variety of edge cases and scenarios. --- .schema/config.schema.json | 12 +- aead/aead_test.go | 17 +- client/handler_test.go | 3 +- client/sdk_test.go | 8 +- client/validator_test.go | 17 +- cmd/cmd_helper_test.go | 3 +- consent/handler_test.go | 19 +- consent/sdk_test.go | 7 +- consent/strategy_logout_test.go | 3 +- consent/strategy_oauth_test.go | 3 +- consent/test/manager_test_helpers.go | 2 + cypress/integration/oauth2/refresh_token.js | 8 +- driver/config/provider.go | 12 +- driver/config/provider_test.go | 2 +- go.mod | 2 +- go.sum | 4 +- health/handler_test.go | 7 +- internal/{ => testhelpers}/driver.go | 18 +- internal/testhelpers/janitor_test_helper.go | 5 +- internal/testhelpers/oauth2.go | 19 +- jwk/handler_test.go | 7 +- jwk/helper_test.go | 5 +- jwk/jwt_strategy_test.go | 7 +- jwk/sdk_test.go | 7 +- ...TestHandlerWellKnown-hsm_enabled=true.json | 102 - oauth2/equalKeys.go | 55 - oauth2/equalKeys_test.go | 20 - ...elpers.go => fosite_store_helpers_test.go} | 654 +++-- oauth2/fosite_store_test.go | 56 +- oauth2/handler.go | 7 +- oauth2/handler_fallback_endpoints_test.go | 7 +- oauth2/handler_test.go | 20 +- oauth2/helper_test.go | 62 + oauth2/helpers.go | 51 + oauth2/introspector_test.go | 8 +- oauth2/oauth2_auth_code_bench_test.go | 3 +- oauth2/oauth2_auth_code_test.go | 2424 +++++++++-------- .../oauth2_client_credentials_bench_test.go | 3 +- oauth2/oauth2_client_credentials_test.go | 3 +- oauth2/oauth2_jwt_bearer_test.go | 3 +- oauth2/oauth2_refresh_token_test.go | 9 +- oauth2/oauth2_rop_test.go | 3 +- oauth2/revocator_test.go | 8 +- oauth2/session_custom_claims_test.go | 5 +- oauth2/trust/handler_test.go | 9 +- persistence/sql/migratest/migration_test.go | 5 +- ...oken_access_token_link.autocommit.down.sql | 1 + ..._token_access_token_link.autocommit.up.sql | 1 + persistence/sql/persister.go | 1 - persistence/sql/persister_nid_test.go | 134 +- persistence/sql/persister_nonce_test.go | 5 +- persistence/sql/persister_oauth2.go | 210 +- persistence/sql/persister_test.go | 13 +- spec/config.json | 4 +- x/oauth2cors/cors_test.go | 7 +- x/tls_termination_test.go | 9 +- 56 files changed, 2289 insertions(+), 1810 deletions(-) rename internal/{ => testhelpers}/driver.go (84%) delete mode 100644 oauth2/.snapshots/TestHandlerWellKnown-hsm_enabled=true.json delete mode 100644 oauth2/equalKeys.go delete mode 100644 oauth2/equalKeys_test.go rename oauth2/{fosite_store_helpers.go => fosite_store_helpers_test.go} (65%) create mode 100644 oauth2/helpers.go create mode 100644 persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql create mode 100644 persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql diff --git a/.schema/config.schema.json b/.schema/config.schema.json index bc1d1476c08..804e6b6024f 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -1101,11 +1101,11 @@ "examples": ["https://my-example.app/token-refresh-hook"], "oneOf": [ { - "type": "string", - "format": "uri" + "$ref": "#/definitions/webhook_config" }, { - "$ref": "#/definitions/webhook_config" + "type": "string", + "format": "uri" } ] }, @@ -1114,11 +1114,11 @@ "examples": ["https://my-example.app/token-hook"], "oneOf": [ { - "type": "string", - "format": "uri" + "$ref": "#/definitions/webhook_config" }, { - "$ref": "#/definitions/webhook_config" + "type": "string", + "format": "uri" } ] } diff --git a/aead/aead_test.go b/aead/aead_test.go index 4cb93f5c3e7..d1b614710a2 100644 --- a/aead/aead_test.go +++ b/aead/aead_test.go @@ -10,13 +10,14 @@ import ( "io" "testing" - "github.com/ory/hydra/v2/aead" - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/internal/testhelpers" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/ory/hydra/v2/aead" + "github.com/ory/hydra/v2/driver/config" ) func secret(t *testing.T) string { @@ -43,7 +44,7 @@ func TestAEAD(t *testing.T) { t.Run("case=without-rotation", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) @@ -63,7 +64,7 @@ func TestAEAD(t *testing.T) { t.Run("case=wrong-secret", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) @@ -78,7 +79,7 @@ func TestAEAD(t *testing.T) { t.Run("case=with-rotation", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() old := secret(t) c.MustSet(ctx, config.KeyGetSystemSecret, []string{old}) a := NewCipher(c) @@ -106,7 +107,7 @@ func TestAEAD(t *testing.T) { t.Run("case=with-rotation-wrong-secret", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) @@ -123,7 +124,7 @@ func TestAEAD(t *testing.T) { t.Run("suite=with additional data", func(t *testing.T) { t.Parallel() ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) a := NewCipher(c) diff --git a/client/handler_test.go b/client/handler_test.go index 3047ad4c87b..8e27caea754 100644 --- a/client/handler_test.go +++ b/client/handler_test.go @@ -35,7 +35,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/client" - "github.com/ory/hydra/v2/internal" ) type responseSnapshot struct { @@ -56,7 +55,7 @@ func getClientID(body string) string { func TestHandler(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) h := client.NewHandler(reg) reg.WithContextualizer(&contextx.TestContextualizer{}) diff --git a/client/sdk_test.go b/client/sdk_test.go index 9db7ab7cddb..ad3193108ad 100644 --- a/client/sdk_test.go +++ b/client/sdk_test.go @@ -9,6 +9,8 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/assertx" "github.com/ory/x/ioutilx" @@ -26,8 +28,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/internal" - hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" ) @@ -63,11 +63,11 @@ var defaultIgnoreFields = []string{"client_id", "registration_access_token", "re func TestClientSDK(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeySubjectTypesSupported, []string{"public"}) conf.MustSet(ctx, config.KeyDefaultClientScope, []string{"foo", "bar"}) conf.MustSet(ctx, config.KeyPublicAllowDynamicRegistration, true) - r := internal.NewRegistryMemory(t, conf, &contextx.Static{C: conf.Source(ctx)}) + r := testhelpers.NewRegistryMemory(t, conf, &contextx.Static{C: conf.Source(ctx)}) routerAdmin := x.NewRouterAdmin(conf.AdminURL) routerPublic := x.NewRouterPublic() diff --git a/client/validator_test.go b/client/validator_test.go index 09f69b26e30..4efe866d5a9 100644 --- a/client/validator_test.go +++ b/client/validator_test.go @@ -12,6 +12,8 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/hashicorp/go-retryablehttp" "github.com/ory/fosite" @@ -24,17 +26,16 @@ import ( . "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" ) func TestValidate(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeySubjectTypesSupported, []string{"pairwise", "public"}) c.MustSet(ctx, config.KeyDefaultClientScope, []string{"openid"}) - reg := internal.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) + reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) v := NewValidator(reg) testCtx := context.TODO() @@ -186,7 +187,7 @@ func (f *fakeHTTP) HTTPClient(ctx context.Context, opts ...httpx.ResilientOption } func TestValidateSectorIdentifierURL(t *testing.T) { - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) var payload string var h http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { @@ -268,8 +269,8 @@ const validJWKS = ` func TestValidateIPRanges(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) + c := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) v := NewValidator(reg) c.MustSet(ctx, config.KeyClientHTTPNoPrivateIPRanges, true) @@ -287,10 +288,10 @@ func TestValidateIPRanges(t *testing.T) { func TestValidateDynamicRegistration(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() c.MustSet(ctx, config.KeySubjectTypesSupported, []string{"pairwise", "public"}) c.MustSet(ctx, config.KeyDefaultClientScope, []string{"openid"}) - reg := internal.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) + reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)}) testCtx := context.TODO() v := NewValidator(reg) diff --git a/cmd/cmd_helper_test.go b/cmd/cmd_helper_test.go index da386b4865d..4953f6b4321 100644 --- a/cmd/cmd_helper_test.go +++ b/cmd/cmd_helper_test.go @@ -19,7 +19,6 @@ import ( "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/x/cmdx" "github.com/ory/x/contextx" @@ -40,7 +39,7 @@ func setupRoutes(t *testing.T, cmd *cobra.Command) (*httptest.Server, *httptest. ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) public, admin := testhelpers.NewOAuth2Server(ctx, t, reg) cmdx.RegisterHTTPClientFlags(cmd.Flags()) diff --git a/consent/handler_test.go b/consent/handler_test.go index d5dfe5254ad..45ba2b7733a 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -13,13 +13,14 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/require" hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" . "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/x/pointerx" @@ -42,8 +43,8 @@ func TestGetLogoutRequest(t *testing.T) { challenge := "challenge" + key requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) if tc.exists { cl := &client.Client{ID: "client" + key} @@ -97,8 +98,8 @@ func TestGetLoginRequest(t *testing.T) { challenge := "challenge" + key requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) if tc.exists { cl := &client.Client{ID: "client" + key} @@ -163,8 +164,8 @@ func TestGetConsentRequest(t *testing.T) { challenge := "challenge" + key requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) if tc.exists { cl := &client.Client{ID: "client" + key} @@ -238,8 +239,8 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { challenge := "challenge" requestURL := "http://192.0.2.1" - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) cl := &client.Client{ID: "client"} require.NoError(t, reg.ClientManager().CreateClient(ctx, cl)) diff --git a/consent/sdk_test.go b/consent/sdk_test.go index f749428d5d8..0f30d16e7c8 100644 --- a/consent/sdk_test.go +++ b/consent/sdk_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/hydra/v2/consent/test" hydra "github.com/ory/hydra-client-go/v2" @@ -23,7 +25,6 @@ import ( . "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" ) @@ -35,10 +36,10 @@ func makeID(base string, network string, key string) string { func TestSDK(t *testing.T) { ctx := context.Background() network := "t1" - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyIssuerURL, "https://www.ory.sh") conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) consentChallenge := func(f *Flow) string { return x.Must(f.ToConsentChallenge(ctx, reg)) } consentVerifier := func(f *Flow) string { return x.Must(f.ToConsentVerifier(ctx, reg)) } diff --git a/consent/strategy_logout_test.go b/consent/strategy_logout_test.go index 6432a3e13a0..80e633e7bf6 100644 --- a/consent/strategy_logout_test.go +++ b/consent/strategy_logout_test.go @@ -28,7 +28,6 @@ import ( hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/x/contextx" "github.com/ory/x/ioutilx" @@ -37,7 +36,7 @@ import ( func TestLogoutFlows(t *testing.T) { ctx := context.Background() fakeKratos := kratos.NewFake() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour) diff --git a/consent/strategy_oauth_test.go b/consent/strategy_oauth_test.go index 370a3378074..a2e39d5b6ec 100644 --- a/consent/strategy_oauth_test.go +++ b/consent/strategy_oauth_test.go @@ -37,12 +37,11 @@ import ( hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" ) func TestStrategyLoginConsentNext(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour) reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour) diff --git a/consent/test/manager_test_helpers.go b/consent/test/manager_test_helpers.go index a5b141f5359..986b4f3144c 100644 --- a/consent/test/manager_test_helpers.go +++ b/consent/test/manager_test_helpers.go @@ -683,6 +683,7 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo require.NoError(t, fositeManager.CreateRefreshTokenSession( ctx, makeID("", network, "rrva1"), + "", &fosite.Request{Client: cr1.Client, ID: crr1.ID, RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}}, )) require.NoError(t, fositeManager.CreateAccessTokenSession( @@ -693,6 +694,7 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo require.NoError(t, fositeManager.CreateRefreshTokenSession( ctx, makeID("", network, "rrva2"), + "", &fosite.Request{Client: cr2.Client, ID: crr2.ID, RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}}, )) diff --git a/cypress/integration/oauth2/refresh_token.js b/cypress/integration/oauth2/refresh_token.js index fbbbf36e80b..2ddf7d30f19 100644 --- a/cypress/integration/oauth2/refresh_token.js +++ b/cypress/integration/oauth2/refresh_token.js @@ -87,13 +87,13 @@ describe("The OAuth 2.0 Refresh Token Grant", function () { return cy .refreshTokenBrowser(client, originalToken) .then((response) => { - expect(response.status).to.eq(401) - expect(response.body.error).to.eq("token_inactive") + expect(response.status).to.eq(400) + expect(response.body.error).to.eq("invalid_grant") }) .then(() => cy.refreshTokenBrowser(client, refreshedToken)) .then((response) => { - expect(response.status).to.eq(401) - expect(response.body.error).to.eq("token_inactive") + expect(response.status).to.eq(400) + expect(response.body.error).to.eq("invalid_grant") }) }, ) diff --git a/driver/config/provider.go b/driver/config/provider.go index 52b9ee45a3f..b02d0ae1da4 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -213,6 +213,10 @@ func (p *DefaultProvider) MustSet(ctx context.Context, key string, value interfa } } +func (p *DefaultProvider) Delete(ctx context.Context, key string) { + p.getProvider(ctx).Delete(key) +} + func (p *DefaultProvider) Source(ctx context.Context) *configx.Provider { return p.getProvider(ctx) } @@ -517,6 +521,10 @@ type ( ) func (p *DefaultProvider) getHookConfig(ctx context.Context, key string) *HookConfig { + if p.getProvider(ctx).String(key) == "" { + return nil + } + if hookURL := p.getProvider(ctx).RequestURIF(key, nil); hookURL != nil { return &HookConfig{ URL: hookURL.String(), @@ -673,8 +681,8 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { gracePeriod := p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) - if gracePeriod > time.Hour { - return time.Hour + if gracePeriod > time.Minute*5 { + return time.Minute * 5 } return gracePeriod } diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index 168ca81d69f..7ec1dce8df9 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -296,7 +296,7 @@ func TestViperProviderValidates(t *testing.T) { require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s")) assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod(ctx)) require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h")) - assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod(ctx)) + assert.Equal(t, time.Minute*5, c.RefreshTokenRotationGracePeriod(ctx)) // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx)) diff --git a/go.mod b/go.mod index 0c9b9277cdf..210341d99d0 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/oleiade/reflections v1.0.1 github.com/ory/analytics-go/v5 v5.0.1 - github.com/ory/fosite v0.48.0 + github.com/ory/fosite v0.48.1-0.20241204100720-b57570a26c3e github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe github.com/ory/graceful v0.1.3 github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88 diff --git a/go.sum b/go.sum index c29c141d383..4761d71ae8e 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,8 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.48.0 h1:zxNPNrCBsFwujviVPhbHZzSHZNzjBFZ36MeBFz6tCuU= -github.com/ory/fosite v0.48.0/go.mod h1:M+C+Ng1UDNgwX4SaErnuZwEw26uDN7I3kNUt0WyValI= +github.com/ory/fosite v0.48.1-0.20241204100720-b57570a26c3e h1:C55B0tN1yuintGQ0N+nTnFlrHlxidM3vagM/+7xQrio= +github.com/ory/fosite v0.48.1-0.20241204100720-b57570a26c3e/go.mod h1:M+C+Ng1UDNgwX4SaErnuZwEw26uDN7I3kNUt0WyValI= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= diff --git a/health/handler_test.go b/health/handler_test.go index 4b717a02c79..b7821d3cab4 100644 --- a/health/handler_test.go +++ b/health/handler_test.go @@ -9,6 +9,8 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/ory/x/contextx" @@ -16,7 +18,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/healthx" ) @@ -71,12 +72,12 @@ func TestPublicHealthHandler(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() for k, v := range tc.config { conf.MustSet(ctx, config.PublicInterface.Key(k), v) } - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) public := x.NewRouterPublic() reg.RegisterRoutes(ctx, x.NewRouterAdmin(conf.AdminURL), public) diff --git a/internal/driver.go b/internal/testhelpers/driver.go similarity index 84% rename from internal/driver.go rename to internal/testhelpers/driver.go index 38a8d8144d4..34a3f40b8bd 100644 --- a/internal/driver.go +++ b/internal/testhelpers/driver.go @@ -1,13 +1,15 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package internal +package testhelpers import ( "context" "sync" "testing" + "github.com/ory/x/dbal" + "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/require" @@ -44,24 +46,28 @@ func NewConfigurationWithDefaultsAndHTTPS() *config.DefaultProvider { } func NewRegistryMemory(t testing.TB, c *config.DefaultProvider, ctxer contextx.Contextualizer) driver.Registry { - return newRegistryDefault(t, "memory", c, true, ctxer) + return registryFactory(t, dbal.NewSQLiteTestDatabase(t), c, true, ctxer) } func NewMockedRegistry(t testing.TB, ctxer contextx.Contextualizer) driver.Registry { - return newRegistryDefault(t, "memory", NewConfigurationWithDefaults(), true, ctxer) + return registryFactory(t, dbal.NewSQLiteTestDatabase(t), NewConfigurationWithDefaults(), true, ctxer) } func NewRegistrySQLFromURL(t testing.TB, url string, migrate bool, ctxer contextx.Contextualizer) driver.Registry { - return newRegistryDefault(t, url, NewConfigurationWithDefaults(), migrate, ctxer) + return registryFactory(t, url, NewConfigurationWithDefaults(), migrate, ctxer) +} + +func registryFactory(t testing.TB, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry { + return RegistryFactory(t, url, c, !migrate, migrate, ctxer) } -func newRegistryDefault(t testing.TB, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry { +func RegistryFactory(t testing.TB, url string, c *config.DefaultProvider, networkInit, migrate bool, ctxer contextx.Contextualizer) driver.Registry { ctx := context.Background() c.MustSet(ctx, config.KeyLogLevel, "trace") c.MustSet(ctx, config.KeyDSN, url) c.MustSet(ctx, "dev", true) - r, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("test_hydra", "master"), false, migrate, ctxer) + r, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("test_hydra", "master"), networkInit, migrate, ctxer) require.NoError(t, err) return r diff --git a/internal/testhelpers/janitor_test_helper.go b/internal/testhelpers/janitor_test_helper.go index f70d7c27495..c452b3248f1 100644 --- a/internal/testhelpers/janitor_test_helper.go +++ b/internal/testhelpers/janitor_test_helper.go @@ -21,7 +21,6 @@ import ( "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" "github.com/ory/hydra/v2/x" @@ -50,7 +49,7 @@ type createGrantRequest struct { const lifespan = time.Hour func NewConsentJanitorTestHelper(uniqueName string) *JanitorConsentTestHelper { - conf := internal.NewConfigurationWithDefaults() + conf := NewConfigurationWithDefaults() conf.MustSet(context.Background(), config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(context.Background(), config.KeyIssuerURL, "http://hydra.localhost") conf.MustSet(context.Background(), config.KeyAccessTokenLifespan, lifespan) @@ -126,7 +125,7 @@ func (j *JanitorConsentTestHelper) RefreshTokenNotAfterSetup(ctx context.Context // Create refresh token clients and session for _, fr := range j.flushRefreshRequests { require.NoError(t, cl.CreateClient(ctx, fr.Client.(*client.Client))) - require.NoError(t, store.CreateRefreshTokenSession(ctx, fr.ID, fr)) + require.NoError(t, store.CreateRefreshTokenSession(ctx, fr.ID, "", fr)) } } } diff --git a/internal/testhelpers/oauth2.go b/internal/testhelpers/oauth2.go index 41f0ddaec8e..4a7b5bc696e 100644 --- a/internal/testhelpers/oauth2.go +++ b/internal/testhelpers/oauth2.go @@ -32,7 +32,6 @@ import ( "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) @@ -67,8 +66,8 @@ func NewOAuth2Server(ctx context.Context, t testing.TB, reg driver.Registry) (pu public, admin := x.NewRouterPublic(), x.NewRouterAdmin(reg.Config().AdminURL) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) reg.RegisterRoutes(ctx, admin, public) @@ -111,6 +110,20 @@ func IntrospectToken(t testing.TB, conf *oauth2.Config, token string, adminTS *h return gjson.ParseBytes(ioutilx.MustReadAll(res.Body)) } +func RevokeToken(t testing.TB, conf *oauth2.Config, token string, publicTS *httptest.Server) gjson.Result { + require.NotEmpty(t, token) + + req := httpx.MustNewRequest("POST", publicTS.URL+"/oauth2/revoke", + strings.NewReader((url.Values{"token": {token}}).Encode()), + "application/x-www-form-urlencoded") + + req.SetBasicAuth(conf.ClientID, conf.ClientSecret) + res, err := publicTS.Client().Do(req) + require.NoError(t, err) + defer res.Body.Close() + return gjson.ParseBytes(ioutilx.MustReadAll(res.Body)) +} + func UpdateClientTokenLifespans(t *testing.T, conf *oauth2.Config, clientID string, lifespans client.Lifespans, adminTS *httptest.Server) { b, err := json.Marshal(lifespans) require.NoError(t, err) diff --git a/jwk/handler_test.go b/jwk/handler_test.go index 5df8182de60..0dc8f6afcdc 100644 --- a/jwk/handler_test.go +++ b/jwk/handler_test.go @@ -10,6 +10,8 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/httprouterx" "github.com/ory/hydra/v2/jwk" @@ -20,15 +22,14 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) func TestHandlerWellKnown(t *testing.T) { t.Parallel() - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) conf.MustSet(context.Background(), config.KeyWellKnownKeys, []string{x.OpenIDConnectKeyName, x.OpenIDConnectKeyName}) router := x.NewRouterPublic() h := reg.KeyHandler() diff --git a/jwk/helper_test.go b/jwk/helper_test.go index c1a5ee46387..5a6dabd6a60 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -17,6 +17,8 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + hydra "github.com/ory/hydra-client-go/v2" "github.com/go-jose/go-jose/v3" @@ -27,7 +29,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" @@ -210,7 +211,7 @@ func TestExcludeOpaquePrivateKeys(t *testing.T) { func TestGetOrGenerateKeys(t *testing.T) { t.Parallel() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) setId := uuid.NewUUID().String() keyId := uuid.NewUUID().String() diff --git a/jwk/jwt_strategy_test.go b/jwk/jwt_strategy_test.go index 8389d20a610..b4def161005 100644 --- a/jwk/jwt_strategy_test.go +++ b/jwk/jwt_strategy_test.go @@ -9,12 +9,13 @@ import ( "strings" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "github.com/ory/fosite/token/jwt" - "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/jwk" "github.com/ory/x/contextx" ) @@ -22,8 +23,8 @@ import ( func TestJWTStrategy(t *testing.T) { for _, alg := range []string{"RS256", "ES256", "ES512"} { t.Run("case="+alg, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) m := reg.KeyManager() _, err := m.GenerateAndPersistKeySet(context.Background(), "foo-set", "foo", alg, "sig") diff --git a/jwk/sdk_test.go b/jwk/sdk_test.go index f7f7d6a21e8..b2088239884 100644 --- a/jwk/sdk_test.go +++ b/jwk/sdk_test.go @@ -9,12 +9,13 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" @@ -23,8 +24,8 @@ import ( func TestJWKSDK(t *testing.T) { t.Parallel() ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) router := x.NewRouterAdmin(conf.AdminURL) h := NewHandler(reg) diff --git a/oauth2/.snapshots/TestHandlerWellKnown-hsm_enabled=true.json b/oauth2/.snapshots/TestHandlerWellKnown-hsm_enabled=true.json deleted file mode 100644 index 5bc92ec79a5..00000000000 --- a/oauth2/.snapshots/TestHandlerWellKnown-hsm_enabled=true.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "authorization_endpoint": "http://hydra.localhost/oauth2/auth", - "backchannel_logout_session_supported": true, - "backchannel_logout_supported": true, - "claims_parameter_supported": false, - "claims_supported": [ - "sub" - ], - "code_challenge_methods_supported": [ - "plain", - "S256" - ], - "credentials_endpoint_draft_00": "http://hydra.localhost/credentials", - "credentials_supported_draft_00": [ - { - "cryptographic_binding_methods_supported": [ - "jwk" - ], - "cryptographic_suites_supported": [ - "PS256", - "RS256", - "ES256", - "PS384", - "RS384", - "ES384", - "PS512", - "RS512", - "ES512", - "EdDSA" - ], - "format": "jwt_vc_json", - "types": [ - "VerifiableCredential", - "UserInfoCredential" - ] - } - ], - "end_session_endpoint": "http://hydra.localhost/oauth2/sessions/logout", - "frontchannel_logout_session_supported": true, - "frontchannel_logout_supported": true, - "grant_types_supported": [ - "authorization_code", - "implicit", - "client_credentials", - "refresh_token" - ], - "id_token_signed_response_alg": [ - "RS256" - ], - "id_token_signing_alg_values_supported": [ - "RS256" - ], - "issuer": "http://hydra.localhost", - "jwks_uri": "http://hydra.localhost/.well-known/jwks.json", - "registration_endpoint": "http://client-register/registration", - "request_object_signing_alg_values_supported": [ - "none", - "RS256", - "ES256" - ], - "request_parameter_supported": true, - "request_uri_parameter_supported": true, - "require_request_uri_registration": true, - "response_modes_supported": [ - "query", - "fragment", - "form_post" - ], - "response_types_supported": [ - "code", - "code id_token", - "id_token", - "token id_token", - "token", - "token id_token code" - ], - "revocation_endpoint": "http://hydra.localhost/oauth2/revoke", - "scopes_supported": [ - "offline_access", - "offline", - "openid" - ], - "subject_types_supported": [ - "pairwise", - "public" - ], - "token_endpoint": "http://hydra.localhost/oauth2/token", - "token_endpoint_auth_methods_supported": [ - "client_secret_post", - "client_secret_basic", - "private_key_jwt", - "none" - ], - "userinfo_endpoint": "/userinfo", - "userinfo_signed_response_alg": [ - "RS256" - ], - "userinfo_signing_alg_values_supported": [ - "none", - "RS256" - ] -} diff --git a/oauth2/equalKeys.go b/oauth2/equalKeys.go deleted file mode 100644 index e16568e078a..00000000000 --- a/oauth2/equalKeys.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package oauth2 - -import ( - "testing" - - "github.com/oleiade/reflections" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func AssertObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - assert.Equal(t, c, d, "%s", k) - } -} - -func AssertObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - assert.NotEqual(t, c, d, "%s", k) - } -} - -func RequireObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - require.Equal(t, c, d, "%s", k) - } -} -func RequireObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { - assert.True(t, len(keys) > 0, "No keys provided.") - for _, k := range keys { - c, err := reflections.GetField(a, k) - assert.Nil(t, err) - d, err := reflections.GetField(b, k) - assert.Nil(t, err) - require.NotEqual(t, c, d, "%s", k) - } -} diff --git a/oauth2/equalKeys_test.go b/oauth2/equalKeys_test.go deleted file mode 100644 index 13243a94bf3..00000000000 --- a/oauth2/equalKeys_test.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package oauth2 - -import "testing" - -func TestAssertObjectsAreEqualByKeys(t *testing.T) { - type foo struct { - Name string - Body int - } - a := &foo{"foo", 1} - b := &foo{"bar", 1} - c := &foo{"baz", 3} - - AssertObjectKeysEqual(t, a, a, "Name", "Body") - AssertObjectKeysNotEqual(t, a, b, "Name") - AssertObjectKeysNotEqual(t, a, c, "Name", "Body") -} diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers_test.go similarity index 65% rename from oauth2/fosite_store_helpers.go rename to oauth2/fosite_store_helpers_test.go index 553a6bae62b..739c432a2ca 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers_test.go @@ -1,85 +1,41 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package oauth2 +package oauth2_test import ( "context" - "crypto/sha256" "fmt" "net/url" "slices" "testing" "time" - "github.com/ory/x/assertx" - - "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/jwk" + "github.com/ory/hydra/v2/persistence/sql" "github.com/go-jose/go-jose/v3" - "github.com/gobuffalo/pop/v6" - "github.com/pborman/uuid" - - "github.com/ory/fosite/handler/rfc7523" - - "github.com/ory/hydra/v2/oauth2/trust" - - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/x" - - "github.com/ory/fosite/storage" - "github.com/ory/x/sqlxx" - gofrsuuid "github.com/gofrs/uuid" + "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" - "github.com/ory/x/sqlcon" - + "github.com/ory/fosite/handler/rfc7523" + "github.com/ory/fosite/storage" "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/jwk" + "github.com/ory/hydra/v2/oauth2" + "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/hydra/v2/x" + "github.com/ory/x/assertx" + "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" ) -func signatureFromJTI(jti string) string { - return fmt.Sprintf("%x", sha256.Sum256([]byte(jti))) -} - -type BlacklistedJTI struct { - JTI string `db:"-"` - ID string `db:"signature"` - Expiry time.Time `db:"expires_at"` - NID gofrsuuid.UUID `db:"nid"` -} - -func (j *BlacklistedJTI) AfterFind(_ *pop.Connection) error { - j.Expiry = j.Expiry.UTC() - return nil -} - -func (BlacklistedJTI) TableName() string { - return "hydra_oauth2_jti_blacklist" -} - -func NewBlacklistedJTI(jti string, exp time.Time) *BlacklistedJTI { - return &BlacklistedJTI{ - JTI: jti, - ID: signatureFromJTI(jti), - // because the database timestamp types are not as accurate as time.Time we truncate to seconds (which should always work) - Expiry: exp.UTC().Truncate(time.Second), - } -} - -type AssertionJWTReader interface { - x.FositeStorer - - GetClientAssertionJWT(ctx context.Context, jti string) (*BlacklistedJTI, error) - - SetClientAssertionJWTRaw(context.Context, *BlacklistedJTI) error -} - var defaultIgnoreKeys = []string{ "id", "session", @@ -94,29 +50,33 @@ var defaultIgnoreKeys = []string{ "client.client_secret", } -var defaultRequest = fosite.Request{ - ID: "blank", - RequestedAt: time.Now().UTC().Round(time.Second), - Client: &client.Client{ - ID: "foobar", - Contacts: []string{}, - RedirectURIs: []string{}, - Audience: []string{}, - AllowedCORSOrigins: []string{}, - ResponseTypes: []string{}, - GrantTypes: []string{}, - JSONWebKeys: &x.JoseJSONWebKeySet{}, - Metadata: sqlxx.JSONRawMessage("{}"), - }, - RequestedScope: fosite.Arguments{"fa", "ba"}, - GrantedScope: fosite.Arguments{"fa", "ba"}, - RequestedAudience: fosite.Arguments{"ad1", "ad2"}, - GrantedAudience: fosite.Arguments{"ad1", "ad2"}, - Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: NewSession("bar"), +func newDefaultRequest(id string) fosite.Request { + return fosite.Request{ + ID: id, + RequestedAt: time.Now().UTC().Round(time.Second), + Client: &client.Client{ + ID: "foobar", + Contacts: []string{}, + RedirectURIs: []string{}, + Audience: []string{}, + AllowedCORSOrigins: []string{}, + ResponseTypes: []string{}, + GrantTypes: []string{}, + JSONWebKeys: &x.JoseJSONWebKeySet{}, + Metadata: sqlxx.JSONRawMessage("{}"), + }, + RequestedScope: fosite.Arguments{"fa", "ba"}, + GrantedScope: fosite.Arguments{"fa", "ba"}, + RequestedAudience: fosite.Arguments{"ad1", "ad2"}, + GrantedAudience: fosite.Arguments{"ad1", "ad2"}, + Form: url.Values{"foo": []string{"bar", "baz"}}, + Session: oauth2.NewSession("bar"), + } } -var lifespan = time.Hour +var defaultRequest = newDefaultRequest("blank") + +// var lifespan = time.Hour var flushRequests = []*fosite.Request{ { ID: "flush-1", @@ -125,7 +85,7 @@ var flushRequests = []*fosite.Request{ RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, { ID: "flush-2", @@ -134,7 +94,7 @@ var flushRequests = []*fosite.Request{ RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, { ID: "flush-3", @@ -143,11 +103,11 @@ var flushRequests = []*fosite.Request{ RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, } -func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry) { +func mockRequestForeignKey(t *testing.T, id string, x oauth2.InternalRegistry) { cl := &client.Client{ID: "foobar"} cr := &flow.OAuth2ConsentRequest{ Client: cl, @@ -193,43 +153,10 @@ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry) { require.NoError(t, err) } -// TestHelperRunner is used to run the database suite of tests in this package. -// KEEP EXPORTED AND AVAILABLE FOR THIRD PARTIES TO TEST PLUGINS! -func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { - t.Helper() - if k != "memory" { - t.Run(fmt.Sprintf("case=testHelperUniqueConstraints/db=%s", k), testHelperRequestIDMultiples(store, k)) - t.Run("case=testFositeSqlStoreTransactionsCommitAccessToken", testFositeSqlStoreTransactionCommitAccessToken(store)) - t.Run("case=testFositeSqlStoreTransactionsRollbackAccessToken", testFositeSqlStoreTransactionRollbackAccessToken(store)) - t.Run("case=testFositeSqlStoreTransactionCommitRefreshToken", testFositeSqlStoreTransactionCommitRefreshToken(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackRefreshToken", testFositeSqlStoreTransactionRollbackRefreshToken(store)) - t.Run("case=testFositeSqlStoreTransactionCommitAuthorizeCode", testFositeSqlStoreTransactionCommitAuthorizeCode(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackAuthorizeCode", testFositeSqlStoreTransactionRollbackAuthorizeCode(store)) - t.Run("case=testFositeSqlStoreTransactionCommitPKCERequest", testFositeSqlStoreTransactionCommitPKCERequest(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackPKCERequest", testFositeSqlStoreTransactionRollbackPKCERequest(store)) - t.Run("case=testFositeSqlStoreTransactionCommitOpenIdConnectSession", testFositeSqlStoreTransactionCommitOpenIdConnectSession(store)) - t.Run("case=testFositeSqlStoreTransactionRollbackOpenIdConnectSession", testFositeSqlStoreTransactionRollbackOpenIdConnectSession(store)) - - } - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAuthorizeCodes/db=%s", k), testHelperCreateGetDeleteAuthorizeCodes(store)) - t.Run(fmt.Sprintf("case=testHelperExpiryFields/db=%s", k), testHelperExpiryFields(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAccessTokenSession/db=%s", k), testHelperCreateGetDeleteAccessTokenSession(store)) - t.Run(fmt.Sprintf("case=testHelperNilAccessToken/db=%s", k), testHelperNilAccessToken(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteOpenIDConnectSession/db=%s", k), testHelperCreateGetDeleteOpenIDConnectSession(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteRefreshTokenSession/db=%s", k), testHelperCreateGetDeleteRefreshTokenSession(store)) - t.Run(fmt.Sprintf("case=testHelperRevokeRefreshToken/db=%s", k), testHelperRevokeRefreshToken(store)) - t.Run(fmt.Sprintf("case=testHelperCreateGetDeletePKCERequestSession/db=%s", k), testHelperCreateGetDeletePKCERequestSession(store)) - t.Run(fmt.Sprintf("case=testHelperFlushTokens/db=%s", k), testHelperFlushTokens(store, time.Hour)) - t.Run(fmt.Sprintf("case=testHelperFlushTokensWithLimitAndBatchSize/db=%s", k), testHelperFlushTokensWithLimitAndBatchSize(store, 3, 2)) - t.Run(fmt.Sprintf("case=testFositeStoreSetClientAssertionJWT/db=%s", k), testFositeStoreSetClientAssertionJWT(store)) - t.Run(fmt.Sprintf("case=testFositeStoreClientAssertionJWTValid/db=%s", k), testFositeStoreClientAssertionJWTValid(store)) - t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) - t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) - t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) - t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) +func TestHelperRunner(t *testing.T) { } -func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { +func testHelperRequestIDMultiples(m oauth2.InternalRegistry, _ string) func(t *testing.T) { return func(t *testing.T) { ctx := context.Background() requestID := uuid.New() @@ -240,12 +167,13 @@ func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing. ID: requestID, Client: cl, RequestedAt: time.Now().UTC().Round(time.Second), - Session: NewSession("bar"), + Session: oauth2.NewSession("bar"), } for i := 0; i < 4; i++ { signature := uuid.New() - err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, fositeRequest) + accessSignature := uuid.New() + err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, accessSignature, fositeRequest) assert.NoError(t, err) err = m.OAuth2Storage().CreateAccessTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) @@ -259,58 +187,60 @@ func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing. } } -func testHelperCreateGetDeleteOpenIDConnectSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteOpenIDConnectSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")}) + _, err := m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewSession("bar")}) assert.NotNil(t, err) - err = m.CreateOpenIDConnectSession(ctx, "4321", &defaultRequest) + err = m.CreateOpenIDConnectSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err := m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")}) + res, err := m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewSession("bar")}) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeleteOpenIDConnectSession(ctx, "4321") + err = m.DeleteOpenIDConnectSession(ctx, code) require.NoError(t, err) - _, err = m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")}) + _, err = m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewSession("bar")}) assert.NotNil(t, err) } } -func testHelperCreateGetDeleteRefreshTokenSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteRefreshTokenSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetRefreshTokenSession(ctx, "4321", NewSession("bar")) + _, err := m.GetRefreshTokenSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) - err = m.CreateRefreshTokenSession(ctx, "4321", &defaultRequest) + err = m.CreateRefreshTokenSession(ctx, code, "", &defaultRequest) require.NoError(t, err) - res, err := m.GetRefreshTokenSession(ctx, "4321", NewSession("bar")) + res, err := m.GetRefreshTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeleteRefreshTokenSession(ctx, "4321") + err = m.DeleteRefreshTokenSession(ctx, code) require.NoError(t, err) - _, err = m.GetRefreshTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetRefreshTokenSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) } } -func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { +func testHelperRevokeRefreshToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := context.Background() - _, err := m.GetRefreshTokenSession(ctx, "1111", NewSession("bar")) + _, err := m.GetRefreshTokenSession(ctx, "1111", oauth2.NewSession("bar")) assert.Error(t, err) reqIdOne := uuid.New() @@ -319,23 +249,23 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { mockRequestForeignKey(t, reqIdOne, x) mockRequestForeignKey(t, reqIdTwo, x) - err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{ + err = m.CreateRefreshTokenSession(ctx, "1111", "", &fosite.Request{ ID: reqIdOne, Client: &client.Client{ID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), - Session: NewSession("user"), + Session: oauth2.NewSession("user"), }) require.NoError(t, err) - err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{ + err = m.CreateRefreshTokenSession(ctx, "1122", "", &fosite.Request{ ID: reqIdTwo, Client: &client.Client{ID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), - Session: NewSession("user"), + Session: oauth2.NewSession("user"), }) require.NoError(t, err) - _, err = m.GetRefreshTokenSession(ctx, "1111", NewSession("bar")) + _, err = m.GetRefreshTokenSession(ctx, "1111", oauth2.NewSession("bar")) require.NoError(t, err) err = m.RevokeRefreshToken(ctx, reqIdOne) @@ -344,39 +274,40 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { err = m.RevokeRefreshToken(ctx, reqIdTwo) require.NoError(t, err) - req, err := m.GetRefreshTokenSession(ctx, "1111", NewSession("bar")) - assert.NotNil(t, req) - assert.EqualError(t, err, fosite.ErrInactiveToken.Error()) - - req, err = m.GetRefreshTokenSession(ctx, "1122", NewSession("bar")) - assert.NotNil(t, req) - assert.EqualError(t, err, fosite.ErrInactiveToken.Error()) + req, err := m.GetRefreshTokenSession(ctx, "1111", oauth2.NewSession("bar")) + assert.Nil(t, req) + assert.EqualError(t, err, fosite.ErrNotFound.Error()) + req, err = m.GetRefreshTokenSession(ctx, "1122", oauth2.NewSession("bar")) + assert.Nil(t, req) + assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } -func testHelperCreateGetDeleteAuthorizeCodes(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteAuthorizeCodes(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() mockRequestForeignKey(t, "blank", x) + code := uuid.New() + ctx := context.Background() - res, err := m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) + res, err := m.GetAuthorizeCodeSession(ctx, code, oauth2.NewSession("bar")) assert.Error(t, err) assert.Nil(t, res) - err = m.CreateAuthorizeCodeSession(ctx, "4321", &defaultRequest) + err = m.CreateAuthorizeCodeSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err = m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) + res, err = m.GetAuthorizeCodeSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.InvalidateAuthorizeCodeSession(ctx, "4321") + err = m.InvalidateAuthorizeCodeSession(ctx, code) require.NoError(t, err) - res, err = m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar")) + res, err = m.GetAuthorizeCodeSession(ctx, code, oauth2.NewSession("bar")) require.Error(t, err) assert.EqualError(t, err, fosite.ErrInvalidatedAuthorizeCode.Error()) assert.NotNil(t, res) @@ -392,7 +323,7 @@ func (r testHelperExpiryFieldsResult) TableName() string { return "hydra_oauth2_" + r.name } -func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { +func testHelperExpiryFields(reg oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := reg.OAuth2Storage() t.Parallel() @@ -401,7 +332,7 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { ctx := context.Background() - s := NewSession("bar") + s := oauth2.NewSession("bar") s.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour).Round(time.Minute)) s.SetExpiresAt(fosite.RefreshToken, time.Now().Add(time.Hour*2).Round(time.Minute)) s.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(time.Hour*3).Round(time.Minute)) @@ -433,7 +364,7 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { t.Run("case=CreateRefreshTokenSession", func(t *testing.T) { id := uuid.New() - err := m.CreateRefreshTokenSession(ctx, id, &request) + err := m.CreateRefreshTokenSession(ctx, id, "", &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "refresh"} @@ -473,12 +404,12 @@ func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) { } } -func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) { +func testHelperNilAccessToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() - c := &client.Client{ID: "nil-request-client-id-123"} + c := &client.Client{ID: uuid.New()} require.NoError(t, x.ClientManager().CreateClient(context.Background(), c)) - err := m.CreateAccessTokenSession(context.Background(), "nil-request-id", &fosite.Request{ + err := m.CreateAccessTokenSession(context.Background(), uuid.New(), &fosite.Request{ ID: "", RequestedAt: time.Now().UTC().Round(time.Second), Client: c, @@ -487,158 +418,251 @@ func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) { RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: NewSession("bar"), + Session: oauth2.NewSession("bar"), }) require.NoError(t, err) } } -func testHelperCreateGetDeleteAccessTokenSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeleteAccessTokenSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Error(t, err) - err = m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + err = m.CreateAccessTokenSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + res, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeleteAccessTokenSession(ctx, "4321") + err = m.DeleteAccessTokenSession(ctx, code) require.NoError(t, err) - _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Error(t, err) } } -func testHelperDeleteAccessTokens(x InternalRegistry) func(t *testing.T) { +func testHelperDeleteAccessTokens(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := context.Background() - err := m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + code := uuid.New() + err := m.CreateAccessTokenSession(ctx, code, &defaultRequest) require.NoError(t, err) - _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) err = m.DeleteAccessTokens(ctx, defaultRequest.Client.GetID()) require.NoError(t, err) - req, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + req, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } -func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { +func testHelperRevokeAccessToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := context.Background() - err := m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + code := uuid.New() + err := m.CreateAccessTokenSession(ctx, code, &defaultRequest) require.NoError(t, err) - _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) err = m.RevokeAccessToken(ctx, defaultRequest.GetID()) require.NoError(t, err) - req, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar")) + req, err := m.GetAccessTokenSession(ctx, code, oauth2.NewSession("bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } -func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { +func testHelperRotateRefreshToken(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { ctx := context.Background() + createTokens := func(t *testing.T, r *fosite.Request) (refreshTokenSession string, accessTokenSession string) { + refreshTokenSession = fmt.Sprintf("refresh_token_%s", uuid.New()) + accessTokenSession = fmt.Sprintf("access_token_%s", uuid.New()) + err := x.OAuth2Storage().CreateAccessTokenSession(ctx, accessTokenSession, r) + require.NoError(t, err) + + err = x.OAuth2Storage().CreateRefreshTokenSession(ctx, refreshTokenSession, accessTokenSession, r) + require.NoError(t, err) + + // Sanity check + req, err := x.OAuth2Storage().GetRefreshTokenSession(ctx, refreshTokenSession, nil) + require.NoError(t, err) + require.EqualValues(t, r.GetID(), req.GetID()) + + req, err = x.OAuth2Storage().GetAccessTokenSession(ctx, accessTokenSession, nil) + require.NoError(t, err) + require.EqualValues(t, r.GetID(), req.GetID()) + return + } + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { - // SETUP m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) + refreshTokenSession, accessTokenSession := createTokens(t, &r) - refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) - err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) - require.NoError(t, err, "precondition failed: could not create refresh token session") - - // ACT - err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + err := m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession) require.NoError(t, err) - tmpSession := new(fosite.Session) - _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + _, err = m.GetAccessTokenSession(ctx, accessTokenSession, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound, "Token is no longer active because it was refreshed") + + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.ErrorIs(t, err, fosite.ErrInactiveToken, "Token is no longer active because it was refreshed") + }) + + t.Run("refresh token is valid until the grace period has ended", func(t *testing.T) { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + + // By setting this to one hour we ensure that using the refresh token triggers the start of the grace period. + x.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1h") + t.Cleanup(func() { + x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) + }) + + m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) + refreshTokenSession, accessTokenSession1 := createTokens(t, &r) + accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.New()) + require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, &r)) + + // Create a second access token + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + req, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound) + + req, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil) + assert.NoError(t, err, "The second access token is still valid.") + + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.NoError(t, err) + assert.Equal(t, r.GetID(), req.GetID()) - // ASSERT - // a revoked refresh token returns an error when getting the token again - assert.ErrorIs(t, err, fosite.ErrInactiveToken) + // We only wait a second, meaning that the token is theoretically still within TTL, but since the + // grace period was issued, the token is still valid. + time.Sleep(time.Second) + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.Error(t, err) }) - t.Run("refresh token enters grace period when configured,", func(t *testing.T) { - // SETUP - x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1m") + t.Run("the used at time does not change", func(t *testing.T) { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + + // By setting this to one hour we ensure that using the refresh token triggers the start of the grace period. + x.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1h") + t.Cleanup(func() { + x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) + }) + + m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) + + refreshTokenSession, _ := createTokens(t, &r) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + var expected sql.OAuth2RefreshTable + require.NoError(t, x.Persister().Connection(ctx).Where("signature=?", refreshTokenSession).First(&expected)) + assert.False(t, expected.FirstUsedAt.Time.IsZero()) + assert.True(t, expected.FirstUsedAt.Valid) + + // Refresh does not change the time + time.Sleep(time.Second * 2) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + var actual sql.OAuth2RefreshTable + require.NoError(t, x.Persister().Connection(ctx).Where("signature=?", refreshTokenSession).First(&actual)) + assert.Equal(t, expected.FirstUsedAt.Time, actual.FirstUsedAt.Time) + }) - // always reset back to the default + t.Run("refresh token revokes all access tokens from the request if the access token signature is not found", func(t *testing.T) { + x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") t.Cleanup(func() { - x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "0m") + x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) }) m := x.OAuth2Storage() + r := newDefaultRequest(uuid.New()) - refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) - err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) - require.NoError(t, err, "precondition failed: could not create refresh token session") + refreshTokenSession := fmt.Sprintf("refresh_token_%s", uuid.New()) + accessTokenSession1 := fmt.Sprintf("access_token_%s", uuid.New()) + accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.New()) + require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession1, &r)) + require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, &r)) + + require.NoError(t, m.CreateRefreshTokenSession(ctx, refreshTokenSession, "", &r), + "precondition failed: could not create refresh token session") // ACT - require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) - require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) - require.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) + + req, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound) - req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + req, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil) + assert.ErrorIs(t, err, fosite.ErrNotFound) - // ASSERT - // when grace period is configured the refresh token can be obtained within - // the grace period without error + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.NoError(t, err) + assert.Equal(t, r.GetID(), req.GetID()) + + time.Sleep(time.Second) - assert.Equal(t, defaultRequest.GetID(), req.GetID()) + req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) + assert.Error(t, err) }) } - } -func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { +func testHelperCreateGetDeletePKCERequestSession(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() + code := uuid.New() ctx := context.Background() - _, err := m.GetPKCERequestSession(ctx, "4321", NewSession("bar")) + _, err := m.GetPKCERequestSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) - err = m.CreatePKCERequestSession(ctx, "4321", &defaultRequest) + err = m.CreatePKCERequestSession(ctx, code, &defaultRequest) require.NoError(t, err) - res, err := m.GetPKCERequestSession(ctx, "4321", NewSession("bar")) + res, err := m.GetPKCERequestSession(ctx, code, oauth2.NewSession("bar")) require.NoError(t, err) AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") - err = m.DeletePKCERequestSession(ctx, "4321") + err = m.DeletePKCERequestSession(ctx, code) require.NoError(t, err) - _, err = m.GetPKCERequestSession(ctx, "4321", NewSession("bar")) + _, err = m.GetPKCERequestSession(ctx, code, oauth2.NewSession("bar")) assert.NotNil(t, err) } } -func testHelperFlushTokens(x InternalRegistry, lifespan time.Duration) func(t *testing.T) { +func testHelperFlushTokens(x oauth2.InternalRegistry, lifespan time.Duration) func(t *testing.T) { m := x.OAuth2Storage() - ds := &Session{} + ds := &oauth2.Session{} return func(t *testing.T) { ctx := context.Background() @@ -676,9 +700,9 @@ func testHelperFlushTokens(x InternalRegistry, lifespan time.Duration) func(t *t } } -func testHelperFlushTokensWithLimitAndBatchSize(x InternalRegistry, limit int, batchSize int) func(t *testing.T) { +func testHelperFlushTokensWithLimitAndBatchSize(x oauth2.InternalRegistry, limit int, batchSize int) func(t *testing.T) { m := x.OAuth2Storage() - ds := &Session{} + ds := &oauth2.Session{} return func(t *testing.T) { ctx := context.Background() @@ -712,7 +736,7 @@ func testHelperFlushTokensWithLimitAndBatchSize(x InternalRegistry, limit int, b } } -func testFositeSqlStoreTransactionCommitAccessToken(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitAccessToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { { doTestCommit(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken) @@ -721,7 +745,7 @@ func testFositeSqlStoreTransactionCommitAccessToken(m InternalRegistry) func(t * } } -func testFositeSqlStoreTransactionRollbackAccessToken(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackAccessToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { { doTestRollback(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken) @@ -730,42 +754,41 @@ func testFositeSqlStoreTransactionRollbackAccessToken(m InternalRegistry) func(t } } -func testFositeSqlStoreTransactionCommitRefreshToken(m InternalRegistry) func(t *testing.T) { - +func testFositeSqlStoreTransactionCommitRefreshToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { - doTestCommit(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) - doTestCommit(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) + doTestCommitRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) + doTestCommitRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) } } -func testFositeSqlStoreTransactionRollbackRefreshToken(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackRefreshToken(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { - doTestRollback(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) - doTestRollback(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) + doTestRollbackRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) + doTestRollbackRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) } } -func testFositeSqlStoreTransactionCommitAuthorizeCode(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitAuthorizeCode(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestCommit(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession) } } -func testFositeSqlStoreTransactionRollbackAuthorizeCode(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackAuthorizeCode(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestRollback(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession) } } -func testFositeSqlStoreTransactionCommitPKCERequest(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitPKCERequest(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestCommit(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession) } } -func testFositeSqlStoreTransactionRollbackPKCERequest(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackPKCERequest(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { doTestRollback(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession) } @@ -773,7 +796,7 @@ func testFositeSqlStoreTransactionRollbackPKCERequest(m InternalRegistry) func(t // OpenIdConnect tests can't use the helper functions, due to the signature of GetOpenIdConnectSession being // different from the other getter methods -func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { txnStore, ok := m.OAuth2Storage().(storage.Transactional) require.True(t, ok) @@ -808,7 +831,7 @@ func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m InternalRegistry) } } -func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m InternalRegistry) func(t *testing.T) { +func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { txnStore, ok := m.OAuth2Storage().(storage.Transactional) require.True(t, ok) @@ -849,12 +872,12 @@ func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m InternalRegistr } } -func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { +func testFositeStoreSetClientAssertionJWT(m oauth2.InternalRegistry) func(*testing.T) { return func(t *testing.T) { t.Run("case=basic setting works", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("basic jti", time.Now().Add(time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) @@ -866,20 +889,20 @@ func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { }) t.Run("case=errors when the JTI is blacklisted", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("already set jti", time.Now().Add(time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) assert.ErrorIs(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown) }) t.Run("case=deletes expired JTIs", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - expiredJTI := NewBlacklistedJTI("expired jti", time.Now().Add(-time.Minute)) + expiredJTI := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), expiredJTI)) - newJTI := NewBlacklistedJTI("some new jti", time.Now().Add(time.Minute)) + newJTI := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry)) @@ -893,9 +916,9 @@ func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { }) t.Run("case=inserts same JTI if expired", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("going to be reused jti", time.Now().Add(-time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) jti.Expiry = jti.Expiry.Add(2 * time.Minute) @@ -907,19 +930,19 @@ func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) { } } -func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) { +func testFositeStoreClientAssertionJWTValid(m oauth2.InternalRegistry) func(*testing.T) { return func(t *testing.T) { t.Run("case=returns valid on unknown JTI", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), "unknown jti")) + assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), uuid.New())) }) t.Run("case=returns invalid on known JTI", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("known jti", time.Now().Add(time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) @@ -927,9 +950,9 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) }) t.Run("case=returns valid on expired JTI", func(t *testing.T) { - store, ok := m.OAuth2Storage().(AssertionJWTReader) + store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - jti := NewBlacklistedJTI("expired jti 2", time.Now().Add(-time.Minute)) + jti := oauth2.NewBlacklistedJTI(uuid.New(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) @@ -938,7 +961,7 @@ func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) } } -func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { +func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.T) { return func(t *testing.T) { ctx := context.Background() grantManager := x.GrantManager() @@ -946,12 +969,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage) t.Run("case=associated key added with grant", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "token-service-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "token-service" - subject := "bob@example.com" + issuer := uuid.New() + subject := "bob+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -992,14 +1015,14 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=only associated key returns", func(t *testing.T) { - keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "some-key", "sig") + keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.New(), "sig") require.NoError(t, err) - require.NoError(t, keyManager.AddKeySet(context.Background(), "some-set", keySetToNotReturn), "adding a random key should not fail") + require.NoError(t, keyManager.AddKeySet(context.Background(), uuid.New(), keySetToNotReturn), "adding a random key should not fail") - issuer := "maria" - subject := "maria@example.com" + issuer := uuid.New() + subject := "maria+" + uuid.New() + "@example.com" - keySet1ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-1", "sig") + keySet1ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.New(), "sig") require.NoError(t, err) require.NoError(t, grantManager.CreateGrant(context.Background(), trust.Grant{ ID: uuid.New(), @@ -1012,7 +1035,7 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), }, keySet1ToReturn.Keys[0].Public())) - keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig") + keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.New(), "sig") require.NoError(t, err) require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{ ID: uuid.New(), @@ -1055,12 +1078,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=associated key is deleted, when granted is deleted", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "hackerman-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "aeneas" - subject := "aeneas@example.com" + issuer := uuid.New() + subject := "aeneas+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1092,12 +1115,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=associated grant is deleted, when key is deleted", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "vladimir-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "vladimir" - subject := "vladimir@example.com" + issuer := uuid.New() + subject := "vladimir+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1129,12 +1152,12 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=only returns the key when subject matches", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "limited-issuer" - subject := "jagoba" + issuer := uuid.New() + subject := "jagoba+" + uuid.New() + "@example.com" grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1171,11 +1194,11 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=returns the key when any subject is allowed", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "unlimited-issuer" + issuer := uuid.New() grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1204,11 +1227,11 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { }) t.Run("case=does not return expired values", func(t *testing.T) { - keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-expired-key", "sig") + keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() - issuer := "expired-issuer" + issuer := uuid.New() grant := trust.Grant{ ID: uuid.New(), Issuer: issuer, @@ -1230,12 +1253,11 @@ func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) { } } -func doTestCommit(m InternalRegistry, t *testing.T, +func doTestCommit(m oauth2.InternalRegistry, t *testing.T, createFn func(context.Context, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, ) { - txnStore, ok := m.OAuth2Storage().(storage.Transactional) require.True(t, ok) ctx := context.Background() @@ -1248,7 +1270,44 @@ func doTestCommit(m InternalRegistry, t *testing.T, require.NoError(t, err) // Require a new context, since the old one contains the transaction. - res, err := getFn(context.Background(), signature, NewSession("bar")) + res, err := getFn(context.Background(), signature, oauth2.NewSession("bar")) + // token should have been created successfully because Commit did not return an error + require.NoError(t, err) + assertx.EqualAsJSONExcept(t, &defaultRequest, res, defaultIgnoreKeys) + // AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") + + // testrevoke within a transaction + ctx, err = txnStore.BeginTX(context.Background()) + require.NoError(t, err) + err = revokeFn(ctx, signature) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + _, err = getFn(context.Background(), signature, oauth2.NewSession("bar")) + // Since commit worked for revoke, we should get an error here. + require.Error(t, err) +} + +func doTestCommitRefresh(m oauth2.InternalRegistry, t *testing.T, + createFn func(context.Context, string, string, fosite.Requester) error, + getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), + revokeFn func(context.Context, string) error, +) { + txnStore, ok := m.OAuth2Storage().(storage.Transactional) + require.True(t, ok) + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + err = createFn(ctx, signature, "", createTestRequest(signature)) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + res, err := getFn(context.Background(), signature, oauth2.NewSession("bar")) // token should have been created successfully because Commit did not return an error require.NoError(t, err) assertx.EqualAsJSONExcept(t, &defaultRequest, res, defaultIgnoreKeys) @@ -1263,12 +1322,12 @@ func doTestCommit(m InternalRegistry, t *testing.T, require.NoError(t, err) // Require a new context, since the old one contains the transaction. - _, err = getFn(context.Background(), signature, NewSession("bar")) + _, err = getFn(context.Background(), signature, oauth2.NewSession("bar")) // Since commit worked for revoke, we should get an error here. require.Error(t, err) } -func doTestRollback(m InternalRegistry, t *testing.T, +func doTestRollback(m oauth2.InternalRegistry, t *testing.T, createFn func(context.Context, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, @@ -1287,7 +1346,7 @@ func doTestRollback(m InternalRegistry, t *testing.T, // Require a new context, since the old one contains the transaction. ctx = context.Background() - _, err = getFn(ctx, signature, NewSession("bar")) + _, err = getFn(ctx, signature, oauth2.NewSession("bar")) // Since we rolled back above, the token should not exist and getting it should result in an error require.Error(t, err) @@ -1295,7 +1354,48 @@ func doTestRollback(m InternalRegistry, t *testing.T, signature2 := uuid.New() err = createFn(ctx, signature2, createTestRequest(signature2)) require.NoError(t, err) - _, err = getFn(ctx, signature2, NewSession("bar")) + _, err = getFn(ctx, signature2, oauth2.NewSession("bar")) + require.NoError(t, err) + + ctx, err = txnStore.BeginTX(context.Background()) + require.NoError(t, err) + err = revokeFn(ctx, signature2) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + _, err = getFn(context.Background(), signature2, oauth2.NewSession("bar")) + require.NoError(t, err) +} + +func doTestRollbackRefresh(m oauth2.InternalRegistry, t *testing.T, + createFn func(context.Context, string, string, fosite.Requester) error, + getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), + revokeFn func(context.Context, string) error, +) { + txnStore, ok := m.OAuth2Storage().(storage.Transactional) + require.True(t, ok) + + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + err = createFn(ctx, signature, "", createTestRequest(signature)) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + ctx = context.Background() + _, err = getFn(ctx, signature, oauth2.NewSession("bar")) + // Since we rolled back above, the token should not exist and getting it should result in an error + require.Error(t, err) + + // create a new token, revoke it, then rollback the revoke. We should be able to then get it successfully. + signature2 := uuid.New() + err = createFn(ctx, signature2, "", createTestRequest(signature2)) + require.NoError(t, err) + _, err = getFn(ctx, signature2, oauth2.NewSession("bar")) require.NoError(t, err) ctx, err = txnStore.BeginTX(context.Background()) @@ -1305,7 +1405,7 @@ func doTestRollback(m InternalRegistry, t *testing.T, err = txnStore.Rollback(ctx) require.NoError(t, err) - _, err = getFn(context.Background(), signature2, NewSession("bar")) + _, err = getFn(context.Background(), signature2, oauth2.NewSession("bar")) require.NoError(t, err) } @@ -1319,6 +1419,6 @@ func createTestRequest(id string) *fosite.Request { RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, } } diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 2a48a52f8e7..e3f3c6a13ec 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -7,16 +7,13 @@ import ( "context" "flag" "testing" + "time" - "github.com/stretchr/testify/require" + "github.com/ory/hydra/v2/internal/testhelpers" - "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" - . "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" - "github.com/ory/x/networkx" "github.com/ory/x/sqlcon/dockertest" ) @@ -29,7 +26,7 @@ func TestMain(m *testing.M) { var registries = make(map[string]driver.Registry) var cleanRegistries = func(t *testing.T) { - registries["memory"] = internal.NewRegistryMemory(t, internal.NewConfigurationWithDefaults(), &contextx.Default{}) + registries["memory"] = testhelpers.NewRegistryMemory(t, testhelpers.NewConfigurationWithDefaults(), &contextx.Default{}) } // returns clean registries that can safely be used for one test @@ -38,7 +35,7 @@ func setupRegistries(t *testing.T) { if len(registries) == 0 && !testing.Short() { // first time called and sql tests var cleanSQL func(*testing.T) - registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = internal.ConnectDatabases(t, true, &contextx.Default{}) + registries["postgres"], registries["mysql"], registries["cockroach"], cleanSQL = testhelpers.ConnectDatabases(t, false, &contextx.Default{}) cleanMem := cleanRegistries cleanMem(t) cleanRegistries = func(t *testing.T) { @@ -52,6 +49,8 @@ func setupRegistries(t *testing.T) { } func TestManagers(t *testing.T) { + setupRegistries(t) + ctx := context.Background() tests := []struct { name string @@ -68,18 +67,43 @@ func TestManagers(t *testing.T) { } for _, tc := range tests { t.Run("suite="+tc.name, func(t *testing.T) { - setupRegistries(t) + for k, r := range registries { + t.Run("database="+k, func(t *testing.T) { + store := testhelpers.NewRegistrySQLFromURL(t, r.Config().DSN(), true, &contextx.Default{}) + store.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) - require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{ID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers. + if k != "memory" { + t.Run("testHelperUniqueConstraints", testHelperRequestIDMultiples(store, k)) + t.Run("case=testFositeSqlStoreTransactionsCommitAccessToken", testFositeSqlStoreTransactionCommitAccessToken(store)) + t.Run("case=testFositeSqlStoreTransactionsRollbackAccessToken", testFositeSqlStoreTransactionRollbackAccessToken(store)) + t.Run("case=testFositeSqlStoreTransactionCommitRefreshToken", testFositeSqlStoreTransactionCommitRefreshToken(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackRefreshToken", testFositeSqlStoreTransactionRollbackRefreshToken(store)) + t.Run("case=testFositeSqlStoreTransactionCommitAuthorizeCode", testFositeSqlStoreTransactionCommitAuthorizeCode(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackAuthorizeCode", testFositeSqlStoreTransactionRollbackAuthorizeCode(store)) + t.Run("case=testFositeSqlStoreTransactionCommitPKCERequest", testFositeSqlStoreTransactionCommitPKCERequest(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackPKCERequest", testFositeSqlStoreTransactionRollbackPKCERequest(store)) + t.Run("case=testFositeSqlStoreTransactionCommitOpenIdConnectSession", testFositeSqlStoreTransactionCommitOpenIdConnectSession(store)) + t.Run("case=testFositeSqlStoreTransactionRollbackOpenIdConnectSession", testFositeSqlStoreTransactionRollbackOpenIdConnectSession(store)) + } - for k, store := range registries { - net := &networkx.Network{} - require.NoError(t, store.Persister().Connection(context.Background()).First(net)) - store.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) - store.WithContextualizer(&contextx.Static{NID: net.ID, C: store.Config().Source(ctx)}) - TestHelperRunner(t, store, k) + t.Run("testHelperCreateGetDeleteAuthorizeCodes", testHelperCreateGetDeleteAuthorizeCodes(store)) + t.Run("testHelperExpiryFields", testHelperExpiryFields(store)) + t.Run("testHelperCreateGetDeleteAccessTokenSession", testHelperCreateGetDeleteAccessTokenSession(store)) + t.Run("testHelperNilAccessToken", testHelperNilAccessToken(store)) + t.Run("testHelperCreateGetDeleteOpenIDConnectSession", testHelperCreateGetDeleteOpenIDConnectSession(store)) + t.Run("testHelperCreateGetDeleteRefreshTokenSession", testHelperCreateGetDeleteRefreshTokenSession(store)) + t.Run("testHelperRevokeRefreshToken", testHelperRevokeRefreshToken(store)) + t.Run("testHelperCreateGetDeletePKCERequestSession", testHelperCreateGetDeletePKCERequestSession(store)) + t.Run("testHelperFlushTokens", testHelperFlushTokens(store, time.Hour)) + t.Run("testHelperFlushTokensWithLimitAndBatchSize", testHelperFlushTokensWithLimitAndBatchSize(store, 3, 2)) + t.Run("testFositeStoreSetClientAssertionJWT", testFositeStoreSetClientAssertionJWT(store)) + t.Run("testFositeStoreClientAssertionJWTValid", testFositeStoreClientAssertionJWTValid(store)) + t.Run("testHelperDeleteAccessTokens", testHelperDeleteAccessTokens(store)) + t.Run("testHelperRevokeAccessToken", testHelperRevokeAccessToken(store)) + t.Run("testFositeJWTBearerGrantStorage", testFositeJWTBearerGrantStorage(store)) + t.Run("testHelperRevokeRefreshTokenMaybeGracePeriod", testHelperRotateRefreshToken(store)) + }) } }) - } } diff --git a/oauth2/handler.go b/oauth2/handler.go index 288ed1f16f0..3f1a633038d 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -727,11 +727,14 @@ type revokeOAuth2Token struct { // default: errorOAuth2 func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - events.Trace(ctx, events.AccessTokenRevoked) - err := h.r.OAuth2Provider().NewRevocationRequest(ctx, r) + err := h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { + return h.r.OAuth2Provider().NewRevocationRequest(ctx, r) + }) if err != nil { x.LogError(r, err, h.r.Logger()) + } else { + events.Trace(ctx, events.AccessTokenRevoked) } h.r.OAuth2Provider().WriteRevocationResponse(ctx, w, err) diff --git a/oauth2/handler_fallback_endpoints_test.go b/oauth2/handler_fallback_endpoints_test.go index 191cd15a03a..9e3107b722b 100644 --- a/oauth2/handler_fallback_endpoints_test.go +++ b/oauth2/handler_fallback_endpoints_test.go @@ -10,22 +10,23 @@ import ( "net/http/httptest" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/httprouterx" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/stretchr/testify/assert" ) func TestHandlerConsent(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(context.Background(), config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) h := reg.OAuth2Handler() r := x.NewRouterAdmin(conf.AdminURL) diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index f2d159af614..50705fad6bf 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -15,6 +15,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/x/httprouterx" @@ -31,13 +33,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" - "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/oauth2" ) @@ -45,9 +45,9 @@ var lifespan = time.Hour func TestHandlerDeleteHandler(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) cm := reg.ClientManager() store := reg.OAuth2Storage() @@ -88,12 +88,12 @@ func TestHandlerDeleteHandler(t *testing.T) { func TestUserinfo(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyScopeStrategy, "") conf.MustSet(ctx, config.KeyAuthCodeLifespan, lifespan) conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) ctrl := gomock.NewController(t) op := NewMockOAuth2Provider(ctrl) @@ -340,7 +340,7 @@ func TestUserinfo(t *testing.T) { func TestHandlerWellKnown(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() t.Run(fmt.Sprintf("hsm_enabled=%v", conf.HSMEnabled()), func(t *testing.T) { conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") @@ -348,7 +348,7 @@ func TestHandlerWellKnown(t *testing.T) { conf.MustSet(ctx, config.KeyOIDCDiscoverySupportedClaims, []string{"sub"}) conf.MustSet(ctx, config.KeyOAuth2ClientRegistrationURL, "http://client-register/registration") conf.MustSet(ctx, config.KeyOIDCDiscoveryUserinfoEndpoint, "/userinfo") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) h := oauth2.NewHandler(reg, conf) diff --git a/oauth2/helper_test.go b/oauth2/helper_test.go index 3a40592bfdd..04f41298b71 100644 --- a/oauth2/helper_test.go +++ b/oauth2/helper_test.go @@ -5,6 +5,11 @@ package oauth2_test import ( "context" + "testing" + + "github.com/oleiade/reflections" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" @@ -20,3 +25,60 @@ func Tokens(c fosite.Configurator, length int) (res [][]string) { } return res } + +func AssertObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + assert.Equal(t, c, d, "%s", k) + } +} + +func AssertObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + assert.NotEqual(t, c, d, "%s", k) + } +} + +func RequireObjectKeysEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + require.Equal(t, c, d, "%s", k) + } +} +func RequireObjectKeysNotEqual(t *testing.T, a, b interface{}, keys ...string) { + assert.True(t, len(keys) > 0, "No keys provided.") + for _, k := range keys { + c, err := reflections.GetField(a, k) + assert.Nil(t, err) + d, err := reflections.GetField(b, k) + assert.Nil(t, err) + require.NotEqual(t, c, d, "%s", k) + } +} + +func TestAssertObjectsAreEqualByKeys(t *testing.T) { + type foo struct { + Name string + Body int + } + a := &foo{"foo", 1} + b := &foo{"bar", 1} + c := &foo{"baz", 3} + + AssertObjectKeysEqual(t, a, a, "Name", "Body") + AssertObjectKeysNotEqual(t, a, b, "Name") + AssertObjectKeysNotEqual(t, a, c, "Name", "Body") +} diff --git a/oauth2/helpers.go b/oauth2/helpers.go new file mode 100644 index 00000000000..4db4bf84d8e --- /dev/null +++ b/oauth2/helpers.go @@ -0,0 +1,51 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + + "github.com/gobuffalo/pop/v6" + gofrsuuid "github.com/gofrs/uuid" + + "github.com/ory/hydra/v2/x" +) + +func signatureFromJTI(jti string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(jti))) +} + +type BlacklistedJTI struct { + JTI string `db:"-"` + ID string `db:"signature"` + Expiry time.Time `db:"expires_at"` + NID gofrsuuid.UUID `db:"nid"` +} + +func (j *BlacklistedJTI) AfterFind(_ *pop.Connection) error { + j.Expiry = j.Expiry.UTC() + return nil +} + +func (BlacklistedJTI) TableName() string { + return "hydra_oauth2_jti_blacklist" +} + +func NewBlacklistedJTI(jti string, exp time.Time) *BlacklistedJTI { + return &BlacklistedJTI{ + JTI: jti, + ID: signatureFromJTI(jti), + // because the database timestamp types are not as accurate as time.Time we truncate to seconds (which should always work) + Expiry: exp.UTC().Truncate(time.Second), + } +} + +type AssertionJWTReader interface { + x.FositeStorer + GetClientAssertionJWT(ctx context.Context, jti string) (*BlacklistedJTI, error) + SetClientAssertionJWTRaw(context.Context, *BlacklistedJTI) error +} diff --git a/oauth2/introspector_test.go b/oauth2/introspector_test.go index 16b279f036f..43b565d2f58 100644 --- a/oauth2/introspector_test.go +++ b/oauth2/introspector_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/x/httprouterx" @@ -30,12 +32,12 @@ import ( func TestIntrospectorSDK(t *testing.T) { ctx := context.Background() - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyScopeStrategy, "wildcard") conf.MustSet(ctx, config.KeyIssuerURL, "https://foobariss") - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) internal.AddFositeExamples(reg) tokens := Tokens(reg.OAuth2ProviderConfig(), 4) diff --git a/oauth2/oauth2_auth_code_bench_test.go b/oauth2/oauth2_auth_code_bench_test.go index 568ff00287c..9347982630a 100644 --- a/oauth2/oauth2_auth_code_bench_test.go +++ b/oauth2/oauth2_auth_code_bench_test.go @@ -33,7 +33,6 @@ import ( hydra "github.com/ory/hydra-client-go/v2" hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" @@ -79,7 +78,7 @@ func BenchmarkAuthCode(b *testing.B) { dsn := stringsx.Coalesce(os.Getenv("DSN"), "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable&max_conns=20&max_idle_conns=20") // dsn := "mysql://root:secret@tcp(localhost:3444)/mysql?max_conns=16&max_idle_conns=16" // dsn := "cockroach://root@localhost:3446/defaultdb?sslmode=disable&max_conns=16&max_idle_conns=16" - reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) + reg := testhelpers.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) reg.Config().MustSet(ctx, config.KeyLogLevel, "error") reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index feea2451e27..9cc687708a7 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net/http" "net/http/httptest" "net/url" @@ -20,6 +21,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/jwk" + "github.com/go-jose/go-jose/v3" "github.com/golang-jwt/jwt/v5" "github.com/julienschmidt/httprouter" @@ -36,7 +39,6 @@ import ( "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" hydraoauth2 "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/x" @@ -62,1299 +64,1480 @@ type clientCreator interface { CreateClient(context.Context, *client.Client) error } -// TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically -// we test: -// -// - [x] If the flow - in general - works -// - [x] If `authenticatedAt` is properly managed across the lifecycle -// - [x] The value `authenticatedAt` should be an old time if no user interaction wrt login was required -// - [x] The value `authenticatedAt` should be a recent time if user interaction wrt login was required -// -// - [x] If `requestedAt` is properly managed across the lifecycle -// - [x] The value of `requestedAt` must be the initial request time, not some other time (e.g. when accepting login) -// -// - [x] If `id_token_hint` is handled properly -// - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") -func TestAuthCodeWithDefaultStrategy(t *testing.T) { - ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") - publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) - - publicClient := hydra.NewAPIClient(hydra.NewConfiguration()) - publicClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: publicTS.URL}} - adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) - adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} - - getAuthorizeCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { - if c == nil { - c = testhelpers.NewEmptyJarClient(t) - } +func getAuthorizeCode(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { + if c == nil { + c = testhelpers.NewEmptyJarClient(t) + } - state := uuid.New() - resp, err := c.Get(conf.AuthCodeURL(state, params...)) - require.NoError(t, err) - defer resp.Body.Close() + state := uuid.New() + resp, err := c.Get(conf.AuthCodeURL(state, params...)) + require.NoError(t, err) + defer resp.Body.Close() - q := resp.Request.URL.Query() - require.EqualValues(t, state, q.Get("state")) - return q.Get("code"), resp - } + q := resp.Request.URL.Query() + require.EqualValues(t, state, q.Get("state")) + return q.Get("code"), resp +} - acceptLoginHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - - acceptBody := hydra.AcceptOAuth2LoginRequest{ - Subject: subject, - Remember: pointerx.Ptr(!rr.Skip), - Acr: pointerx.Ptr("1"), - Amr: []string{"pwd"}, - Context: map[string]interface{}{"context": "bar"}, - } - if checkRequestPayload != nil { - if b := checkRequestPayload(rr); b != nil { - acceptBody = *b - } - } +func acceptLoginHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() + require.NoError(t, err) - v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(acceptBody). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) + + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Remember: pointerx.Ptr(!rr.Skip), + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, } - } - - acceptConsentHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) - assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - if checkRequestPayload != nil { - checkRequestPayload(rr) + if checkRequestPayload != nil { + if b := checkRequestPayload(rr); b != nil { + acceptBody = *b } - - assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) - v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). - ConsentChallenge(r.URL.Query().Get("consent_challenge")). - AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ - GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), - GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, - Session: &hydra.AcceptOAuth2ConsentRequestSession{ - AccessToken: map[string]interface{}{"foo": "bar"}, - IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, - }, - }). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) } - } - - assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { - introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) - actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) - require.NoError(t, err, "%s", introspect) - requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) - } - assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result { - idt, ok := token.Extra("id_token").(string) - require.True(t, ok) - assert.NotEmpty(t, idt) - - body, err := x.DecodeSegment(strings.Split(idt, ".")[1]) + v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() require.NoError(t, err) - - claims := gjson.ParseBytes(body) - assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims) - assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims) - assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims) - requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 2*time.Second) - assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims) - assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims) - assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) - assert.Equal(t, "1", claims.Get("acr").String(), "%s", claims) - require.Len(t, claims.Get("amr").Array(), 1, "%s", claims) - assert.EqualValues(t, "pwd", claims.Get("amr").Array()[0].String(), "%s", claims) - - require.Len(t, claims.Get("aud").Array(), 1, "%s", claims) - assert.EqualValues(t, c.ClientID, claims.Get("aud").Array()[0].String(), "%s", claims) - assert.EqualValues(t, expectedSubject, claims.Get("sub").String(), "%s", claims) - assert.EqualValues(t, expectedNonce, claims.Get("nonce").String(), "%s", claims) - assert.EqualValues(t, `baz`, claims.Get("bar").String(), "%s", claims) - assert.EqualValues(t, `foo@bar.com`, claims.Get("email").String(), "%s", claims) - assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) - - return claims + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) } +} - introspectAccessToken := func(t *testing.T, conf *oauth2.Config, token *oauth2.Token, expectedSubject string) gjson.Result { - require.NotEmpty(t, token.AccessToken) - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.True(t, i.Get("active").Bool(), "%s", i) - assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) - assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) - return i - } +func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) - assertJWTAccessToken := func(t *testing.T, strat string, conf *oauth2.Config, token *oauth2.Token, expectedSubject string, expectedExp time.Time, scopes string) gjson.Result { - require.NotEmpty(t, token.AccessToken) - parts := strings.Split(token.AccessToken, ".") - if strat != "jwt" { - require.Len(t, parts, 2) - return gjson.Parse("null") + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) + assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(r.Context()).String()) + if checkRequestPayload != nil { + checkRequestPayload(rr) } - require.Len(t, parts, 3) - body, err := x.DecodeSegment(parts[1]) + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, + }, + }). + Execute() require.NoError(t, err) - - i := gjson.ParseBytes(body) - assert.NotEmpty(t, i.Get("jti").String()) - assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) - assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), i.Get("iss").String(), "%s", i) - assert.True(t, time.Now().After(time.Unix(i.Get("iat").Int(), 0)), "%s", i) - assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) - assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) - requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) - assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) - assert.EqualValues(t, scopes, i.Get("scp").Raw, "%s", i) - return i - } - - waitForRefreshTokenExpiry := func() { - time.Sleep(reg.Config().GetRefreshTokenLifespan(ctx) + time.Second) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) } +} - t.Run("case=checks if request fails when audience does not match", func(t *testing.T) { - testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) - _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("audience", "https://not-ory-api/")) - require.Empty(t, code) - }) - - subject := "aeneas-rekkas" - nonce := uuid.New() - t.Run("case=perform authorize code flow with ID token and refresh tokens", func(t *testing.T) { - run := func(t *testing.T, strategy string) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - iat := time.Now() - require.NoError(t, err) - - assert.Empty(t, token.Extra("c_nonce_draft_00"), "should not be set if not requested") - assert.Empty(t, token.Extra("c_nonce_expires_in_draft_00"), "should not be set if not requested") - introspectAccessToken(t, conf, token, subject) - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) - iat = time.Now() - refreshedToken, err := conf.TokenSource(context.Background(), token).Token() - require.NoError(t, err) - - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) - require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - introspectAccessToken(t, conf, refreshedToken, subject) - - t.Run("followup=refreshed tokens contain valid tokens", func(t *testing.T) { - assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - }) - - t.Run("followup=original access token is no longer valid", func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - - t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - }) - - t.Run("followup=but fail subsequent refresh because expiry was reached", func(t *testing.T) { - waitForRefreshTokenExpiry() +// TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically +// we test: +// +// - [x] If the flow - in general - works +// - [x] If `authenticatedAt` is properly managed across the lifecycle +// - [x] The value `authenticatedAt` should be an old time if no user interaction wrt login was required +// - [x] The value `authenticatedAt` should be a recent time if user interaction wrt login was required +// +// - [x] If `requestedAt` is properly managed across the lifecycle +// - [x] The value of `requestedAt` must be the initial request time, not some other time (e.g. when accepting login) +// +// - [x] If `id_token_hint` is handled properly +// - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") +func TestAuthCodeWithDefaultStrategy(t *testing.T) { + setupRegistries(t) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + ctx := context.Background() - // Force golang to refresh token - refreshedToken.Expiry = refreshedToken.Expiry.Add(-time.Hour * 24) - _, err := conf.TokenSource(context.Background(), refreshedToken).Token() - require.Error(t, err) - }) - }) - } + for dbName, reg := range registries { + t.Run("registry="+dbName, func(t *testing.T) { + reg := testhelpers.NewRegistrySQLFromURL(t, reg.Config().DSN(), true, &contextx.Default{}) - t.Run("strategy=jwt", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt") - }) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) - t.Run("strategy=opaque", func(t *testing.T) { reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque") - }) - }) - - t.Run("case=graceful token rotation", func(t *testing.T) { - run := func(t *testing.T, strategy string) { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") - t.Cleanup(func() { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) - }) + reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") + publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) + + publicClient := hydra.NewAPIClient(hydra.NewConfiguration()) + publicClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: publicTS.URL}} + adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) + adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} + + assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { + introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) + actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) + require.NoError(t, err, "%s", introspect) + requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second*2) + } - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) + assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result { + idt, ok := token.Extra("id_token").(string) + require.True(t, ok) + assert.NotEmpty(t, idt) - issueTokens := func(t *testing.T) *oauth2.Token { - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - iat := time.Now() + body, err := x.DecodeSegment(strings.Split(idt, ".")[1]) require.NoError(t, err) - introspectAccessToken(t, conf, token, subject) - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - return token + claims := gjson.ParseBytes(body) + assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims) + assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims) + assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims) + requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 3*time.Second) + assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims) + assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) + assert.Equal(t, "1", claims.Get("acr").String(), "%s", claims) + require.Len(t, claims.Get("amr").Array(), 1, "%s", claims) + assert.EqualValues(t, "pwd", claims.Get("amr").Array()[0].String(), "%s", claims) + + require.Len(t, claims.Get("aud").Array(), 1, "%s", claims) + assert.EqualValues(t, c.ClientID, claims.Get("aud").Array()[0].String(), "%s", claims) + assert.EqualValues(t, expectedSubject, claims.Get("sub").String(), "%s", claims) + assert.EqualValues(t, expectedNonce, claims.Get("nonce").String(), "%s", claims) + assert.EqualValues(t, `baz`, claims.Get("bar").String(), "%s", claims) + assert.EqualValues(t, `foo@bar.com`, claims.Get("email").String(), "%s", claims) + assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) + + return claims } - refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { - require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) - iat := time.Now() - refreshedToken, err := conf.TokenSource(context.Background(), token).Token() - require.NoError(t, err) - - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) - require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - - introspectAccessToken(t, conf, refreshedToken, subject) - assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - return refreshedToken + introspectAccessToken := func(t *testing.T, conf *oauth2.Config, token *oauth2.Token, expectedSubject string) gjson.Result { + require.NotEmpty(t, token.AccessToken) + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.True(t, i.Get("active").Bool(), "%s", i) + assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) + assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + return i } - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - start := time.Now() - - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + assertJWTAccessToken := func(t *testing.T, strat string, conf *oauth2.Config, token *oauth2.Token, expectedSubject string, expectedExp time.Time, scopes string) gjson.Result { + require.NotEmpty(t, token.AccessToken) + parts := strings.Split(token.AccessToken, ".") + if strat != "jwt" { + require.Len(t, parts, 2) + return gjson.Parse("null") + } + require.Len(t, parts, 3) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + body, err := x.DecodeSegment(parts[1]) + require.NoError(t, err) - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i := gjson.ParseBytes(body) + assert.NotEmpty(t, i.Get("jti").String()) + assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) + assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), i.Get("iss").String(), "%s", i) + assert.True(t, time.Now().After(time.Unix(i.Get("iat").Int(), 0)), "%s", i) + assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) + assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) + requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + assert.EqualValues(t, scopes, i.Get("scp").Raw, "%s", i) + return i + } - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + waitForRefreshTokenExpiry := func() { + time.Sleep(reg.Config().GetRefreshTokenLifespan(ctx) + time.Second) + } - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + subject := "aeneas-rekkas" + nonce := uuid.New() - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) + t.Run("case=checks if request fails when audience does not match", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) + _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("audience", "https://not-ory-api/")) + require.Empty(t, code) }) - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - start := time.Now() - - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + t.Run("case=perform authorize code flow with ID token and refresh tokens", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() require.NoError(t, err) - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + assert.Empty(t, token.Extra("c_nonce_draft_00"), "should not be set if not requested") + assert.Empty(t, token.Extra("c_nonce_expires_in_draft_00"), "should not be set if not requested") + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat = time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + introspectAccessToken(t, conf, refreshedToken, subject) - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=refreshed tokens contain valid tokens", func(t *testing.T) { + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + }) - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=original access token is no longer valid", func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - }) + t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + }) - t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - a1Refresh = refreshTokens(t, token) - }) + t.Run("followup=but fail subsequent refresh because expiry was reached", func(t *testing.T) { + waitForRefreshTokenExpiry() - t.Run("followup=second refresh", func(t *testing.T) { - b1Refresh = refreshTokens(t, token) - }) + // Force golang to refresh token + refreshedToken.Expiry = refreshedToken.Expiry.Add(-time.Hour * 24) + _, err := conf.TokenSource(context.Background(), refreshedToken).Token() + require.Error(t, err) + }) + }) + } - t.Run("followup=first refresh from first refresh", func(t *testing.T) { - a2RefreshA = refreshTokens(t, a1Refresh) + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") }) - t.Run("followup=second refresh from first refresh", func(t *testing.T) { - a2RefreshB = refreshTokens(t, a1Refresh) + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") }) + }) - t.Run("followup=first refresh from second refresh", func(t *testing.T) { - b2RefreshA = refreshTokens(t, b1Refresh) - }) + t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { + // Make sure we test against all crypto suites that we advertise. + cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() + require.NoError(t, err) + supportedCryptoSuites := cfg.CredentialsSupportedDraft00[0].CryptographicSuitesSupported + + run := func(t *testing.T, strategy string) { + _, conf := newOAuth2Client( + t, + reg, + testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler), + withScope("openid userinfo_credential_draft_00"), + ) + testhelpers.NewLoginConsentUI(t, reg.Config(), + func(w http.ResponseWriter, r *http.Request) { + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, + } + v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + }, + func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) - t.Run("followup=second refresh from second refresh", func(t *testing.T) { - b2RefreshB = refreshTokens(t, b1Refresh) - }) + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"openid", "userinfo_credential_draft_00"}, + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"email": "foo@bar.com", "bar": "baz"}, + }, + }). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + }, + ) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("scope", "openid userinfo_credential_draft_00"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + iat := time.Now() - for k, token := range []*oauth2.Token{ - a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + vcNonce := token.Extra("c_nonce_draft_00").(string) + assert.NotEmpty(t, vcNonce) + expiry := token.Extra("c_nonce_expires_in_draft_00") + assert.NotEmpty(t, expiry) + assert.NoError(t, reg.Persister().IsNonceValid(ctx, token.AccessToken, vcNonce)) - i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=successfully create a verifiable credential", func(t *testing.T) { + t.Parallel() - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + for _, alg := range supportedCryptoSuites { + alg := alg + t.Run(fmt.Sprintf("alg=%s", alg), func(t *testing.T) { + t.Parallel() + assertCreateVerifiableCredential(t, reg, vcNonce, token, jose.SignatureAlgorithm(alg)) + }) + } + }) - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("followup=get new nonce from priming request", func(t *testing.T) { + t.Parallel() + // Assert that we can fetch a verifiable credential with the nonce. + res, err := doPrimingRequest(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ + Format: "jwt_vc_json", + Types: []string{"VerifiableCredential", "UserInfoCredential"}, }) - } - }) - }) - } + assert.NoError(t, err) - t.Run("strategy=jwt", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt") - }) + t.Run("followup=successfully create a verifiable credential from fresh nonce", func(t *testing.T) { + assertCreateVerifiableCredential(t, reg, res.Nonce, token, jose.ES256) + }) + }) - t.Run("strategy=opaque", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque") - }) - }) + t.Run("followup=rejects proof signed by another key", func(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + format string + proofType string + proof func() string + }{ + { + name: "proof=mismatching keys", + proof: func() string { + // Create mismatching public and private keys. + pubKey, _, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + _, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) + }, + }, + { + name: "proof=invalid format", + format: "invalid_format", + proof: func() string { + // Create mismatching public and private keys. + pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) + }, + }, + { + name: "proof=invalid type", + proofType: "invalid", + proof: func() string { + // Create mismatching public and private keys. + pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) + }, + }, + { + name: "proof=invalid nonce", + proof: func() string { + // Create mismatching public and private keys. + pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) + require.NoError(t, err) + pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} + return createVCProofJWT(t, pubKeyJWK, privKey, "invalid nonce") + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := createVerifiableCredential(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ + Format: stringsx.Coalesce(tc.format, "jwt_vc_json"), + Types: []string{"VerifiableCredential", "UserInfoCredential"}, + Proof: &hydraoauth2.VerifiableCredentialProof{ + ProofType: stringsx.Coalesce(tc.proofType, "jwt"), + JWT: tc.proof(), + }, + }) + require.Error(t, err) + assert.Equal(t, "invalid_request", err.Error()) + }) + } - t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { - // Make sure we test against all crypto suites that we advertise. - cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() - require.NoError(t, err) - supportedCryptoSuites := cfg.CredentialsSupportedDraft00[0].CryptographicSuitesSupported - - run := func(t *testing.T, strategy string) { - _, conf := newOAuth2Client( - t, - reg, - testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler), - withScope("openid userinfo_credential_draft_00"), - ) - testhelpers.NewLoginConsentUI(t, reg.Config(), - func(w http.ResponseWriter, r *http.Request) { - acceptBody := hydra.AcceptOAuth2LoginRequest{ - Subject: subject, - Acr: pointerx.Ptr("1"), - Amr: []string{"pwd"}, - Context: map[string]interface{}{"context": "bar"}, - } - v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(acceptBody). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - }, - func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() - require.NoError(t, err) + }) - assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) - v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). - ConsentChallenge(r.URL.Query().Get("consent_challenge")). - AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ - GrantScope: []string{"openid", "userinfo_credential_draft_00"}, - GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, - Session: &hydra.AcceptOAuth2ConsentRequestSession{ - AccessToken: map[string]interface{}{"foo": "bar"}, - IdToken: map[string]interface{}{"email": "foo@bar.com", "bar": "baz"}, - }, - }). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - }, - ) - - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("scope", "openid userinfo_credential_draft_00"), - ) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - iat := time.Now() - - vcNonce := token.Extra("c_nonce_draft_00").(string) - assert.NotEmpty(t, vcNonce) - expiry := token.Extra("c_nonce_expires_in_draft_00") - assert.NotEmpty(t, expiry) - assert.NoError(t, reg.Persister().IsNonceValid(ctx, token.AccessToken, vcNonce)) - - t.Run("followup=successfully create a verifiable credential", func(t *testing.T) { - t.Parallel() - - for _, alg := range supportedCryptoSuites { - alg := alg - t.Run(fmt.Sprintf("alg=%s", alg), func(t *testing.T) { - t.Parallel() - assertCreateVerifiableCredential(t, reg, vcNonce, token, jose.SignatureAlgorithm(alg)) + t.Run("followup=access token and id token are valid", func(t *testing.T) { + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["openid","userinfo_credential_draft_00"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) }) } - }) - t.Run("followup=get new nonce from priming request", func(t *testing.T) { - t.Parallel() - // Assert that we can fetch a verifiable credential with the nonce. - res, err := doPrimingRequest(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ - Format: "jwt_vc_json", - Types: []string{"VerifiableCredential", "UserInfoCredential"}, + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") }) - assert.NoError(t, err) - t.Run("followup=successfully create a verifiable credential from fresh nonce", func(t *testing.T) { - assertCreateVerifiableCredential(t, reg, res.Nonce, token, jose.ES256) + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") }) }) - t.Run("followup=rejects proof signed by another key", func(t *testing.T) { - t.Parallel() - for _, tc := range []struct { - name string - format string - proofType string - proof func() string - }{ - { - name: "proof=mismatching keys", - proof: func() string { - // Create mismatching public and private keys. - pubKey, _, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - _, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) - }, + t.Run("suite=invalid query params", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + otherClient, _ := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + withWrongClientAfterLogin := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("login_verifier") { + return nil + } + q.Set("client_id", otherClient.GetID()) + req.URL.RawQuery = q.Encode() + return nil }, - { - name: "proof=invalid format", - format: "invalid_format", - proof: func() string { - // Create mismatching public and private keys. - pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) - }, + } + withWrongClientAfterConsent := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("consent_verifier") { + return nil + } + q.Set("client_id", otherClient.GetID()) + req.URL.RawQuery = q.Encode() + return nil }, - { - name: "proof=invalid type", - proofType: "invalid", - proof: func() string { - // Create mismatching public and private keys. - pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, vcNonce) - }, + } + + withWrongScopeAfterLogin := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("login_verifier") { + return nil + } + q.Set("scope", "invalid scope") + req.URL.RawQuery = q.Encode() + return nil }, - { - name: "proof=invalid nonce", - proof: func() string { - // Create mismatching public and private keys. - pubKey, privKey, err := josex.NewSigningKey(jose.ES256, 0) - require.NoError(t, err) - pubKeyJWK := &jose.JSONWebKey{Key: pubKey, Algorithm: string(jose.ES256)} - return createVCProofJWT(t, pubKeyJWK, privKey, "invalid nonce") - }, + } + + withWrongScopeAfterConsent := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("consent_verifier") { + return nil + } + q.Set("scope", "invalid scope") + req.URL.RawQuery = q.Encode() + return nil }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - _, res := createVerifiableCredential(t, reg, token, &hydraoauth2.CreateVerifiableCredentialRequestBody{ - Format: stringsx.Coalesce(tc.format, "jwt_vc_json"), - Types: []string{"VerifiableCredential", "UserInfoCredential"}, - Proof: &hydraoauth2.VerifiableCredentialProof{ - ProofType: stringsx.Coalesce(tc.proofType, "jwt"), - JWT: tc.proof(), - }, - }) + } + for _, tc := range []struct { + name string + client *http.Client + expectedResponse string + }{{ + name: "fails with wrong client ID after login", + client: withWrongClientAfterLogin, + expectedResponse: "invalid_client", + }, { + name: "fails with wrong client ID after consent", + client: withWrongClientAfterConsent, + expectedResponse: "invalid_client", + }, { + name: "fails with wrong scopes after login", + client: withWrongScopeAfterLogin, + expectedResponse: "invalid_scope", + }, { + name: "fails with wrong scopes after consent", + client: withWrongScopeAfterConsent, + expectedResponse: "invalid_scope", + }} { + t.Run("case="+tc.name, func(t *testing.T) { + state := uuid.New() + resp, err := tc.client.Get(conf.AuthCodeURL(state)) require.NoError(t, err) - require.NotNil(t, res) - assert.Equal(t, "invalid_request", res.Error()) + assert.Equal(t, tc.expectedResponse, resp.Request.URL.Query().Get("error"), "%s", resp.Request.URL.RawQuery) + resp.Body.Close() }) } - }) - t.Run("followup=access token and id token are valid", func(t *testing.T) { - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["openid","userinfo_credential_draft_00"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + t.Run("case=checks if request fails when subject is empty", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), func(w http.ResponseWriter, r *http.Request) { + _, res, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(ctx). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(hydra.AcceptOAuth2LoginRequest{Subject: "", Remember: pointerx.Ptr(true)}).Execute() + require.Error(t, err) // expects 400 + body := string(ioutilx.MustReadAll(res.Body)) + assert.Contains(t, body, "Field 'subject' must not be empty", "%s", body) + }, testhelpers.HTTPServerNoExpectedCallHandler(t)) + _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + + _, err := testhelpers.NewEmptyJarClient(t).Get(conf.AuthCodeURL(uuid.New())) + require.NoError(t, err) }) - } - t.Run("strategy=jwt", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt") - }) + t.Run("case=perform flow with prompt=registration", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - t.Run("strategy=opaque", func(t *testing.T) { - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque") - }) - }) + regUI := httptest.NewServer(acceptLoginHandler(t, c, adminClient, reg, subject, nil)) + t.Cleanup(regUI.Close) + reg.Config().MustSet(ctx, config.KeyRegistrationURL, regUI.URL) - t.Run("suite=invalid query params", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - otherClient, _ := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - withWrongClientAfterLogin := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("login_verifier") { - return nil - } - q.Set("client_id", otherClient.GetID()) - req.URL.RawQuery = q.Encode() - return nil - }, - } - withWrongClientAfterConsent := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("consent_verifier") { - return nil - } - q.Set("client_id", otherClient.GetID()) - req.URL.RawQuery = q.Encode() - return nil - }, - } + testhelpers.NewLoginConsentUI(t, reg.Config(), + nil, + acceptConsentHandler(t, c, adminClient, reg, subject, nil)) - withWrongScopeAfterLogin := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("login_verifier") { - return nil - } - q.Set("scope", "invalid scope") - req.URL.RawQuery = q.Encode() - return nil - }, - } + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("prompt", "registration"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - withWrongScopeAfterConsent := &http.Client{ - Jar: testhelpers.NewEmptyCookieJar(t), - CheckRedirect: func(req *http.Request, _ []*http.Request) error { - if req.URL.Path != "/oauth2/auth" { - return nil - } - q := req.URL.Query() - if !q.Has("consent_verifier") { - return nil - } - q.Set("scope", "invalid scope") - req.URL.RawQuery = q.Encode() - return nil - }, - } - for _, tc := range []struct { - name string - client *http.Client - expectedResponse string - }{{ - name: "fails with wrong client ID after login", - client: withWrongClientAfterLogin, - expectedResponse: "invalid_client", - }, { - name: "fails with wrong client ID after consent", - client: withWrongClientAfterConsent, - expectedResponse: "invalid_client", - }, { - name: "fails with wrong scopes after login", - client: withWrongScopeAfterLogin, - expectedResponse: "invalid_scope", - }, { - name: "fails with wrong scopes after consent", - client: withWrongScopeAfterConsent, - expectedResponse: "invalid_scope", - }} { - t.Run("case="+tc.name, func(t *testing.T) { - state := uuid.New() - resp, err := tc.client.Get(conf.AuthCodeURL(state)) + token, err := conf.Exchange(context.Background(), code) require.NoError(t, err) - assert.Equal(t, tc.expectedResponse, resp.Request.URL.Query().Get("error"), "%s", resp.Request.URL.RawQuery) - resp.Body.Close() + + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) }) - } - }) - t.Run("case=checks if request fails when subject is empty", func(t *testing.T) { - testhelpers.NewLoginConsentUI(t, reg.Config(), func(w http.ResponseWriter, r *http.Request) { - _, res, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(ctx). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(hydra.AcceptOAuth2LoginRequest{Subject: "", Remember: pointerx.Ptr(true)}).Execute() - require.Error(t, err) // expects 400 - body := string(ioutilx.MustReadAll(res.Body)) - assert.Contains(t, body, "Field 'subject' must not be empty", "%s", body) - }, testhelpers.HTTPServerNoExpectedCallHandler(t)) - _, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - - _, err := testhelpers.NewEmptyJarClient(t).Get(conf.AuthCodeURL(uuid.New())) - require.NoError(t, err) - }) + t.Run("case=perform flow with audience", func(t *testing.T) { + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) - t.Run("case=perform flow with prompt=registration", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - regUI := httptest.NewServer(acceptLoginHandler(t, c, subject, nil)) - t.Cleanup(regUI.Close) - reg.Config().MustSet(ctx, config.KeyRegistrationURL, regUI.URL) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - testhelpers.NewLoginConsentUI(t, reg.Config(), - nil, - acceptConsentHandler(t, c, subject, nil)) + claims := introspectAccessToken(t, conf, token, subject) + aud := claims.Get("aud").Array() + require.Len(t, aud, 1) + assert.EqualValues(t, aud[0].String(), expectAud) - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("prompt", "registration"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + }) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + t.Run("case=respects client token lifespan configuration", func(t *testing.T) { + run := func(t *testing.T, strategy string, c *client.Client, conf *oauth2.Config, expectedLifespans client.Lifespans) { + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) - assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - }) + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) - t.Run("case=perform flow with audience", func(t *testing.T) { - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + body := introspectAccessToken(t, conf, token, subject) + requirex.EqualTime(t, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(expectedLifespans.AuthorizationCodeGrantIDTokenLifespan.Duration)) + assertRefreshToken(t, token, conf, iat.Add(expectedLifespans.AuthorizationCodeGrantRefreshTokenLifespan.Duration)) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + iat = time.Now() + require.NoError(t, err) + assertRefreshToken(t, refreshedToken, conf, iat.Add(expectedLifespans.RefreshTokenGrantRefreshTokenLifespan.Duration)) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(expectedLifespans.RefreshTokenGrantIDTokenLifespan.Duration)) - claims := introspectAccessToken(t, conf, token, subject) - aud := claims.Get("aud").Array() - require.Len(t, aud, 1) - assert.EqualValues(t, aud[0].String(), expectAud) + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - }) + body := introspectAccessToken(t, conf, refreshedToken, subject) + requirex.EqualTime(t, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) - t.Run("case=respects client token lifespan configuration", func(t *testing.T) { - run := func(t *testing.T, strategy string, c *client.Client, conf *oauth2.Config, expectedLifespans client.Lifespans) { - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - iat := time.Now() - require.NoError(t, err) - - body := introspectAccessToken(t, conf, token, subject) - requirex.EqualTime(t, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) - - assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(expectedLifespans.AuthorizationCodeGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) - assertIDToken(t, token, conf, subject, nonce, iat.Add(expectedLifespans.AuthorizationCodeGrantIDTokenLifespan.Duration)) - assertRefreshToken(t, token, conf, iat.Add(expectedLifespans.AuthorizationCodeGrantRefreshTokenLifespan.Duration)) - - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) - refreshedToken, err := conf.TokenSource(context.Background(), token).Token() - iat = time.Now() - require.NoError(t, err) - assertRefreshToken(t, refreshedToken, conf, iat.Add(expectedLifespans.RefreshTokenGrantRefreshTokenLifespan.Duration)) - assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), `["hydra","offline","openid"]`) - assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(expectedLifespans.RefreshTokenGrantIDTokenLifespan.Duration)) + t.Run("followup=original access token is no longer valid", func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) - require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + }) + }) + } - body := introspectAccessToken(t, conf, refreshedToken, subject) - requirex.EqualTime(t, iat.Add(expectedLifespans.RefreshTokenGrantAccessTokenLifespan.Duration), time.Unix(body.Get("exp").Int(), 0), time.Second) + t.Run("case=custom-lifespans-active-jwt", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + ls := testhelpers.TestLifespans + ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} + testhelpers.UpdateClientTokenLifespans( + t, + &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, + c.GetID(), + ls, adminTS, + ) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt", c, conf, ls) + }) - t.Run("followup=original access token is no longer valid", func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + t.Run("case=custom-lifespans-active-opaque", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + ls := testhelpers.TestLifespans + ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} + testhelpers.UpdateClientTokenLifespans( + t, + &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, + c.GetID(), + ls, adminTS, + ) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque", c, conf, ls) }) - t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + t.Run("case=custom-lifespans-unset", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), testhelpers.TestLifespans, adminTS) + testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), client.Lifespans{}, adminTS) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + + //goland:noinspection GoDeprecation + expectedLifespans := client.Lifespans{ + AuthorizationCodeGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + AuthorizationCodeGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, + AuthorizationCodeGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, + ClientCredentialsGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + ImplicitGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + ImplicitGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, + JwtBearerGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + PasswordGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + PasswordGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, + RefreshTokenGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, + RefreshTokenGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, + RefreshTokenGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, + } + run(t, "opaque", c, conf, expectedLifespans) }) }) - } - t.Run("case=custom-lifespans-active-jwt", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - ls := testhelpers.TestLifespans - ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} - testhelpers.UpdateClientTokenLifespans( - t, - &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, - c.GetID(), - ls, adminTS, - ) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") - run(t, "jwt", c, conf, ls) - }) + t.Run("case=use remember feature and prompt=none", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) - t.Run("case=custom-lifespans-active-opaque", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - ls := testhelpers.TestLifespans - ls.AuthorizationCodeGrantAccessTokenLifespan = x.NullDuration{Valid: true, Duration: 6 * time.Second} - testhelpers.UpdateClientTokenLifespans( - t, - &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, - c.GetID(), - ls, adminTS, - ) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") - run(t, "opaque", c, conf, ls) - }) + oc := testhelpers.NewEmptyJarClient(t) + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "login consent"), + oauth2.SetAuthURLParam("max_age", "1"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + introspectAccessToken(t, conf, token, subject) - t.Run("case=custom-lifespans-unset", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), testhelpers.TestLifespans, adminTS) - testhelpers.UpdateClientTokenLifespans(t, &oauth2.Config{ClientID: c.GetID(), ClientSecret: conf.ClientSecret}, c.GetID(), client.Lifespans{}, adminTS) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + // Reset UI to check for skip values + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + require.True(t, r.Skip) + require.EqualValues(t, subject, r.Subject) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + require.True(t, *r.Skip) + require.EqualValues(t, subject, *r.Subject) + }), + ) - //goland:noinspection GoDeprecation - expectedLifespans := client.Lifespans{ - AuthorizationCodeGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - AuthorizationCodeGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, - AuthorizationCodeGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, - ClientCredentialsGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - ImplicitGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - ImplicitGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, - JwtBearerGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - PasswordGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - PasswordGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, - RefreshTokenGrantIDTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetIDTokenLifespan(ctx)}, - RefreshTokenGrantAccessTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetAccessTokenLifespan(ctx)}, - RefreshTokenGrantRefreshTokenLifespan: x.NullDuration{Valid: true, Duration: reg.Config().GetRefreshTokenLifespan(ctx)}, - } - run(t, "opaque", c, conf, expectedLifespans) - }) - }) + t.Run("followup=checks if authenticatedAt/requestedAt is properly forwarded across the lifecycle by checking if prompt=none works", func(t *testing.T) { + // In order to check if authenticatedAt/requestedAt works, we'll sleep first in order to ensure that authenticatedAt is in the past + // if handled correctly. + time.Sleep(time.Second + time.Nanosecond) - t.Run("case=use remember feature and prompt=none", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - oc := testhelpers.NewEmptyJarClient(t) - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("prompt", "login consent"), - oauth2.SetAuthURLParam("max_age", "1"), - ) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - introspectAccessToken(t, conf, token, subject) - - // Reset UI to check for skip values - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - require.True(t, r.Skip) - require.EqualValues(t, subject, r.Subject) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - require.True(t, *r.Skip) - require.EqualValues(t, subject, *r.Subject) - }), - ) - - t.Run("followup=checks if authenticatedAt/requestedAt is properly forwarded across the lifecycle by checking if prompt=none works", func(t *testing.T) { - // In order to check if authenticatedAt/requestedAt works, we'll sleep first in order to ensure that authenticatedAt is in the past - // if handled correctly. - time.Sleep(time.Second + time.Nanosecond) - - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("prompt", "none"), - oauth2.SetAuthURLParam("max_age", "60"), - ) - require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - original := introspectAccessToken(t, conf, token, subject) - - t.Run("followup=run the flow three more times", func(t *testing.T) { - for i := 0; i < 3; i++ { - t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "none"), + oauth2.SetAuthURLParam("max_age", "60"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + original := introspectAccessToken(t, conf, token, subject) + + t.Run("followup=run the flow three more times", func(t *testing.T) { + for i := 0; i < 3; i++ { + t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "none"), + oauth2.SetAuthURLParam("max_age", "60"), + ) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + followup := introspectAccessToken(t, conf, token, subject) + assert.Equal(t, original.Get("auth_time").Int(), followup.Get("auth_time").Int()) + }) + } + }) + + t.Run("followup=fails when max age is reached and prompt is none", func(t *testing.T) { code, _ := getAuthorizeCode(t, conf, oc, oauth2.SetAuthURLParam("nonce", nonce), oauth2.SetAuthURLParam("prompt", "none"), - oauth2.SetAuthURLParam("max_age", "60"), + oauth2.SetAuthURLParam("max_age", "1"), + ) + require.Empty(t, code) + }) + + t.Run("followup=passes and resets skip when prompt=login", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + require.False(t, r.Skip) + require.Empty(t, r.Subject) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + require.True(t, *r.Skip) + require.EqualValues(t, subject, *r.Subject) + }), + ) + code, _ := getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("prompt", "login"), + oauth2.SetAuthURLParam("max_age", "1"), ) require.NotEmpty(t, code) token, err := conf.Exchange(context.Background(), code) require.NoError(t, err) - followup := introspectAccessToken(t, conf, token, subject) - assert.Equal(t, original.Get("auth_time").Int(), followup.Get("auth_time").Int()) + introspectAccessToken(t, conf, token, subject) + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) }) - } + }) }) - t.Run("followup=fails when max age is reached and prompt is none", func(t *testing.T) { + t.Run("case=should fail if prompt=none but no auth session given", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + oc := testhelpers.NewEmptyJarClient(t) code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), oauth2.SetAuthURLParam("prompt", "none"), - oauth2.SetAuthURLParam("max_age", "1"), ) require.Empty(t, code) }) - t.Run("followup=passes and resets skip when prompt=login", func(t *testing.T) { + t.Run("case=requires re-authentication when id_token_hint is set to a user 'patrik-neu' but the session is 'aeneas-rekkas' and then fails because the user id from the log in endpoint is 'aeneas-rekkas'", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.False(t, r.Skip) require.Empty(t, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - require.True(t, *r.Skip) - require.EqualValues(t, subject, *r.Subject) + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + oc := testhelpers.NewEmptyJarClient(t) + + // Create login session for aeneas-rekkas + code, _ := getAuthorizeCode(t, conf, oc) + require.NotEmpty(t, code) + + // Perform authentication for aeneas-rekkas which fails because id_token_hint is patrik-neu + code, _ = getAuthorizeCode(t, conf, oc, + oauth2.SetAuthURLParam("id_token_hint", testhelpers.NewIDToken(t, reg, "patrik-neu")), + ) + require.Empty(t, code) + }) + + t.Run("case=should not cause issues if max_age is very low and consent takes a long time", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + time.Sleep(time.Second * 2) + return nil }), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("nonce", nonce), - oauth2.SetAuthURLParam("prompt", "login"), - oauth2.SetAuthURLParam("max_age", "1"), + + code, _ := getAuthorizeCode(t, conf, nil) + require.NotEmpty(t, code) + }) + + t.Run("case=ensure consistent claims returned for userinfo", func(t *testing.T) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) + + code, _ := getAuthorizeCode(t, conf, nil) require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) require.NoError(t, err) - introspectAccessToken(t, conf, token, subject) - assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + + idClaims := assertIDToken(t, token, conf, subject, "", time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + + uiClaims := testhelpers.Userinfo(t, token, publicTS) + + for _, f := range []string{ + "sub", + "iss", + "aud", + "bar", + "auth_time", + } { + assert.NotEmpty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) + assert.EqualValues(t, idClaims.Get(f).Raw, uiClaims.Get(f).Raw, "%s\nuserinfo: %s\nidtoken: %s", f, uiClaims, idClaims) + } + + for _, f := range []string{ + "at_hash", + "c_hash", + "nonce", + "sid", + "jti", + } { + assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) + } }) - }) - }) - t.Run("case=should fail if prompt=none but no auth session given", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) - - oc := testhelpers.NewEmptyJarClient(t) - code, _ := getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("prompt", "none"), - ) - require.Empty(t, code) - }) + t.Run("case=add ext claims from hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value") + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, map[string]interface{}{"foo": "bar"}, hookReq.Session.Extra) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, []string{}, hookReq.Request.GrantedAudience) + require.Equal(t, map[string][]string{"grant_type": {"authorization_code"}}, hookReq.Request.Payload) + + claims := map[string]interface{}{ + "hooked": true, + } - t.Run("case=requires re-authentication when id_token_hint is set to a user 'patrik-neu' but the session is 'aeneas-rekkas' and then fails because the user id from the log in endpoint is 'aeneas-rekkas'", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - require.False(t, r.Skip) - require.Empty(t, r.Subject) - return nil - }), - acceptConsentHandler(t, c, subject, nil), - ) - - oc := testhelpers.NewEmptyJarClient(t) - - // Create login session for aeneas-rekkas - code, _ := getAuthorizeCode(t, conf, oc) - require.NotEmpty(t, code) - - // Perform authentication for aeneas-rekkas which fails because id_token_hint is patrik-neu - code, _ = getAuthorizeCode(t, conf, oc, - oauth2.SetAuthURLParam("id_token_hint", testhelpers.NewIDToken(t, reg, "patrik-neu")), - ) - require.Empty(t, code) - }) + hookResp := hydraoauth2.TokenHookResponse{ + Session: flow.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } - t.Run("case=should not cause issues if max_age is very low and consent takes a long time", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - time.Sleep(time.Second * 2) - return nil - }), - acceptConsentHandler(t, c, subject, nil), - ) - - code, _ := getAuthorizeCode(t, conf, nil) - require.NotEmpty(t, code) - }) + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{ + URL: hs.URL, + Auth: &config.Auth{ + Type: "api_key", + Config: config.AuthConfig{ + In: "header", + Name: "Authorization", + Value: "Bearer secret value", + }, + }, + }) - t.Run("case=ensure consistent claims returned for userinfo", func(t *testing.T) { - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), - ) + t.Cleanup(func() { + reg.Config().Delete(ctx, config.KeyTokenHook) + }) - code, _ := getAuthorizeCode(t, conf, nil) - require.NotEmpty(t, code) + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - idClaims := assertIDToken(t, token, conf, subject, "", time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - time.Sleep(time.Second) - uiClaims := testhelpers.Userinfo(t, token, publicTS) + // NOTE: using introspect to cover both jwt and opaque strategies + accessTokenClaims := introspectAccessToken(t, conf, token, subject) + require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) - for _, f := range []string{ - "sub", - "iss", - "aud", - "bar", - "auth_time", - } { - assert.NotEmpty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) - assert.EqualValues(t, idClaims.Get(f).Raw, uiClaims.Get(f).Raw, "%s\nuserinfo: %s\nidtoken: %s", f, uiClaims, idClaims) - } + idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + require.True(t, idTokenClaims.Get("hooked").Bool()) + } + } - for _, f := range []string{ - "at_hash", - "c_hash", - "nonce", - "sid", - "jti", - } { - assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) - } - }) + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook fails", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) - t.Run("case=add ext claims from hook if configured", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value") - - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, map[string]interface{}{"foo": "bar"}, hookReq.Session.Extra) - require.NotEmpty(t, hookReq.Request) - require.ElementsMatch(t, []string{}, hookReq.Request.GrantedAudience) - require.Equal(t, map[string][]string{"grant_type": {"authorization_code"}}, hookReq.Request.Payload) - - claims := map[string]interface{}{ - "hooked": true, + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) } + } - hookResp := hydraoauth2.TokenHookResponse{ - Session: flow.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) } + } - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{ - URL: hs.URL, - Auth: &config.Auth{ - Type: "api_key", - Config: config.AuthConfig{ - In: "header", - Name: "Authorization", - Value: "Bearer secret value", - }, - }, + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=graceful token rotation", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "2s") + reg.Config().Delete(ctx, config.KeyTokenHook) + reg.Config().Delete(ctx, config.KeyRefreshTokenHook) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, "1m") + t.Cleanup(func() { + reg.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) + reg.Config().Delete(ctx, config.KeyRefreshTokenLifespan) + reg.Config().Delete(ctx, config.KeyAccessTokenLifespan) }) - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + // This is an essential and complex test suite. We need to cover the following cases: + // + // * Graceful refresh token rotation invalidates the previous access token. + // * An expired refresh token cannot be used even if grace period is active. + // * A used refresh token cannot be re-used once the grace period ends, and it triggers re-use detection. + // * A test suite with a variety of concurrent refresh token chains. + run := func(t *testing.T, strategy string) { + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), + ) + + issueTokens := func(t *testing.T) *oauth2.Token { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return token + } - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + iat := time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) - assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + introspectAccessToken(t, conf, refreshedToken, subject) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return refreshedToken + } - // NOTE: using introspect to cover both jwt and opaque strategies - accessTokenClaims := introspectAccessToken(t, conf, token, subject) - require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) + assertInactive := func(t *testing.T, token string, c *oauth2.Config) { + t.Helper() + at := testhelpers.IntrospectToken(t, conf, token, adminTS) + assert.False(t, at.Get("active").Bool(), "%s", at) + } - idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - require.True(t, idTokenClaims.Get("hooked").Bool()) - } - } + t.Run("gracefully refreshing a token does invalidate the previous access token", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "2s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + token := issueTokens(t) + _ = refreshTokens(t, token) - t.Run("case=fail token exchange if hook fails", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() + assertInactive(t, token.AccessToken, conf) // Original access token is invalid - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + _ = refreshTokens(t, token) + assertInactive(t, token.AccessToken, conf) // Original access token is still invalid + }) - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + t.Run("an expired refresh token can not be used even if we are in the grace period", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1s") - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + token := issueTokens(t) + time.Sleep(time.Second * 2) // Let token expire - we need 2 seconds to reliably be longer than TTL - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err, "Rotating an expired token is not possible even when we are in the grace period") - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } + // The access token is still valid because using an expired refresh token has no effect on the access token. + assertInactive(t, token.RefreshToken, conf) + }) - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + t.Run("a used refresh token can not be re-used once the grace period ends and it triggers re-use detection", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") - t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() + token := issueTokens(t) + refreshed := refreshTokens(t, token) - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + time.Sleep(time.Second + time.Millisecond*100) // Wait for the grace period to end - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err, "Rotating a used refresh token is not possible after the grace period") - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + assertInactive(t, refreshed.AccessToken, conf) + assertInactive(t, refreshed.RefreshToken, conf) + }) - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } + // This test suite covers complex scenarios where we have multiple generations of tokens and we need to ensure + // that key security mitigations are in place: + // + // - Token re-use detection clears all tokens if a refresh token is re-used after the grace period. + // - Revoking consent clears all tokens. + // - Token revokation clears all tokens. + // + // The test creates 4 token generations, where each generations has twice as many tokens as the previous generation. + // The generations are created like this: + // + // - In the first scenario, all token generations are created at the same time. + // - In the second scenario, we create token generations with a delay that is longer than the grace period between them. + // + // Tokens for each generation are created in parallel to ensure we have no state leak anywhere.0 + t.Run("token generations", func(t *testing.T) { + + gracePeriod := time.Second + aboveGracePeriod := time.Second + time.Millisecond*100 + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, gracePeriod.String()) + reg.Config().Delete(ctx, config.KeyTokenHook) + reg.Config().Delete(ctx, config.KeyRefreshTokenHook) + + createTokenGenerations := func(t *testing.T, count int, withSleep time.Duration) [][]*oauth2.Token { + generations := make([][]*oauth2.Token, count) + generations[0] = []*oauth2.Token{issueTokens(t)} + // Start from the first generation. For every next generation, we refresh all the tokens of the previous generation twice. + for i := 1; i < len(generations); i++ { + generations[i] = make([]*oauth2.Token, 0, len(generations[i-1])*2) + + var wg sync.WaitGroup + gen := func(i int, token *oauth2.Token) { + defer wg.Done() + generations[i] = append(generations[i], refreshTokens(t, token)) + } - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + for _, token := range generations[i-1] { + wg.Add(2) + if dbName == "memory" { + // SQLite can not handle concurrency + gen(i, token) + gen(i, token) + } else { + go gen(i, token) + go gen(i, token) + } + } - t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() + wg.Wait() + if withSleep > 0 { + time.Sleep(withSleep) + } + } + return generations + } - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Run("re-using an old graceful refresh token invalidates all tokens", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + // This test only works if the refresh token lifespan is longer than the grace period. + generations := createTokenGenerations(t, 4, time.Second+time.Millisecond*100) + + generationIndex := rng.Intn(len(generations) - 1) // Exclude the last generation + tokenIndex := rng.Intn(len(generations[generationIndex])) + + token := generations[generationIndex][tokenIndex] + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err) + + // Now all tokens are inactive + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) - defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + for _, withSleep := range []time.Duration{0, aboveGracePeriod} { + t.Run(fmt.Sprintf("withSleep=%s", withSleep), func(t *testing.T) { + createTokenGenerations := func(t *testing.T, count int) [][]*oauth2.Token { + return createTokenGenerations(t, count, withSleep) + } - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) + t.Run("only the most recent token generation is valid across the board", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + // All generations except the last one are valid. + for i, generation := range generations[:len(generations)-1] { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + }) + } + }) + } - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) + // The last generation is valid: + t.Run(fmt.Sprintf("generation=%d", len(generations)-1), func(t *testing.T) { + for j, token := range generations[len(generations)-1] { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + introspectAccessToken(t, conf, token, subject) + assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, time.Now().Add(reg.Config().GetRefreshTokenLifespan(ctx))) + }) + } + }) + }) + + t.Run("revoking consent revokes all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + // After revoking consent, all generations are invalid. + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + t.Run("re-using the a recent refresh token after the grace period has ended invalidates all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + token := generations[len(generations)-1][0] + + finalToken := refreshTokens(t, token) + time.Sleep(aboveGracePeriod) // Wait for the grace period to end + + token.Expiry = time.Now().Add(-time.Hour * 24) + _, err := conf.TokenSource(ctx, token).Token() + require.Error(t, err) + + // Now all tokens are inactive + for i, generation := range append(generations, []*oauth2.Token{finalToken}) { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + + t.Run("revoking a refresh token in the chain revokes all tokens", func(t *testing.T) { + generations := createTokenGenerations(t, 4) + + testhelpers.RevokeToken(t, conf, generations[len(generations)-1][0].RefreshToken, publicTS) + + for i, generation := range generations { + t.Run(fmt.Sprintf("generation=%d", i), func(t *testing.T) { + for j, token := range generation { + token := token + t.Run(fmt.Sprintf("token=%d", j), func(t *testing.T) { + assertInactive(t, token.AccessToken, conf) + assertInactive(t, token.RefreshToken, conf) + }) + } + }) + } + }) + }) + } + }) - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } + t.Run("it is possible to refresh tokens concurrently", func(t *testing.T) { + // SQLite can not handle concurrency + if dbName == "memory" { + t.Skip("Skipping test because SQLite can not handle concurrency") + } - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1m") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + + token := issueTokens(t) + + var wg sync.WaitGroup + refresh := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + tt, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + return tt + } + + refreshes := make([]*oauth2.Token, 5) + for k := range refreshes { + wg.Add(1) + go func(k int) { + defer wg.Done() + refreshes[k] = refresh(t, token) + }(k) + } + wg.Wait() + + // All tokens are valid. + for k, actual := range refreshes { + refresh := actual + require.NotEmpty(t, refresh.RefreshToken, "token %d:\ntoken:%+v", k, refresh) + require.NotEmpty(t, refresh.AccessToken, "token %d:\ntoken:%+v", k, refresh) + require.NotEmpty(t, refresh.Extra("id_token"), "token %d:\ntoken:%+v", k, refresh) + + i := testhelpers.IntrospectToken(t, conf, refresh.AccessToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + + i = testhelpers.IntrospectToken(t, conf, refresh.RefreshToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + } + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) + }) + } } func assertCreateVerifiableCredential(t *testing.T, reg driver.Registry, nonce string, accessToken *oauth2.Token, alg jose.SignatureAlgorithm) { @@ -1365,7 +1548,7 @@ func assertCreateVerifiableCredential(t *testing.T, reg driver.Registry, nonce s proofJWT := createVCProofJWT(t, pubKeyJWK, privKey, nonce) // Assert that we can fetch a verifiable credential with the nonce. - verifiableCredential, _ := createVerifiableCredential(t, reg, accessToken, &hydraoauth2.CreateVerifiableCredentialRequestBody{ + verifiableCredential, err := createVerifiableCredential(t, reg, accessToken, &hydraoauth2.CreateVerifiableCredentialRequestBody{ Format: "jwt_vc_json", Types: []string{"VerifiableCredential", "UserInfoCredential"}, Proof: &hydraoauth2.VerifiableCredentialProof{ @@ -1414,7 +1597,7 @@ func createVerifiableCredential( reg driver.Registry, token *oauth2.Token, createVerifiableCredentialReq *hydraoauth2.CreateVerifiableCredentialRequestBody, -) (vcRes *hydraoauth2.VerifiableCredentialResponse, vcErr *fosite.RFC6749Error) { +) (vcRes *hydraoauth2.VerifiableCredentialResponse, vcErr error) { var ( ctx = context.Background() body bytes.Buffer @@ -1486,18 +1669,18 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { ctx := context.Background() for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { t.Run("strategy="+strat.d, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2) conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + testhelpers.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) consentStrategy := &consentMock{} router := x.NewRouterPublic() ts := httptest.NewServer(router) - defer ts.Close() + t.Cleanup(ts.Close) reg.WithConsentStrategy(consentStrategy) handler := reg.OAuth2Handler() @@ -1511,7 +1694,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) var mutex sync.Mutex - require.NoError(t, reg.ClientManager().CreateClient(context.TODO(), &client.Client{ + require.NoError(t, reg.ClientManager().CreateClient(ctx, &client.Client{ ID: "app-client", Secret: "secret", RedirectURIs: []string{ts.URL + "/callback"}, @@ -1874,6 +2057,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) defer conf.MustSet(ctx, config.KeyTokenHook, nil) @@ -2033,13 +2217,13 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { t.Run("refreshing old token should no longer work", func(t *testing.T) { res, err := testRefresh(t, token, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + assert.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + assert.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("duplicate code exchange fails", func(t *testing.T) { diff --git a/oauth2/oauth2_client_credentials_bench_test.go b/oauth2/oauth2_client_credentials_bench_test.go index 310727f34cc..560925ffb3e 100644 --- a/oauth2/oauth2_client_credentials_bench_test.go +++ b/oauth2/oauth2_client_credentials_bench_test.go @@ -22,7 +22,6 @@ import ( hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" @@ -36,7 +35,7 @@ func BenchmarkClientCredentials(b *testing.B) { tracer := trace.NewTracerProvider(trace.WithSpanProcessor(spans)).Tracer("") dsn := "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable" - reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) + reg := testhelpers.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") public, admin := testhelpers.NewOAuth2Server(ctx, b, reg) diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 9d5067dafb1..a93ea067716 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -28,14 +28,13 @@ import ( hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" "github.com/ory/x/requirex" ) func TestClientCredentials(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") public, admin := testhelpers.NewOAuth2Server(ctx, t, reg) diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index e9e7ddf9120..0b1a862ba05 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -35,13 +35,12 @@ import ( hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) func TestJWTBearer(t *testing.T) { ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") _, admin := testhelpers.NewOAuth2Server(ctx, t, reg) diff --git a/oauth2/oauth2_refresh_token_test.go b/oauth2/oauth2_refresh_token_test.go index ffabb0dd2a0..018af343104 100644 --- a/oauth2/oauth2_refresh_token_test.go +++ b/oauth2/oauth2_refresh_token_test.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" @@ -22,7 +24,6 @@ import ( "github.com/ory/fosite" hc "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" "github.com/ory/x/dbal" @@ -89,12 +90,12 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { } net := &networkx.Network{} require.NoError(t, dbRegistry.Persister().Connection(context.Background()).First(net)) - dbRegistry.WithContextualizer(&contextx.Static{NID: net.ID, C: internal.NewConfigurationWithDefaults().Source(context.Background())}) + dbRegistry.WithContextualizer(&contextx.Static{NID: net.ID, C: testhelpers.NewConfigurationWithDefaults().Source(context.Background())}) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) t.Cleanup(cancel) require.NoError(t, dbRegistry.OAuth2Storage().(clientCreator).CreateClient(ctx, &testClient)) - require.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, request)) + require.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, "", request)) _, err := dbRegistry.OAuth2Storage().GetRefreshTokenSession(ctx, tokenSignature, nil) require.NoError(t, err) provider := dbRegistry.OAuth2Provider() @@ -250,7 +251,7 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { // reset state for the next test iteration assert.NoError(t, dbRegistry.OAuth2Storage().DeleteRefreshTokenSession(ctx, tokenSignature)) - assert.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, request)) + assert.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, "", request)) } } } diff --git a/oauth2/oauth2_rop_test.go b/oauth2/oauth2_rop_test.go index 4adb4904452..0428e86e7a1 100644 --- a/oauth2/oauth2_rop_test.go +++ b/oauth2/oauth2_rop_test.go @@ -22,7 +22,6 @@ import ( "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/fositex" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/kratos" "github.com/ory/hydra/v2/internal/testhelpers" hydraoauth2 "github.com/ory/hydra/v2/oauth2" @@ -34,7 +33,7 @@ import ( func TestResourceOwnerPasswordGrant(t *testing.T) { ctx := context.Background() fakeKratos := kratos.NewFake() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) reg.WithKratos(fakeKratos) reg.WithExtraFositeFactories([]fositex.Factory{compose.OAuth2ResourceOwnerPasswordCredentialsFactory}) publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) diff --git a/oauth2/revocator_test.go b/oauth2/revocator_test.go index 4ad0be8cac7..32283730fa9 100644 --- a/oauth2/revocator_test.go +++ b/oauth2/revocator_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/gobuffalo/pop/v6" "github.com/ory/x/httprouterx" @@ -60,10 +62,10 @@ func countAccessTokens(t *testing.T, c *pop.Connection) int { } func TestRevoke(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + conf := testhelpers.NewConfigurationWithDefaults() + reg := testhelpers.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(context.Background(), reg, x.OpenIDConnectKeyName) + testhelpers.MustEnsureRegistryKeys(context.Background(), reg, x.OpenIDConnectKeyName) internal.AddFositeExamples(reg) tokens := Tokens(reg.OAuth2ProviderConfig(), 4) diff --git a/oauth2/session_custom_claims_test.go b/oauth2/session_custom_claims_test.go index 5fbe3c5c1a5..5594df88021 100644 --- a/oauth2/session_custom_claims_test.go +++ b/oauth2/session_custom_claims_test.go @@ -7,11 +7,12 @@ import ( "context" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/stretchr/testify/assert" @@ -39,7 +40,7 @@ func createSessionWithCustomClaims(ctx context.Context, p *config.DefaultProvide func TestCustomClaimsInSession(t *testing.T) { ctx := context.Background() - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() t.Run("no_custom_claims", func(t *testing.T) { c.MustSet(ctx, config.KeyAllowedTopLevelClaims, []string{}) diff --git a/oauth2/trust/handler_test.go b/oauth2/trust/handler_test.go index daacc8ed282..e93066eac97 100644 --- a/oauth2/trust/handler_test.go +++ b/oauth2/trust/handler_test.go @@ -15,6 +15,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/go-jose/go-jose/v3" "github.com/tidwall/gjson" @@ -33,7 +35,6 @@ import ( hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/x" ) @@ -50,10 +51,10 @@ type HandlerTestSuite struct { // Setup will run before the tests in the suite are run. func (s *HandlerTestSuite) SetupSuite() { - conf := internal.NewConfigurationWithDefaults() + conf := testhelpers.NewConfigurationWithDefaults() conf.MustSet(context.Background(), config.KeySubjectTypesSupported, []string{"public"}) conf.MustSet(context.Background(), config.KeyDefaultClientScope, []string{"foo", "bar"}) - s.registry = internal.NewRegistryMemory(s.T(), conf, &contextx.Default{}) + s.registry = testhelpers.NewRegistryMemory(s.T(), conf, &contextx.Default{}) router := x.NewRouterAdmin(conf.AdminURL) handler := trust.NewHandler(s.registry) @@ -80,7 +81,7 @@ func (s *HandlerTestSuite) TearDownSuite() { // Will run after each test in the suite. func (s *HandlerTestSuite) TearDownTest() { - internal.CleanAndMigrate(s.registry)(s.T()) + testhelpers.CleanAndMigrate(s.registry)(s.T()) } // In order for 'go test' to run this suite, we need to create diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 8564cfab969..71435d95687 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -13,7 +13,8 @@ import ( "testing" "time" - "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/x/contextx" "github.com/bradleyjkemp/cupaloy/v2" @@ -64,7 +65,7 @@ func TestMigrations(t *testing.T) { connections := make(map[string]*pop.Connection, 1) if testing.Short() { - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) require.NoError(t, reg.Persister().MigrateUp(context.Background())) c := reg.Persister().Connection(context.Background()) connections["sqlite"] = c diff --git a/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql new file mode 100644 index 00000000000..46db0f98db5 --- /dev/null +++ b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN access_token_signature; diff --git a/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql new file mode 100644 index 00000000000..3b389709bc7 --- /dev/null +++ b/persistence/sql/migrations/20241129111700000000_add_refresh_token_access_token_link.autocommit.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD access_token_signature VARCHAR(255) DEFAULT NULL; diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index ba2647393a5..413e40a8eaa 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -33,7 +33,6 @@ var _ persistence.Persister = new(Persister) var _ storage.Transactional = new(Persister) var ( - ErrTransactionOpen = errors.New("There is already a Transaction in this context.") ErrNoTransactionOpen = errors.New("There is no Transaction in this context.") ) diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 5d556d44b4d..93bccdcfe58 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/fosite/handler/openid" "github.com/stretchr/testify/assert" @@ -29,7 +31,6 @@ import ( "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" @@ -57,11 +58,11 @@ var _ interface { func (s *PersisterTestSuite) SetupSuite() { s.registries = map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}), } if !testing.Short() { - s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) + s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = testhelpers.ConnectDatabases(s.T(), true, &contextx.Default{}) } s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) @@ -533,7 +534,7 @@ func (s *PersisterTestSuite) TestCreateRefreshTokenSession() { authorizeCode := uuid.Must(uuid.NewV4()).String() actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, authorizeCode, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, authorizeCode, "", request)) require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, authorizeCode)) require.Equal(t, s.t1NID, actual.NID) }) @@ -727,7 +728,7 @@ func (s *PersisterTestSuite) TestDeleteRefreshTokenSession() { request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} signature := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request)) actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} @@ -933,7 +934,7 @@ func (s *PersisterTestSuite) TestFlushInactiveRefreshTokens() { signature := uuid.Must(uuid.NewV4()).String() require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request)) actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} @@ -1392,7 +1393,7 @@ func (s *PersisterTestSuite) TestGetRefreshTokenSession() { request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} sig := uuid.Must(uuid.NewV4()).String() require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, sig, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, sig, "", request)) actual, err := r.Persister().GetRefreshTokenSession(s.t2, sig, &fosite.DefaultSession{}) require.Error(t, err) @@ -1777,47 +1778,114 @@ func (s *PersisterTestSuite) TestRevokeRefreshToken() { request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} signature := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) - - actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request)) + var actualt2 persistencesql.OAuth2RefreshTable require.NoError(t, r.Persister().RevokeRefreshToken(s.t2, request.ID)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, true, actual.Active) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actualt2, signature)) + require.Equal(t, true, actualt2.Active) + require.NoError(t, r.Persister().RevokeRefreshToken(s.t1, request.ID)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, false, actual.Active) + require.ErrorIs(t, r.Persister().Connection(context.Background()).Find(new(persistencesql.OAuth2RefreshTable), signature), sql.ErrNoRows) }) } } -func (s *PersisterTestSuite) TestRevokeRefreshTokenMaybeGracePeriod() { +func (s *PersisterTestSuite) TestRotateRefreshToken() { t := s.T() for k, r := range s.registries { t.Run(k, func(t *testing.T) { - client := &client.Client{ID: "client-id"} - require.NoError(t, r.Persister().CreateClient(s.t1, client)) + t.Run("with access signature", func(t *testing.T) { + clientID := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateClient(s.t1, &client.Client{ID: clientID})) + require.NoError(t, r.Persister().CreateClient(s.t2, &client.Client{ID: clientID})) - request := fosite.NewRequest() - request.Client = &fosite.DefaultClient{ID: "client-id"} - request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} + request := fosite.NewRequest() + request.Client = &fosite.DefaultClient{ID: clientID} + request.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} - signature := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, request)) + // Create token T1 + signatureT1 := uuid.Must(uuid.NewV4()).String() + accessSignatureT1 := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t1, accessSignatureT1, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signatureT1, accessSignatureT1, request)) - actual := persistencesql.OAuth2RequestSQL{Table: "refresh"} + // Create token T2 + signatureT2 := uuid.Must(uuid.NewV4()).String() + accessSignatureT2 := uuid.Must(uuid.NewV4()).String() + require.ErrorIs(t, r.Persister().RotateRefreshToken(s.t2, request.ID, signatureT2), fosite.ErrNotFound, "Rotation fails as token is non-existent.") + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t2, accessSignatureT2, request)) + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t2, signatureT2, accessSignatureT2, request)) - store, ok := r.Persister().(*persistencesql.Persister) - if !ok { - t.Fatal("type assertion failed") - } + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + assert.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", x.SignatureHash(accessSignatureT2)).First(&accessT2)) + require.Equal(t, true, accessT2.Active) - require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t2, request.ID, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, true, actual.Active) - require.NoError(t, store.RevokeRefreshTokenMaybeGracePeriod(s.t1, request.ID, signature)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, signature)) - require.Equal(t, false, actual.Active) + accessT1 := persistencesql.OAuth2RequestSQL{Table: "access"} + assert.NoError(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignatureT1)).First(&accessT1)) + require.Equal(t, true, accessT2.Active) + + // Rotate token T1 + require.NoError(t, r.Persister().RotateRefreshToken(s.t1, request.ID, signatureT1)) + { + refreshT1 := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t1).Where("signature = ?", signatureT1).First(&refreshT1)) + require.Equal(t, false, refreshT1.Active) + + accessT1 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignatureT1)).First(&accessT1), sql.ErrNoRows) + + refreshT2 := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", signatureT2).First(&refreshT2)) + require.Equal(t, true, refreshT2.Active) + + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", x.SignatureHash(accessSignatureT2)).First(&accessT2)) + require.Equal(t, true, accessT2.Active) + } + + require.NoError(t, r.Persister().RotateRefreshToken(s.t2, request.ID, signatureT2)) + { + refreshT2 := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t2).Where("signature = ?", signatureT2).First(&refreshT2)) + require.Equal(t, false, refreshT2.Active) + + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t2).Where("signature = ?", x.SignatureHash(accessSignatureT2)).First(&accessT2), sql.ErrNoRows) + require.Equal(t, false, accessT2.Active) + } + }) + + t.Run("without access signature", func(t *testing.T) { + clientID := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateClient(s.t1, &client.Client{ID: clientID})) + + request1 := fosite.NewRequest() + request1.Client = &fosite.DefaultClient{ID: clientID} + request1.Session = &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "sub"}} + + signature := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateRefreshTokenSession(s.t1, signature, "", request1)) + + accessSignature1 := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t1, accessSignature1, request1)) + + accessSignature2 := uuid.Must(uuid.NewV4()).String() + require.NoError(t, r.Persister().CreateAccessTokenSession(s.t1, accessSignature2, request1)) + + require.NoError(t, r.Persister().RotateRefreshToken(s.t1, request1.ID, signature)) + { + accessT1 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignature1)).First(&accessT1), sql.ErrNoRows) + + refresh := persistencesql.OAuth2RequestSQL{Table: "refresh"} + require.NoError(t, r.Persister().Connection(s.t1).Where("signature = ?", signature).First(&refresh)) + require.Equal(t, false, refresh.Active) + + accessT2 := persistencesql.OAuth2RequestSQL{Table: "access"} + require.ErrorIs(t, r.Persister().Connection(s.t1).Where("signature = ?", x.SignatureHash(accessSignature2)).First(&accessT2), sql.ErrNoRows) + } + }) }) } } diff --git a/persistence/sql/persister_nonce_test.go b/persistence/sql/persister_nonce_test.go index 1de7eda543a..933af0a9a7a 100644 --- a/persistence/sql/persister_nonce_test.go +++ b/persistence/sql/persister_nonce_test.go @@ -8,18 +8,19 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/fosite" - "github.com/ory/hydra/v2/internal" "github.com/ory/x/contextx" "github.com/ory/x/randx" ) func TestPersister_Nonce(t *testing.T) { ctx := context.Background() - p := internal.NewMockedRegistry(t, new(contextx.Default)).Persister() + p := testhelpers.NewMockedRegistry(t, new(contextx.Default)).Persister() accessToken := randx.MustString(100, randx.AlphaNum) anotherToken := randx.MustString(100, randx.AlphaNum) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index b67b6ae17ed..58eb107306b 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -14,23 +14,22 @@ import ( "strings" "time" - "github.com/ory/hydra/v2/x" - - "github.com/ory/x/sqlxx" - - "go.opentelemetry.io/otel/trace" - "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/ory/fosite" "github.com/ory/fosite/storage" "github.com/ory/hydra/v2/oauth2" + "github.com/ory/hydra/v2/x" "github.com/ory/hydra/v2/x/events" + "github.com/ory/x/dbal" "github.com/ory/x/errorsx" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" "github.com/ory/x/stringsx" ) @@ -60,7 +59,8 @@ type ( } OAuth2RefreshTable struct { OAuth2RequestSQL - FirstUsedAt sql.NullTime `db:"first_used_at"` + FirstUsedAt sql.NullTime `db:"first_used_at"` + AccessTokenSignature sql.NullString `db:"access_token_signature"` } ) @@ -445,41 +445,61 @@ func toEventOptions(requester fosite.Requester) []trace.EventOption { } } -func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { +func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, accessTokenSignature string, requester fosite.Requester) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRefreshTokenSession") defer otelx.End(span, &err) events.Trace(ctx, events.RefreshTokenIssued, toEventOptions(requester)...) - return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) + + req, err := p.sqlSchemaFromRequest(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) + if err != nil { + return err + } + + var sig sql.NullString + if len(accessTokenSignature) > 0 { + sig = sql.NullString{ + Valid: true, + String: x.SignatureHash(accessTokenSignature), + } + } + + if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, &OAuth2RefreshTable{ + OAuth2RequestSQL: *req, + AccessTokenSignature: sig, + })); errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } else if err != nil { + return err + } + + return nil } func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - r := OAuth2RefreshTable{OAuth2RequestSQL: OAuth2RequestSQL{Table: sqlTableRefresh}} - err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) - if errors.Is(err, sql.ErrNoRows) { + var row OAuth2RefreshTable + if err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&row); errors.Is(err, sql.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrNotFound) } else if err != nil { return nil, sqlcon.HandleError(err) } - fositeRequest, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, err - } - - if r.Active { - return fositeRequest, nil + gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx) + if row.Active { + // Token is active + return row.toRequest(ctx, session, p) + } else if gracePeriod > 0 && + row.FirstUsedAt.Valid && + row.FirstUsedAt.Time.Add(gracePeriod).After(time.Now()) { + // We return the request as is, which indicates that the token is active (because we are in the grace period still). + return row.toRequest(ctx, session, p) } - if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid { - if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) { - return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) - } - - r.Active = true // We set active to true because we are in the grace period. - return r.toRequest(ctx, session, p) // And re-generate the request + fositeRequest, err := row.toRequest(ctx, session, p) + if err != nil { + return nil, err } return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) @@ -533,23 +553,7 @@ func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature stri func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshToken") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) -} - -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), - id, - p.NetworkID(ctx), - ). - Exec(), - ) + return p.deleteSessionByRequestID(ctx, id, sqlTableRefresh) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { @@ -612,3 +616,123 @@ func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (er p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), ) } + +func handleRetryError(err error) error { + if err == nil { + return nil + } + + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err +} + +// strictRefreshRotation implements the strict refresh token rotation strategy. In strict rotation, we disable all +// refresh and access tokens associated with a request ID and subsequently create the only valid, new token pair. +func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.strictRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("network_id", p.NetworkID(ctx).String()))) + defer otelx.End(span, &err) + + c := p.Connection(ctx) + + // In strict rotation we only have one token chain for every request. Therefore, we remove all + // access tokens associated with the request ID. + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + + // The same applies to refresh tokens in strict mode. We disable all old refresh tokens when rotating. + count, err := c.RawQuery( + "UPDATE hydra_oauth2_refresh SET active=false WHERE request_id=? AND nid = ? AND active", + requestID, + p.NetworkID(ctx), + ).ExecWithCount() + if err != nil { + return sqlcon.HandleError(err) + } else if count == 0 { + return errorsx.WithStack(fosite.ErrNotFound) + } + + return nil +} + +func (p *Persister) gracefulRefreshRotation(ctx context.Context, requestID string, refreshSignature string, period time.Duration) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("network_id", p.NetworkID(ctx).String()))) + defer otelx.End(span, &err) + + c := p.Connection(ctx) + now := time.Now().UTC().Round(time.Millisecond) + + var accessTokenSignature sql.NullString + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var tokenToRevoke OAuth2RefreshTable + if err := c. + Select("access_token_signature"). + // Filtering by "active" status would break graceful token rotation. We know and trust (with tests) + // that Fosite is dealing with the refresh token reuse detection business logic without + // relying on the active filter her. + Where("signature=? AND nid = ?", refreshSignature, p.NetworkID(ctx)). + First(&tokenToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + if count, err := c.RawQuery( + // Signature is the primary key so no limit needed. We only update first_used_at if it is not set yet (otherwise + // we would "refresh" the grace period again and again, and the refresh token would never "expire"). + `UPDATE hydra_oauth2_refresh SET active=false, first_used_at = COALESCE(first_used_at, ?) WHERE signature=? AND nid = ?`, + now, refreshSignature, p.NetworkID(ctx), + ).ExecWithCount(); err != nil { + return sqlcon.HandleError(err) + } else if count == 0 { + return errorsx.WithStack(fosite.ErrNotFound) + } + + accessTokenSignature = tokenToRevoke.AccessTokenSignature + } else { + var tokenToRevoke OAuth2RefreshTable + if err := c.RawQuery( + // Same query like in the MySQL case, but we can return the access token signature directly. + `UPDATE hydra_oauth2_refresh SET active=false, first_used_at = COALESCE(first_used_at, ?) WHERE signature=? AND nid = ? RETURNING access_token_signature`, + now, refreshSignature, p.NetworkID(ctx), + ).First(&tokenToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + accessTokenSignature = tokenToRevoke.AccessTokenSignature + } + + if !accessTokenSignature.Valid { + // If the access token is not found, we fall back to deleting all access tokens associated with the request ID. + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + return nil + } + + // We have the signature and we will only remove that specific access token as part of the rotation. + return p.deleteSessionBySignature(ctx, accessTokenSignature.String, sqlTableAccess) +} + +func (p *Persister) RotateRefreshToken(ctx context.Context, requestID string, refreshTokenSignature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken") + defer otelx.End(span, &err) + + // If we end up here, we have a valid refresh token and can proceed with the rotation. + gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx) + if gracePeriod > 0 { + return handleRetryError(p.gracefulRefreshRotation(ctx, requestID, refreshTokenSignature, gracePeriod)) + } + + return handleRetryError(p.strictRefreshRotation(ctx, requestID)) +} diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index a4818a3e69d..b4c88ef01c3 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -28,7 +28,6 @@ import ( "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/internal" ) func init() { @@ -120,11 +119,11 @@ func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registr func TestManagersNextGen(t *testing.T) { regs := map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), } if !testing.Short() { - regs["postgres"], regs["mysql"], regs["cockroach"], _ = internal.ConnectDatabases(t, true, &contextx.Default{}) + regs["postgres"], regs["mysql"], regs["cockroach"], _ = testhelpers.ConnectDatabases(t, true, &contextx.Default{}) } ctx := context.Background() @@ -153,16 +152,16 @@ func TestManagersNextGen(t *testing.T) { func TestManagers(t *testing.T) { ctx := context.TODO() t1registries := map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), true, &contextx.Default{}), } t2registries := map[string]driver.Registry{ - "memory": internal.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), false, &contextx.Default{}), + "memory": testhelpers.NewRegistrySQLFromURL(t, dbal.NewSQLiteTestDatabase(t), false, &contextx.Default{}), } if !testing.Short() { - t2registries["postgres"], t2registries["mysql"], t2registries["cockroach"], _ = internal.ConnectDatabases(t, false, &contextx.Default{}) - t1registries["postgres"], t1registries["mysql"], t1registries["cockroach"], _ = internal.ConnectDatabases(t, true, &contextx.Default{}) + t2registries["postgres"], t2registries["mysql"], t2registries["cockroach"], _ = testhelpers.ConnectDatabases(t, false, &contextx.Default{}) + t1registries["postgres"], t1registries["mysql"], t1registries["cockroach"], _ = testhelpers.ConnectDatabases(t, true, &contextx.Default{}) } network1NID, _ := uuid.NewV4() diff --git a/spec/config.json b/spec/config.json index 72f81534c66..effd1cc866d 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1071,9 +1071,9 @@ "refresh_token": { "type": "object", "properties": { - "grace_period": { + "rotation_grace_period": { "title": "Refresh Token Rotation Grace Period", - "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is 5 minutes.", "default": "0s", "allOf": [ { diff --git a/x/oauth2cors/cors_test.go b/x/oauth2cors/cors_test.go index d450fe308ab..dee215eae77 100644 --- a/x/oauth2cors/cors_test.go +++ b/x/oauth2cors/cors_test.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/hydra/v2/driver" "github.com/ory/x/contextx" @@ -24,13 +26,12 @@ import ( "github.com/ory/fosite" "github.com/ory/hydra/v2/client" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" ) func TestOAuth2AwareCORSMiddleware(t *testing.T) { ctx := context.Background() - r := internal.NewRegistryMemory(t, internal.NewConfigurationWithDefaults(), &contextx.Default{}) + r := testhelpers.NewRegistryMemory(t, testhelpers.NewConfigurationWithDefaults(), &contextx.Default{}) token, signature, _ := r.OAuth2HMACStrategy().GenerateAccessToken(ctx, nil) for k, tc := range []struct { @@ -275,7 +276,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - r.WithConfig(internal.NewConfigurationWithDefaults()) + r.WithConfig(testhelpers.NewConfigurationWithDefaults()) if tc.prep != nil { tc.prep(t, r) diff --git a/x/tls_termination_test.go b/x/tls_termination_test.go index bdb5581ce91..0c7be56f549 100644 --- a/x/tls_termination_test.go +++ b/x/tls_termination_test.go @@ -10,10 +10,11 @@ import ( "net/url" "testing" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/stretchr/testify/assert" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" ) @@ -27,8 +28,8 @@ func noopHandler(w http.ResponseWriter, r *http.Request) { } func TestDoesRequestSatisfyTermination(t *testing.T) { - c := internal.NewConfigurationWithDefaultsAndHTTPS() - r := internal.NewRegistryMemory(t, c, &contextx.Default{}) + c := testhelpers.NewConfigurationWithDefaultsAndHTTPS() + r := testhelpers.NewRegistryMemory(t, c, &contextx.Default{}) t.Run("case=tls-termination-disabled", func(t *testing.T) { c.MustSet(context.Background(), config.KeyTLSAllowTerminationFrom, "") @@ -178,7 +179,7 @@ func TestDoesRequestSatisfyTermination(t *testing.T) { // test: in case http is forced request should be accepted t.Run("case=forced-http", func(t *testing.T) { - c := internal.NewConfigurationWithDefaults() + c := testhelpers.NewConfigurationWithDefaults() res := httptest.NewRecorder() RejectInsecureRequests(r, c.TLS(context.Background(), config.PublicInterface))(res, &http.Request{Header: http.Header{}, URL: new(url.URL)}, noopHandler) assert.EqualValues(t, http.StatusNoContent, res.Code)