Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: oidc does not require a method in the payload #3564

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions selfservice/flow/registration/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
package registration_test

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/bxcodec/faker/v3"
"github.com/gofrs/uuid"

"github.com/ory/kratos/corpx"
Expand All @@ -30,6 +33,8 @@ import (
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/strategy/oidc"
"github.com/ory/kratos/selfservice/strategy/password"
"github.com/ory/kratos/x"
)

Expand Down Expand Up @@ -376,3 +381,107 @@ func TestGetFlow(t *testing.T) {
assert.EqualValues(t, http.StatusNotFound, res.StatusCode)
})
}

// TODO(Benehiko): this test will be updated when the `oidc` strategy is fixed.
// the OIDC strategy incorrectly assumes that is should continue if no
// method is specified but the provider is set.
func TestMultipleStrategies(t *testing.T) {
t.Logf("This test has been set up to validate the current incorrect `oidc` behaviour. When submitting the form, the `oidc` strategy is executed first, even if the method is set to `password`.")

ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)

// we need to replicate the oidc strategy before the password strategy
reg.WithSelfserviceStrategies(t, []any{
oidc.NewStrategy(reg),
password.NewStrategy(reg),
})

conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationEnabled, true)
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/registration.schema.json")
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypePassword),
map[string]interface{}{"enabled": true})
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeOIDC),
map[string]interface{}{"enabled": true})
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeCodeAuth),
map[string]interface{}{"passwordless_enabled": true})
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeOIDC)+".config", &oidc.ConfigurationCollection{Providers: []oidc.Configuration{
{
ID: "google",
Provider: "google",
ClientID: "1234",
ClientSecret: "1234",
},
}})

public, _ := testhelpers.NewKratosServerWithCSRF(t, reg)
_ = testhelpers.NewErrorTestServer(t, reg)
_ = testhelpers.NewRedirTS(t, "", conf)

setupRegistrationUI := func(t *testing.T, c *http.Client) *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write(x.EasyGetBody(t, c, public.URL+registration.RouteGetFlow+"?id="+r.URL.Query().Get("flow")))
require.NoError(t, err)
}))
t.Cleanup(ts.Close)
conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationUI, ts.URL)
return ts
}

t.Run("case=accept `password` method while including `provider:google`", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
_ = setupRegistrationUI(t, client)
body := x.EasyGetBody(t, client, public.URL+registration.RouteInitBrowserFlow)

flow := gjson.GetBytes(body, "id").String()

csrfToken := gjson.GetBytes(body, "ui.nodes.#(attributes.name==csrf_token).attributes.value").String()
email := faker.Email()
payload := json.RawMessage(`{"traits": {"email": "` + email + `"},"method": "password","password": "asdasdasdsa21312@#!@%","provider": "google","csrf_token": "` + csrfToken + `"}`)

req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, public.URL+registration.RouteSubmitFlow+"?flow="+flow, bytes.NewBuffer(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")

require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusOK, resp.StatusCode)

verifiableAddress, err := reg.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, email)
require.NoError(t, err)
require.Equal(t, strings.ToLower(email), verifiableAddress.Value)

id, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, verifiableAddress.IdentityID)
require.NoError(t, err)
require.NotNil(t, id.ID)

_, ok := id.GetCredentials(identity.CredentialsTypePassword)
require.True(t, ok)
})

t.Run("case=accept oidc flow with just `provider:google`", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
_ = setupRegistrationUI(t, client)
body := x.EasyGetBody(t, client, public.URL+registration.RouteInitBrowserFlow)

flow := gjson.GetBytes(body, "id").String()

csrfToken := gjson.GetBytes(body, "ui.nodes.#(attributes.name==csrf_token).attributes.value").String()

payload := json.RawMessage(`{"provider": "google","csrf_token": "` + csrfToken + `"}`)

req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, public.URL+registration.RouteSubmitFlow+"?flow="+flow, bytes.NewBuffer(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Containsf(t, gjson.GetBytes(b, "error.reason").String(), "In order to complete this flow please redirect the browser to: https://accounts.google.com/o/oauth2/v2/auth", "accounts.google.com", "%s", b)
})
}
4 changes: 4 additions & 0 deletions selfservice/flow/registration/stub/registration.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"credentials": {
"password": {
"identifier": true
},
"code": {
"identifier": true,
"via": "email"
}
},
"verification": {
Expand Down
3 changes: 3 additions & 0 deletions selfservice/strategy/oidc/.schema/link.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"type": "object",
"additionalProperties": true
},
"method": {
"type": "string"
},
"upstream_parameters": {
"type": "object",
"$comment": "Only the defined parameters are allowed. This is to prevent users from sending arbitrary parameters or craft URLs that cause unexpected behavior.",
Expand Down
43 changes: 35 additions & 8 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"encoding/json"
"net/http"
"strings"
"time"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -181,13 +182,16 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo
}

