Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refresh token rotation configuration & github and facebook providers #1298

Merged
merged 12 commits into from
Nov 14, 2023
70 changes: 60 additions & 10 deletions config/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,46 @@ package config
import (
"fmt"
"net/url"
"strings"
"time"

"github.com/samber/lo"
)

const (
// 24 hours is the default access token expiry period
DefaultAccessTokenExpiry time.Duration = time.Hour * 24
// 3 months is the default refresh token expiry period
DefaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90
)

const (
GoogleProvider = "google"
FacebookProvider = "facebook"
GitLabProvider = "gitlab"
OpenIdConnectProvider = "oidc"
OAuthProvider = "oauth"
)

var (
SupportedProviderTypes = []string{
GoogleProvider,
FacebookProvider,
GitLabProvider,
OpenIdConnectProvider,
OAuthProvider,
}
)

type AuthConfig struct {
Tokens *TokensConfig `yaml:"tokens"`
Providers []Provider `yaml:"providers"`
Tokens TokensConfig `yaml:"tokens"`
Providers []Provider `yaml:"providers"`
}

type TokensConfig struct {
AccessTokenExpiry int `yaml:"accessTokenExpiry"`
RefreshTokenExpiry int `yaml:"refreshTokenExpiry"`
AccessTokenExpiry *int `yaml:"accessTokenExpiry,omitempty"`
RefreshTokenExpiry *int `yaml:"refreshTokenExpiry,omitempty"`
RefreshTokenRotationEnabled *bool `yaml:"refreshTokenRotationEnabled,omitempty"`
}

