Skip to content

Commit

Permalink
refactor: only update strategies order in test
Browse files Browse the repository at this point in the history
  • Loading branch information
Benehiko committed Oct 11, 2023
1 parent 1575c1d commit c582928
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 19 deletions.
4 changes: 2 additions & 2 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ func (m *RegistryDefault) selfServiceStrategies() []any {
} else {
// Construct the default list of strategies
m.selfserviceStrategies = []any{
oidc.NewStrategy(m),
password.NewStrategy(m),
oidc.NewStrategy(m),
profile.NewStrategy(m),
code.NewStrategy(m),
link.NewStrategy(m),
Expand Down Expand Up @@ -680,7 +680,7 @@ func (m *RegistryDefault) Init(ctx context.Context, ctxer contextx.Contextualize
m.Logger().WithError(err).Warnf("Unable to open database, retrying.")
return errors.WithStack(err)
}
p, err := sql.NewPersister(ctx, m, c, o.extraMigrations...)
p, err := sql.NewPersister(ctx, m, c, sql.WithExtraMigrations(o.extraMigrations...), sql.WithDisabledLogging(o.disableMigrationLogging))
if err != nil {
m.Logger().WithError(err).Warnf("Unable to initialize persister, retrying.")
return err
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ func TestDefaultRegistry_AllStrategies(t *testing.T) {
_, reg := internal.NewVeryFastRegistryWithoutDB(t)

t.Run("case=all login strategies", func(t *testing.T) {
expects := []string{"oidc", "password", "code", "totp", "webauthn", "lookup_secret"}
expects := []string{"password", "oidc", "code", "totp", "webauthn", "lookup_secret"}
s := reg.AllLoginStrategies()
require.Len(t, s, len(expects))
for k, e := range expects {
Expand Down
10 changes: 10 additions & 0 deletions selfservice/flow/registration/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/strategy/oidc"
"github.com/ory/kratos/selfservice/strategy/password"
"github.com/ory/kratos/x"
)

Expand Down Expand Up @@ -382,11 +383,20 @@ func TestGetFlow(t *testing.T) {
}

// TODO(Benehiko): this test will be updated when the `oidc` strategy is fixed.
// the OIDC strategy incorrectly assumes that is should continue if no
// method is specified but the provider is set.
func TestMultipleStrategies(t *testing.T) {
t.Logf("This test has been set up to validate the current incorrect `oidc` behaviour. When submitting the form, the `oidc` strategy is executed first, even if the method is set to `password`.")

ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)

// we need to replicate the oidc strategy before the password strategy
reg.WithSelfserviceStrategies(t, []any{
oidc.NewStrategy(reg),
password.NewStrategy(reg),
})

conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationEnabled, true)
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/registration.schema.json")
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypePassword),
Expand Down
43 changes: 35 additions & 8 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"encoding/json"
"net/http"
"strings"
"time"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -181,13 +182,16 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo
}

func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ uuid.UUID) (i *identity.Identity, err error) {
ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Login")
defer span.End()

if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil {
return nil, err
}

var p UpdateLoginFlowWithOidcMethod
if err := s.newLinkDecoder(&p, r); err != nil {
return nil, s.handleError(w, r, f, "", nil, errors.WithStack(herodot.ErrBadRequest.WithDebug(err.Error()).WithReasonf("Unable to parse HTTP form request: %s", err.Error())))
return nil, s.handleError(w, r, f, "", nil, err)
}

f.IDToken = p.IDToken
Expand All @@ -198,21 +202,44 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, errors.WithStack(flow.ErrStrategyNotResponsible)
}

if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
if p.Method == "" {
// the user is sending a method that is not oidc, but the payload includes a provider
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
Warn("The payload includes a `provider` field but does not specify the `method` field. This is incorrect behavior and will be removed in the future.")
}

// This is a small check to ensure users do not encounter issues with the current "incorrect" oidc behavior.
// this will be removed in the future when the oidc method behavior is fixed.
if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" {
if pid != "" {
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
WithField("method", p.Method).
Warn("The payload includes a `provider` field but does not specify the `method` field or does not use the `oidc` method. This is incorrect behavior and will be removed in the future.")
}
return nil, errors.WithStack(flow.ErrStrategyNotResponsible)
}

// TODO(Benehiko): Change the following line to actually match the payload `method` field with the current strategy `oidc`.
// right now it matches itself so it will always be true
if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}

provider, err := s.provider(r.Context(), r, pid)
provider, err := s.provider(ctx, r, pid)
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}

c, err := provider.OAuth2(r.Context())
c, err := provider.OAuth2(ctx)
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}

req, err := s.validateFlow(r.Context(), r, f.ID)
req, err := s.validateFlow(ctx, r, f.ID)
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}
Expand All @@ -239,10 +266,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

state := generateState(f.ID.String())
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); hasCode {
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode {
state.setCode(code.InitCode)
}
if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName,
if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName,
continuity.WithPayload(&AuthCodeContainer{
State: state.String(),
FlowID: f.ID.String(),
Expand All @@ -253,7 +280,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

f.Active = s.ID()
if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil {
if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil {
return nil, s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error())))
}

Expand Down
22 changes: 14 additions & 8 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,23 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return errors.WithStack(flow.ErrStrategyNotResponsible)
}

if p.Method == "" {
// the user is sending a method that is not oidc, but the payload includes a provider
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
Warn("The payload includes a `provider` field but does not specify the `method` field. This is incorrect behavior and will be removed in the future.")
}

// This is a small check to ensure users do not encounter issues with the current "incorrect" oidc behavior.
// this will be removed in the future when the oidc method behavior is fixed.
if !strings.EqualFold(p.Method, s.SettingsStrategyID()) && p.Method != "" {
if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" {
// the user is sending a method that is not oidc, but the payload includes a provider
if p.Provider != "" {
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
WithField("method", p.Method).
Warn("The payload includes a `provider` field but does not specify the `method` field or does not use the `oidc` method. This is incorrect behavior and will be removed in the future.")
}
s.d.Audit().
WithRequest(r).
WithField("provider", p.Provider).
WithField("method", p.Method).
Warn("The payload includes a `provider` field but specifies a `method` other than `oidc`. This is incorrect behavior and will be removed in the future.")
return errors.WithStack(flow.ErrStrategyNotResponsible)
}

Expand Down

0 comments on commit c582928

Please sign in to comment.