func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ uuid.UUID) (i *identity.Identity, err error) {
ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Login")
defer span.End()

if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil {
return nil, err
}

var p UpdateLoginFlowWithOidcMethod
if err := s.newLinkDecoder(&p, r); err != nil {
return nil, s.handleError(w, r, f, "", nil, errors.WithStack(herodot.ErrBadRequest.WithDebug(err.Error()).WithReasonf("Unable to parse HTTP form request: %s", err.Error())))
return nil, s.handleError(w, r, f, "", nil, err)
}

f.IDToken = p.IDToken
Expand All @@ -198,21 +202,44 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, errors.WithStack(flow.ErrStrategyNotResponsible)
}

if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
if p.Method == "" {
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
// the user is sending a method that is not oidc, but the payload includes a provider
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
Warn("The payload includes a `provider` field but does not specify the `method` field. This is incorrect behavior and will be removed in the future.")
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
}

// This is a small check to ensure users do not encounter issues with the current "incorrect" oidc behavior.
// this will be removed in the future when the oidc method behavior is fixed.
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" {
if pid != "" {
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
WithField("method", p.Method).
Warn("The payload includes a `provider` field but does not specify the `method` field or does not use the `oidc` method. This is incorrect behavior and will be removed in the future.")
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
}
return nil, errors.WithStack(flow.ErrStrategyNotResponsible)
}

// TODO(Benehiko): Change the following line to actually match the payload `method` field with the current strategy `oidc`.
// right now it matches itself so it will always be true
if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}

provider, err := s.provider(r.Context(), r, pid)
provider, err := s.provider(ctx, r, pid)
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}

c, err := provider.OAuth2(r.Context())
c, err := provider.OAuth2(ctx)
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}

req, err := s.validateFlow(r.Context(), r, f.ID)
req, err := s.validateFlow(ctx, r, f.ID)
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}
Expand All @@ -239,10 +266,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

state := generateState(f.ID.String())
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode {
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode {
state.setCode(code.InitCode)
}
if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName,
if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName,
continuity.WithPayload(&AuthCodeContainer{
State: state.String(),
FlowID: f.ID.String(),
Expand All @@ -253,7 +280,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

f.Active = s.ID()
if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil {
if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil {
return nil, s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error())))
}

Expand Down
42 changes: 33 additions & 9 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"

"github.com/ory/x/otelx"
"github.com/ory/x/sqlxx"

"github.com/ory/herodot"
Expand Down Expand Up @@ -151,6 +152,9 @@ func (s *Strategy) newLinkDecoder(p interface{}, r *http.Request) error {
}

func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) {
ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Register")
defer otelx.End(span, &err)

var p UpdateRegistrationFlowWithOidcMethod
if err := s.newLinkDecoder(&p, r); err != nil {
return s.handleError(w, r, f, "", nil, err)
Expand All @@ -165,21 +169,43 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return errors.WithStack(flow.ErrStrategyNotResponsible)
}

if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
if p.Method == "" {
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
// the user is sending a method that is not oidc, but the payload includes a provider
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
Warn("The payload includes a `provider` field but does not specify the `method` field. This is incorrect behavior and will be removed in the future.")
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
}

// This is a small check to ensure users do not encounter issues with the current "incorrect" oidc behavior.
// this will be removed in the future when the oidc method behavior is fixed.
if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" {
// the user is sending a method that is not oidc, but the payload includes a provider
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
WithField("method", p.Method).
Warn("The payload includes a `provider` field but specifies a `method` other than `oidc`. This is incorrect behavior and will be removed in the future.")
return errors.WithStack(flow.ErrStrategyNotResponsible)
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
}

// TODO(Benehiko): Change the following line to actually match the payload `method` field with the current strategy `oidc`.
// right now it matches itself so it will always be true
if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
return s.handleError(w, r, f, pid, nil, err)
}

provider, err := s.provider(r.Context(), r, pid)
provider, err := s.provider(ctx, r, pid)
if err != nil {
return s.handleError(w, r, f, pid, nil, err)
}

c, err := provider.OAuth2(r.Context())
c, err := provider.OAuth2(ctx)
if err != nil {
return s.handleError(w, r, f, pid, nil, err)
}

req, err := s.validateFlow(r.Context(), r, f.ID)
req, err := s.validateFlow(ctx, r, f.ID)
if err != nil {
return s.handleError(w, r, f, pid, nil, err)
}
Expand Down Expand Up @@ -207,10 +233,10 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
}

state := generateState(f.ID.String())
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode {
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode {
state.setCode(code.InitCode)
}
if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName,
if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName,
continuity.WithPayload(&AuthCodeContainer{
State: state.String(),
FlowID: f.ID.String(),
Expand Down Expand Up @@ -321,9 +347,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r
}

var it string = idToken
var (
cat, crt string
)
var cat, crt string
if token != nil {
if idToken, ok := token.Extra("id_token").(string); ok {
if it, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(idToken)); err != nil {
Expand Down
Loading