type Provider struct {
Expand All @@ -40,6 +54,33 @@ type Provider struct {
AuthorizationUrl string `yaml:"authorizationUrl"`
}

// AccessTokenExpiryOrDefault retrieves the configured or default access token expiry
func (c *AuthConfig) AccessTokenExpiryOrDefault() time.Duration {
jonbretman marked this conversation as resolved.
Show resolved Hide resolved
if c.Tokens.AccessTokenExpiry != nil {
return time.Duration(*c.Tokens.AccessTokenExpiry) * time.Second
} else {
return DefaultAccessTokenExpiry
}
}

// RefreshTokenExpiryOrDefault retrieves the configured or default refresh token expiry
func (c *AuthConfig) RefreshTokenExpiryOrDefault() time.Duration {
if c.Tokens.RefreshTokenExpiry != nil {
return time.Duration(*c.Tokens.RefreshTokenExpiry) * time.Second
} else {
return DefaultRefreshTokenExpiry
}
}

// RefreshTokenRotationEnabled retrieves the configured or default refresh token rotation
func (c *AuthConfig) RefreshTokenRotationEnabled() bool {
if c.Tokens.RefreshTokenRotationEnabled != nil {
return *c.Tokens.RefreshTokenRotationEnabled
} else {
return true
}
}

func (c *AuthConfig) GetOidcProviders() []Provider {
oidcProviders := []Provider{}
for _, p := range c.Providers {
Expand All @@ -50,27 +91,36 @@ func (c *AuthConfig) GetOidcProviders() []Provider {
return oidcProviders
}

func (c *AuthConfig) HasOidcIssuer(issuer string) (bool, error) {
// GetOidcProvidersByIssuer gets all OpenID Connect providers by issuer url.
// It's possible that multiple providers from the same issuer are configured.
func (c *AuthConfig) GetOidcProvidersByIssuer(issuer string) ([]Provider, error) {
providers := []Provider{}

for _, p := range c.Providers {
if p.Type == OAuthProvider {
continue
}

issuerUrl, err := p.GetIssuer()
if err != nil {
return false, err
return nil, err
}
if issuerUrl == issuer {
return true, nil
if strings.TrimSuffix(issuerUrl, "/") == strings.TrimSuffix(issuer, "/") {
providers = append(providers, p)
}
}
return false, nil

return providers, nil
}

func (c *Provider) GetIssuer() (string, error) {
switch c.Type {
case GoogleProvider:
return "https://accounts.google.com/", nil
return "https://accounts.google.com", nil
case FacebookProvider:
return "https://www.facebook.com", nil
case GitLabProvider:
return "https://gitlab.com", nil
case OpenIdConnectProvider:
return c.IssuerUrl, nil
default:
Expand Down
4 changes: 2 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,14 @@ func Validate(config *ProjectConfig) *ConfigErrors {
}
}

if config.Auth.Tokens != nil && config.Auth.Tokens.AccessTokenExpiry <= 0 {
if config.Auth.AccessTokenExpiryOrDefault() <= 0 {
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "access", "accessTokenExpiry"),
})
}

if config.Auth.Tokens != nil && config.Auth.Tokens.RefreshTokenExpiry <= 0 {
if config.Auth.RefreshTokenExpiryOrDefault() <= 0 {
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "refresh", "refreshTokenExpiry"),
Expand Down
52 changes: 40 additions & 12 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -119,8 +120,26 @@ func TestAuthTokens(t *testing.T) {
config, err := Load("fixtures/test_auth.yaml")
assert.NoError(t, err)

assert.Equal(t, 3600, config.Auth.Tokens.AccessTokenExpiry)
assert.Equal(t, 604800, config.Auth.Tokens.RefreshTokenExpiry)
assert.Equal(t, 3600, *config.Auth.Tokens.AccessTokenExpiry)
assert.Equal(t, 604800, *config.Auth.Tokens.RefreshTokenExpiry)
assert.Equal(t, false, *config.Auth.Tokens.RefreshTokenRotationEnabled)

assert.Equal(t, time.Duration(3600)*time.Second, config.Auth.AccessTokenExpiryOrDefault())
assert.Equal(t, time.Duration(604800)*time.Second, config.Auth.RefreshTokenExpiryOrDefault())
assert.Equal(t, false, config.Auth.RefreshTokenRotationEnabled())
}

func TestAuthDefaults(t *testing.T) {
config, err := Load("fixtures/test_auth_empty.yaml")
assert.NoError(t, err)

assert.Nil(t, config.Auth.Tokens.AccessTokenExpiry)
assert.Nil(t, config.Auth.Tokens.RefreshTokenExpiry)
assert.Nil(t, config.Auth.Tokens.RefreshTokenRotationEnabled)

assert.Equal(t, time.Duration(24)*time.Hour, config.Auth.AccessTokenExpiryOrDefault())
assert.Equal(t, time.Duration(24)*time.Hour*90, config.Auth.RefreshTokenExpiryOrDefault())
assert.Equal(t, true, config.Auth.RefreshTokenRotationEnabled())
}

func TestAuthNegativeTokenLifespan(t *testing.T) {
Expand Down Expand Up @@ -171,9 +190,9 @@ func TestDuplicateProviderName(t *testing.T) {
func TestInvalidProviderTypes(t *testing.T) {
_, err := Load("fixtures/test_auth_invalid_types.yaml")

assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, facebook, gitlab, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, facebook, gitlab, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, facebook, gitlab, oidc, oauth\n")
}

func TestMissingClientId(t *testing.T) {
Expand Down Expand Up @@ -201,19 +220,28 @@ func TestMissingOrInvalidTokenEndpoint(t *testing.T) {
assert.Contains(t, err.Error(), "auth provider 'missing-endpoint' has missing or invalid https url for field: tokenUrl\n")
}

func TestHasIssuer(t *testing.T) {
func TestGetOidcIssuer(t *testing.T) {
config, err := Load("fixtures/test_auth.yaml")
assert.NoError(t, err)

hasGoogleIssuer, err := config.Auth.HasOidcIssuer("https://accounts.google.com/")
googleIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://accounts.google.com/")
assert.NoError(t, err)
assert.True(t, hasGoogleIssuer)
assert.Len(t, googleIssuer, 2)

auth0Issuer, err := config.Auth.GetOidcProvidersByIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com")
assert.NoError(t, err)
assert.Len(t, auth0Issuer, 1)

nopeIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://nope.com")
assert.NoError(t, err)
assert.Len(t, nopeIssuer, 0)
}

hasCustomIssuer, err := config.Auth.HasOidcIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com")
func TestGetOidcSameIssuers(t *testing.T) {
config, err := Load("fixtures/test_auth_same_issuers.yaml")
assert.NoError(t, err)
assert.True(t, hasCustomIssuer)

hasUnknownIssuer, err := config.Auth.HasOidcIssuer("https://nope.com")
googleIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://accounts.google.com/")
assert.NoError(t, err)
assert.False(t, hasUnknownIssuer)
assert.Len(t, googleIssuer, 3)
}
3 changes: 2 additions & 1 deletion config/fixtures/test_auth.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ auth:
tokens:
accessTokenExpiry: 3600
refreshTokenExpiry: 604800

refreshTokenRotationEnabled: false

providers:
# Built-in Google provider
- type: google
Expand Down
Empty file.
18 changes: 18 additions & 0 deletions config/fixtures/test_auth_same_issuers.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
auth:
tokens:
accessTokenExpiry: 3600
refreshTokenExpiry: 604800

providers:
- type: google
name: google_1
clientId: 1234

- type: google
name: google_2
clientId: 1234

- type: oidc
name: google_3
issuerUrl: https://accounts.google.com/
clientId: 1234
83 changes: 47 additions & 36 deletions runtime/apis/authapi/token_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/teamkeel/keel/runtime/actions"
"github.com/teamkeel/keel/runtime/common"
"github.com/teamkeel/keel/runtime/oauth"
"github.com/teamkeel/keel/runtime/runtimectx"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
)
Expand Down Expand Up @@ -66,6 +67,15 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc {
ctx, span := tracer.Start(r.Context(), "Token Endpoint")
defer span.End()

var identityId string
var refreshToken string

config, err := runtimectx.GetOAuthConfig(ctx)
if err != nil {
span.RecordError(err)
jonbretman marked this conversation as resolved.
Show resolved Hide resolved
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}

if r.Method != http.MethodPost {
return common.NewJsonResponse(http.StatusMethodNotAllowed, &TokenErrorResponse{
Error: TokenEndpointInvalidRequest,
Expand Down Expand Up @@ -102,7 +112,7 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc {
}, nil)
}

refreshTokenRaw := r.Form.Get(ArgRefreshToken)
refreshTokenRaw := r.FormValue(ArgRefreshToken)

if refreshTokenRaw == "" {
return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{
Expand All @@ -111,10 +121,24 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc {
}, nil)
}

isValid, newRefreshToken, identityId, err := oauth.RotateRefreshToken(ctx, refreshTokenRaw)
if err != nil {
span.RecordError(err)
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
var isValid bool
if config.RefreshTokenRotationEnabled() {
// Rotate and revoke this refresh token, and mint a new one.
isValid, refreshToken, identityId, err = oauth.RotateRefreshToken(ctx, refreshTokenRaw)
if err != nil {
span.RecordError(err)
jonbretman marked this conversation as resolved.
Show resolved Hide resolved
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}
} else {
// Response with the same refresh token when refresh token rotation is disabled
refreshToken = refreshTokenRaw

// Check that the refresh token exists and has not expired.
isValid, identityId, err = oauth.ValidateRefreshToken(ctx, refreshToken)
if err != nil {
span.RecordError(err)
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}
}

if !isValid {
Expand All @@ -124,21 +148,6 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc {
}, nil)
}

// Generate an access token for this identity.
accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identityId)
if err != nil {
span.RecordError(err)
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}

response := &TokenResponse{
AccessToken: accessTokenRaw,
TokenType: TokenType,
ExpiresIn: int(expiresIn.Seconds()),
RefreshToken: newRefreshToken,
}

return common.NewJsonResponse(http.StatusOK, response, nil)
case GrantTypeTokenExchange:
if !r.Form.Has(ArgSubjectToken) {
return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{
Expand Down Expand Up @@ -217,35 +226,37 @@ func TokenEndpointHandler(schema *proto.Schema) common.HandlerFunc {
}
}

// Generate an access token for this identity.
accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identity.Id)
if err != nil {
span.RecordError(err)
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}

// Generate a refresh token.
refreshTokenRaw, err := oauth.NewRefreshToken(ctx, identity.Id)
refreshToken, err = oauth.NewRefreshToken(ctx, identity.Id)
if err != nil {
span.RecordError(err)
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}

response := &TokenResponse{
AccessToken: accessTokenRaw,
TokenType: TokenType,
ExpiresIn: int(expiresIn.Seconds()),
RefreshToken: refreshTokenRaw,
}

return common.NewJsonResponse(http.StatusOK, response, nil)
identityId = identity.Id

default:
return common.NewJsonResponse(http.StatusBadRequest, &TokenErrorResponse{
Error: TokenEndpointUnsupportedGrantType,
ErrorDescription: "the only supported grants are 'refresh_token' and 'token_exchange'",
}, nil)
}

// Generate a new access token for this identity.
accessTokenRaw, expiresIn, err := oauth.GenerateAccessToken(ctx, identityId)
if err != nil {
span.RecordError(err)
jonbretman marked this conversation as resolved.
Show resolved Hide resolved
return common.NewJsonResponse(http.StatusInternalServerError, nil, nil)
}

response := &TokenResponse{
AccessToken: accessTokenRaw,
TokenType: TokenType,
ExpiresIn: int(expiresIn.Seconds()),
RefreshToken: refreshToken,
}

return common.NewJsonResponse(http.StatusOK, response, nil)
}
}

Expand Down
Loading
Loading