From e16fed1f8563509aac30886386668bb85e6dc797 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Mon, 9 Oct 2023 16:22:42 +0200 Subject: [PATCH] fix: change ListIdentities to keyset pagination --- .vscode/settings.json | 6 + go.mod | 5 +- go.sum | 6 +- identity/handler.go | 48 ++++--- identity/handler_test.go | 82 +++++------ identity/identity.go | 9 ++ identity/pool.go | 9 +- identity/test/pool.go | 82 +++++------ internal/driver.go | 7 +- .../sql/identity/persister_identity.go | 129 +++++++++++------- persistence/sql/migratest/migration_test.go | 5 +- x/pagination.go | 42 ++++++ 12 files changed, 258 insertions(+), 172 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000000..9b853f315519 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "gopls": { + "formatting.gofumpt": true, + "formatting.local": "github.com/ory" + } +} diff --git a/go.mod b/go.mod index 3b98b3d283ac..d40791136ea9 100644 --- a/go.mod +++ b/go.mod @@ -77,7 +77,8 @@ require ( github.com/ory/jsonschema/v3 v3.0.8 github.com/ory/mail/v3 v3.0.0 github.com/ory/nosurf v1.2.7 - github.com/ory/x v0.0.590 + github.com/ory/x v0.0.591 + github.com/peterhellberg/link v1.2.0 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.4.0 @@ -92,7 +93,6 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.14.3 github.com/tidwall/sjson v1.2.5 - github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/urfave/negroni v1.0.0 github.com/zmb3/spotify/v2 v2.0.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 @@ -261,7 +261,6 @@ require ( github.com/openzipkin/zipkin-go v0.4.1 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.0.7 // indirect - github.com/peterhellberg/link v1.2.0 // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/pkg/profile v1.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 1493216aec55..3c3a6cbe2ef7 100644 --- a/go.sum +++ b/go.sum @@ -847,8 +847,8 @@ github.com/ory/nosurf v1.2.7 h1:YrHrbSensQyU6r6HT/V5+HPdVEgrOTMJiLoJABSBOp4= github.com/ory/nosurf v1.2.7/go.mod h1:d4L3ZBa7Amv55bqxCBtCs63wSlyaiCkWVl4vKf3OUxA= github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 h1:zm6sDvHy/U9XrGpixwHiuAwpp0Ock6khSVHkrv6lQQU= github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/ory/x v0.0.590 h1:t0+XlSlDw5pzZhdAxOB8uFp1Dp+MStPRTG8Nn/fm1PE= -github.com/ory/x v0.0.590/go.mod h1:ksLBEd6iW6czGpE6eNA0gCIxO1FFeqIxCZgsgwNrzMM= +github.com/ory/x v0.0.591 h1:a3hyQZIwokuRCeoPzMxbewY/y6C6r1NgX4Jn3csVZv0= +github.com/ory/x v0.0.591/go.mod h1:ksLBEd6iW6czGpE6eNA0gCIxO1FFeqIxCZgsgwNrzMM= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -1032,8 +1032,6 @@ github.com/timtadh/lexmachine v0.2.2 h1:g55RnjdYazm5wnKv59pwFcBJHOyvTPfDEoz21s4P github.com/timtadh/lexmachine v0.2.2/go.mod h1:GBJvD5OAfRn/gnp92zb9KTgHLB7akKyxmVivoYCcjQI= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= -github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= -github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ= github.com/toqueteos/webbrowser v1.2.0/go.mod h1:XWoZq4cyp9WeUeak7w7LXRUQf1F1ATJMir8RTqb4ayM= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= diff --git a/identity/handler.go b/identity/handler.go index c7505b0dae6c..b4eceae3a515 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -10,7 +10,9 @@ import ( "net/http" "time" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/pagination/migrationpagination" + "github.com/ory/x/pagination/pagepagination" "github.com/ory/x/sqlcon" "github.com/ory/kratos/hash" @@ -169,32 +171,45 @@ type listIdentitiesParameters struct { // 200: listIdentities // default: errorGeneric func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - page, itemsPerPage := x.ParsePagination(r) - - params := ListIdentityParameters{ - Expand: ExpandDefault, - Page: page, - PerPage: itemsPerPage, - CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"), - CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"), + var ( + err error + params = ListIdentityParameters{ + Expand: ExpandDefault, + CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"), + CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"), + } + ) + if params.CredentialsIdentifier != "" && params.CredentialsIdentifierSimilar != "" { + h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithReason("Cannot pass both credentials_identifier and preview_credentials_identifier_similar.")) + return } - if params.CredentialsIdentifier != "" { + if params.CredentialsIdentifier != "" || params.CredentialsIdentifierSimilar != "" { params.Expand = ExpandEverything } + params.KeySetPagination, params.PagePagination, err = x.ParseKeysetOrPagePagination(r) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } - is, err := h.r.IdentityPool().ListIdentities(r.Context(), params) + is, nextPage, err := h.r.IdentityPool().ListIdentities(r.Context(), params) if err != nil { h.r.Writer().WriteError(w, r, err) return } - total := int64(len(is)) - if params.CredentialsIdentifier == "" { - total, err = h.r.IdentityPool().CountIdentities(r.Context()) - if err != nil { - h.r.Writer().WriteError(w, r, err) - return + if params.PagePagination != nil { + total := int64(len(is)) + if params.CredentialsIdentifier == "" { + total, err = h.r.IdentityPool().CountIdentities(r.Context()) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } } + pagepagination.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, params.PagePagination.Page, params.PagePagination.ItemsPerPage) + } else { + keysetpagination.Header(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), nextPage) } // Identities using the marshaler for including metadata_admin @@ -203,7 +218,6 @@ func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Para isam[i] = WithCredentialsMetadataAndAdminMetadataInJSON(identity) } - migrationpagination.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, itemsPerPage) h.r.Writer().Write(w, r, isam) } diff --git a/identity/handler_test.go b/identity/handler_test.go index 934ca62946b0..f237f6643301 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -20,12 +20,11 @@ import ( "github.com/bxcodec/faker/v3" "github.com/gofrs/uuid" + "github.com/peterhellberg/link" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - "github.com/tomnomnom/linkheader" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/hash" "github.com/ory/kratos/identity" @@ -683,7 +682,6 @@ func TestHandler(t *testing.T) { req := &identity.BatchPatchIdentitiesBody{Identities: validPatches} send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) }) - }) t.Run("case=ignores create nil bodies", func(t *testing.T) { @@ -787,7 +785,6 @@ func TestHandler(t *testing.T) { for name, ts := range map[string]*httptest.Server{"public": publicTS, "admin": adminTS} { t.Run("endpoint="+name, func(t *testing.T) { - email := "UPPER" + x.NewUUID().String() + "@ory.sh" lowercaseEmail := strings.ToLower(email) var cr identity.CreateIdentityBody @@ -820,7 +817,6 @@ func TestHandler(t *testing.T) { assert.EqualValues(t, identity.StateActive, res.Get("state").String(), "%s", res.Raw) }) } - }) t.Run("case=PATCH update should not persist if schema id is invalid", func(t *testing.T) { @@ -1490,53 +1486,46 @@ func TestHandler(t *testing.T) { perPage := perPage t.Run(fmt.Sprintf("perPage=%d", perPage), func(t *testing.T) { t.Parallel() - body, res := getFull(t, ts, fmt.Sprintf("/identities?per_page=%d", perPage), http.StatusOK) + body, _ := getFull(t, ts, fmt.Sprintf("/identities?per_page=%d", perPage), http.StatusOK) assert.Len(t, body.Array(), perPage) - assert.Equal(t, strconv.Itoa(count), res.Header.Get("X-Total-Count")) }) } t.Run("iterate over next page", func(t *testing.T) { perPage := 10 - pagePath := fmt.Sprintf("/identities?per_page=%d", perPage) - run := func(t *testing.T, path string, knownIDs map[string]struct{}) (isLast bool, parsed *url.URL) { - var err error + run := func(t *testing.T, path string, knownIDs map[string]struct{}) (next *url.URL, res *http.Response) { t.Logf("Requesting %s", path) body, res := getFull(t, ts, path, http.StatusOK) - for _, link := range linkheader.Parse(res.Header.Get("Link")) { - if link.Rel != "next" { - isLast = true - continue - } - parsed, err = url.Parse(link.URL) - require.NoError(t, err) - isLast = false - break - } - for _, i := range body.Array() { - assert.NotContains(t, knownIDs, i.Get("id").String()) - knownIDs[i.Get("id").String()] = struct{}{} + id := i.Get("id").String() + _, seen := knownIDs[id] + require.Falsef(t, seen, "ID %s was previously returned from the API", id) + knownIDs[id] = struct{}{} + } + links := link.ParseResponse(res) + if link, ok := links["next"]; ok { + next, err := url.Parse(link.URI) + require.NoError(t, err) + return next, res } - return isLast, parsed + return nil, res } t.Run("using token pagination", func(t *testing.T) { knownIDs := make(map[string]struct{}) - var isLast bool var pages int - path := pagePath - for !isLast { - t.Run(fmt.Sprintf("page=%d", pages), func(t *testing.T) { - var res *url.URL - pages++ - isLast, res = run(t, path, knownIDs) - if isLast { - return - } - path = fmt.Sprintf("/identities?page_size=%s&page_token=%s", res.Query().Get("page_size"), res.Query().Get("page_token")) - }) + path := fmt.Sprintf("/identities?page_size=%d", perPage) + for { + pages++ + next, res := run(t, path, knownIDs) + assert.NotContains(t, res.Header, "X-Total-Count", "not supported in token pagination") + if next == nil { + break + } + assert.NotContains(t, next.Query(), "page") + assert.NotContains(t, next.Query(), "per_page") + path = next.Path + "?" + next.Query().Encode() } assert.Len(t, knownIDs, count) @@ -1545,19 +1534,16 @@ func TestHandler(t *testing.T) { t.Run("using page pagination", func(t *testing.T) { knownIDs := make(map[string]struct{}) - var isLast bool var pages int - path := pagePath - for !isLast { - t.Run(fmt.Sprintf("page=%d", pages), func(t *testing.T) { - var res *url.URL - pages++ - isLast, res = run(t, path, knownIDs) - if isLast { - return - } - path = fmt.Sprintf("/identities?per_page=%s&page=%s", res.Query().Get("per_page"), res.Query().Get("page")) - }) + path := fmt.Sprintf("/identities?page=0&per_page=%d", perPage) + for { + pages++ + next, res := run(t, path, knownIDs) + assert.Equal(t, strconv.Itoa(count), res.Header.Get("X-Total-Count")) + if next == nil { + break + } + path = next.Path + "?" + next.Query().Encode() } assert.Len(t, knownIDs, count) diff --git a/identity/identity.go b/identity/identity.go index 064c21121fc1..85030bd41e92 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/cipher" "github.com/ory/herodot" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/sqlxx" "github.com/ory/kratos/driver/config" @@ -132,6 +133,14 @@ type Identity struct { OrganizationID uuid.NullUUID `json:"organization_id,omitempty" faker:"-" db:"organization_id"` } +func (i *Identity) PageToken() keysetpagination.PageToken { + return keysetpagination.StringPageToken(i.ID.String()) +} + +func DefaultPageToken() keysetpagination.PageToken { + return keysetpagination.StringPageToken(uuid.Nil.String()) +} + // Traits represent an identity's traits. The identity is able to create, modify, and delete traits // in a self-service manner. The input will always be validated against the JSON Schema defined // in `schema_url`. diff --git a/identity/pool.go b/identity/pool.go index 1cf4888e52ae..61665f056abd 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -6,6 +6,8 @@ package identity import ( "context" + "github.com/ory/kratos/x" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/sqlxx" "github.com/gofrs/uuid" @@ -16,13 +18,14 @@ type ( Expand Expandables CredentialsIdentifier string CredentialsIdentifierSimilar string - Page int - PerPage int + KeySetPagination []keysetpagination.Option + // DEPRECATED + PagePagination *x.Page } Pool interface { // ListIdentities lists all identities in the store given the page and itemsPerPage. - ListIdentities(ctx context.Context, params ListIdentityParameters) ([]Identity, error) + ListIdentities(ctx context.Context, params ListIdentityParameters) ([]Identity, *keysetpagination.Paginator, error) // CountIdentities counts the number of identities in the store. CountIdentities(ctx context.Context) (int64, error) diff --git a/identity/test/pool.go b/identity/test/pool.go index 5adf9fb364b6..b51a88f242ca 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -13,33 +13,25 @@ import ( "testing" "time" - "github.com/ory/x/randx" - - "github.com/tidwall/gjson" - - "github.com/ory/x/assertx" - - "github.com/ory/kratos/internal/testhelpers" - - "github.com/ory/kratos/identity" - "github.com/ory/kratos/persistence" - "github.com/bxcodec/faker/v3" - - "github.com/ory/x/sqlxx" - - "github.com/ory/x/errorsx" - "github.com/ory/x/sqlcon" - "github.com/ory/x/urlx" - - "github.com/ory/kratos/schema" - "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/identity" + "github.com/ory/kratos/internal/testhelpers" + "github.com/ory/kratos/persistence" + "github.com/ory/kratos/schema" "github.com/ory/kratos/x" + "github.com/ory/x/assertx" + "github.com/ory/x/errorsx" + "github.com/ory/x/pagination/keysetpagination" + "github.com/ory/x/randx" + "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" + "github.com/ory/x/urlx" ) func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, m *identity.Manager) func(t *testing.T) { @@ -88,7 +80,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, }) t.Run("case=expand", func(t *testing.T) { - require.NoError(t, p.GetConnection(ctx).RawQuery("DELETE FROM identities WHERE nid = ?", nid).Exec()) t.Cleanup(func() { require.NoError(t, p.GetConnection(ctx).RawQuery("DELETE FROM identities WHERE nid = ?", nid).Exec()) @@ -120,12 +111,20 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, assertion(t, actual) }) - t.Run("list", func(t *testing.T) { - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: expand, Page: 0, PerPage: 10}) + t.Run("list/page-pagination", func(t *testing.T) { + actual, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: expand, PagePagination: &x.Page{Page: 0, ItemsPerPage: 10}}) require.NoError(t, err) require.Len(t, actual, 1) assertion(t, &actual[0]) }) + + t.Run("list/token-pagination", func(t *testing.T) { + actual, next, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: expand, KeySetPagination: []keysetpagination.Option{keysetpagination.WithSize(10)}}) + require.NoError(t, err) + require.Len(t, actual, 1) + require.True(t, next.IsLast()) + assertion(t, &actual[0]) + }) } t.Run("expand=nothing", func(t *testing.T) { @@ -170,7 +169,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, t.Run("expand=default", func(t *testing.T) { runner(t, identity.ExpandDefault, func(t *testing.T, actual *identity.Identity) { - assert.Empty(t, actual.Credentials) require.Len(t, actual.RecoveryAddresses, 1) @@ -183,7 +181,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, t.Run("expand=everything", func(t *testing.T) { runner(t, identity.ExpandEverything, func(t *testing.T, actual *identity.Identity) { - require.Len(t, actual.Credentials, 2) assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"}) @@ -235,7 +232,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, }) var createdIDs []uuid.UUID - var passwordIdentity = func(schemaID string, credentialsID string) *identity.Identity { + passwordIdentity := func(schemaID string, credentialsID string) *identity.Identity { i := identity.NewIdentity(schemaID) i.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{ Type: identity.CredentialsTypePassword, Identifiers: []string{credentialsID}, @@ -244,7 +241,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, return i } - var webAuthnIdentity = func(schemaID string, credentialsID string) *identity.Identity { + webAuthnIdentity := func(schemaID string, credentialsID string) *identity.Identity { i := identity.NewIdentity(schemaID) i.SetCredentials(identity.CredentialsTypeWebAuthn, identity.Credentials{ Type: identity.CredentialsTypeWebAuthn, Identifiers: []string{credentialsID}, @@ -253,7 +250,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, return i } - var oidcIdentity = func(schemaID string, credentialsID string) *identity.Identity { + oidcIdentity := func(schemaID string, credentialsID string) *identity.Identity { i := identity.NewIdentity(schemaID) i.SetCredentials(identity.CredentialsTypeOIDC, identity.Credentials{ Type: identity.CredentialsTypeOIDC, Identifiers: []string{credentialsID}, @@ -262,7 +259,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, return i } - var assertEqual = func(t *testing.T, expected, actual *identity.Identity) { + assertEqual := func(t *testing.T, expected, actual *identity.Identity) { assert.Empty(t, actual.Credentials) require.Equal(t, expected.Traits, actual.Traits) require.Equal(t, expected.ID, actual.ID) @@ -634,9 +631,9 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, }) t.Run("case=list", func(t *testing.T) { - is, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, Page: 0, PerPage: 25}) + is, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault}) require.NoError(t, err) - require.NotZero(t, len(is)) + require.NotEmpty(t, is) require.Len(t, is, len(createdIDs)) for _, id := range createdIDs { var found bool @@ -653,7 +650,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, t.Run("no results on other network", func(t *testing.T) { _, p := testhelpers.NewNetwork(t, ctx, p) - is, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, Page: 0, PerPage: 25}) + is, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault}) require.NoError(t, err) assert.Len(t, is, 0) }) @@ -683,7 +680,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, create.SetCredentials(identity.CredentialsTypeWebAuthn, identity.Credentials{Type: identity.CredentialsTypeWebAuthn, Identifiers: []string{"find-identity-by-identifier-common@ory.sh"}, Config: sqlxx.JSONRawMessage(`{}`)}) require.NoError(t, p.CreateIdentity(ctx, create)) - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ Expand: identity.ExpandEverything, }) require.NoError(t, err) @@ -694,7 +691,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, identity.CredentialsTypeWebAuthn, } { t.Run(ct.String(), func(t *testing.T) { - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ // Match is normalized CredentialsIdentifier: expectedIdentifiers[c], }) @@ -707,7 +704,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, } t.Run("similarity search", func(t *testing.T) { - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifierSimilar: "find-identity-by-identifier", Expand: identity.ExpandCredentials, }) @@ -731,40 +728,44 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, }) t.Run("only webauthn and password", func(t *testing.T) { - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, next, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifier: "find-identity-by-identifier-oidc@ory.sh", Expand: identity.ExpandEverything, }) require.NoError(t, err) assert.Len(t, actual, 0) + assert.True(t, next.IsLast()) }) t.Run("one result set even if multiple matches", func(t *testing.T) { - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, next, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifier: "find-identity-by-identifier-common@ory.sh", Expand: identity.ExpandEverything, }) require.NoError(t, err) assert.Len(t, actual, 1) + assert.True(t, next.IsLast()) }) t.Run("non existing identifier", func(t *testing.T) { - actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, next, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifier: "find-identity-by-identifier-non-existing@ory.sh", Expand: identity.ExpandEverything, }) require.NoError(t, err) assert.Len(t, actual, 0) + assert.True(t, next.IsLast()) }) t.Run("not if on another network", func(t *testing.T) { _, on := testhelpers.NewNetwork(t, ctx, p) - actual, err := on.ListIdentities(ctx, identity.ListIdentityParameters{ + actual, next, err := on.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifier: expectedIdentifiers[0], Expand: identity.ExpandEverything, }) require.NoError(t, err) assert.Len(t, actual, 0) + assert.True(t, next.IsLast()) }) }) @@ -1243,7 +1244,8 @@ func NewTestIdentity(numAddresses int, prefix string, i int) *identity.Identity id.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{ Type: identity.CredentialsTypePassword, Identifiers: []string{traits.Username}, - Config: sqlxx.JSONRawMessage(`{}`)}) + Config: sqlxx.JSONRawMessage(`{}`), + }) return id } diff --git a/internal/driver.go b/internal/driver.go index 739afe03dc3a..d17f2b0ca62b 100644 --- a/internal/driver.go +++ b/internal/driver.go @@ -8,6 +8,8 @@ import ( "os" "testing" + "github.com/sirupsen/logrus" + "github.com/ory/x/contextx" "github.com/ory/x/jsonnetsecure" @@ -37,7 +39,7 @@ func NewConfigurationWithDefaults(t testing.TB) *config.Config { c := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.WithValues(map[string]interface{}{ - "log.level": "trace", + "log.level": "error", config.ViperKeyDSN: dbal.NewSQLiteTestDatabase(t), config.ViperKeyHasherArgon2ConfigMemory: 16384, config.ViperKeyHasherArgon2ConfigIterations: 1, @@ -77,8 +79,7 @@ func NewRegistryDefaultWithDSN(t testing.TB, dsn string) (*config.Config, *drive ctx := context.Background() c := NewConfigurationWithDefaults(t) c.MustSet(ctx, config.ViperKeyDSN, stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t))) - - reg, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("", "")) + reg, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("", "", logrusx.ForceLevel(logrus.ErrorLevel))) require.NoError(t, err) reg.Config().MustSet(ctx, "dev", true) require.NoError(t, reg.Init(context.Background(), &contextx.Default{}, driver.SkipNetworkInit, driver.WithDisabledMigrationLogging())) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 5a31d7bd6f3d..fadac678951f 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -12,19 +12,15 @@ import ( "sync" "time" - "github.com/ory/x/contextx" - "github.com/ory/x/pointerx" - "github.com/ory/x/popx" - - "golang.org/x/sync/errgroup" - + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" - "github.com/ory/x/otelx" - + "github.com/ory/herodot" "github.com/ory/jsonschema/v3" - "github.com/ory/x/sqlxx" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/identity" "github.com/ory/kratos/otp" @@ -32,14 +28,14 @@ import ( "github.com/ory/kratos/persistence/sql/update" "github.com/ory/kratos/schema" "github.com/ory/kratos/x" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/pkg/errors" - - "github.com/ory/herodot" + "github.com/ory/x/contextx" "github.com/ory/x/errorsx" + "github.com/ory/x/otelx" + "github.com/ory/x/pagination/keysetpagination" + "github.com/ory/x/pointerx" + "github.com/ory/x/popx" "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" ) var ( @@ -649,17 +645,35 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma return credentialsPerIdentity, nil } -func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities") - defer otelx.End(span, &err) - - span.SetAttributes( - attribute.Int("page", params.Page), - attribute.Int("per_page", params.PerPage), +func paginationAttributes(params *identity.ListIdentityParameters, paginator *keysetpagination.Paginator) []attribute.KeyValue { + attrs := []attribute.KeyValue{ attribute.StringSlice("expand", params.Expand.ToEager()), attribute.Bool("use:credential_identifier_filter", params.CredentialsIdentifier != ""), - attribute.String("network.id", p.NetworkID(ctx).String()), - ) + attribute.Bool("use:credential_identifier_similar_filter", params.CredentialsIdentifierSimilar != ""), + } + if params.PagePagination != nil { + attrs = append(attrs, + attribute.Int("page", params.PagePagination.Page), + attribute.Int("per_page", params.PagePagination.ItemsPerPage)) + } else { + attrs = append(attrs, + attribute.String("page_token", paginator.Token().Encode()), + attribute.Int("page_size", paginator.Size())) + } + return attrs +} + +func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (_ []identity.Identity, nextPage *keysetpagination.Paginator, err error) { + paginator := keysetpagination.GetPaginator(append( + params.KeySetPagination, + keysetpagination.WithDefaultToken(identity.DefaultPageToken()), + keysetpagination.WithDefaultSize(250), + keysetpagination.WithColumn("id", "ASC"))...) + + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities", trace.WithAttributes(append( + paginationAttributes(¶ms, paginator), + attribute.String("network.id", p.NetworkID(ctx).String()))...)) + defer otelx.End(span, &err) is := make([]identity.Identity, 0) @@ -667,8 +681,15 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. nid := p.NetworkID(ctx) joins := "" - wheres := "" - args := []any{nid} + wheres := "identities.nid = ? AND identities.id > ?" + args := []any{nid, paginator.Token().Encode()} + limit := fmt.Sprintf("LIMIT %d", paginator.Size()+1) + if params.PagePagination != nil { + wheres = "identities.nid = ?" + args = []any{nid} + paginator := pop.NewPaginator(params.PagePagination.Page+1, params.PagePagination.ItemsPerPage) + limit = fmt.Sprintf("LIMIT %d OFFSET %d", paginator.PerPage, paginator.Offset) + } identifier := params.CredentialsIdentifier identifierOperator := "=" if identifier == "" && params.CredentialsIdentifierSimilar != "" { @@ -688,31 +709,35 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. identifier = NormalizeIdentifier(identity.CredentialsTypePassword, identifier) joins = ` -INNER JOIN identity_credentials ic ON ic.identity_id = identities.id -INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id -INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id` - wheres = fmt.Sprintf(` -AND (ic.nid = ? AND ici.nid = ? AND ici.identifier %s ?) -AND ict.name IN (?, ?)`, identifierOperator) + INNER JOIN identity_credentials ic ON ic.identity_id = identities.id + INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id + INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id` + wheres += fmt.Sprintf(` + AND (ic.nid = ? AND ici.nid = ? AND ici.identifier %s ?) + AND ict.name IN (?, ?)`, identifierOperator) args = append(args, nid, nid, identifier, identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword) } - // Follow up: add page token support here, will be easy. - paginator := pop.NewPaginator(params.Page+1, params.PerPage) + query := fmt.Sprintf(` + SELECT DISTINCT identities.* + FROM identities AS identities + %s + WHERE + %s + ORDER BY identities.id ASC + %s`, + joins, wheres, limit) - if err := con.RawQuery(fmt.Sprintf(`SELECT DISTINCT identities.* -FROM identities AS identities -%s -WHERE identities.nid = ? -%s -ORDER BY identities.id DESC -LIMIT %d -OFFSET %d`, joins, wheres, paginator.PerPage, paginator.Offset), args...).All(&is); err != nil { - return nil, sqlcon.HandleError(err) + if err := con.RawQuery(query, args...).All(&is); err != nil { + return nil, nil, sqlcon.HandleError(err) + } + + if params.PagePagination == nil { + is, nextPage = keysetpagination.Result(is, paginator) } if len(is) == 0 { - return is, nil + return is, nextPage, nil } identitiesByID := make(map[uuid.UUID]*identity.Identity, len(is)) @@ -729,7 +754,7 @@ OFFSET %d`, joins, wheres, paginator.PerPage, paginator.Offset), args...).All(&i Where{"identity_credentials.nid = ?", []any{nid}}, Where{"identity_credentials.identity_id IN (?)", identityIDs}) if err != nil { - return nil, err + return nil, nil, err } for k := range is { is[k].Credentials = creds[is[k].ID] @@ -737,7 +762,7 @@ OFFSET %d`, joins, wheres, paginator.PerPage, paginator.Offset), args...).All(&i case identity.ExpandFieldVerifiableAddresses: addrs := make([]identity.VerifiableAddress, 0) if err := con.Where("nid = ?", nid).Where("identity_id IN (?)", identityIDs).Order("id").All(&addrs); err != nil { - return nil, sqlcon.HandleError(err) + return nil, nil, sqlcon.HandleError(err) } for _, addr := range addrs { identitiesByID[addr.IdentityID].VerifiableAddresses = append(identitiesByID[addr.IdentityID].VerifiableAddresses, addr) @@ -745,7 +770,7 @@ OFFSET %d`, joins, wheres, paginator.PerPage, paginator.Offset), args...).All(&i case identity.ExpandFieldRecoveryAddresses: addrs := make([]identity.RecoveryAddress, 0) if err := con.Where("nid = ?", nid).Where("identity_id IN (?)", identityIDs).Order("id").All(&addrs); err != nil { - return nil, sqlcon.HandleError(err) + return nil, nil, sqlcon.HandleError(err) } for _, addr := range addrs { identitiesByID[addr.IdentityID].RecoveryAddresses = append(identitiesByID[addr.IdentityID].RecoveryAddresses, addr) @@ -761,23 +786,23 @@ OFFSET %d`, joins, wheres, paginator.PerPage, paginator.Offset), args...).All(&i i.SchemaURL = u } else { if err := p.InjectTraitsSchemaURL(ctx, i); err != nil { - return nil, err + return nil, nil, err } schemaCache[i.SchemaID] = i.SchemaURL } if err := i.Validate(); err != nil { - return nil, err + return nil, nil, err } if err := identity.UpgradeCredentials(i); err != nil { - return nil, err + return nil, nil, err } is[k] = *i } - return is, nil + return is, nextPage, nil } func (p *IdentityPersister) UpdateIdentity(ctx context.Context, i *identity.Identity) (err error) { diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 798afd54dc79..afa101987742 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/servicelocatorx" "github.com/ory/kratos/identity" @@ -164,7 +165,7 @@ func testDatabase(t *testing.T, db string, c *pop.Connection) { defer wg.Done() t.Parallel() - ids, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) + ids, _, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, KeySetPagination: []keysetpagination.Option{keysetpagination.WithSize(1000)}}) require.NoError(t, err) require.NotEmpty(t, ids) @@ -192,7 +193,7 @@ func testDatabase(t *testing.T, db string, c *pop.Connection) { defer wg.Done() t.Parallel() - ids, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandNothing, Page: 0, PerPage: 1000}) + ids, _, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandNothing, KeySetPagination: []keysetpagination.Option{keysetpagination.WithSize(1000)}}) require.NoError(t, err) require.NotEmpty(t, ids) diff --git a/x/pagination.go b/x/pagination.go index c19ed06450e0..e0d6e18b01a4 100644 --- a/x/pagination.go +++ b/x/pagination.go @@ -7,7 +7,10 @@ import ( "net/http" "net/url" + "github.com/ory/herodot" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/pagination/migrationpagination" + "github.com/ory/x/pagination/pagepagination" ) // ParsePagination parses limit and page from *http.Request with given limits and defaults. @@ -18,3 +21,42 @@ func ParsePagination(r *http.Request) (page, itemsPerPage int) { func PaginationHeader(w http.ResponseWriter, u *url.URL, total int64, page, itemsPerPage int) { migrationpagination.PaginationHeader(w, u, total, page, itemsPerPage) } + +type Page struct { + Page, ItemsPerPage int +} + +var PagePaginationLimit = 1000 + +func ParseKeysetOrPagePagination(r *http.Request) ([]keysetpagination.Option, *Page, error) { + q := r.URL.Query() + // If we have any new-style pagination parameters, use those and ignore the rest. + if q.Has("page_token") || q.Has("page_size") { + keyset, err := keysetpagination.Parse(q, keysetpagination.NewStringPageToken) + if err != nil { + return nil, nil, herodot.ErrBadRequest.WithReason(err.Error()) + } + return keyset, nil, nil + } + // allow fallback page pagination with upper limit + if q.Has("page") { + paginator := pagepagination.PagePaginator{MaxItems: 500, DefaultItems: 250} + page, perPage := paginator.ParsePagination(r) + if page*perPage > PagePaginationLimit { + return nil, nil, herodot.ErrBadRequest.WithReasonf("Legacy pagination is not supported for enumerating over %d items. Please switch to using page_token and page_size.", PagePaginationLimit) + } + return nil, &Page{page, perPage}, nil + } + // Allow passing per_page instead of page_size if only the former is set... + if q.Has("per_page") && !q.Has("page_size") { + q.Set("page_size", q.Get("per_page")) + q.Del("per_page") + r.URL.RawQuery = q.Encode() + } + // ... and defaul to keyset pagination + keyset, err := keysetpagination.Parse(q, keysetpagination.NewStringPageToken) + if err != nil { + return nil, nil, herodot.ErrBadRequest.WithReason(err.Error()) + } + return keyset, nil, nil +}