Skip to content

Commit

Permalink
Merge pull request #279 from canonical/IAM-776-cover-remaining-handlers
Browse files Browse the repository at this point in the history
IAM 776 Implement validation for all handlers
  • Loading branch information
BarcoMasile authored Apr 19, 2024
2 parents ed3c804 + 7e58651 commit 77fcdfb
Show file tree
Hide file tree
Showing 28 changed files with 1,532 additions and 73 deletions.
4 changes: 2 additions & 2 deletions internal/validation/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ import (

type payloadValidator struct{}

func (_ *payloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) {
func (p *payloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) {
e := mockValidationErrors()
if e == nil {
return ctx, nil, nil
}
return ctx, e, nil
}

func (_ *payloadValidator) NeedsValidation(r *http.Request) bool {
func (p *payloadValidator) NeedsValidation(r *http.Request) bool {
return true
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/clients/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ type API struct {
func (a *API) RegisterEndpoints(mux *chi.Mux) {
mux.Get("/api/v0/clients", a.ListClients)
mux.Post("/api/v0/clients", a.CreateClient)
mux.Get("/api/v0/clients/{id}", a.GetClient)
mux.Put("/api/v0/clients/{id}", a.UpdateClient)
mux.Delete("/api/v0/clients/{id}", a.DeleteClient)
mux.Get("/api/v0/clients/{id:.+}", a.GetClient)
mux.Put("/api/v0/clients/{id:.+}", a.UpdateClient)
mux.Delete("/api/v0/clients/{id:.+}", a.DeleteClient)
}

func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) {
err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator)
if err != nil {
a.logger.Fatalf("unexpected validatingFunc already registered for clients, %s", err)
a.logger.Fatalf("unexpected error while registering PayloadValidator for clients, %s", err)
}
}

Expand Down Expand Up @@ -202,7 +202,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API {
a.apiKey = "clients"

a.service = service
//a.payloadValidator = NewClientsPayloadValidator(a.apiKey)
a.payloadValidator = NewClientsPayloadValidator(a.apiKey)
a.logger = logger

return a
Expand Down
26 changes: 26 additions & 0 deletions pkg/clients/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

//go:generate mockgen -build_flags=--mod=mod -package clients -destination ./mock_logger.go -source=../../internal/logging/interfaces.go
//go:generate mockgen -build_flags=--mod=mod -package clients -destination ./mock_clients.go -source=./interfaces.go
//go:generate mockgen -build_flags=--mod=mod -package clients -destination ./mock_validation.go -source=../../internal/validation/registry.go

func TestHandleGetClientSuccess(t *testing.T) {
ctrl := gomock.NewController(t)
Expand Down Expand Up @@ -573,3 +574,28 @@ func TestHandleListClientServiceError(t *testing.T) {
t.Fatalf("expected data to be %+v, got: %+v", expectedData, rr)
}
}

func TestRegisterValidation(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockLogger := NewMockLoggerInterface(ctrl)
mockService := NewMockServiceInterface(ctrl)
mockValidationRegistry := NewMockValidationRegistryInterface(ctrl)

apiKey := "clients"
mockValidationRegistry.EXPECT().
RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()).
Return(nil)
mockValidationRegistry.EXPECT().
RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()).
Return(fmt.Errorf("key is already registered"))

// first registration of `apiKey` is successful
NewAPI(mockService, mockLogger).RegisterValidation(mockValidationRegistry)

mockLogger.EXPECT().Fatalf(gomock.Any(), gomock.Any()).Times(1)

// second registration of `apiKey` causes logger.Fatal invocation
NewAPI(mockService, mockLogger).RegisterValidation(mockValidationRegistry)
}
87 changes: 87 additions & 0 deletions pkg/clients/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2024 Canonical Ltd.
// SPDX-License-Identifier: AGPL-3.0

package clients

import (
"context"
"encoding/json"
"net/http"
"strings"

"github.com/go-playground/validator/v10"
client "github.com/ory/hydra-client-go/v2"

"github.com/canonical/identity-platform-admin-ui/internal/validation"
)

var (
oauth2ClientRules = map[string]string{
// if not empy, validate every item is not nil and not empty
"AllowedCorsOrigins": "omitempty,dive,required",
"Audience": "omitempty,dive,required",
"GrantTypes": "omitempty,dive,required",
"ClientName": "required",
// if not empty, validate value is one of 'pairwise' and 'public'
"SubjectType": "omitempty,oneof=pairwise public",
// if not empty, validate value is one of 'client_secret_basic', 'client_secret_post', 'private_key_jwt' and 'none'
"TokenEndpointAuthMethod": "omitempty,oneof=client_secret_basic client_secret_post private_key_jwt none",
}
)

type PayloadValidator struct {
apiKey string
validator *validator.Validate
}

func (p *PayloadValidator) setupValidator() {
p.validator.RegisterStructValidationMapRules(oauth2ClientRules, client.OAuth2Client{})
}

func (p *PayloadValidator) NeedsValidation(req *http.Request) bool {
return req.Method == http.MethodPost || req.Method == http.MethodPut
}

func (p *PayloadValidator) Validate(ctx context.Context, method, endpoint string, body []byte) (context.Context, validator.ValidationErrors, error) {
validated := false
var err error

if p.isCreateClient(method, endpoint) || p.isUpdateClient(method, endpoint) {
clientRequest := new(client.OAuth2Client)
if err := json.Unmarshal(body, clientRequest); err != nil {
return ctx, nil, err
}

err = p.validator.Struct(clientRequest)
validated = true

}

if !validated {
return ctx, nil, validation.NoMatchError(p.apiKey)
}

if err == nil {
return ctx, nil, nil
}

return ctx, err.(validator.ValidationErrors), nil
}

func (p *PayloadValidator) isCreateClient(method string, endpoint string) bool {
return method == http.MethodPost && endpoint == ""
}

func (p *PayloadValidator) isUpdateClient(method string, endpoint string) bool {
return method == http.MethodPut && strings.HasPrefix(endpoint, "/")
}

func NewClientsPayloadValidator(apiKey string) *PayloadValidator {
p := new(PayloadValidator)
p.apiKey = apiKey
p.validator = validation.NewValidator()

p.setupValidator()

return p
}
Loading

0 comments on commit 77fcdfb

Please sign in to comment.