diff --git a/selfservice/flow/registration/handler_test.go b/selfservice/flow/registration/handler_test.go index 1ed093609400..27df3e0868a5 100644 --- a/selfservice/flow/registration/handler_test.go +++ b/selfservice/flow/registration/handler_test.go @@ -4,6 +4,7 @@ package registration_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -11,9 +12,11 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" + "github.com/bxcodec/faker/v3" "github.com/gofrs/uuid" "github.com/ory/kratos/corpx" @@ -30,6 +33,8 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/password" "github.com/ory/kratos/x" ) @@ -376,3 +381,107 @@ func TestGetFlow(t *testing.T) { assert.EqualValues(t, http.StatusNotFound, res.StatusCode) }) } + +// This test verifies that the password method is still executed even if the +// oidc strategy is ordered before the password strategy +// when submitting the form with both `method=password` and `provider=google`. +func TestOIDCStrategyOrder(t *testing.T) { + t.Logf("This test has been set up to validate the current incorrect `oidc` behaviour. When submitting the form, the `oidc` strategy is executed first, even if the method is set to `password`.") + + ctx := context.Background() + conf, reg := internal.NewFastRegistryWithMocks(t) + + // reorder the strategies + reg.WithSelfserviceStrategies(t, []any{ + oidc.NewStrategy(reg), + password.NewStrategy(reg), + }) + + conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationEnabled, true) + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/registration.schema.json") + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypePassword), + map[string]interface{}{"enabled": true}) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeOIDC), + map[string]interface{}{"enabled": true}) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeCodeAuth), + map[string]interface{}{"passwordless_enabled": true}) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeOIDC)+".config", &oidc.ConfigurationCollection{Providers: []oidc.Configuration{ + { + ID: "google", + Provider: "google", + ClientID: "1234", + ClientSecret: "1234", + }, + }}) + + public, _ := testhelpers.NewKratosServerWithCSRF(t, reg) + _ = testhelpers.NewErrorTestServer(t, reg) + _ = testhelpers.NewRedirTS(t, "", conf) + + setupRegistrationUI := func(t *testing.T, c *http.Client) *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write(x.EasyGetBody(t, c, public.URL+registration.RouteGetFlow+"?id="+r.URL.Query().Get("flow"))) + require.NoError(t, err) + })) + t.Cleanup(ts.Close) + conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationUI, ts.URL) + return ts + } + + t.Run("case=accept `password` method while including `provider:google`", func(t *testing.T) { + client := testhelpers.NewClientWithCookies(t) + _ = setupRegistrationUI(t, client) + body := x.EasyGetBody(t, client, public.URL+registration.RouteInitBrowserFlow) + + flow := gjson.GetBytes(body, "id").String() + + csrfToken := gjson.GetBytes(body, "ui.nodes.#(attributes.name==csrf_token).attributes.value").String() + email := faker.Email() + payload := json.RawMessage(`{"traits": {"email": "` + email + `"},"method": "password","password": "asdasdasdsa21312@#!@%","provider": "google","csrf_token": "` + csrfToken + `"}`) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, public.URL+registration.RouteSubmitFlow+"?flow="+flow, bytes.NewBuffer(payload)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, resp.StatusCode) + + verifiableAddress, err := reg.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, email) + require.NoError(t, err) + require.Equal(t, strings.ToLower(email), verifiableAddress.Value) + + id, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, verifiableAddress.IdentityID) + require.NoError(t, err) + require.NotNil(t, id.ID) + + _, ok := id.GetCredentials(identity.CredentialsTypePassword) + require.True(t, ok) + }) + + t.Run("case=accept oidc flow with just `provider:google`", func(t *testing.T) { + client := testhelpers.NewClientWithCookies(t) + _ = setupRegistrationUI(t, client) + body := x.EasyGetBody(t, client, public.URL+registration.RouteInitBrowserFlow) + + flow := gjson.GetBytes(body, "id").String() + + csrfToken := gjson.GetBytes(body, "ui.nodes.#(attributes.name==csrf_token).attributes.value").String() + + payload := json.RawMessage(`{"provider": "google","csrf_token": "` + csrfToken + `"}`) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, public.URL+registration.RouteSubmitFlow+"?flow="+flow, bytes.NewBuffer(payload)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Containsf(t, gjson.GetBytes(b, "error.reason").String(), "In order to complete this flow please redirect the browser to: https://accounts.google.com/o/oauth2/v2/auth", "accounts.google.com", "%s", b) + }) +} diff --git a/selfservice/flow/registration/stub/registration.schema.json b/selfservice/flow/registration/stub/registration.schema.json index 78fcb22c6587..e591660812a3 100644 --- a/selfservice/flow/registration/stub/registration.schema.json +++ b/selfservice/flow/registration/stub/registration.schema.json @@ -16,6 +16,10 @@ "credentials": { "password": { "identifier": true + }, + "code": { + "identifier": true, + "via": "email" } }, "verification": { diff --git a/selfservice/strategy/oidc/.schema/link.schema.json b/selfservice/strategy/oidc/.schema/link.schema.json index 1c174841244e..aae8e86f2626 100644 --- a/selfservice/strategy/oidc/.schema/link.schema.json +++ b/selfservice/strategy/oidc/.schema/link.schema.json @@ -17,6 +17,9 @@ "type": "object", "additionalProperties": true }, + "method": { + "type": "string" + }, "upstream_parameters": { "type": "object", "$comment": "Only the defined parameters are allowed. This is to prevent users from sending arbitrary parameters or craft URLs that cause unexpected behavior.", diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 5150eabcda52..c81537015bc9 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/json" "net/http" + "strings" "time" "github.com/julienschmidt/httprouter" @@ -181,13 +182,16 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ uuid.UUID) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Login") + defer span.End() + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil { return nil, err } var p UpdateLoginFlowWithOidcMethod if err := s.newLinkDecoder(&p, r); err != nil { - return nil, s.handleError(w, r, f, "", nil, errors.WithStack(herodot.ErrBadRequest.WithDebug(err.Error()).WithReasonf("Unable to parse HTTP form request: %s", err.Error()))) + return nil, s.handleError(w, r, f, "", nil, err) } f.IDToken = p.IDToken @@ -198,21 +202,31 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" { + // the user is sending a method that is not oidc, but the payload includes a provider + s.d.Audit(). + WithRequest(r). + WithField("provider", p.Provider). + WithField("method", p.Method). + Warn("The payload includes a `provider` field but is using a method other than `oidc`. Therefore, social sign in will not be executed.") + return nil, errors.WithStack(flow.ErrStrategyNotResponsible) + } + + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return nil, s.handleError(w, r, f, pid, nil, err) } - provider, err := s.provider(r.Context(), r, pid) + provider, err := s.provider(ctx, r, pid) if err != nil { return nil, s.handleError(w, r, f, pid, nil, err) } - c, err := provider.OAuth2(r.Context()) + c, err := provider.OAuth2(ctx) if err != nil { return nil, s.handleError(w, r, f, pid, nil, err) } - req, err := s.validateFlow(r.Context(), r, f.ID) + req, err := s.validateFlow(ctx, r, f.ID) if err != nil { return nil, s.handleError(w, r, f, pid, nil, err) } @@ -239,10 +253,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } state := generateState(f.ID.String()) - if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode { + if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode { state.setCode(code.InitCode) } - if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName, + if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName, continuity.WithPayload(&AuthCodeContainer{ State: state.String(), FlowID: f.ID.String(), @@ -253,7 +267,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } f.Active = s.ID() - if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 13c6a448c047..b6747ef9ad0a 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -13,6 +13,7 @@ import ( "github.com/gofrs/uuid" "github.com/julienschmidt/httprouter" + "github.com/ory/x/otelx" "github.com/ory/x/sqlxx" "github.com/ory/herodot" @@ -151,6 +152,9 @@ func (s *Strategy) newLinkDecoder(p interface{}, r *http.Request) error { } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Register") + defer otelx.End(span, &err) + var p UpdateRegistrationFlowWithOidcMethod if err := s.newLinkDecoder(&p, r); err != nil { return s.handleError(w, r, f, "", nil, err) @@ -165,21 +169,31 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return errors.WithStack(flow.ErrStrategyNotResponsible) } - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" { + // the user is sending a method that is not oidc, but the payload includes a provider + s.d.Audit(). + WithRequest(r). + WithField("provider", p.Provider). + WithField("method", p.Method). + Warn("The payload includes a `provider` field but is using a method other than `oidc`. Therefore, social sign in will not be executed.") + return errors.WithStack(flow.ErrStrategyNotResponsible) + } + + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return s.handleError(w, r, f, pid, nil, err) } - provider, err := s.provider(r.Context(), r, pid) + provider, err := s.provider(ctx, r, pid) if err != nil { return s.handleError(w, r, f, pid, nil, err) } - c, err := provider.OAuth2(r.Context()) + c, err := provider.OAuth2(ctx) if err != nil { return s.handleError(w, r, f, pid, nil, err) } - req, err := s.validateFlow(r.Context(), r, f.ID) + req, err := s.validateFlow(ctx, r, f.ID) if err != nil { return s.handleError(w, r, f, pid, nil, err) } @@ -207,10 +221,10 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat } state := generateState(f.ID.String()) - if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode { + if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode { state.setCode(code.InitCode) } - if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName, + if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName, continuity.WithPayload(&AuthCodeContainer{ State: state.String(), FlowID: f.ID.String(), @@ -321,9 +335,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r } var it string = idToken - var ( - cat, crt string - ) + var cat, crt string if token != nil { if idToken, ok := token.Extra("id_token").(string); ok { if it, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(idToken)); err != nil {