Skip to content

Commit

Permalink
Merge pull request #275 from canonical/IAM-776-validating-funcs
Browse files Browse the repository at this point in the history
IAM 776 Implement validation for `groups`, `schemas`, `identities` handlers
  • Loading branch information
BarcoMasile authored Apr 17, 2024
2 parents c8bad87 + 0ed8519 commit 6cc4e95
Show file tree
Hide file tree
Showing 20 changed files with 1,368 additions and 131 deletions.
81 changes: 67 additions & 14 deletions internal/validation/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
package validation

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/go-playground/validator/v10"

"github.com/canonical/identity-platform-admin-ui/internal/http/types"
"github.com/canonical/identity-platform-admin-ui/internal/logging"
"github.com/canonical/identity-platform-admin-ui/internal/monitoring"
"github.com/canonical/identity-platform-admin-ui/internal/tracing"
Expand All @@ -20,13 +24,16 @@ const apiVersion = "v0"

type ValidationRegistryInterface interface {
ValidationMiddleware(next http.Handler) http.Handler
RegisterValidatingFunc(key string, vf ValidatingFunc) error
RegisterPayloadValidator(key string, vf PayloadValidatorInterface) error
}

type ValidatingFunc func(r *http.Request) validator.ValidationErrors
type PayloadValidatorInterface interface {
NeedsValidation(r *http.Request) bool
Validate(ctx context.Context, method, endpoint string, body []byte) (context.Context, validator.ValidationErrors, error)
}

type ValidationRegistry struct {
validatingFuncs map[string]ValidatingFunc
validators map[string]PayloadValidatorInterface

tracer tracing.TracingInterface
monitor monitoring.MonitorInterface
Expand All @@ -41,34 +48,69 @@ func (v *ValidationRegistry) ValidationMiddleware(next http.Handler) http.Handle
r = r.WithContext(ctx)
key := v.apiKey(r.URL.Path)

vf, ok := v.validatingFuncs[key]
if !ok {
payloadValidator, ok := v.validators[key]
if !ok || !payloadValidator.NeedsValidation(r) {
next.ServeHTTP(w, r)
return
}

if validationErr := vf(r); validationErr != nil {
e := NewValidationError(validationErr)
reqBody := r.Body
defer reqBody.Close()
body, err := io.ReadAll(reqBody)

if err != nil {
badRequestFromError(w, NoBodyError)
return
}

// don't break existing handlers, replace the body that was consumed
r.Body = io.NopCloser(bytes.NewReader(body))

endpoint, _ := ApiEndpoint(r.URL.Path, key)
var validationErr validator.ValidationErrors

ctx, validationErr, err = payloadValidator.Validate(r.Context(), r.Method, endpoint, body)

if err != nil {
badRequestFromError(w, err)
return
}

// handler validation errors
if validationErr != nil {
e := NewValidationError("validation errors", validationErr)

w.WriteHeader(e.Status)
_ = json.NewEncoder(w).Encode(e)
return
}

next.ServeHTTP(w, r)
// if no errors, proceed with the request
next.ServeHTTP(w, r.WithContext(ctx))
})
}

