From 44d7223b6466d5cab9fadf851d4830e7d8ae0062 Mon Sep 17 00:00:00 2001 From: barco Date: Mon, 8 Apr 2024 12:10:13 +0200 Subject: [PATCH 01/13] feat: add constructor for validator + use json tags for validation errors --- internal/validation/types.go | 18 ++++++++++++++++++ pkg/clients/handlers.go | 2 +- pkg/groups/handlers.go | 2 +- pkg/identities/handlers.go | 2 +- pkg/idp/handlers.go | 2 +- pkg/roles/handlers.go | 2 +- pkg/rules/handlers.go | 2 +- pkg/schemas/handlers.go | 2 +- 8 files changed, 25 insertions(+), 7 deletions(-) diff --git a/internal/validation/types.go b/internal/validation/types.go index 7d5e3e384..3221b3ecf 100644 --- a/internal/validation/types.go +++ b/internal/validation/types.go @@ -5,6 +5,8 @@ package validation import ( "net/http" + "reflect" + "strings" "github.com/go-playground/validator/v10" @@ -22,3 +24,19 @@ func NewValidationError(errors validator.ValidationErrors) *types.Response { func buildErrorData(err validator.ValidationErrors) []any { return nil } + +func NewValidator() *validator.Validate { + validate := validator.New(validator.WithRequiredStructEnabled()) + + // register a function to make 3rd party validator's errors reference json field names instead of Go struct field + // these errors will be used by frontend code + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) + + return validate +} diff --git a/pkg/clients/handlers.go b/pkg/clients/handlers.go index 94977be2f..b63d4f212 100644 --- a/pkg/clients/handlers.go +++ b/pkg/clients/handlers.go @@ -205,7 +205,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger return a diff --git a/pkg/groups/handlers.go b/pkg/groups/handlers.go index e605b103b..dddb7f18f 100644 --- a/pkg/groups/handlers.go +++ b/pkg/groups/handlers.go @@ -700,7 +700,7 @@ func NewAPI(service ServiceInterface, tracer tracing.TracingInterface, monitor m a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger a.tracer = tracer a.monitor = monitor diff --git a/pkg/identities/handlers.go b/pkg/identities/handlers.go index d9635e5e5..b15a6720f 100644 --- a/pkg/identities/handlers.go +++ b/pkg/identities/handlers.go @@ -271,7 +271,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger return a diff --git a/pkg/idp/handlers.go b/pkg/idp/handlers.go index 36403d403..60c4ab4c2 100644 --- a/pkg/idp/handlers.go +++ b/pkg/idp/handlers.go @@ -260,7 +260,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger return a diff --git a/pkg/roles/handlers.go b/pkg/roles/handlers.go index 617b3ef43..809b3d236 100644 --- a/pkg/roles/handlers.go +++ b/pkg/roles/handlers.go @@ -472,7 +472,7 @@ func NewAPI(service ServiceInterface, tracer tracing.TracingInterface, monitor m a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger a.tracer = tracer a.monitor = monitor diff --git a/pkg/rules/handlers.go b/pkg/rules/handlers.go index a374bf2ac..19527c57c 100644 --- a/pkg/rules/handlers.go +++ b/pkg/rules/handlers.go @@ -256,7 +256,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger return a diff --git a/pkg/schemas/handlers.go b/pkg/schemas/handlers.go index 3c4fcaaea..03b0f97de 100644 --- a/pkg/schemas/handlers.go +++ b/pkg/schemas/handlers.go @@ -366,7 +366,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) a.service = service - a.validator = validator.New(validator.WithRequiredStructEnabled()) + a.validator = validation.NewValidator() a.logger = logger return a From a21462c78249d83961ad19a167ceeb57e5366e1f Mon Sep 17 00:00:00 2001 From: barco Date: Mon, 8 Apr 2024 12:22:08 +0200 Subject: [PATCH 02/13] feat: enhanced ValidationError with specific field errors and common errors --- internal/validation/types.go | 38 ++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/internal/validation/types.go b/internal/validation/types.go index 3221b3ecf..f469c864b 100644 --- a/internal/validation/types.go +++ b/internal/validation/types.go @@ -4,6 +4,8 @@ package validation import ( + "errors" + "fmt" "net/http" "reflect" "strings" @@ -13,16 +15,44 @@ import ( "github.com/canonical/identity-platform-admin-ui/internal/http/types" ) -func NewValidationError(errors validator.ValidationErrors) *types.Response { +var ( + NoBodyError = errors.New("request body is not present") +) + +func NoMatchError(apiKey string) error { + return fmt.Errorf("can't find matching validation process for '%s' endpoint", apiKey) +} + +func NewValidationError(msg string, errors validator.ValidationErrors) *types.Response { return &types.Response{ Status: http.StatusBadRequest, - Message: "validation errors", + Message: msg, Data: buildErrorData(errors), } } -func buildErrorData(err validator.ValidationErrors) []any { - return nil +func buildErrorData(errors validator.ValidationErrors) map[string][]string { + if errors == nil { + return nil + } + + failedValidations := make(map[string][]string) + for _, e := range errors { + field := e.Field() + + failures, ok := failedValidations[field] + if !ok { + failedValidations[field] = make([]string, 0) + } + + failures = append( + failures, + fmt.Sprintf("value '%s' fails validation of type `%s`", e.Value(), e.Tag()), + ) + failedValidations[field] = failures + } + + return failedValidations } func NewValidator() *validator.Validate { From 313617a7faaf8292df5b0a5cfc509f9e40188290 Mon Sep 17 00:00:00 2001 From: barco Date: Thu, 11 Apr 2024 15:21:23 +0200 Subject: [PATCH 03/13] feat: enhance ValidationRegistry with PayloadValidator and adjust in handlers + enhance Middleware + add func for ApiKey retrieval from endpoint --- internal/validation/registry.go | 81 ++++++++++++++--- internal/validation/registry_test.go | 126 +++++++++++++++++++-------- pkg/clients/handlers.go | 21 ++--- pkg/identities/handlers.go | 21 +++-- pkg/idp/handlers.go | 15 ++-- pkg/roles/handlers.go | 17 ++-- pkg/rules/handlers.go | 17 ++-- 7 files changed, 196 insertions(+), 102 deletions(-) diff --git a/internal/validation/registry.go b/internal/validation/registry.go index 1ad0ae73f..6e6225b0d 100644 --- a/internal/validation/registry.go +++ b/internal/validation/registry.go @@ -4,13 +4,17 @@ package validation import ( + "bytes" + "context" "encoding/json" "fmt" + "io" "net/http" "strings" "github.com/go-playground/validator/v10" + "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" "github.com/canonical/identity-platform-admin-ui/internal/monitoring" "github.com/canonical/identity-platform-admin-ui/internal/tracing" @@ -20,13 +24,16 @@ const apiVersion = "v0" type ValidationRegistryInterface interface { ValidationMiddleware(next http.Handler) http.Handler - RegisterValidatingFunc(key string, vf ValidatingFunc) error + RegisterPayloadValidator(key string, vf PayloadValidatorInterface) error } -type ValidatingFunc func(r *http.Request) validator.ValidationErrors +type PayloadValidatorInterface interface { + NeedsValidation(r *http.Request) bool + Validate(ctx context.Context, method, endpoint string, body []byte) (context.Context, validator.ValidationErrors, error) +} type ValidationRegistry struct { - validatingFuncs map[string]ValidatingFunc + validators map[string]PayloadValidatorInterface tracer tracing.TracingInterface monitor monitoring.MonitorInterface @@ -41,34 +48,69 @@ func (v *ValidationRegistry) ValidationMiddleware(next http.Handler) http.Handle r = r.WithContext(ctx) key := v.apiKey(r.URL.Path) - vf, ok := v.validatingFuncs[key] - if !ok { + payloadValidator, ok := v.validators[key] + if !ok || !payloadValidator.NeedsValidation(r) { next.ServeHTTP(w, r) return } - if validationErr := vf(r); validationErr != nil { - e := NewValidationError(validationErr) + reqBody := r.Body + defer reqBody.Close() + body, err := io.ReadAll(reqBody) + + if err != nil { + badRequestFromError(w, NoBodyError) + return + } + + // don't break existing handlers, replace the body that was consumed + r.Body = io.NopCloser(bytes.NewReader(body)) + + endpoint, _ := ApiEndpoint(r.URL.Path, key) + var validationErr validator.ValidationErrors + + ctx, validationErr, err = payloadValidator.Validate(r.Context(), r.Method, endpoint, body) + + if err != nil { + badRequestFromError(w, err) + return + } + + // handler validation errors + if validationErr != nil { + e := NewValidationError("validation errors", validationErr) w.WriteHeader(e.Status) _ = json.NewEncoder(w).Encode(e) return } - next.ServeHTTP(w, r) + // if no errors, proceed with the request + next.ServeHTTP(w, r.WithContext(ctx)) }) } -func (v *ValidationRegistry) RegisterValidatingFunc(key string, vf ValidatingFunc) error { - if vf == nil { - return fmt.Errorf("validatingFunc can't be null") +func badRequestFromError(w http.ResponseWriter, err error) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode( + types.Response{ + Message: err.Error(), + Status: http.StatusBadRequest, + }, + ) + return +} + +func (v *ValidationRegistry) RegisterPayloadValidator(key string, payloadValidator PayloadValidatorInterface) error { + if payloadValidator == nil { + return fmt.Errorf("payloadValidator can't be null") } - if _, ok := v.validatingFuncs[key]; ok { + if _, ok := v.validators[key]; ok { return fmt.Errorf("key is already registered") } - v.validatingFuncs[key] = vf + v.validators[key] = payloadValidator return nil } @@ -81,9 +123,20 @@ func (v *ValidationRegistry) apiKey(endpoint string) string { return strings.SplitN(after, "/", 1)[0] } +// ApiEndpoint returns the endpoint string stripped from the api and version prefix, and the apikey +// it doesn't strip away trailing slash if there is one +func ApiEndpoint(endpoint, apiKey string) (string, bool) { + after, found := strings.CutPrefix(endpoint, fmt.Sprintf("/api/%s/", apiVersion)) + if !found { + return "", false + } + + return strings.CutPrefix(after, apiKey) +} + func NewRegistry(tracer tracing.TracingInterface, monitor monitoring.MonitorInterface, logger logging.LoggerInterface) *ValidationRegistry { v := new(ValidationRegistry) - v.validatingFuncs = make(map[string]ValidatingFunc) + v.validators = make(map[string]PayloadValidatorInterface) v.tracer = tracer v.monitor = monitor diff --git a/internal/validation/registry_test.go b/internal/validation/registry_test.go index d12dccab6..ef48b9f55 100644 --- a/internal/validation/registry_test.go +++ b/internal/validation/registry_test.go @@ -7,6 +7,8 @@ import ( "context" "net/http" "net/http/httptest" + "reflect" + "strings" "testing" "github.com/go-playground/validator/v10" @@ -18,6 +20,30 @@ import ( //go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_monitor.go -source=../monitoring/interfaces.go //go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer +type payloadValidator struct{} + +func (_ *payloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) { + e := mockValidationErrors() + if e == nil { + return ctx, nil, nil + } + return ctx, e, nil +} + +func (_ *payloadValidator) NeedsValidation(r *http.Request) bool { + return true +} + +type noopPayloadValidator struct{} + +func (_ *noopPayloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) { + return ctx, nil, nil +} + +func (_ *noopPayloadValidator) NeedsValidation(r *http.Request) bool { + return true +} + func TestValidator_Middleware(t *testing.T) { ctrl := gomock.NewController(t) tracer := NewMockTracer(ctrl) @@ -34,17 +60,7 @@ func TestValidator_Middleware(t *testing.T) { }) vld := NewRegistry(tracer, monitor, logger) - vld.validatingFuncs["mock-key"] = func(r *http.Request) validator.ValidationErrors { - type InvalidStruct struct { - FirstName string `validate:"required"` - } - - e := validator.New(validator.WithRequiredStructEnabled()).Struct(InvalidStruct{}) - if e == nil { - return nil - } - return e.(validator.ValidationErrors) - } + vld.validators["mock-key"] = &payloadValidator{} for _, tt := range []struct { name string @@ -88,6 +104,26 @@ func TestValidator_Middleware(t *testing.T) { } } +func mockValidationErrors() validator.ValidationErrors { + type InvalidStruct struct { + FirstName string `json:"first_name" validate:"required"` + } + + validate := validator.New(validator.WithRequiredStructEnabled()) + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) + e := validate.Struct(InvalidStruct{}) + if e == nil { + return nil + } + return e.(validator.ValidationErrors) +} + func TestValidator_RegisterValidator(t *testing.T) { ctrl := gomock.NewController(t) tracer := NewMockTracer(ctrl) @@ -95,57 +131,56 @@ func TestValidator_RegisterValidator(t *testing.T) { logger := NewMockLoggerInterface(ctrl) emptyValidator := &ValidationRegistry{ - validatingFuncs: make(map[string]ValidatingFunc), - tracer: tracer, - monitor: monitor, - logger: logger, + validators: make(map[string]PayloadValidatorInterface), + tracer: tracer, + monitor: monitor, + logger: logger, } - noopVf := ValidatingFunc(func(r *http.Request) validator.ValidationErrors { - return nil - }) - validatingFuncs := make(map[string]ValidatingFunc) - validatingFuncs["mock-key-1"] = noopVf + noopValidator := &noopPayloadValidator{} + + validators := make(map[string]PayloadValidatorInterface) + validators["mock-key-1"] = noopValidator nonEmptyValidator := &ValidationRegistry{ - validatingFuncs: validatingFuncs, - tracer: tracer, - monitor: monitor, - logger: logger, + validators: validators, + tracer: tracer, + monitor: monitor, + logger: logger, } for _, tt := range []struct { name string validator *ValidationRegistry prefix string - vf ValidatingFunc + v PayloadValidatorInterface expected string }{ { name: "Nil middleware", validator: emptyValidator, prefix: "", - vf: nil, - expected: "validatingFunc can't be null", + v: nil, + expected: "payloadValidator can't be null", }, { name: "Existing key", validator: nonEmptyValidator, prefix: "mock-key-1", - vf: noopVf, + v: noopValidator, expected: "key is already registered", }, { name: "Success", validator: emptyValidator, prefix: "mock-key", - vf: noopVf, + v: noopValidator, expected: "", }, } { tt := tt t.Run(tt.name, func(t *testing.T) { - result := tt.validator.RegisterValidatingFunc(tt.prefix, tt.vf) + result := tt.validator.RegisterPayloadValidator(tt.prefix, tt.v) if tt.expected == "" && nil == result { return @@ -178,11 +213,34 @@ func TestNewValidator(t *testing.T) { t.FailNow() } - if v.validatingFuncs == nil { - t.Fatalf("validatingFuncs map expected not empty") + if v.validators == nil { + t.Fatalf("validators map expected not empty") + } + + if len(v.validators) != 0 { + t.Fatalf("validators map expected not populated") + } +} + +func TestNewValidationError(t *testing.T) { + ve := mockValidationErrors() + response := NewValidationError("validation errors", ve) + + if response.Status != http.StatusBadRequest { + t.Fatalf("response status does not match expected") + } + + if response.Message != "validation errors" { + t.Fatalf("response message does not match expected") + } + + expectedData := map[string][]string{ + "first_name": { + "value '' fails validation of type `required`", + }, } - if len(v.validatingFuncs) != 0 { - t.Fatalf("validatingFuncs map expected not populated") + if !reflect.DeepEqual(expectedData, response.Data) { + t.Fatalf("response data does not match expected validation errors") } } diff --git a/pkg/clients/handlers.go b/pkg/clients/handlers.go index b63d4f212..a1accef46 100644 --- a/pkg/clients/handlers.go +++ b/pkg/clients/handlers.go @@ -10,7 +10,6 @@ import ( "strconv" "github.com/go-chi/chi/v5" - "github.com/go-playground/validator/v10" "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" @@ -18,8 +17,9 @@ import ( ) type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface } @@ -27,22 +27,18 @@ type API struct { func (a *API) RegisterEndpoints(mux *chi.Mux) { mux.Get("/api/v0/clients", a.ListClients) mux.Post("/api/v0/clients", a.CreateClient) - mux.Get("/api/v0/clients/{id:.+}", a.GetClient) - mux.Put("/api/v0/clients/{id:.+}", a.UpdateClient) - mux.Delete("/api/v0/clients/{id:.+}", a.DeleteClient) + mux.Get("/api/v0/clients/{id}", a.GetClient) + mux.Put("/api/v0/clients/{id}", a.UpdateClient) + mux.Delete("/api/v0/clients/{id}", a.DeleteClient) } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("clients", a.validatingFunc) + err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator) if err != nil { a.logger.Fatal("unexpected validatingFunc already registered for clients") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} - func (a *API) WriteJSONResponse(w http.ResponseWriter, data interface{}, msg string, status int, links interface{}, meta *types.Pagination) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) @@ -203,9 +199,10 @@ func (a *API) parseListClientsRequest(r *http.Request) (*ListClientsRequest, err func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) + a.apiKey = "clients" a.service = service - a.validator = validation.NewValidator() + //a.payloadValidator = NewClientsPayloadValidator(a.apiKey) a.logger = logger return a diff --git a/pkg/identities/handlers.go b/pkg/identities/handlers.go index b15a6720f..29b177499 100644 --- a/pkg/identities/handlers.go +++ b/pkg/identities/handlers.go @@ -9,7 +9,6 @@ import ( "net/http" "github.com/go-chi/chi/v5" - "github.com/go-playground/validator/v10" kClient "github.com/ory/kratos-client-go" "github.com/canonical/identity-platform-admin-ui/internal/http/types" @@ -28,8 +27,9 @@ type UpdateIdentityRequest struct { } type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface } @@ -46,16 +46,13 @@ func (a *API) RegisterEndpoints(mux *chi.Mux) { } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("identities", a.validatingFunc) + err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator) + if err != nil { - a.logger.Fatal("unexpected validatingFunc already registered for identities") + a.logger.Fatal("unexpected PayloadValidator already registered for identities") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} - func (a *API) handleList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -269,9 +266,11 @@ func (a *API) error(e *kClient.GenericError) types.Response { func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) - + a.apiKey = "identities" a.service = service - a.validator = validation.NewValidator() + + a.payloadValidator = NewIdentitiesPayloadValidator(a.apiKey) + a.logger = logger return a diff --git a/pkg/idp/handlers.go b/pkg/idp/handlers.go index 60c4ab4c2..5a1929fcb 100644 --- a/pkg/idp/handlers.go +++ b/pkg/idp/handlers.go @@ -9,7 +9,6 @@ import ( "net/http" "github.com/go-chi/chi/v5" - "github.com/go-playground/validator/v10" "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" @@ -19,8 +18,9 @@ import ( const okValue = "ok" type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface } @@ -34,16 +34,12 @@ func (a *API) RegisterEndpoints(mux *chi.Mux) { } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("idps", a.validatingFunc) + err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator) if err != nil { a.logger.Fatal("unexpected validatingFunc already registered for idps") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} - func (a *API) handleList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -258,9 +254,10 @@ func (a *API) handleRemove(w http.ResponseWriter, r *http.Request) { func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) + a.apiKey = "idps" + //a.payloadValidator = NewIdPPayloadValidator(a.apiKey) a.service = service - a.validator = validation.NewValidator() a.logger = logger return a diff --git a/pkg/roles/handlers.go b/pkg/roles/handlers.go index 809b3d236..5ccb6c750 100644 --- a/pkg/roles/handlers.go +++ b/pkg/roles/handlers.go @@ -10,8 +10,6 @@ import ( "io" "net/http" - "github.com/go-playground/validator/v10" - "github.com/canonical/identity-platform-admin-ui/internal/authorization" "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" @@ -42,8 +40,9 @@ type RoleRequest struct { // API is the core HTTP object that implements all the HTTP and business logic for the roles // HTTP API functionality type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface tracer tracing.TracingInterface @@ -63,16 +62,12 @@ func (a *API) RegisterEndpoints(mux *chi.Mux) { mux.Get("/api/v0/roles/{id:.+}/groups", a.handleListRoleGroup) } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("roles", a.validatingFunc) + err := v.RegisterPayloadValidator("roles", a.payloadValidator) if err != nil { - a.logger.Fatal("unexpected validatingFunc already registered for roles") + a.logger.Fatal("unexpected PayloadValidator already registered for roles") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} - func (a *API) userFromContext(ctx context.Context) *authorization.User { // TODO @shipperizer implement the FromContext and NewContext in authorization package // see snippet below copied from https://pkg.go.dev/context#Context @@ -472,7 +467,7 @@ func NewAPI(service ServiceInterface, tracer tracing.TracingInterface, monitor m a := new(API) a.service = service - a.validator = validation.NewValidator() + //a.payloadValidator = NewRolesPayloadValidator(a.apiKey) a.logger = logger a.tracer = tracer a.monitor = monitor diff --git a/pkg/rules/handlers.go b/pkg/rules/handlers.go index 19527c57c..37d1ce85f 100644 --- a/pkg/rules/handlers.go +++ b/pkg/rules/handlers.go @@ -9,8 +9,6 @@ import ( "io" "net/http" - "github.com/go-playground/validator/v10" - "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" "github.com/canonical/identity-platform-admin-ui/internal/validation" @@ -20,8 +18,9 @@ import ( ) type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface } @@ -35,16 +34,12 @@ func (a *API) RegisterEndpoints(mux *chi.Mux) { } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("rules", a.validatingFunc) + err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator) if err != nil { - a.logger.Fatal("unexpected validatingFunc already registered for rules") + a.logger.Fatal("unexpected PayloadValidator already registered for rules") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} - func (a *API) handleList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -254,9 +249,9 @@ func (a *API) handleRemove(w http.ResponseWriter, r *http.Request) { func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) + a.apiKey = "rules" a.service = service - a.validator = validation.NewValidator() a.logger = logger return a From 5a30836583b3278b18692b7259f0aace9429b167 Mon Sep 17 00:00:00 2001 From: barco Date: Mon, 8 Apr 2024 12:22:56 +0200 Subject: [PATCH 04/13] test: add test for enhanced ValidationError + extract mockValidationErrors function --- internal/validation/registry_test.go | 77 +++++++++++----------------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/internal/validation/registry_test.go b/internal/validation/registry_test.go index ef48b9f55..de77726cd 100644 --- a/internal/validation/registry_test.go +++ b/internal/validation/registry_test.go @@ -20,30 +20,6 @@ import ( //go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_monitor.go -source=../monitoring/interfaces.go //go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer -type payloadValidator struct{} - -func (_ *payloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) { - e := mockValidationErrors() - if e == nil { - return ctx, nil, nil - } - return ctx, e, nil -} - -func (_ *payloadValidator) NeedsValidation(r *http.Request) bool { - return true -} - -type noopPayloadValidator struct{} - -func (_ *noopPayloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) { - return ctx, nil, nil -} - -func (_ *noopPayloadValidator) NeedsValidation(r *http.Request) bool { - return true -} - func TestValidator_Middleware(t *testing.T) { ctrl := gomock.NewController(t) tracer := NewMockTracer(ctrl) @@ -60,7 +36,13 @@ func TestValidator_Middleware(t *testing.T) { }) vld := NewRegistry(tracer, monitor, logger) - vld.validators["mock-key"] = &payloadValidator{} + vld.validatingFuncs["mock-key"] = func(r *http.Request) (validator.ValidationErrors, error) { + e := mockValidationErrors() + if e == nil { + return nil, nil + } + return e, nil + } for _, tt := range []struct { name string @@ -131,56 +113,57 @@ func TestValidator_RegisterValidator(t *testing.T) { logger := NewMockLoggerInterface(ctrl) emptyValidator := &ValidationRegistry{ - validators: make(map[string]PayloadValidatorInterface), - tracer: tracer, - monitor: monitor, - logger: logger, + validatingFuncs: make(map[string]ValidatingFunc), + tracer: tracer, + monitor: monitor, + logger: logger, } - noopValidator := &noopPayloadValidator{} - - validators := make(map[string]PayloadValidatorInterface) - validators["mock-key-1"] = noopValidator + noopVf := ValidatingFunc(func(r *http.Request) (validator.ValidationErrors, error) { + return nil, nil + }) + validatingFuncs := make(map[string]ValidatingFunc) + validatingFuncs["mock-key-1"] = noopVf nonEmptyValidator := &ValidationRegistry{ - validators: validators, - tracer: tracer, - monitor: monitor, - logger: logger, + validatingFuncs: validatingFuncs, + tracer: tracer, + monitor: monitor, + logger: logger, } for _, tt := range []struct { name string validator *ValidationRegistry prefix string - v PayloadValidatorInterface + vf ValidatingFunc expected string }{ { name: "Nil middleware", validator: emptyValidator, prefix: "", - v: nil, - expected: "payloadValidator can't be null", + vf: nil, + expected: "validatingFunc can't be null", }, { name: "Existing key", validator: nonEmptyValidator, prefix: "mock-key-1", - v: noopValidator, + vf: noopVf, expected: "key is already registered", }, { name: "Success", validator: emptyValidator, prefix: "mock-key", - v: noopValidator, + vf: noopVf, expected: "", }, } { tt := tt t.Run(tt.name, func(t *testing.T) { - result := tt.validator.RegisterPayloadValidator(tt.prefix, tt.v) + result := tt.validator.RegisterValidatingFunc(tt.prefix, tt.vf) if tt.expected == "" && nil == result { return @@ -213,12 +196,12 @@ func TestNewValidator(t *testing.T) { t.FailNow() } - if v.validators == nil { - t.Fatalf("validators map expected not empty") + if v.validatingFuncs == nil { + t.Fatalf("validatingFuncs map expected not empty") } - if len(v.validators) != 0 { - t.Fatalf("validators map expected not populated") + if len(v.validatingFuncs) != 0 { + t.Fatalf("validatingFuncs map expected not populated") } } From 8c5e17319243cc44dbe3d353acb2df57819334ac Mon Sep 17 00:00:00 2001 From: barco Date: Thu, 11 Apr 2024 15:43:35 +0200 Subject: [PATCH 05/13] feat: add validation setup for `schemas` endpoint --- pkg/schemas/handlers.go | 66 +++++++++++++++++++++++++++++++++------ pkg/schemas/service.go | 2 +- pkg/schemas/validation.go | 34 ++++++++++++++++++++ 3 files changed, 91 insertions(+), 11 deletions(-) create mode 100644 pkg/schemas/validation.go diff --git a/pkg/schemas/handlers.go b/pkg/schemas/handlers.go index 03b0f97de..613b0832f 100644 --- a/pkg/schemas/handlers.go +++ b/pkg/schemas/handlers.go @@ -9,7 +9,6 @@ import ( "net/http" "github.com/go-chi/chi/v5" - "github.com/go-playground/validator/v10" kClient "github.com/ory/kratos-client-go" "github.com/canonical/identity-platform-admin-ui/internal/http/types" @@ -17,11 +16,10 @@ import ( "github.com/canonical/identity-platform-admin-ui/internal/validation" ) -const okValue = "ok" - type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface } @@ -37,15 +35,62 @@ func (a *API) RegisterEndpoints(mux *chi.Mux) { } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("schemas", a.validatingFunc) + err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator) if err != nil { a.logger.Fatal("unexpected validatingFunc already registered for schemas") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} +/*func (a *API) validatingFunc(r *http.Request) (validator.ValidationErrors, error) { + if !shouldValidate(r) { + return nil, nil + } + + defer r.Body.Close() + body, err := io.ReadAll(r.Body) + + if err != nil { + return nil, validation.NoBodyError + } + + // don't break existing handlers, replace the body that was consumed + r.Body = io.NopCloser(bytes.NewReader(body)) + + // key "schemas" must be there since we registered it in the setup func + endpoint, _ := validation.ApiEndpoint(r.URL.Path, a.apiKey) + + validated := false + + if isCreateOrUpdateSchema(r, endpoint) { + schema := new(kClient.IdentitySchemaContainer) + if err := json.Unmarshal(body, schema); err != nil { + return nil, err + } + + err = a.validator.Struct(schema) + validated = true + } + + if isPartialUpdate(r, endpoint) { + schema := new(DefaultSchema) + if err := json.Unmarshal(body, schema); err != nil { + return nil, err + } + + err = a.validator.Struct(schema) + validated = true + } + + if !validated { + return nil, validation.NoMatchError(a.apiKey) + } + + if err == nil { + return nil, nil + } + + return err.(validator.ValidationErrors), nil +}*/ func (a *API) handleList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -365,8 +410,9 @@ func (a *API) error(e *kClient.GenericError) types.Response { func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a := new(API) + a.apiKey = "schemas" a.service = service - a.validator = validation.NewValidator() + //a.payloadValidator = NewSchemasPayloadValidator() a.logger = logger return a diff --git a/pkg/schemas/service.go b/pkg/schemas/service.go index aaacfc1ab..40df0fd59 100644 --- a/pkg/schemas/service.go +++ b/pkg/schemas/service.go @@ -38,7 +38,7 @@ type IdentitySchemaData struct { } type DefaultSchema struct { - ID string `json:"schema_id"` + ID string `json:"schema_id" validate:"required"` } // TODO @shipperizer verify during integration test if this is actually the format diff --git a/pkg/schemas/validation.go b/pkg/schemas/validation.go new file mode 100644 index 000000000..6ed5bbd92 --- /dev/null +++ b/pkg/schemas/validation.go @@ -0,0 +1,34 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package schemas + +import ( + "net/http" + "strings" + + "github.com/go-playground/validator/v10" +) + +var ( + identitySchemaContainerRules = map[string]string{ + "Schema": "required", + } +) + +type PayloadValidator struct { + apiKey string + validator *validator.Validate +} + +func isPartialUpdate(r *http.Request, endpoint string) bool { + return strings.HasPrefix(endpoint, "/") && r.Method == http.MethodPatch +} + +func isCreateOrUpdateSchema(r *http.Request, endpoint string) bool { + return (endpoint == "" && r.Method == http.MethodPost) || (endpoint == "/default" && r.Method == http.MethodPut) +} + +func shouldValidate(r *http.Request) bool { + return r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch +} From b4178c95c2771b2149fb92cc80d43431b6c7028b Mon Sep 17 00:00:00 2001 From: barco Date: Thu, 11 Apr 2024 15:45:18 +0200 Subject: [PATCH 06/13] feat: add validation setup for `identities` endpoint --- pkg/identities/handlers_test.go | 26 ++++ pkg/identities/validation.go | 97 +++++++++++++ pkg/identities/validation_test.go | 218 ++++++++++++++++++++++++++++++ 3 files changed, 341 insertions(+) create mode 100644 pkg/identities/validation.go create mode 100644 pkg/identities/validation_test.go diff --git a/pkg/identities/handlers_test.go b/pkg/identities/handlers_test.go index 1ee05f473..44d8efa95 100644 --- a/pkg/identities/handlers_test.go +++ b/pkg/identities/handlers_test.go @@ -27,6 +27,7 @@ import ( //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_monitor.go -source=../../internal/monitoring/interfaces.go //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_kratos.go github.com/ory/kratos-client-go IdentityAPI +//go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_validation.go -source=../../internal/validation/registry.go func TestHandleListSuccess(t *testing.T) { ctrl := gomock.NewController(t) @@ -717,3 +718,28 @@ func TestHandleRemoveFailAndPropagatesKratosError(t *testing.T) { t.Errorf("expected code to be %v got %v", *gerr.Code, rr.Status) } } + +func TestRegisterValidation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockService := NewMockServiceInterface(ctrl) + mockValidationRegistry := NewMockValidationRegistryInterface(ctrl) + + apiKey := "identities" + mockValidationRegistry.EXPECT(). + RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()). + Return(nil) + mockValidationRegistry.EXPECT(). + RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()). + Return(fmt.Errorf("key is already registered")) + + // first registration of `apiKey` is successful + NewAPI(mockService, mockLogger).RegisterValidation(mockValidationRegistry) + + mockLogger.EXPECT().Fatal(gomock.Any()).Times(1) + + // second registration of `apiKey` causes logger.Fatal invocation + NewAPI(mockService, mockLogger).RegisterValidation(mockValidationRegistry) +} diff --git a/pkg/identities/validation.go b/pkg/identities/validation.go new file mode 100644 index 000000000..1156a060f --- /dev/null +++ b/pkg/identities/validation.go @@ -0,0 +1,97 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package identities + +import ( + "context" + "encoding/json" + "net/http" + "strings" + + "github.com/go-playground/validator/v10" + + "github.com/canonical/identity-platform-admin-ui/internal/validation" +) + +var ( + identityRules = map[string]string{ + "Credentials": "required", + } + + // mutually exclusive fields + credentialsRules = map[string]string{ + "Oidc": "required_without=Password,excluded_with=Password", + "Password": "required_without=Oidc,excluded_with=Oidc", + } +) + +type PayloadValidator struct { + apiKey string + validator *validator.Validate +} + +func (p *PayloadValidator) setupValidator() { + p.validator.RegisterStructValidationMapRules(identityRules, CreateIdentityRequest{}.CreateIdentityBody) + p.validator.RegisterStructValidationMapRules(credentialsRules, CreateIdentityRequest{}.CreateIdentityBody.Credentials) + + p.validator.RegisterStructValidationMapRules(identityRules, UpdateIdentityRequest{}.UpdateIdentityBody) + p.validator.RegisterStructValidationMapRules(credentialsRules, UpdateIdentityRequest{}.UpdateIdentityBody.Credentials) +} + +func (_ *PayloadValidator) NeedsValidation(req *http.Request) bool { + return req.Method == http.MethodPost || req.Method == http.MethodPut +} + +func (p *PayloadValidator) Validate(ctx context.Context, method, endpoint string, body []byte) (context.Context, validator.ValidationErrors, error) { + validated := false + var err error + + if p.isCreateIdentity(method, endpoint) { + createIdentity := new(CreateIdentityRequest) + if err := json.Unmarshal(body, createIdentity); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(createIdentity) + validated = true + + } else if p.isUpdateIdentity(method, endpoint) { + updateIdentity := new(UpdateIdentityRequest) + if err := json.Unmarshal(body, updateIdentity); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(updateIdentity) + validated = true + + } + + if !validated { + return ctx, nil, validation.NoMatchError(p.apiKey) + } + + if err == nil { + return ctx, nil, nil + } + + return ctx, err.(validator.ValidationErrors), nil +} + +func (_ *PayloadValidator) isCreateIdentity(method, endpoint string) bool { + return endpoint == "" && method == http.MethodPost +} + +func (_ *PayloadValidator) isUpdateIdentity(method, endpoint string) bool { + return strings.HasPrefix(endpoint, "/") && method == http.MethodPut +} + +func NewIdentitiesPayloadValidator(apiKey string) *PayloadValidator { + p := new(PayloadValidator) + p.apiKey = apiKey + p.validator = validation.NewValidator() + + p.setupValidator() + + return p +} diff --git a/pkg/identities/validation_test.go b/pkg/identities/validation_test.go new file mode 100644 index 000000000..96b20712a --- /dev/null +++ b/pkg/identities/validation_test.go @@ -0,0 +1,218 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package identities + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-playground/validator/v10" + client "github.com/ory/kratos-client-go" + + "github.com/canonical/identity-platform-admin-ui/internal/validation" +) + +func TestNeedsValidation(t *testing.T) { + p := new(PayloadValidator) + p.validator = validation.NewValidator() + p.setupValidator() + + for _, tt := range []struct { + name string + req *http.Request + expectedResult bool + }{ + { + name: http.MethodPost, + req: httptest.NewRequest(http.MethodPost, "/", nil), + expectedResult: true, + }, + { + name: http.MethodPut, + req: httptest.NewRequest(http.MethodPut, "/", nil), + expectedResult: true, + }, + { + name: http.MethodGet, + req: httptest.NewRequest(http.MethodGet, "/", nil), + expectedResult: false, + }, + { + name: http.MethodPatch, + req: httptest.NewRequest(http.MethodPatch, "/", nil), + expectedResult: false, + }, + { + name: http.MethodDelete, + req: httptest.NewRequest(http.MethodDelete, "/", nil), + expectedResult: false, + }, + { + name: http.MethodConnect, + req: httptest.NewRequest(http.MethodConnect, "/", nil), + expectedResult: false, + }, + { + name: http.MethodHead, + req: httptest.NewRequest(http.MethodHead, "/", nil), + expectedResult: false, + }, + { + name: http.MethodTrace, + req: httptest.NewRequest(http.MethodTrace, "/", nil), + expectedResult: false, + }, + { + name: http.MethodOptions, + req: httptest.NewRequest(http.MethodOptions, "/", nil), + expectedResult: false, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + result := p.NeedsValidation(tt.req) + + if result != tt.expectedResult { + t.Fatalf("Result doesn't match expected one, obtained %t instead of %t", result, tt.expectedResult) + } + }) + } + +} + +func TestValidate(t *testing.T) { + p := new(PayloadValidator) + p.apiKey = "identities" + p.validator = validation.NewValidator() + p.setupValidator() + + for _, tt := range []struct { + name string + method string + endpoint string + body func() []byte + expectedResult validator.ValidationErrors + expectedError error + }{ + { + name: "CreateIdentitySuccessOidc", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + identity := client.NewCreateIdentityBodyWithDefaults() + identity.Credentials = &client.IdentityWithCredentials{ + Oidc: &client.IdentityWithCredentialsOidc{}, + } + marshal, _ := json.Marshal(identity) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "CreateIdentitySuccessPassword", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + identity := client.NewCreateIdentityBodyWithDefaults() + identity.Credentials = &client.IdentityWithCredentials{ + Password: &client.IdentityWithCredentialsPassword{}, + } + marshal, _ := json.Marshal(identity) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "UpdateIdentitySuccessOidc", + method: http.MethodPut, + endpoint: "/identity-id", + body: func() []byte { + identity := client.NewUpdateIdentityBodyWithDefaults() + identity.Credentials = &client.IdentityWithCredentials{ + Oidc: &client.IdentityWithCredentialsOidc{}, + } + marshal, _ := json.Marshal(identity) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "UpdateIdentitySuccessPassword", + method: http.MethodPut, + endpoint: "/identity-id", + body: func() []byte { + identity := client.NewUpdateIdentityBodyWithDefaults() + identity.Credentials = &client.IdentityWithCredentials{ + Password: &client.IdentityWithCredentialsPassword{}, + } + marshal, _ := json.Marshal(identity) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "NoMatch", + method: http.MethodPost, + endpoint: "no-match-endpoint", + body: func() []byte { + return nil + }, + expectedResult: nil, + expectedError: validation.NoMatchError(p.apiKey), + }, + { + name: "CreateIdentityValidationError", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + identity := client.NewCreateIdentityBodyWithDefaults() + identity.Credentials = &client.IdentityWithCredentials{ + Password: nil, + Oidc: nil, + } + marshal, _ := json.Marshal(identity) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + { + name: "UpdateIdentityValidationError", + method: http.MethodPut, + endpoint: "/identity-id", + body: func() []byte { + identity := client.NewUpdateIdentityBodyWithDefaults() + identity.Credentials = &client.IdentityWithCredentials{ + Oidc: nil, + Password: nil, + } + marshal, _ := json.Marshal(identity) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + _, result, err := p.Validate(context.TODO(), tt.method, tt.endpoint, tt.body()) + + if err != nil && err.Error() != tt.expectedError.Error() { + t.Fatalf("Returned error doesn't match expected, obtained '%v' instead of '%v'", err, tt.expectedError) + } else if result != nil && errors.Is(result, tt.expectedResult) { + t.Fatalf("Returned validation errors don't match expected, obtained '%v' instead of '%v'", result, tt.expectedResult) + } else { + return + } + }) + } +} From 06fb9f4c777b880b4be1fb646360e9cf6b805095 Mon Sep 17 00:00:00 2001 From: barco Date: Thu, 11 Apr 2024 18:31:32 +0200 Subject: [PATCH 07/13] feat: add validation setup for `groups` endpoint --- pkg/groups/handlers.go | 96 +++++++++++++++++++++++++++++++++++----- pkg/groups/validation.go | 34 ++++++++++++++ 2 files changed, 118 insertions(+), 12 deletions(-) create mode 100644 pkg/groups/validation.go diff --git a/pkg/groups/handlers.go b/pkg/groups/handlers.go index dddb7f18f..4eb0a62ea 100644 --- a/pkg/groups/handlers.go +++ b/pkg/groups/handlers.go @@ -10,8 +10,6 @@ import ( "io" "net/http" - "github.com/go-playground/validator/v10" - "github.com/canonical/identity-platform-admin-ui/internal/authorization" "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" @@ -28,7 +26,7 @@ const ( ) type UpdateRolesRequest struct { - Roles []string `json:"roles" validate:"required"` + Roles []string `json:"roles" validate:"required,dive,required"` } type Permission struct { @@ -37,7 +35,7 @@ type Permission struct { } type UpdatePermissionsRequest struct { - Permissions []Permission `json:"permissions" validate:"required"` + Permissions []Permission `json:"permissions" validate:"required,dive,required"` } type GroupRequest struct { @@ -45,14 +43,15 @@ type GroupRequest struct { } type UpdateIdentitiesRequest struct { - Identities []string `json:"identities" validate:"required"` + Identities []string `json:"identities" validate:"required,dive,required"` } // API is the core HTTP object that implements all the HTTP and business logic for the groups // HTTP API functionality type API struct { - service ServiceInterface - validator *validator.Validate + apiKey string + service ServiceInterface + payloadValidator validation.PayloadValidatorInterface logger logging.LoggerInterface tracer tracing.TracingInterface @@ -78,15 +77,87 @@ func (a *API) RegisterEndpoints(mux *chi.Mux) { } func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { - err := v.RegisterValidatingFunc("groups", a.validatingFunc) + err := v.RegisterPayloadValidator(a.apiKey, a.payloadValidator) if err != nil { a.logger.Fatal("unexpected validatingFunc already registered for groups") } } -func (a *API) validatingFunc(r *http.Request) validator.ValidationErrors { - return nil -} +/*func (a *API) validatingFunc(r *http.Request) (validator.ValidationErrors, error) { + if !shouldValidate(r) { + return nil, nil + } + + defer r.Body.Close() + body, err := io.ReadAll(r.Body) + + if err != nil { + return nil, validation.NoBodyError + } + + // don't break existing handlers, replace the body that was consumed + r.Body = io.NopCloser(bytes.NewReader(body)) + + // key "identities" must be there since we registered it in the setup func + endpoint, _ := validation.ApiEndpoint(r.URL.Path, a.apiKey) + + validated := false + + if isCreateGroup(r, endpoint) { + group := new(GroupRequest) + if err := json.Unmarshal(body, group); err != nil { + return nil, err + } + + err = a.validator.Struct(group) + validated = true + } + + if isUpdateGroup(r, endpoint) { + // TODO: @barco to implement when the UpdateGroup is implemented + validated = true + } + + if isAssignRoles(r, endpoint) { + updateRoles := new(UpdateRolesRequest) + if err := json.Unmarshal(body, updateRoles); err != nil { + return nil, err + } + + err = a.validator.Struct(updateRoles) + validated = true + } + + if isAssignPermissions(r, endpoint) { + updatePermissions := new(UpdatePermissionsRequest) + if err := json.Unmarshal(body, updatePermissions); err != nil { + return nil, err + } + + err = a.validator.Struct(updatePermissions) + validated = true + } + + if isAssignIdentities(r, endpoint) { + updateIdentities := new(UpdateIdentitiesRequest) + if err := json.Unmarshal(body, updateIdentities); err != nil { + return nil, err + } + + err = a.validator.Struct(updateIdentities) + validated = true + } + + if !validated { + return nil, validation.NoMatchError(a.apiKey) + } + + if err == nil { + return nil, nil + } + + return err.(validator.ValidationErrors), nil +}*/ func (a *API) userFromContext(ctx context.Context) *authorization.User { // TODO @shipperizer implement the FromContext and NewContext in authorization package @@ -699,8 +770,9 @@ func (a *API) handleRemoveIdentities(w http.ResponseWriter, r *http.Request) { func NewAPI(service ServiceInterface, tracer tracing.TracingInterface, monitor monitoring.MonitorInterface, logger logging.LoggerInterface) *API { a := new(API) + a.apiKey = "groups" a.service = service - a.validator = validation.NewValidator() + //a.payloadValidator = NewGroupsPayloadValidator() a.logger = logger a.tracer = tracer a.monitor = monitor diff --git a/pkg/groups/validation.go b/pkg/groups/validation.go new file mode 100644 index 000000000..c52793e2d --- /dev/null +++ b/pkg/groups/validation.go @@ -0,0 +1,34 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package groups + +import ( + "net/http" + "strings" +) + +func shouldValidate(r *http.Request) bool { + return r.Method == http.MethodPost || r.Method == http.MethodPatch +} + +func isCreateGroup(r *http.Request, endpoint string) bool { + return r.Method == http.MethodPost && endpoint == "" +} + +func isUpdateGroup(r *http.Request, endpoint string) bool { + // make sure at least one character is present for the Group ID URL Param + return r.Method == http.MethodPatch && strings.HasPrefix(endpoint, "/") && len(endpoint) > 1 +} + +func isAssignRoles(r *http.Request, endpoint string) bool { + return r.Method == http.MethodPost && strings.HasSuffix(endpoint, "/roles") +} + +func isAssignPermissions(r *http.Request, endpoint string) bool { + return r.Method == http.MethodPatch && strings.HasSuffix(endpoint, "/entitlements") +} + +func isAssignIdentities(r *http.Request, endpoint string) bool { + return r.Method == http.MethodPatch && strings.HasSuffix(endpoint, "/identities") +} From 24c8d99319e1782cd742451d9b09f6846bd6fa3e Mon Sep 17 00:00:00 2001 From: barco Date: Fri, 12 Apr 2024 10:14:02 +0200 Subject: [PATCH 08/13] feat: add URL param validation for groups handlers --- pkg/groups/validation.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/groups/validation.go b/pkg/groups/validation.go index c52793e2d..6713fcf24 100644 --- a/pkg/groups/validation.go +++ b/pkg/groups/validation.go @@ -17,8 +17,7 @@ func isCreateGroup(r *http.Request, endpoint string) bool { } func isUpdateGroup(r *http.Request, endpoint string) bool { - // make sure at least one character is present for the Group ID URL Param - return r.Method == http.MethodPatch && strings.HasPrefix(endpoint, "/") && len(endpoint) > 1 + return r.Method == http.MethodPatch && strings.HasPrefix(endpoint, "/") } func isAssignRoles(r *http.Request, endpoint string) bool { From 1c479df763c8a38734024f2ea6161981cc78677e Mon Sep 17 00:00:00 2001 From: barco Date: Tue, 16 Apr 2024 14:28:26 +0200 Subject: [PATCH 09/13] refactor: adjust tests for registry --- internal/validation/registry_test.go | 77 +++++++++++++++++----------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/internal/validation/registry_test.go b/internal/validation/registry_test.go index de77726cd..ef48b9f55 100644 --- a/internal/validation/registry_test.go +++ b/internal/validation/registry_test.go @@ -20,6 +20,30 @@ import ( //go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_monitor.go -source=../monitoring/interfaces.go //go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer +type payloadValidator struct{} + +func (_ *payloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) { + e := mockValidationErrors() + if e == nil { + return ctx, nil, nil + } + return ctx, e, nil +} + +func (_ *payloadValidator) NeedsValidation(r *http.Request) bool { + return true +} + +type noopPayloadValidator struct{} + +func (_ *noopPayloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) { + return ctx, nil, nil +} + +func (_ *noopPayloadValidator) NeedsValidation(r *http.Request) bool { + return true +} + func TestValidator_Middleware(t *testing.T) { ctrl := gomock.NewController(t) tracer := NewMockTracer(ctrl) @@ -36,13 +60,7 @@ func TestValidator_Middleware(t *testing.T) { }) vld := NewRegistry(tracer, monitor, logger) - vld.validatingFuncs["mock-key"] = func(r *http.Request) (validator.ValidationErrors, error) { - e := mockValidationErrors() - if e == nil { - return nil, nil - } - return e, nil - } + vld.validators["mock-key"] = &payloadValidator{} for _, tt := range []struct { name string @@ -113,57 +131,56 @@ func TestValidator_RegisterValidator(t *testing.T) { logger := NewMockLoggerInterface(ctrl) emptyValidator := &ValidationRegistry{ - validatingFuncs: make(map[string]ValidatingFunc), - tracer: tracer, - monitor: monitor, - logger: logger, + validators: make(map[string]PayloadValidatorInterface), + tracer: tracer, + monitor: monitor, + logger: logger, } - noopVf := ValidatingFunc(func(r *http.Request) (validator.ValidationErrors, error) { - return nil, nil - }) - validatingFuncs := make(map[string]ValidatingFunc) - validatingFuncs["mock-key-1"] = noopVf + noopValidator := &noopPayloadValidator{} + + validators := make(map[string]PayloadValidatorInterface) + validators["mock-key-1"] = noopValidator nonEmptyValidator := &ValidationRegistry{ - validatingFuncs: validatingFuncs, - tracer: tracer, - monitor: monitor, - logger: logger, + validators: validators, + tracer: tracer, + monitor: monitor, + logger: logger, } for _, tt := range []struct { name string validator *ValidationRegistry prefix string - vf ValidatingFunc + v PayloadValidatorInterface expected string }{ { name: "Nil middleware", validator: emptyValidator, prefix: "", - vf: nil, - expected: "validatingFunc can't be null", + v: nil, + expected: "payloadValidator can't be null", }, { name: "Existing key", validator: nonEmptyValidator, prefix: "mock-key-1", - vf: noopVf, + v: noopValidator, expected: "key is already registered", }, { name: "Success", validator: emptyValidator, prefix: "mock-key", - vf: noopVf, + v: noopValidator, expected: "", }, } { tt := tt t.Run(tt.name, func(t *testing.T) { - result := tt.validator.RegisterValidatingFunc(tt.prefix, tt.vf) + result := tt.validator.RegisterPayloadValidator(tt.prefix, tt.v) if tt.expected == "" && nil == result { return @@ -196,12 +213,12 @@ func TestNewValidator(t *testing.T) { t.FailNow() } - if v.validatingFuncs == nil { - t.Fatalf("validatingFuncs map expected not empty") + if v.validators == nil { + t.Fatalf("validators map expected not empty") } - if len(v.validatingFuncs) != 0 { - t.Fatalf("validatingFuncs map expected not populated") + if len(v.validators) != 0 { + t.Fatalf("validators map expected not populated") } } From 45993ed14506cd90f9f019d5317b4df29d726e22 Mon Sep 17 00:00:00 2001 From: barco Date: Tue, 16 Apr 2024 15:23:11 +0200 Subject: [PATCH 10/13] feat: add full validation implementation for schemas --- pkg/schemas/handlers.go | 53 +----------------------------- pkg/schemas/validation.go | 69 +++++++++++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 58 deletions(-) diff --git a/pkg/schemas/handlers.go b/pkg/schemas/handlers.go index 613b0832f..1bb150c4c 100644 --- a/pkg/schemas/handlers.go +++ b/pkg/schemas/handlers.go @@ -41,57 +41,6 @@ func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { } } -/*func (a *API) validatingFunc(r *http.Request) (validator.ValidationErrors, error) { - if !shouldValidate(r) { - return nil, nil - } - - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - - if err != nil { - return nil, validation.NoBodyError - } - - // don't break existing handlers, replace the body that was consumed - r.Body = io.NopCloser(bytes.NewReader(body)) - - // key "schemas" must be there since we registered it in the setup func - endpoint, _ := validation.ApiEndpoint(r.URL.Path, a.apiKey) - - validated := false - - if isCreateOrUpdateSchema(r, endpoint) { - schema := new(kClient.IdentitySchemaContainer) - if err := json.Unmarshal(body, schema); err != nil { - return nil, err - } - - err = a.validator.Struct(schema) - validated = true - } - - if isPartialUpdate(r, endpoint) { - schema := new(DefaultSchema) - if err := json.Unmarshal(body, schema); err != nil { - return nil, err - } - - err = a.validator.Struct(schema) - validated = true - } - - if !validated { - return nil, validation.NoMatchError(a.apiKey) - } - - if err == nil { - return nil, nil - } - - return err.(validator.ValidationErrors), nil -}*/ - func (a *API) handleList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -412,7 +361,7 @@ func NewAPI(service ServiceInterface, logger logging.LoggerInterface) *API { a.apiKey = "schemas" a.service = service - //a.payloadValidator = NewSchemasPayloadValidator() + a.payloadValidator = NewSchemasPayloadValidator(a.apiKey) a.logger = logger return a diff --git a/pkg/schemas/validation.go b/pkg/schemas/validation.go index 6ed5bbd92..da625a564 100644 --- a/pkg/schemas/validation.go +++ b/pkg/schemas/validation.go @@ -4,10 +4,15 @@ package schemas import ( + "context" + "encoding/json" "net/http" "strings" "github.com/go-playground/validator/v10" + kClient "github.com/ory/kratos-client-go" + + "github.com/canonical/identity-platform-admin-ui/internal/validation" ) var ( @@ -21,14 +26,66 @@ type PayloadValidator struct { validator *validator.Validate } -func isPartialUpdate(r *http.Request, endpoint string) bool { - return strings.HasPrefix(endpoint, "/") && r.Method == http.MethodPatch +func (p *PayloadValidator) setupValidator() { + p.validator.RegisterStructValidationMapRules( + identitySchemaContainerRules, + kClient.IdentitySchemaContainer{}, + ) } -func isCreateOrUpdateSchema(r *http.Request, endpoint string) bool { - return (endpoint == "" && r.Method == http.MethodPost) || (endpoint == "/default" && r.Method == http.MethodPut) +func (_ *PayloadValidator) NeedsValidation(r *http.Request) bool { + return r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch } -func shouldValidate(r *http.Request) bool { - return r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch +func (p *PayloadValidator) isPartialUpdate(method, endpoint string) bool { + return strings.HasPrefix(endpoint, "/") && method == http.MethodPatch +} + +func (p *PayloadValidator) isCreateOrUpdateSchema(method, endpoint string) bool { + return (endpoint == "" && method == http.MethodPost) || (endpoint == "/default" && method == http.MethodPut) +} + +func (p *PayloadValidator) Validate(ctx context.Context, method, endpoint string, body []byte) (context.Context, validator.ValidationErrors, error) { + validated := false + var err error + + if p.isCreateOrUpdateSchema(method, endpoint) { + schema := new(kClient.IdentitySchemaContainer) + if err := json.Unmarshal(body, schema); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(schema) + validated = true + } + + if p.isPartialUpdate(method, endpoint) { + schema := new(DefaultSchema) + if err := json.Unmarshal(body, schema); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(schema) + validated = true + } + + if !validated { + return ctx, nil, validation.NoMatchError(p.apiKey) + } + + if err == nil { + return ctx, nil, nil + } + + return ctx, err.(validator.ValidationErrors), nil +} + +func NewSchemasPayloadValidator(apiKey string) *PayloadValidator { + p := new(PayloadValidator) + p.apiKey = apiKey + p.validator = validation.NewValidator() + + p.setupValidator() + + return p } From 81b1f8c89f9e8e29f612f952c6d4b52a31e3426d Mon Sep 17 00:00:00 2001 From: barco Date: Tue, 16 Apr 2024 15:23:28 +0200 Subject: [PATCH 11/13] test: add test for schemas validation --- pkg/schemas/handlers_test.go | 26 ++++ pkg/schemas/validation_test.go | 227 +++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 pkg/schemas/validation_test.go diff --git a/pkg/schemas/handlers_test.go b/pkg/schemas/handlers_test.go index 9c80558b3..09d74ac12 100644 --- a/pkg/schemas/handlers_test.go +++ b/pkg/schemas/handlers_test.go @@ -29,6 +29,7 @@ import ( //go:generate mockgen -build_flags=--mod=mod -package schemas -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer //go:generate mockgen -build_flags=--mod=mod -package schemas -destination ./mock_corev1.go k8s.io/client-go/kubernetes/typed/core/v1 CoreV1Interface,ConfigMapInterface //go:generate mockgen -build_flags=--mod=mod -package schemas -destination ./mock_kratos.go github.com/ory/kratos-client-go IdentityAPI +//go:generate mockgen -build_flags=--mod=mod -package schemas -destination ./mock_validation.go -source=../../internal/validation/registry.go func TestHandleListSuccess(t *testing.T) { ctrl := gomock.NewController(t) @@ -1198,3 +1199,28 @@ func TestHandleUpdateDefaultFail(t *testing.T) { t.Errorf("expected code to be %v got %v", http.StatusInternalServerError, rr.Status) } } + +func TestRegisterValidation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockService := NewMockServiceInterface(ctrl) + mockValidationRegistry := NewMockValidationRegistryInterface(ctrl) + + apiKey := "schemas" + mockValidationRegistry.EXPECT(). + RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()). + Return(nil) + mockValidationRegistry.EXPECT(). + RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()). + Return(fmt.Errorf("key is already registered")) + + // first registration of `apiKey` is successful + NewAPI(mockService, mockLogger).RegisterValidation(mockValidationRegistry) + + mockLogger.EXPECT().Fatal(gomock.Any()).Times(1) + + // second registration of `apiKey` causes logger.Fatal invocation + NewAPI(mockService, mockLogger).RegisterValidation(mockValidationRegistry) +} diff --git a/pkg/schemas/validation_test.go b/pkg/schemas/validation_test.go new file mode 100644 index 000000000..efc9bbf42 --- /dev/null +++ b/pkg/schemas/validation_test.go @@ -0,0 +1,227 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package schemas + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-playground/validator/v10" + kClient "github.com/ory/kratos-client-go" + + "github.com/canonical/identity-platform-admin-ui/internal/validation" +) + +func TestNeedsValidation(t *testing.T) { + p := new(PayloadValidator) + p.validator = validation.NewValidator() + p.setupValidator() + + for _, tt := range []struct { + name string + req *http.Request + expectedResult bool + }{ + { + name: http.MethodPost, + req: httptest.NewRequest(http.MethodPost, "/", nil), + expectedResult: true, + }, + { + name: http.MethodPut, + req: httptest.NewRequest(http.MethodPut, "/", nil), + expectedResult: true, + }, + { + name: http.MethodPatch, + req: httptest.NewRequest(http.MethodPatch, "/", nil), + expectedResult: true, + }, + { + name: http.MethodGet, + req: httptest.NewRequest(http.MethodGet, "/", nil), + expectedResult: false, + }, + { + name: http.MethodDelete, + req: httptest.NewRequest(http.MethodDelete, "/", nil), + expectedResult: false, + }, + { + name: http.MethodConnect, + req: httptest.NewRequest(http.MethodConnect, "/", nil), + expectedResult: false, + }, + { + name: http.MethodHead, + req: httptest.NewRequest(http.MethodHead, "/", nil), + expectedResult: false, + }, + { + name: http.MethodTrace, + req: httptest.NewRequest(http.MethodTrace, "/", nil), + expectedResult: false, + }, + { + name: http.MethodOptions, + req: httptest.NewRequest(http.MethodOptions, "/", nil), + expectedResult: false, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + result := p.NeedsValidation(tt.req) + + if result != tt.expectedResult { + t.Fatalf("Result doesn't match expected one, obtained %t instead of %t", result, tt.expectedResult) + } + }) + } +} + +var mockSchema = map[string]interface{}{ + "$id": "https://schemas.canonical.com/presets/kratos/test_v1.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Admin Account", + "type": "object", + "properties": map[string]interface{}{ + "traits": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "username": map[string]interface{}{ + "type": "string", + "title": "Username", + "ory.sh/kratos": map[string]interface{}{ + "credentials": map[string]interface{}{ + "password": map[string]interface{}{ + "identifier": true, + }, + }, + }, + }, + }, + }, + }, + "additionalProperties": true, +} + +func TestValidate(t *testing.T) { + p := new(PayloadValidator) + p.apiKey = "schemas" + p.validator = validation.NewValidator() + p.setupValidator() + + for _, tt := range []struct { + name string + method string + endpoint string + body func() []byte + expectedResult validator.ValidationErrors + expectedError error + }{ + { + name: "CreateOrUpdateSchemaSuccessCreate", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + updateRequest := new(kClient.IdentitySchemaContainer) + updateRequest.Schema = mockSchema + + marshal, _ := json.Marshal(updateRequest) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "CreateOrUpdateSchemaSuccessUpdateDefault", + method: http.MethodPut, + endpoint: "/default", + body: func() []byte { + id := "default" + updateRequest := new(kClient.IdentitySchemaContainer) + updateRequest.Schema = mockSchema + updateRequest.Id = &id + + marshal, _ := json.Marshal(updateRequest) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "PartialUpdateSchemaSuccess", + method: http.MethodPatch, + endpoint: "/mock-id", + body: func() []byte { + id := "mock-id" + updateRequest := new(kClient.IdentitySchemaContainer) + updateRequest.Schema = mockSchema + updateRequest.Id = &id + + marshal, _ := json.Marshal(updateRequest) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "NoMatch", + method: http.MethodPost, + endpoint: "no-match-endpoint", + body: func() []byte { + return nil + }, + expectedResult: nil, + expectedError: validation.NoMatchError(p.apiKey), + }, + { + name: "CreateOrUpdateSchemaFailure", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + id := "mock-id" + updateRequest := new(kClient.IdentitySchemaContainer) + updateRequest.Id = &id + + marshal, _ := json.Marshal(updateRequest) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + { + name: "PartialUpdateSchemaFailure", + method: http.MethodPatch, + endpoint: "/mock-id", + body: func() []byte { + id := "mock-id" + updateRequest := new(kClient.IdentitySchemaContainer) + updateRequest.Id = &id + + marshal, _ := json.Marshal(updateRequest) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + _, result, err := p.Validate(context.TODO(), tt.method, tt.endpoint, tt.body()) + + if err != nil && err.Error() != tt.expectedError.Error() { + t.Fatalf("Returned error doesn't match expected, obtained '%v' instead of '%v'", err, tt.expectedError) + } else if result != nil && errors.Is(result, tt.expectedResult) { + t.Fatalf("Returned validation errors don't match expected, obtained '%v' instead of '%v'", result, tt.expectedResult) + } else { + return + } + }) + } +} From 700cf0401d657a771e56511bd04f95cea93675e6 Mon Sep 17 00:00:00 2001 From: barco Date: Tue, 16 Apr 2024 16:07:55 +0200 Subject: [PATCH 12/13] feat: add validation implementation for `groups` --- pkg/groups/handlers.go | 78 +----------------------------- pkg/groups/validation.go | 101 ++++++++++++++++++++++++++++++++++----- 2 files changed, 91 insertions(+), 88 deletions(-) diff --git a/pkg/groups/handlers.go b/pkg/groups/handlers.go index 4eb0a62ea..eb9291dcd 100644 --- a/pkg/groups/handlers.go +++ b/pkg/groups/handlers.go @@ -83,82 +83,6 @@ func (a *API) RegisterValidation(v validation.ValidationRegistryInterface) { } } -/*func (a *API) validatingFunc(r *http.Request) (validator.ValidationErrors, error) { - if !shouldValidate(r) { - return nil, nil - } - - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - - if err != nil { - return nil, validation.NoBodyError - } - - // don't break existing handlers, replace the body that was consumed - r.Body = io.NopCloser(bytes.NewReader(body)) - - // key "identities" must be there since we registered it in the setup func - endpoint, _ := validation.ApiEndpoint(r.URL.Path, a.apiKey) - - validated := false - - if isCreateGroup(r, endpoint) { - group := new(GroupRequest) - if err := json.Unmarshal(body, group); err != nil { - return nil, err - } - - err = a.validator.Struct(group) - validated = true - } - - if isUpdateGroup(r, endpoint) { - // TODO: @barco to implement when the UpdateGroup is implemented - validated = true - } - - if isAssignRoles(r, endpoint) { - updateRoles := new(UpdateRolesRequest) - if err := json.Unmarshal(body, updateRoles); err != nil { - return nil, err - } - - err = a.validator.Struct(updateRoles) - validated = true - } - - if isAssignPermissions(r, endpoint) { - updatePermissions := new(UpdatePermissionsRequest) - if err := json.Unmarshal(body, updatePermissions); err != nil { - return nil, err - } - - err = a.validator.Struct(updatePermissions) - validated = true - } - - if isAssignIdentities(r, endpoint) { - updateIdentities := new(UpdateIdentitiesRequest) - if err := json.Unmarshal(body, updateIdentities); err != nil { - return nil, err - } - - err = a.validator.Struct(updateIdentities) - validated = true - } - - if !validated { - return nil, validation.NoMatchError(a.apiKey) - } - - if err == nil { - return nil, nil - } - - return err.(validator.ValidationErrors), nil -}*/ - func (a *API) userFromContext(ctx context.Context) *authorization.User { // TODO @shipperizer implement the FromContext and NewContext in authorization package // see snippet below copied from https://pkg.go.dev/context#Context @@ -772,7 +696,7 @@ func NewAPI(service ServiceInterface, tracer tracing.TracingInterface, monitor m a.apiKey = "groups" a.service = service - //a.payloadValidator = NewGroupsPayloadValidator() + a.payloadValidator = NewGroupsPayloadValidator(a.apiKey) a.logger = logger a.tracer = tracer a.monitor = monitor diff --git a/pkg/groups/validation.go b/pkg/groups/validation.go index 6713fcf24..e37dd08e4 100644 --- a/pkg/groups/validation.go +++ b/pkg/groups/validation.go @@ -4,30 +4,109 @@ package groups import ( + "context" + "encoding/json" "net/http" "strings" + + "github.com/go-playground/validator/v10" + + "github.com/canonical/identity-platform-admin-ui/internal/validation" ) -func shouldValidate(r *http.Request) bool { +type PayloadValidator struct { + apiKey string + validator *validator.Validate +} + +func (_ *PayloadValidator) NeedsValidation(r *http.Request) bool { return r.Method == http.MethodPost || r.Method == http.MethodPatch } -func isCreateGroup(r *http.Request, endpoint string) bool { - return r.Method == http.MethodPost && endpoint == "" +func (p *PayloadValidator) Validate(ctx context.Context, method, endpoint string, body []byte) (context.Context, validator.ValidationErrors, error) { + validated := false + var err error + + if p.isCreateGroup(method, endpoint) { + group := new(GroupRequest) + if err := json.Unmarshal(body, group); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(group) + validated = true + } + + if p.isUpdateGroup(method, endpoint) { + // TODO: @barco to implement when the UpdateGroup is implemented + validated = true + } + + if p.isAssignRoles(method, endpoint) { + updateRoles := new(UpdateRolesRequest) + if err := json.Unmarshal(body, updateRoles); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(updateRoles) + validated = true + } + + if p.isAssignPermissions(method, endpoint) { + updatePermissions := new(UpdatePermissionsRequest) + if err := json.Unmarshal(body, updatePermissions); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(updatePermissions) + validated = true + } + + if p.isAssignIdentities(method, endpoint) { + updateIdentities := new(UpdateIdentitiesRequest) + if err := json.Unmarshal(body, updateIdentities); err != nil { + return ctx, nil, err + } + + err = p.validator.Struct(updateIdentities) + validated = true + } + + if !validated { + return ctx, nil, validation.NoMatchError(p.apiKey) + } + + if err == nil { + return ctx, nil, nil + } + + return ctx, err.(validator.ValidationErrors), nil } -func isUpdateGroup(r *http.Request, endpoint string) bool { - return r.Method == http.MethodPatch && strings.HasPrefix(endpoint, "/") +func (p *PayloadValidator) isCreateGroup(method, endpoint string) bool { + return method == http.MethodPost && endpoint == "" } -func isAssignRoles(r *http.Request, endpoint string) bool { - return r.Method == http.MethodPost && strings.HasSuffix(endpoint, "/roles") +func (p *PayloadValidator) isUpdateGroup(method, endpoint string) bool { + return method == http.MethodPatch && strings.HasPrefix(endpoint, "/") } -func isAssignPermissions(r *http.Request, endpoint string) bool { - return r.Method == http.MethodPatch && strings.HasSuffix(endpoint, "/entitlements") +func (p *PayloadValidator) isAssignRoles(method, endpoint string) bool { + return method == http.MethodPost && strings.HasSuffix(endpoint, "/roles") } -func isAssignIdentities(r *http.Request, endpoint string) bool { - return r.Method == http.MethodPatch && strings.HasSuffix(endpoint, "/identities") +func (p *PayloadValidator) isAssignPermissions(method, endpoint string) bool { + return method == http.MethodPatch && strings.HasSuffix(endpoint, "/entitlements") +} + +func (p *PayloadValidator) isAssignIdentities(method, endpoint string) bool { + return method == http.MethodPatch && strings.HasSuffix(endpoint, "/identities") +} + +func NewGroupsPayloadValidator(apiKey string) *PayloadValidator { + p := new(PayloadValidator) + p.apiKey = apiKey + p.validator = validation.NewValidator() + + return p } From 0ed85198c03189461d160b15309e64e9f67d2d57 Mon Sep 17 00:00:00 2001 From: barco Date: Tue, 16 Apr 2024 16:08:17 +0200 Subject: [PATCH 13/13] test: test: add test for groups validation --- pkg/groups/handlers_test.go | 28 ++++ pkg/groups/validation_test.go | 279 ++++++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+) create mode 100644 pkg/groups/validation_test.go diff --git a/pkg/groups/handlers_test.go b/pkg/groups/handlers_test.go index 692574d86..d45a77ae2 100644 --- a/pkg/groups/handlers_test.go +++ b/pkg/groups/handlers_test.go @@ -29,6 +29,7 @@ import ( //go:generate mockgen -build_flags=--mod=mod -package groups -destination ./mock_interfaces.go -source=./interfaces.go //go:generate mockgen -build_flags=--mod=mod -package groups -destination ./mock_monitor.go -source=../../internal/monitoring/interfaces.go //go:generate mockgen -build_flags=--mod=mod -package groups -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer +//go:generate mockgen -build_flags=--mod=mod -package groups -destination ./mock_validation.go -source=../../internal/validation/registry.go // + http :8000/api/v0/groups X-Authorization:c2hpcHBlcml6ZXI= // HTTP/1.1 200 OK @@ -1985,3 +1986,30 @@ func TestHandleListIdentitiesSuccess(t *testing.T) { }) } } + +func TestRegisterValidation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockService := NewMockServiceInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockValidationRegistry := NewMockValidationRegistryInterface(ctrl) + + apiKey := "groups" + mockValidationRegistry.EXPECT(). + RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()). + Return(nil) + mockValidationRegistry.EXPECT(). + RegisterPayloadValidator(gomock.Eq(apiKey), gomock.Any()). + Return(fmt.Errorf("key is already registered")) + + // first registration of `apiKey` is successful + NewAPI(mockService, mockTracer, mockMonitor, mockLogger).RegisterValidation(mockValidationRegistry) + + mockLogger.EXPECT().Fatal(gomock.Any()).Times(1) + + // second registration of `apiKey` causes logger.Fatal invocation + NewAPI(mockService, mockTracer, mockMonitor, mockLogger).RegisterValidation(mockValidationRegistry) +} diff --git a/pkg/groups/validation_test.go b/pkg/groups/validation_test.go new file mode 100644 index 000000000..d1bfa4053 --- /dev/null +++ b/pkg/groups/validation_test.go @@ -0,0 +1,279 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package groups + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-playground/validator/v10" + + "github.com/canonical/identity-platform-admin-ui/internal/validation" +) + +func TestNeedsValidation(t *testing.T) { + p := new(PayloadValidator) + p.validator = validation.NewValidator() + + for _, tt := range []struct { + name string + req *http.Request + expectedResult bool + }{ + { + name: http.MethodPost, + req: httptest.NewRequest(http.MethodPost, "/", nil), + expectedResult: true, + }, + { + name: http.MethodPut, + req: httptest.NewRequest(http.MethodPut, "/", nil), + expectedResult: false, + }, + { + name: http.MethodPatch, + req: httptest.NewRequest(http.MethodPatch, "/", nil), + expectedResult: true, + }, + { + name: http.MethodGet, + req: httptest.NewRequest(http.MethodGet, "/", nil), + expectedResult: false, + }, + { + name: http.MethodDelete, + req: httptest.NewRequest(http.MethodDelete, "/", nil), + expectedResult: false, + }, + { + name: http.MethodConnect, + req: httptest.NewRequest(http.MethodConnect, "/", nil), + expectedResult: false, + }, + { + name: http.MethodHead, + req: httptest.NewRequest(http.MethodHead, "/", nil), + expectedResult: false, + }, + { + name: http.MethodTrace, + req: httptest.NewRequest(http.MethodTrace, "/", nil), + expectedResult: false, + }, + { + name: http.MethodOptions, + req: httptest.NewRequest(http.MethodOptions, "/", nil), + expectedResult: false, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + result := p.NeedsValidation(tt.req) + + if result != tt.expectedResult { + t.Fatalf("Result doesn't match expected one, obtained %t instead of %t", result, tt.expectedResult) + } + }) + } +} + +func TestValidate(t *testing.T) { + p := new(PayloadValidator) + p.apiKey = "groups" + p.validator = validation.NewValidator() + + for _, tt := range []struct { + name string + method string + endpoint string + body func() []byte + expectedResult validator.ValidationErrors + expectedError error + }{ + { + name: "CreateGroup", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + r := new(GroupRequest) + r.ID = "mock-id" + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "UpdateGroup", + method: http.MethodPatch, + endpoint: "/mock-id", + body: func() []byte { + id := "mock-id" + r := new(GroupRequest) + r.ID = id + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "AssignRoles", + method: http.MethodPost, + endpoint: "/mock-id/roles", + body: func() []byte { + r := new(UpdateRolesRequest) + r.Roles = []string{ + "viewer", "writer", + } + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "AssignPermissions", + method: http.MethodPatch, + endpoint: "/mock-id/entitlements", + body: func() []byte { + r := new(UpdatePermissionsRequest) + r.Permissions = []Permission{ + { + Relation: "mock-relation", + Object: "mock-object", + }, + } + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "AssignIdentities", + method: http.MethodPatch, + endpoint: "/mock-id/identities", + body: func() []byte { + r := new(UpdateIdentitiesRequest) + r.Identities = []string{ + "mock-identity", + } + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: nil, + expectedError: nil, + }, + { + name: "NoMatch", + method: http.MethodPost, + endpoint: "no-match-endpoint", + body: func() []byte { + return nil + }, + expectedResult: nil, + expectedError: validation.NoMatchError(p.apiKey), + }, + { + name: "CreateGroupFailure", + method: http.MethodPost, + endpoint: "", + body: func() []byte { + r := new(GroupRequest) + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + { + name: "UpdateGroupFailure", + method: http.MethodPatch, + endpoint: "/mock-id", + body: func() []byte { + r := new(GroupRequest) + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + { + name: "AssignRolesFailure", + method: http.MethodPost, + endpoint: "/mock-id/roles", + body: func() []byte { + r := new(UpdateRolesRequest) + r.Roles = []string{ + "viewer", "", + } + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + { + name: "AssignPermissionsFailure", + method: http.MethodPatch, + endpoint: "/mock-id/entitlements", + body: func() []byte { + r := new(UpdatePermissionsRequest) + r.Permissions = []Permission{ + { + Relation: "", + Object: "mock-object", + }, + } + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + { + name: "AssignIdentitiesFailure", + method: http.MethodPatch, + endpoint: "/mock-id/identities", + body: func() []byte { + r := new(UpdateIdentitiesRequest) + r.Identities = []string{ + "", + } + + marshal, _ := json.Marshal(r) + return marshal + }, + expectedResult: validator.ValidationErrors{}, + expectedError: nil, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + _, result, err := p.Validate(context.TODO(), tt.method, tt.endpoint, tt.body()) + + if err != nil && err.Error() != tt.expectedError.Error() { + t.Fatalf("Returned error doesn't match expected, obtained '%v' instead of '%v'", err, tt.expectedError) + } else if result != nil && errors.Is(result, tt.expectedResult) { + t.Fatalf("Returned validation errors don't match expected, obtained '%v' instead of '%v'", result, tt.expectedResult) + } else { + return + } + }) + } +}