diff --git a/internal/http/types/generic.go b/internal/http/types/generic.go index 8abca225c..115d0230b 100644 --- a/internal/http/types/generic.go +++ b/internal/http/types/generic.go @@ -1,5 +1,5 @@ -// Copyright 2024 Canonical Ltd -// SPDX-License-Identifier: AGPL +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 package types @@ -21,7 +21,7 @@ type Response struct { // NavigationTokens are parameters used to navigate `list` result endpoints type NavigationTokens struct { - // deserialization only + // serialization only Next string `json:"next,omitempty"` Prev string `json:"prev,omitempty"` } @@ -29,11 +29,11 @@ type NavigationTokens struct { // Pagination object is used to serialize and deserialize pagination parameters // it will populate the `meta` part for the `Response` struct type Pagination struct { - PageToken string `json:"page_token,omitempty"` // serialization only + PageToken string `json:"page_token,omitempty"` // deserialization only Size int64 `json:"size"` Page int64 `json:"page"` // to be deprecated - // deserialization only + // serialization only NavigationTokens } @@ -51,6 +51,7 @@ func ParsePagination(q url.Values) *Pagination { p := NewPaginationWithDefaults() + // TODO @barco deprecate `page` if page, err := strconv.ParseInt(q.Get("page"), 10, 64); err == nil && page > 1 { p.Page = page } diff --git a/pkg/rules/handlers.go b/pkg/rules/handlers.go index 74cdac4b8..eb3d2c8d1 100644 --- a/pkg/rules/handlers.go +++ b/pkg/rules/handlers.go @@ -4,6 +4,7 @@ package rules import ( + "encoding/base64" "encoding/json" "fmt" "io" @@ -17,6 +18,8 @@ import ( oathkeeper "github.com/ory/oathkeeper-client-go" ) +const DEFAULT_OFFSET int64 = 0 + type API struct { apiKey string service ServiceInterface @@ -25,6 +28,10 @@ type API struct { logger logging.LoggerInterface } +type PageToken struct { + Offset int64 `json:"offset" validate:"required"` +} + func (a *API) RegisterEndpoints(mux *chi.Mux) { mux.Get("/api/v0/rules", a.handleList) mux.Get("/api/v0/rules/{id:.+}", a.handleDetail) @@ -45,11 +52,9 @@ func (a *API) handleList(w http.ResponseWriter, r *http.Request) { pagination := types.ParsePagination(r.URL.Query()) - if pagination.Page < 1 { - pagination.Page = 1 - } + offset := a.offsetDecode(pagination.PageToken) - rules, err := a.service.ListRules(r.Context(), pagination.Page, pagination.Size) + rules, err := a.service.ListRules(r.Context(), offset, pagination.Size) if err != nil { @@ -59,21 +64,71 @@ func (a *API) handleList(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(rr) + _ = json.NewEncoder(w).Encode(rr) return } w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode( + _ = json.NewEncoder(w).Encode( types.Response{ Data: rules, Message: "List of rules", Status: http.StatusOK, + Meta: &types.Pagination{ + NavigationTokens: types.NavigationTokens{ + Next: a.offsetTokenEncode(offset + pagination.Size), + Prev: a.offsetTokenEncode(offset - pagination.Size), + }, + Size: pagination.Size, + }, }, ) } +func (a *API) offsetTokenEncode(offset int64) string { + if offset < DEFAULT_OFFSET { + return "" + } + + pt := new(PageToken) + pt.Offset = offset + + token, err := json.Marshal(pt) + if err != nil { + a.logger.Warnf("bad page token encoding, defaulting to an empty one: %s", err) + return "" + } + + return base64.RawURLEncoding.EncodeToString(token) +} + +func (a *API) offsetDecode(pageToken string) int64 { + if pageToken == "" { + return DEFAULT_OFFSET + } + + pt := new(PageToken) + + rawPt, err := base64.RawURLEncoding.DecodeString(pageToken) + if err != nil { + a.logger.Warnf("bad page token encoding, defaulting to an empty one: %s", err) + return DEFAULT_OFFSET + } + + if err := json.Unmarshal(rawPt, pt); err != nil { + a.logger.Warnf("bad page token format, defaulting to an empty one: %s", err) + return DEFAULT_OFFSET + } + + if err != nil || pt.Offset < DEFAULT_OFFSET { + a.logger.Warnf("invalid offset, default to %d %s", DEFAULT_OFFSET, err) + return DEFAULT_OFFSET + } + + return pt.Offset +} + func (a *API) handleDetail(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/pkg/rules/handlers_test.go b/pkg/rules/handlers_test.go index 0e2446249..43f517f0d 100644 --- a/pkg/rules/handlers_test.go +++ b/pkg/rules/handlers_test.go @@ -82,12 +82,12 @@ func TestHandleListSuccess(t *testing.T) { }, } - var page int64 = 1 + var offset int64 = 0 var size int64 = 100 - mockService.EXPECT().ListRules(gomock.Any(), page, size).Return(serviceOutput, nil) + mockService.EXPECT().ListRules(gomock.Any(), offset, size).Return(serviceOutput, nil) - req := httptest.NewRequest(http.MethodGet, "/api/v0/rules?page=1&size=100", nil) + req := httptest.NewRequest(http.MethodGet, "/api/v0/rules?page_token=eyJvZmZzZXQiOjB9&size=100", nil) w := httptest.NewRecorder() mux := chi.NewMux() NewAPI(mockService, mockLogger).RegisterEndpoints(mux) @@ -130,10 +130,10 @@ func TestHandleListFailed(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockService := NewMockServiceInterface(ctrl) - var page int64 = 1 + var offset int64 = 0 var size int64 = 100 - mockService.EXPECT().ListRules(gomock.Any(), page, size).Return(nil, fmt.Errorf("mock_error")) + mockService.EXPECT().ListRules(gomock.Any(), offset, size).Return(nil, fmt.Errorf("mock_error")) req := httptest.NewRequest(http.MethodGet, "/api/v0/rules?page=0&offset=100", nil) w := httptest.NewRecorder() diff --git a/pkg/rules/service.go b/pkg/rules/service.go index 9c45df642..aa44b553f 100644 --- a/pkg/rules/service.go +++ b/pkg/rules/service.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical Ltd +// Copyright 2024 Canonical Ltd. // SPDX-License-Identifier: AGPL-3.0 package rules @@ -52,14 +52,11 @@ type Service struct { logger logging.LoggerInterface } -func (s *Service) ListRules(ctx context.Context, page, size int64) ([]oathkeeper.Rule, error) { +func (s *Service) ListRules(ctx context.Context, offset, size int64) ([]oathkeeper.Rule, error) { ctx, span := s.tracer.Start(ctx, "rules.Service.ListRules") defer span.End() - limit := size - offset := (page - 1) * size - - rules, _, err := s.oathkeeper.ListRules(ctx).Limit(limit).Offset(offset).Execute() + rules, _, err := s.oathkeeper.ListRules(ctx).Limit(size).Offset(offset).Execute() if err != nil { s.logger.Error(err.Error())