func (v *ValidationRegistry) RegisterValidatingFunc(key string, vf ValidatingFunc) error {
if vf == nil {
return fmt.Errorf("validatingFunc can't be null")
func badRequestFromError(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(
types.Response{
Message: err.Error(),
Status: http.StatusBadRequest,
},
)
return
}

func (v *ValidationRegistry) RegisterPayloadValidator(key string, payloadValidator PayloadValidatorInterface) error {
if payloadValidator == nil {
return fmt.Errorf("payloadValidator can't be null")
}

if _, ok := v.validatingFuncs[key]; ok {
if _, ok := v.validators[key]; ok {
return fmt.Errorf("key is already registered")
}

v.validatingFuncs[key] = vf
v.validators[key] = payloadValidator

return nil
}
Expand All @@ -81,9 +123,20 @@ func (v *ValidationRegistry) apiKey(endpoint string) string {
return strings.SplitN(after, "/", 1)[0]
}

// ApiEndpoint returns the endpoint string stripped from the api and version prefix, and the apikey
// it doesn't strip away trailing slash if there is one
func ApiEndpoint(endpoint, apiKey string) (string, bool) {
after, found := strings.CutPrefix(endpoint, fmt.Sprintf("/api/%s/", apiVersion))
if !found {
return "", false
}

return strings.CutPrefix(after, apiKey)
}

func NewRegistry(tracer tracing.TracingInterface, monitor monitoring.MonitorInterface, logger logging.LoggerInterface) *ValidationRegistry {
v := new(ValidationRegistry)
v.validatingFuncs = make(map[string]ValidatingFunc)
v.validators = make(map[string]PayloadValidatorInterface)

v.tracer = tracer
v.monitor = monitor
Expand Down
126 changes: 92 additions & 34 deletions internal/validation/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"

"github.com/go-playground/validator/v10"
Expand All @@ -18,6 +20,30 @@ import (
//go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_monitor.go -source=../monitoring/interfaces.go
//go:generate mockgen -build_flags=--mod=mod -package validation -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer

type payloadValidator struct{}

func (_ *payloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) {
e := mockValidationErrors()
if e == nil {
return ctx, nil, nil
}
return ctx, e, nil
}

func (_ *payloadValidator) NeedsValidation(r *http.Request) bool {
return true
}

type noopPayloadValidator struct{}

func (_ *noopPayloadValidator) Validate(ctx context.Context, _, _ string, _ []byte) (context.Context, validator.ValidationErrors, error) {
return ctx, nil, nil
}

func (_ *noopPayloadValidator) NeedsValidation(r *http.Request) bool {
return true
}

func TestValidator_Middleware(t *testing.T) {
ctrl := gomock.NewController(t)
tracer := NewMockTracer(ctrl)
Expand All @@ -34,17 +60,7 @@ func TestValidator_Middleware(t *testing.T) {
})

vld := NewRegistry(tracer, monitor, logger)
vld.validatingFuncs["mock-key"] = func(r *http.Request) validator.ValidationErrors {
type InvalidStruct struct {
FirstName string `validate:"required"`
}

e := validator.New(validator.WithRequiredStructEnabled()).Struct(InvalidStruct{})
if e == nil {
return nil
}
return e.(validator.ValidationErrors)
}
vld.validators["mock-key"] = &payloadValidator{}

for _, tt := range []struct {
name string
Expand Down Expand Up @@ -88,64 +104,83 @@ func TestValidator_Middleware(t *testing.T) {
}
}

func mockValidationErrors() validator.ValidationErrors {
type InvalidStruct struct {
FirstName string `json:"first_name" validate:"required"`
}

validate := validator.New(validator.WithRequiredStructEnabled())
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
e := validate.Struct(InvalidStruct{})
if e == nil {
return nil
}
return e.(validator.ValidationErrors)
}

func TestValidator_RegisterValidator(t *testing.T) {
ctrl := gomock.NewController(t)
tracer := NewMockTracer(ctrl)
monitor := NewMockMonitorInterface(ctrl)
logger := NewMockLoggerInterface(ctrl)

emptyValidator := &ValidationRegistry{
validatingFuncs: make(map[string]ValidatingFunc),
tracer: tracer,
monitor: monitor,
logger: logger,
validators: make(map[string]PayloadValidatorInterface),
tracer: tracer,
monitor: monitor,
logger: logger,
}

noopVf := ValidatingFunc(func(r *http.Request) validator.ValidationErrors {
return nil
})
validatingFuncs := make(map[string]ValidatingFunc)
validatingFuncs["mock-key-1"] = noopVf
noopValidator := &noopPayloadValidator{}

validators := make(map[string]PayloadValidatorInterface)
validators["mock-key-1"] = noopValidator

nonEmptyValidator := &ValidationRegistry{
validatingFuncs: validatingFuncs,
tracer: tracer,
monitor: monitor,
logger: logger,
validators: validators,
tracer: tracer,
monitor: monitor,
logger: logger,
}

for _, tt := range []struct {
name string
validator *ValidationRegistry
prefix string
vf ValidatingFunc
v PayloadValidatorInterface
expected string
}{
{
name: "Nil middleware",
validator: emptyValidator,
prefix: "",
vf: nil,
expected: "validatingFunc can't be null",
v: nil,
expected: "payloadValidator can't be null",
},
{
name: "Existing key",
validator: nonEmptyValidator,
prefix: "mock-key-1",
vf: noopVf,
v: noopValidator,
expected: "key is already registered",
},
{
name: "Success",
validator: emptyValidator,
prefix: "mock-key",
vf: noopVf,
v: noopValidator,
expected: "",
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
result := tt.validator.RegisterValidatingFunc(tt.prefix, tt.vf)
result := tt.validator.RegisterPayloadValidator(tt.prefix, tt.v)

if tt.expected == "" && nil == result {
return
Expand Down Expand Up @@ -178,11 +213,34 @@ func TestNewValidator(t *testing.T) {
t.FailNow()
}

if v.validatingFuncs == nil {
t.Fatalf("validatingFuncs map expected not empty")
if v.validators == nil {
t.Fatalf("validators map expected not empty")
}

if len(v.validators) != 0 {
t.Fatalf("validators map expected not populated")
}
}

func TestNewValidationError(t *testing.T) {
ve := mockValidationErrors()
response := NewValidationError("validation errors", ve)

if response.Status != http.StatusBadRequest {
t.Fatalf("response status does not match expected")
}

if response.Message != "validation errors" {
t.Fatalf("response message does not match expected")
}

expectedData := map[string][]string{
"first_name": {
"value '' fails validation of type `required`",
},
}

if len(v.validatingFuncs) != 0 {
t.Fatalf("validatingFuncs map expected not populated")
if !reflect.DeepEqual(expectedData, response.Data) {
t.Fatalf("response data does not match expected validation errors")
}
}
Loading

0 comments on commit 6cc4e95

Please sign in to comment.