Skip to content

Commit

Permalink
feat: redirect to OIDC providers only once in registration flows
Browse files Browse the repository at this point in the history
test(e2e): ensure there is only one OIDC redirect

Co-authored-by: Jakub Fijałkowski <[email protected]>
  • Loading branch information
2 people authored and David-Wobrock committed Dec 28, 2024
1 parent 74ae377 commit b8fd734
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 4 deletions.
40 changes: 40 additions & 0 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config[[]byte, []byte]{

type MetadataType string

type OIDCProviderData struct {
Provider string `json:"provider"`
Tokens *identity.CredentialsOIDCEncryptedTokens `json:"tokens"`
Claims Claims `json:"claims"`
}

type VerifiedAddress struct {
Value string `json:"value"`
Via identity.VerifiableAddressType `json:"via"`
Expand All @@ -53,6 +59,8 @@ const (

PublicMetadata MetadataType = "identity.metadata_public"
AdminMetadata MetadataType = "identity.metadata_admin"

InternalContextKeyProviderData = "provider_data"
)

func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) {
Expand Down Expand Up @@ -216,6 +224,27 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return errors.WithStack(flow.ErrCompletedByStrategy)
}

providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)
if oidcProviderData := gjson.GetBytes(f.InternalContext, providerDataKey); oidcProviderData.IsObject() {
var providerData OIDCProviderData
if err = json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil {
return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %s", err)))
}
if pid != providerData.Provider {
return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider)))
}
container := &AuthCodeContainer{
FlowID: f.ID.String(),
Traits: p.Traits,
TransientPayload: f.TransientPayload,
}
_, err = s.processRegistration(ctx, w, r, f, providerData.Tokens, &providerData.Claims, provider, container)
if err != nil {
return s.handleError(ctx, w, r, f, pid, container.Traits, err)
}
return errors.WithStack(flow.ErrCompletedByStrategy)
}

state, pkce, err := s.GenerateState(ctx, provider, f.ID)
if err != nil {
return s.handleError(ctx, w, r, f, pid, nil, err)
Expand Down Expand Up @@ -313,6 +342,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
return nil, nil
}

providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)
if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData {
if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Tokens: token, Claims: *claims}); err == nil {
rf.InternalContext = internalContext
}
}

fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(ctx)), fetcher.WithCache(jsonnetCache, 60*time.Minute))
jsonnetMapperSnippet, err := fetch.FetchContext(ctx, provider.Config().Mapper)
if err != nil {
Expand Down Expand Up @@ -351,6 +387,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err)
}

if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil {
rf.InternalContext = internalContext
}

return nil, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ context("Social Sign Up Successes", () => {

cy.triggerOidc(app)

// Email verification, for API call.
if (app === "react") {
cy.url().should("contain", "verification")
cy.getVerificationCodeFromEmail(email).then((code) => {
cy.get("input[name=code]").type(code)
cy.get("button[name=method][value=code]").click()
})
cy.get('[data-testid="ui/message/1080002"]').should(
"have.text",
"You successfully verified your email address.",
)
cy.get("[data-testid='node/anchor/continue']").click()
}

// Connected.
cy.location("pathname").should((loc) => {
expect(loc).to.be.oneOf(["/welcome", "/", "/sessions"])
})
Expand All @@ -103,6 +118,52 @@ context("Social Sign Up Successes", () => {
})
})

it("should redirect to oidc provider only once", () => {
const email = gen.email()

cy.registerOidc({
app,
email,
expectSession: false,
route: registration,
})

cy.get(appPrefix(app) + '[name="traits.email"]').should(
"have.value",
email,
)

cy.get('[name="traits.consent"][type="checkbox"]')
.siblings("label")
.click()
cy.get('[name="traits.newsletter"][type="checkbox"]')
.siblings("label")
.click()
cy.get('[name="traits.website"]').type(website)

cy.intercept("GET", "http://*/oauth2/auth*", {
forceNetworkError: true,
}).as("additionalRedirect")

cy.triggerOidc(app)

cy.get("@additionalRedirect").should("not.exist")

cy.location("pathname").should((loc) => {
expect(loc).to.be.oneOf([
"/welcome",
"/",
"/sessions",
"/verification",
])
})

cy.getSession().should((session) => {
shouldSession(email)(session)
expect(session.identity.traits.consent).to.equal(true)
})
})

it("should pass transient_payload to webhook", () => {
testFlowWebhook(
(hooks) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,25 @@ context("Social Sign In Settings Success", () => {
cy.get("#accept").click()

cy.get('input[name="traits.website"]').clear().type(website)

cy.intercept({
url: "http://localhost:4433/self-service/registration*",
query: { flow: "*" },
}).as("registrationCall")
cy.triggerOidc(app, "hydra")

cy.get('[data-testid="ui/message/1010016"]').should(
"contain.text",
"as another way to sign in.",
)
if (app === "react") {
cy.wait("@registrationCall").should((intercept) => {
expect(intercept.response.body.ui.messages[0].text).contain(
"as another way to sign in.",
)
})
} else {
cy.get('[data-testid="ui/message/1010016"]').should(
"contain.text",
"as another way to sign in.",
)
}

cy.noSession()
}
Expand Down

0 comments on commit b8fd734

Please sign in to comment.