From 7fc95dedd972f285adbd7167cfaeb61350756b12 Mon Sep 17 00:00:00 2001 From: Dave New Date: Mon, 13 Nov 2023 12:38:16 +0200 Subject: [PATCH] feat: refresh tokens, refresh grant & rotation (#1294) --- cmd/program/commands.go | 4 - go.mod | 1 + go.sum | 2 + migrations/columns.sql | 2 +- migrations/migrations.go | 5 +- runtime/apis/authapi/revoke_endpoint.go | 16 + runtime/apis/authapi/token_endpoint.go | 64 +++- runtime/apis/authapi/token_endpoint_test.go | 295 ++++++++---------- runtime/oauth/authentication.go | 71 ----- runtime/oauth/id_token.go | 71 ++++- ...uthentication_test.go => id_token_test.go} | 0 runtime/oauth/refresh_token.go | 170 ++++++++++ runtime/oauth/refresh_token_test.go | 137 ++++++++ runtime/runtime.go | 3 + runtime/runtime_audit_test.go | 55 +--- runtime/runtime_events_test.go | 17 +- testing/util.go | 47 +++ 17 files changed, 668 insertions(+), 292 deletions(-) create mode 100644 runtime/apis/authapi/revoke_endpoint.go delete mode 100644 runtime/oauth/authentication.go rename runtime/oauth/{authentication_test.go => id_token_test.go} (100%) create mode 100644 runtime/oauth/refresh_token.go create mode 100644 runtime/oauth/refresh_token_test.go create mode 100644 testing/util.go diff --git a/cmd/program/commands.go b/cmd/program/commands.go index d667cfded..99533ba12 100644 --- a/cmd/program/commands.go +++ b/cmd/program/commands.go @@ -321,10 +321,6 @@ func RunMigrations(schema *proto.Schema, database db.Database) tea.Cmd { Changes: m.Changes, } - if !m.HasModelFieldChanges() { - return msg - } - err = m.Apply(context.Background()) if err != nil { msg.Err = &ApplyMigrationsError{ diff --git a/go.mod b/go.mod index e6eea0f1d..8ad4fbb47 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect + github.com/dchest/uniuri v1.2.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/fatih/color v1.13.0 // indirect github.com/felixge/httpsnoop v1.0.3 // indirect diff --git a/go.sum b/go.sum index d8f73299a..a916ea412 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g= +github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= diff --git a/migrations/columns.sql b/migrations/columns.sql index b80fe84e3..f8a51bb5c 100644 --- a/migrations/columns.sql +++ b/migrations/columns.sql @@ -16,7 +16,7 @@ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_catalog.pg_index i on i.indexrelid = a.attrelid WHERE n.nspname = 'public' - AND c.relname not in ('keel_schema', 'pg_stat_statements_info', 'pg_stat_statements') + AND c.relname not in ('keel_schema', 'keel_refresh_token', 'pg_stat_statements_info', 'pg_stat_statements') AND a.attnum > 0 AND NOT a.attisdropped AND i.indexrelid is null; -- no indexes \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 8c8f848cd..94c4ec3e4 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -101,7 +101,7 @@ func (m *Migrations) Apply(ctx context.Context) error { sql.WriteString(setUpdatedAt) sql.WriteString("\n") - sql.WriteString("CREATE TABLE IF NOT EXISTS keel_schema ( schema TEXT NOT NULL );\n") + sql.WriteString("CREATE TABLE IF NOT EXISTS keel_schema (schema TEXT NOT NULL);\n") sql.WriteString("DELETE FROM keel_schema;\n") b, err := protojson.Marshal(m.Schema) @@ -113,6 +113,9 @@ func (m *Migrations) Apply(ctx context.Context) error { sql.WriteString(fmt.Sprintf("INSERT INTO keel_schema (schema) VALUES (%s);", escapedJSON)) sql.WriteString("\n") + sql.WriteString("CREATE TABLE IF NOT EXISTS keel_refresh_token (token TEXT NOT NULL PRIMARY KEY, identity_id TEXT NOT NULL, created_at TIMESTAMP, expires_at TIMESTAMP);\n") + sql.WriteString("\n") + sql.WriteString(fmt.Sprintf("SELECT set_trace_id('%s');\n", span.SpanContext().TraceID().String())) sql.WriteString(m.SQL) diff --git a/runtime/apis/authapi/revoke_endpoint.go b/runtime/apis/authapi/revoke_endpoint.go new file mode 100644 index 000000000..d12a6d9a1 --- /dev/null +++ b/runtime/apis/authapi/revoke_endpoint.go @@ -0,0 +1,16 @@ +package authapi + +import ( + "net/http" + + "github.com/teamkeel/keel/proto" + "github.com/teamkeel/keel/runtime/common" +) + +func RevokeHandler(schema *proto.Schema) common.HandlerFunc { + return func(r *http.Request) common.Response { + return common.Response{ + Status: http.StatusNotImplemented, + } + } +} diff --git a/runtime/apis/authapi/token_endpoint.go b/runtime/apis/authapi/token_endpoint.go index 5ac4d93ba..e60ff8230 100644 --- a/runtime/apis/authapi/token_endpoint.go +++ b/runtime/apis/authapi/token_endpoint.go @@ -21,6 +21,11 @@ const ( ArgSubjectToken = "subject_token" ArgSubjectTokenType = "subject_token_type" ArgRequestedTokeType = "requested_token_type" + ArgRefreshToken = "refresh_token" +) + +const ( + TokenType = "bearer" ) // https://openid.net/specs/openid-connect-standard-1_0-21_orig.html#AccessTokenResponse @@ -90,7 +95,50 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { switch grantType { case GrantTypeRefreshToken: - return common.Response{Status: http.StatusNotImplemented} + if !r.Form.Has(ArgRefreshToken) { + return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ + Error: TokenEndpointInvalidRequest, + ErrorDescription: "the refresh token must be provided in the refresh_token field", + }, nil) + } + + refreshTokenRaw := r.Form.Get(ArgRefreshToken) + + if refreshTokenRaw == "" { + return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ + Error: TokenEndpointInvalidRequest, + ErrorDescription: "the refresh token in the refresh_token field cannot be an empty string", + }, nil) + } + + isValid, newRefreshToken, identityId, err := oauth.RotateRefreshToken(ctx, refreshTokenRaw) + if err != nil { + span.RecordError(err) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } + + if !isValid { + return common.NewJsonResponse(http.StatusUnauthorized, &TokenErrorResponse{ + Error: TokenEndpointInvalidClient, + ErrorDescription: "possible causes may be that the refresh token has been revoked or has expired", + }, nil) + } + + // Generate an access token for this identity. + accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identityId) + if err != nil { + span.RecordError(err) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } + + response := &TokenResponse{ + AccessToken: accessTokenRaw, + TokenType: TokenType, + ExpiresIn: int(expiresIn.Seconds()), + RefreshToken: newRefreshToken, + } + + return common.NewJsonResponse(http.StatusOK, response, nil) case GrantTypeTokenExchange: if !r.Form.Has(ArgSubjectToken) { return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{ @@ -176,10 +224,18 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc { return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) } + // Generate a refresh token. + refreshTokenRaw, err := oauth.NewRefreshToken(ctx, identity.Id) + if err != nil { + span.RecordError(err) + return common.NewJsonResponse(http.StatusInternalServerError, nil, nil) + } + response := &TokenResponse{ - AccessToken: accessTokenRaw, - TokenType: "bearer", - ExpiresIn: int(expiresIn.Seconds()), + AccessToken: accessTokenRaw, + TokenType: TokenType, + ExpiresIn: int(expiresIn.Seconds()), + RefreshToken: refreshTokenRaw, } return common.NewJsonResponse(http.StatusOK, response, nil) diff --git a/runtime/apis/authapi/token_endpoint_test.go b/runtime/apis/authapi/token_endpoint_test.go index 490988b13..8449e8c5f 100644 --- a/runtime/apis/authapi/token_endpoint_test.go +++ b/runtime/apis/authapi/token_endpoint_test.go @@ -9,23 +9,23 @@ import ( "net/http/httptest" "net/url" "testing" + "time" "github.com/stretchr/testify/require" - "github.com/teamkeel/keel/db" "github.com/teamkeel/keel/proto" "github.com/teamkeel/keel/runtime" "github.com/teamkeel/keel/runtime/apis/authapi" + "github.com/teamkeel/keel/runtime/auth" "github.com/teamkeel/keel/runtime/oauth" "github.com/teamkeel/keel/runtime/oauth/oauthtest" "github.com/teamkeel/keel/runtime/runtimectx" - "github.com/teamkeel/keel/schema" - "github.com/teamkeel/keel/testhelpers" + keeltesting "github.com/teamkeel/keel/testing" ) var authTestSchema = `model Post{}` func TestTokenExchange_ValidNewIdentity(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() // Set up auth config @@ -56,6 +56,7 @@ func TestTokenExchange_ValidNewIdentity(t *testing.T) { require.NotEmpty(t, validResponse.AccessToken) require.Equal(t, "bearer", validResponse.TokenType) require.NotEmpty(t, validResponse.ExpiresIn) + require.NotEmpty(t, validResponse.RefreshToken) require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) sub, iss, err := oauth.ValidateAccessToken(ctx, validResponse.AccessToken, "") @@ -84,7 +85,7 @@ func TestTokenExchange_ValidNewIdentity(t *testing.T) { } func TestTokenExchange_ValidNewIdentityAllUserInfo(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() // Set up auth config @@ -159,7 +160,7 @@ func TestTokenExchange_ValidNewIdentityAllUserInfo(t *testing.T) { } func TestTokenExchange_ValidUpdatedIdentity(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() // Set up auth config @@ -222,24 +223,11 @@ func TestTokenExchange_ValidUpdatedIdentity(t *testing.T) { } func TestTokenEndpoint_HttpGet(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) + request := makeTokenExchangeRequest(ctx, "mock_token") request.Method = http.MethodGet @@ -254,25 +242,11 @@ func TestTokenEndpoint_HttpGet(t *testing.T) { } func TestTokenEndpoint_ApplicationJsonRequest(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") request.Header = http.Header{} request.Header.Add("Content-Type", "application/json") @@ -287,27 +261,13 @@ func TestTokenEndpoint_ApplicationJsonRequest(t *testing.T) { } func TestTokenEndpoint_MissingGrantType(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") form := url.Values{} - form.Add("subject_token", idToken) + form.Add("subject_token", "mock_token") request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse @@ -321,28 +281,14 @@ func TestTokenEndpoint_MissingGrantType(t *testing.T) { } func TestTokenEndpoint_WrongGrantType(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") form := url.Values{} form.Add("grant_type", "password") - form.Add("subject_token", idToken) + form.Add("subject_token", "mock_token") request.URL.RawQuery = form.Encode() // Handle runtime request, expecting TokenErrorResponse @@ -356,25 +302,11 @@ func TestTokenEndpoint_WrongGrantType(t *testing.T) { } func TestTokenExchange_NoSubjectToken(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") form := url.Values{} form.Add("grant_type", "token_exchange") request.URL.RawQuery = form.Encode() @@ -390,25 +322,11 @@ func TestTokenExchange_NoSubjectToken(t *testing.T) { } func TestTokenExchange_EmptySubjectToken(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") form := url.Values{} form.Add("grant_type", "token_exchange") form.Add("subject_token", "") @@ -425,28 +343,14 @@ func TestTokenExchange_EmptySubjectToken(t *testing.T) { } func TestTokenExchange_WrongSubjectTokenType(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") form := url.Values{} form.Add("grant_type", "token_exchange") - form.Add("subject_token", idToken) + form.Add("subject_token", "mock_token") form.Add("subject_token_type", "access_token") form.Add("requested_token_type", "access_token") request.URL.RawQuery = form.Encode() @@ -462,28 +366,14 @@ func TestTokenExchange_WrongSubjectTokenType(t *testing.T) { } func TestTokenExchange_WrongRequestedTokenType(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() - // OIDC test server - server, err := oauthtest.NewOIDCServer() - require.NoError(t, err) - - server.SetUser("id|285620", &oauth.UserClaims{ - Email: "keelson@keel.so", - Name: "Keelson", - }) - - // Get ID token from server - idToken, err := server.FetchIdToken("id|285620", []string{}) - require.NoError(t, err) - // Make a token exchange grant request - request := makeTokenExchangeRequest(ctx, idToken) - + request := makeTokenExchangeRequest(ctx, "mock_token") form := url.Values{} form.Add("grant_type", "token_exchange") - form.Add("subject_token", idToken) + form.Add("subject_token", "mock_token") form.Add("subject_token_type", "id_token") form.Add("requested_token_type", "id_token") request.URL.RawQuery = form.Encode() @@ -499,7 +389,7 @@ func TestTokenExchange_WrongRequestedTokenType(t *testing.T) { } func TestTokenExchange_BadIdToken(t *testing.T) { - ctx, database, schema := newContext(t, authTestSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) defer database.Close() // OIDC test server @@ -517,7 +407,6 @@ func TestTokenExchange_BadIdToken(t *testing.T) { // Make a token exchange grant request request := makeTokenExchangeRequest(ctx, idToken) - form := url.Values{} form.Add("grant_type", "token_exchange") form.Add("subject_token", "this is not a jwt token") @@ -533,35 +422,113 @@ func TestTokenExchange_BadIdToken(t *testing.T) { require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } -func newContext(t *testing.T, keelSchema string, resetDatabase bool) (context.Context, db.Database, *proto.Schema) { - dbConnInfo := &db.ConnectionInfo{ - Host: "localhost", - Port: "8001", - Username: "postgres", - Password: "postgres", - Database: "keel", - } +func TestRefreshToken_Valid(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) - builder := &schema.Builder{} - schema, err := builder.MakeFromString(keelSchema) + // OIDC test server + server, err := oauthtest.NewOIDCServer() require.NoError(t, err) - ctx := context.Background() + server.SetUser("id|285620", &oauth.UserClaims{ + Email: "keelson@keel.so", + }) - // Add private key to context - pk, err := testhelpers.GetEmbeddedPrivateKey() + // Get ID token from server + idToken, err := server.FetchIdToken("id|285620", []string{}) require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, pk) - ctx, err = testhelpers.WithTracing(ctx) + // Make a token exchange grant request + request := makeTokenExchangeRequest(ctx, idToken) + + // Handle runtime request, expecting TokenErrorResponse + tokenExchangeResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) require.NoError(t, err) + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + + // We need 1 second to pass in order to get a different access token + time.Sleep(1000 * time.Millisecond) - // Add database to context - database, err := testhelpers.SetupDatabaseForTestCase(ctx, dbConnInfo, schema, "runtime_test", resetDatabase) + // Make a refresh token grant request + request = makeRefreshTokenRequest(ctx, tokenExchangeResponse.RefreshToken) + + // Handle runtime request, expecting TokenErrorResponse + refreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenResponse](schema, request) require.NoError(t, err) - ctx = db.WithDatabase(ctx, database) - return ctx, database, schema + require.Equal(t, http.StatusOK, httpResponse.StatusCode) + require.NotEmpty(t, refreshGrantResponse.AccessToken) + require.Equal(t, "bearer", refreshGrantResponse.TokenType) + require.NotEmpty(t, refreshGrantResponse.ExpiresIn) + require.NotEmpty(t, refreshGrantResponse.RefreshToken) + require.NotEqual(t, refreshGrantResponse.RefreshToken, tokenExchangeResponse.RefreshToken) + require.NotEqual(t, refreshGrantResponse.AccessToken, tokenExchangeResponse.AccessToken) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) + + accessToken1Issuer, err := auth.ExtractClaimFromToken(tokenExchangeResponse.AccessToken, "iss") + require.NoError(t, err) + accessToken2Issuer, err := auth.ExtractClaimFromToken(refreshGrantResponse.AccessToken, "iss") + require.NoError(t, err) + require.NotEmpty(t, accessToken1Issuer) + require.Equal(t, accessToken1Issuer, accessToken2Issuer) + + accessToken1Sub, err := auth.ExtractClaimFromToken(tokenExchangeResponse.AccessToken, "sub") + require.NoError(t, err) + accessToken2Sub, err := auth.ExtractClaimFromToken(refreshGrantResponse.AccessToken, "sub") + require.NoError(t, err) + require.NotEmpty(t, accessToken1Sub) + require.Equal(t, accessToken1Sub, accessToken2Sub) + + // Make a refresh token grant request using the original refresh token + request = makeRefreshTokenRequest(ctx, tokenExchangeResponse.RefreshToken) + + // Handle runtime request, expecting TokenErrorResponse + secondRefreshGrantResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, httpResponse.StatusCode) + require.Equal(t, "possible causes may be that the refresh token has been revoked or has expired", secondRefreshGrantResponse.ErrorDescription) +} + +func TestRefreshToken_NoRefreshToken(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Make a refresh token grant request + request := makeRefreshTokenRequest(ctx, "") + form := url.Values{} + form.Add("grant_type", "refresh_token") + request.URL.RawQuery = form.Encode() + + // Handle runtime request, expecting TokenErrorResponse + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) + require.Equal(t, "invalid_request", errorResponse.Error) + require.Equal(t, "the refresh token must be provided in the refresh_token field", errorResponse.ErrorDescription) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) +} + +func TestRefreshToken_EmptyRefreshToken(t *testing.T) { + ctx, database, schema := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Make a refresh token grant request + request := makeRefreshTokenRequest(ctx, "") + + // Handle runtime request, expecting TokenErrorResponse + errorResponse, httpResponse, err := handleRuntimeRequest[authapi.TokenErrorResponse](schema, request) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, httpResponse.StatusCode) + require.Equal(t, "invalid_request", errorResponse.Error) + require.Equal(t, "the refresh token in the refresh_token field cannot be an empty string", errorResponse.ErrorDescription) + require.True(t, authapi.HasContentType(httpResponse.Header, "application/json")) } func handleRuntimeRequest[T any](schema *proto.Schema, req *http.Request) (T, *http.Response, error) { @@ -595,7 +562,19 @@ func makeTokenExchangeRequest(ctx context.Context, token string) *http.Request { form.Add("subject_token_type", "id_token") form.Add("requested_token_type", "access_token") request.URL.RawQuery = form.Encode() + request = request.WithContext(ctx) + + return request +} +func makeRefreshTokenRequest(ctx context.Context, token string) *http.Request { + request := httptest.NewRequest(http.MethodPost, "http://mykeelapp.keel.so/auth/token", nil) + request.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + form := url.Values{} + form.Add("grant_type", "refresh_token") + form.Add("refresh_token", token) + request.URL.RawQuery = form.Encode() request = request.WithContext(ctx) return request diff --git a/runtime/oauth/authentication.go b/runtime/oauth/authentication.go deleted file mode 100644 index 9b16f1ac5..000000000 --- a/runtime/oauth/authentication.go +++ /dev/null @@ -1,71 +0,0 @@ -package oauth - -import ( - "context" - "errors" - "fmt" - - "github.com/coreos/go-oidc" - "github.com/teamkeel/keel/runtime/auth" - "github.com/teamkeel/keel/runtime/runtimectx" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" -) - -var tracer = otel.Tracer("github.com/teamkeel/keel/runtime") - -// VerifyIdToken will verify the ID token from an OpenID Connect provider. -func VerifyIdToken(ctx context.Context, idTokenRaw string) (*oidc.IDToken, error) { - ctx, span := tracer.Start(ctx, "Verify ID Token") - defer span.End() - - issuer, err := auth.ExtractClaimFromToken(idTokenRaw, "iss") - if err != nil { - return nil, err - } - if issuer == "" { - return nil, errors.New("iss claim cannot be an empty string") - } - span.AddEvent("Issuer extracted from ID Token") - - span.SetAttributes(attribute.String("issuer", issuer)) - - authConfig, err := runtimectx.GetAuthConfig(ctx) - if err != nil { - return nil, err - } - - issuerPermitted := authConfig.AllowAnyIssuers - if !issuerPermitted { - for _, e := range authConfig.Issuers { - if e.Iss == issuer { - issuerPermitted = true - } - } - } - - if !issuerPermitted { - return nil, fmt.Errorf("issuer %s not registered to authenticate on this server", issuer) - } - - // Establishes new OIDC provider. This will call the providers discovery endpoint - provider, err := oidc.NewProvider(ctx, issuer) - if err != nil { - return nil, err - } - span.AddEvent("Provider's ODIC config fetched") - - // TODO: Enable this check once we have the client ID as configurable - verifier := provider.Verifier(&oidc.Config{ - SkipClientIDCheck: true, - }) - - // Verify that the ID token legitimately was signed by the provider and that it has not expired - idToken, err := verifier.Verify(ctx, idTokenRaw) - if err != nil { - return nil, err - } - span.AddEvent("ID Token verified") - - return idToken, nil -} diff --git a/runtime/oauth/id_token.go b/runtime/oauth/id_token.go index 1330eefd0..ff441abc1 100644 --- a/runtime/oauth/id_token.go +++ b/runtime/oauth/id_token.go @@ -1,6 +1,17 @@ package oauth -import "github.com/golang-jwt/jwt/v4" +import ( + "context" + "errors" + "fmt" + + "github.com/coreos/go-oidc" + "github.com/golang-jwt/jwt/v4" + "github.com/teamkeel/keel/runtime/auth" + "github.com/teamkeel/keel/runtime/runtimectx" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" +) type IdTokenClaims struct { jwt.RegisteredClaims @@ -31,3 +42,61 @@ type UserClaims struct { PhoneNumber string `json:"phone_number,omitempty"` PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` } + +var tracer = otel.Tracer("github.com/teamkeel/keel/runtime") + +// VerifyIdToken will verify the ID token from an OpenID Connect provider. +func VerifyIdToken(ctx context.Context, idTokenRaw string) (*oidc.IDToken, error) { + ctx, span := tracer.Start(ctx, "Verify ID Token") + defer span.End() + + issuer, err := auth.ExtractClaimFromToken(idTokenRaw, "iss") + if err != nil { + return nil, err + } + if issuer == "" { + return nil, errors.New("iss claim cannot be an empty string") + } + span.AddEvent("Issuer extracted from ID Token") + + span.SetAttributes(attribute.String("issuer", issuer)) + + authConfig, err := runtimectx.GetAuthConfig(ctx) + if err != nil { + return nil, err + } + + issuerPermitted := authConfig.AllowAnyIssuers + if !issuerPermitted { + for _, e := range authConfig.Issuers { + if e.Iss == issuer { + issuerPermitted = true + } + } + } + + if !issuerPermitted { + return nil, fmt.Errorf("issuer %s not registered to authenticate on this server", issuer) + } + + // Establishes new OIDC provider. This will call the providers discovery endpoint + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, err + } + span.AddEvent("Provider's ODIC config fetched") + + // TODO: Enable this check once we have the client ID as configurable + verifier := provider.Verifier(&oidc.Config{ + SkipClientIDCheck: true, + }) + + // Verify that the ID token legitimately was signed by the provider and that it has not expired + idToken, err := verifier.Verify(ctx, idTokenRaw) + if err != nil { + return nil, err + } + span.AddEvent("ID Token verified") + + return idToken, nil +} diff --git a/runtime/oauth/authentication_test.go b/runtime/oauth/id_token_test.go similarity index 100% rename from runtime/oauth/authentication_test.go rename to runtime/oauth/id_token_test.go diff --git a/runtime/oauth/refresh_token.go b/runtime/oauth/refresh_token.go new file mode 100644 index 000000000..eaee505b1 --- /dev/null +++ b/runtime/oauth/refresh_token.go @@ -0,0 +1,170 @@ +package oauth + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/dchest/uniuri" + "github.com/teamkeel/keel/db" + "golang.org/x/crypto/sha3" +) + +const ( + refreshTokenLength = 64 + defaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90 // 3 months is the default +) + +// NewRefreshToken generates a new refresh token for the identity using the +// configured or default expiry time. +func NewRefreshToken(ctx context.Context, identityId string) (string, error) { + ctx, span := tracer.Start(ctx, "New Refresh Token") + defer span.End() + + if identityId == "" { + return "", errors.New("identity ID cannot be empty when generating new refresh token") + } + + token := uniuri.NewLen(refreshTokenLength) + hash, err := hashToken(token) + if err != nil { + return "", err + } + + database, err := db.GetDatabase(ctx) + if err != nil { + return "", err + } + + now := time.Now().UTC() + expiry := now.Add(defaultRefreshTokenExpiry) + + sql := ` + INSERT INTO + keel_refresh_token (token, identity_id, expires_at, created_at) + VALUES + (?, ?, ?, ?)` + + db := database.GetDB().Exec(sql, hash, identityId, expiry, now) + if db.Error != nil { + return "", db.Error + } + + if db.RowsAffected != 1 { + return "", errors.New("failed to insert refresh token into database") + } + + return token, nil +} + +// RotateRefreshToken validates that the provided refresh token has not expired, +// and then rotates it for a new refresh token with the exact same expiry time and +// identity. The original refresh token is then revoked from future use. +func RotateRefreshToken(ctx context.Context, refreshTokenRaw string) (isValid bool, refreshToken string, identityId string, err error) { + ctx, span := tracer.Start(ctx, "Rotate Refresh Token") + defer span.End() + + tokenHash, err := hashToken(refreshTokenRaw) + if err != nil { + return false, "", "", err + } + + newRefreshToken := uniuri.NewLen(refreshTokenLength) + newTokenHash, err := hashToken(newRefreshToken) + if err != nil { + return false, "", "", err + } + + database, err := db.GetDatabase(ctx) + if err != nil { + return false, "", "", err + } + + // This query has the following (important) characteristics: + // - find and delete the refresh token if it has not expired (the latter is for performance) + // - create a new refresh token with the identity_id and expire_at of the original token + // - only creates the new token if the original token had not expired + sql := ` + WITH revoked_token AS ( + DELETE FROM + keel_refresh_token + WHERE + token = ? + RETURNING *) + INSERT INTO + keel_refresh_token (token, identity_id, expires_at, created_at) + SELECT + ?, identity_id, expires_at, now() + FROM + revoked_token + RETURNING *;` + + rows := []map[string]any{} + err = database.GetDB().Raw(sql, tokenHash, newTokenHash).Scan(&rows).Error + if err != nil { + return false, "", "", err + } + + // There was no refresh token found, and thus nothing to rotate. + if len(rows) != 1 { + return false, "", "", nil + } + + identityId, ok := rows[0]["identity_id"].(string) + if !ok { + return false, "", "", errors.New("could not parse identity_id from database result") + } + + return true, newRefreshToken, identityId, nil +} + +// RevokeRefreshToken will delete (revoke) the provided refresh token, +// which will prevent it from being used again. +func RevokeRefreshToken(ctx context.Context, refreshTokenRaw string) error { + ctx, span := tracer.Start(ctx, "Revoke Refresh Token") + defer span.End() + + tokenHash, err := hashToken(refreshTokenRaw) + if err != nil { + return err + } + + database, err := db.GetDatabase(ctx) + if err != nil { + return err + } + + sql := ` + DELETE FROM + keel_refresh_token + WHERE + token = ? + RETURNING *` + + rows := []map[string]any{} + err = database.GetDB().Raw(sql, tokenHash).Scan(&rows).Error + if err != nil { + return err + } + + // There was no refresh token found, and thus none to revoke. + if len(rows) == 0 { + return nil + } + + return nil +} + +// hashToken will produce a 256-bit SHA3 hash without salt +func hashToken(input string) (string, error) { + hash := sha3.New256() + _, err := hash.Write([]byte(input)) + if err != nil { + return "", err + } + + sha3 := hash.Sum(nil) + + return fmt.Sprintf("%x", sha3), nil +} diff --git a/runtime/oauth/refresh_token_test.go b/runtime/oauth/refresh_token_test.go new file mode 100644 index 000000000..dbea44493 --- /dev/null +++ b/runtime/oauth/refresh_token_test.go @@ -0,0 +1,137 @@ +package oauth_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/runtime/oauth" + "github.com/teamkeel/keel/runtime/runtimectx" + keeltesting "github.com/teamkeel/keel/testing" +) + +var authTestSchema = `model Post{}` + +func TestNewRefreshToken_NotEmpty(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + require.NotEmpty(t, refreshToken) +} + +func TestNewRefreshToken_ErrorOnEmptyIdentityId(t *testing.T) { + ctx := context.Background() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) + + _, err := oauth.NewRefreshToken(ctx, "") + require.Error(t, err) +} + +func TestRotateRefreshToken_Valid(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + isValid1, newRefreshToken1, identityId1, err := oauth.RotateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.True(t, isValid1) + require.Equal(t, "identity_id", identityId1) + require.NotEmpty(t, newRefreshToken1) + + isValid2, newRefreshToken2, identityId2, err := oauth.RotateRefreshToken(ctx, newRefreshToken1) + require.NoError(t, err) + require.True(t, isValid2) + require.Equal(t, "identity_id", identityId2) + require.NotEmpty(t, newRefreshToken2) + require.NotEqual(t, newRefreshToken2, newRefreshToken1) +} + +func TestRotateRefreshToken_ReuseRefreshTokenNotValid(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + isValid, newRefreshToken, identityId, err := oauth.RotateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.True(t, isValid) + require.Equal(t, "identity_id", identityId) + require.NotEmpty(t, newRefreshToken) + + isValid2, newRefreshToken2, identityId2, err := oauth.RotateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.False(t, isValid2) + require.Empty(t, identityId2) + require.Empty(t, newRefreshToken2) +} + +func TestRevokeRefreshToken_Unauthorised(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) + + refreshToken, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + err = oauth.RevokeRefreshToken(ctx, refreshToken) + require.NoError(t, err) + + isValid, _, _, err := oauth.RotateRefreshToken(ctx, refreshToken) + require.NoError(t, err) + require.False(t, isValid) +} + +func TestRevokeRefreshToken_MultipleForIdentity(t *testing.T) { + ctx, database, _ := keeltesting.MakeContext(t, authTestSchema, true) + defer database.Close() + + // Set up auth config + ctx = runtimectx.WithAuthConfig(ctx, runtimectx.AuthConfig{ + AllowAnyIssuers: true, + }) + + refreshToken1, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + refreshToken2, err := oauth.NewRefreshToken(ctx, "identity_id") + require.NoError(t, err) + + err = oauth.RevokeRefreshToken(ctx, refreshToken1) + require.NoError(t, err) + + isValid1, _, _, err := oauth.RotateRefreshToken(ctx, refreshToken1) + require.NoError(t, err) + require.False(t, isValid1) + + isValid2, _, _, err := oauth.RotateRefreshToken(ctx, refreshToken2) + require.NoError(t, err) + require.True(t, isValid2) +} diff --git a/runtime/runtime.go b/runtime/runtime.go index 7a891827f..a7d8f1998 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -96,6 +96,7 @@ func NewHttpHandler(currSchema *proto.Schema) http.Handler { func NewAuthHandler(schema *proto.Schema) common.HandlerFunc { handleToken := authapi.TokenEndpointHandler(schema) handleOAuth := authapi.OAuthHandler(schema) + handleRevoke := authapi.RevokeHandler(schema) return func(r *http.Request) common.Response { switch { @@ -103,6 +104,8 @@ func NewAuthHandler(schema *proto.Schema) common.HandlerFunc { return handleToken(r) case strings.HasPrefix(r.URL.Path, "/auth/oauth"): return handleOAuth(r) + case r.URL.Path == "/auth/revoke": + return handleRevoke(r) default: return common.Response{ Status: http.StatusNotFound, diff --git a/runtime/runtime_audit_test.go b/runtime/runtime_audit_test.go index d25756215..c408612dc 100644 --- a/runtime/runtime_audit_test.go +++ b/runtime/runtime_audit_test.go @@ -11,14 +11,12 @@ import ( "github.com/nsf/jsondiff" "github.com/samber/lo" "github.com/stretchr/testify/require" - "github.com/teamkeel/keel/db" "github.com/teamkeel/keel/migrations" "github.com/teamkeel/keel/proto" "github.com/teamkeel/keel/runtime/actions" "github.com/teamkeel/keel/runtime/auth" - "github.com/teamkeel/keel/runtime/runtimectx" - "github.com/teamkeel/keel/schema" "github.com/teamkeel/keel/testhelpers" + keeltesting "github.com/teamkeel/keel/testing" "go.opentelemetry.io/otel/trace" ) @@ -44,37 +42,6 @@ model WeddingInvitee { } }` -func newContext(t *testing.T, keelSchema string, resetDatabase bool) (context.Context, db.Database, *proto.Schema) { - dbConnInfo := &db.ConnectionInfo{ - Host: "localhost", - Port: "8001", - Username: "postgres", - Password: "postgres", - Database: "keel", - } - - builder := &schema.Builder{} - schema, err := builder.MakeFromString(keelSchema) - require.NoError(t, err) - - ctx := context.Background() - - // Add private key to context - pk, err := testhelpers.GetEmbeddedPrivateKey() - require.NoError(t, err) - ctx = runtimectx.WithPrivateKey(ctx, pk) - - ctx, err = testhelpers.WithTracing(ctx) - require.NoError(t, err) - - // Add database to context - database, err := testhelpers.SetupDatabaseForTestCase(ctx, dbConnInfo, schema, "runtime_test", resetDatabase) - require.NoError(t, err) - ctx = db.WithDatabase(ctx, database) - - return ctx, database, schema -} - func withIdentity(t *testing.T, ctx context.Context, schema *proto.Schema) (context.Context, *auth.Identity) { identity, err := actions.CreateIdentity(ctx, schema, "dave.new@keel.xyz", "1234") require.NoError(t, err) @@ -82,7 +49,7 @@ func withIdentity(t *testing.T, ctx context.Context, schema *proto.Schema) (cont } func TestAuditCreateAction(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -126,7 +93,7 @@ func TestAuditCreateAction(t *testing.T) { } func TestAuditNestedCreateAction(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -233,7 +200,7 @@ func TestAuditNestedCreateAction(t *testing.T) { } func TestAuditUpdateAction(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -291,7 +258,7 @@ func TestAuditUpdateAction(t *testing.T) { } func TestAuditDeleteAction(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -347,7 +314,7 @@ func TestAuditDeleteAction(t *testing.T) { } func TestAuditTablesWithOnlyIdentity(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -378,7 +345,7 @@ func TestAuditTablesWithOnlyIdentity(t *testing.T) { } func TestAuditTablesWithOnlyTracing(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -404,7 +371,7 @@ func TestAuditTablesWithOnlyTracing(t *testing.T) { } func TestAuditOnStatementExecuteWithoutResult(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() db := database.GetDB() @@ -445,7 +412,7 @@ func TestAuditOnStatementExecuteWithoutResult(t *testing.T) { } func TestAuditFieldsAreDroppedOnCreate(t *testing.T) { - ctx, database, schema := newContext(t, auditSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, auditSchema, true) defer database.Close() ctx, _ = withIdentity(t, ctx, schema) @@ -474,7 +441,7 @@ func TestAuditDatabaseMigration(t *testing.T) { @permission(expression: true, actions: [create, update, delete]) }` - ctx, database, pSchema := newContext(t, keelSchema, true) + ctx, database, pSchema := keeltesting.MakeContext(t, keelSchema, true) create := proto.FindAction(pSchema, "createPerson") _, _, err := actions.Execute( @@ -496,7 +463,7 @@ func TestAuditDatabaseMigration(t *testing.T) { }` database.Close() - ctx, database, pSchema = newContext(t, updatedSchema, false) + ctx, database, pSchema = keeltesting.MakeContext(t, updatedSchema, false) db := database.GetDB() defer database.Close() diff --git a/runtime/runtime_events_test.go b/runtime/runtime_events_test.go index 71fb6cbc1..05a820cf8 100644 --- a/runtime/runtime_events_test.go +++ b/runtime/runtime_events_test.go @@ -11,6 +11,7 @@ import ( "github.com/teamkeel/keel/events" "github.com/teamkeel/keel/proto" "github.com/teamkeel/keel/runtime/actions" + keeltesting "github.com/teamkeel/keel/testing" ) var eventsSchema = ` @@ -80,7 +81,7 @@ func (handler *EventHandler) HandleEvent(ctx context.Context, subscriber string, } func TestCreateEvent(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() ctx, identity := withIdentity(t, ctx, schema) @@ -124,7 +125,7 @@ func TestCreateEvent(t *testing.T) { } func TestUpdateEvent(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() result, _, err := actions.Execute( @@ -182,7 +183,7 @@ func TestUpdateEvent(t *testing.T) { } func TestDeleteEvent(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() result, _, err := actions.Execute( @@ -231,7 +232,7 @@ func TestDeleteEvent(t *testing.T) { } func TestNoIdentityEvent(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() handler := NewEventHandler(t) @@ -252,7 +253,7 @@ func TestNoIdentityEvent(t *testing.T) { } func TestNestedCreateEvent(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() ctx, _ = withIdentity(t, ctx, schema) @@ -294,7 +295,7 @@ func TestNestedCreateEvent(t *testing.T) { } func TestMultipleEvents(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() ctx, _ = withIdentity(t, ctx, schema) @@ -342,7 +343,7 @@ func TestMultipleEvents(t *testing.T) { } func TestAuditTableEventCreatedAtUpdated(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() handler := NewEventHandler(t) @@ -381,7 +382,7 @@ func TestAuditTableEventCreatedAtUpdated(t *testing.T) { } func TestFailedEventHandling(t *testing.T) { - ctx, database, schema := newContext(t, eventsSchema, true) + ctx, database, schema := keeltesting.MakeContext(t, eventsSchema, true) defer database.Close() ctx, err := events.WithEventHandler(ctx, func(ctx context.Context, subscriber string, event *events.Event, traceparent string) error { diff --git a/testing/util.go b/testing/util.go new file mode 100644 index 000000000..5b6494949 --- /dev/null +++ b/testing/util.go @@ -0,0 +1,47 @@ +package testing + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/db" + "github.com/teamkeel/keel/proto" + "github.com/teamkeel/keel/runtime/runtimectx" + "github.com/teamkeel/keel/schema" + "github.com/teamkeel/keel/testhelpers" +) + +func MakeContext(t *testing.T, keelSchema string, resetDatabase bool) (context.Context, db.Database, *proto.Schema) { + dbConnInfo := &db.ConnectionInfo{ + Host: "localhost", + Port: "8001", + Username: "postgres", + Password: "postgres", + Database: "keel", + } + + builder := &schema.Builder{} + schema, err := builder.MakeFromString(keelSchema) + require.NoError(t, err) + + ctx := context.Background() + + // Add private key to context + pk, err := testhelpers.GetEmbeddedPrivateKey() + require.NoError(t, err) + ctx = runtimectx.WithPrivateKey(ctx, pk) + + ctx, err = testhelpers.WithTracing(ctx) + require.NoError(t, err) + + databaseName := strings.ToLower("keel_test_" + t.Name()) + + // Add database to context + database, err := testhelpers.SetupDatabaseForTestCase(ctx, dbConnInfo, schema, databaseName, resetDatabase) + require.NoError(t, err) + ctx = db.WithDatabase(ctx, database) + + return ctx, database, schema +}