diff --git a/providers/openidConnect/openidConnect.go b/providers/openidConnect/openidConnect.go index 4a721594..5b22d3d7 100644 --- a/providers/openidConnect/openidConnect.go +++ b/providers/openidConnect/openidConnect.go @@ -51,13 +51,14 @@ const ( // Provider is the implementation of `goth.Provider` for accessing OpenID Connect provider type Provider struct { - ClientKey string - Secret string - CallbackURL string - HTTPClient *http.Client - OpenIDConfig *OpenIDConfig - config *oauth2.Config - providerName string + ClientKey string + Secret string + CallbackURL string + HTTPClient *http.Client + OpenIDConfig *OpenIDConfig + config *oauth2.Config + authCodeOptions []oauth2.AuthCodeOption + providerName string UserIdClaims []string NameClaims []string @@ -186,6 +187,14 @@ func (p *Provider) SetName(name string) { p.providerName = name } +// SetAuthCodeOptions sets additional parameters for the authentication URL. +// It takes a map of string key-value pairs and appends them to the provider's authCodeOptions. +func (p *Provider) SetAuthCodeOptions(params map[string]string) { + for k, v := range params { + p.authCodeOptions = append(p.authCodeOptions, oauth2.SetAuthURLParam(k, v)) + } +} + func (p *Provider) Client() *http.Client { return goth.HTTPClientWithFallBack(p.HTTPClient) } @@ -195,7 +204,7 @@ func (p *Provider) Debug(debug bool) {} // BeginAuth asks the OpenID Connect provider for an authentication end-point. func (p *Provider) BeginAuth(state string) (goth.Session, error) { - url := p.config.AuthCodeURL(state) + url := p.config.AuthCodeURL(state, p.authCodeOptions...) session := &Session{ AuthURL: url, } diff --git a/providers/openidConnect/openidConnect_test.go b/providers/openidConnect/openidConnect_test.go index 3e844359..7dd76e04 100644 --- a/providers/openidConnect/openidConnect_test.go +++ b/providers/openidConnect/openidConnect_test.go @@ -78,6 +78,24 @@ func Test_BeginAuth(t *testing.T) { a.Contains(s.AuthURL, "scope=openid") } +func Test_BeginAuth_AuthCodeOptions(t *testing.T) { + t.Parallel() + a := assert.New(t) + + provider := openidConnectProvider() + provider.SetAuthCodeOptions(map[string]string{"domain_hint": "test_domain.com", "prompt": "none"}) + session, err := provider.BeginAuth("test_state") + s := session.(*Session) + a.NoError(err) + a.Contains(s.AuthURL, "https://accounts.google.com/o/oauth2/v2/auth") + a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", os.Getenv("OPENID_CONNECT_KEY"))) + a.Contains(s.AuthURL, "state=test_state") + a.Contains(s.AuthURL, "redirect_uri=http%3A%2F%2Flocalhost%2Ffoo") + a.Contains(s.AuthURL, "scope=openid") + a.Contains(s.AuthURL, "domain_hint=test_domain.com") + a.Contains(s.AuthURL, "prompt=none") +} + func Test_Implements_Provider(t *testing.T) { t.Parallel() a := assert.New(t)