Skip to content

Commit

Permalink
fix: prevent infinite recursion for binding (#1016)
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF authored Dec 7, 2023
1 parent 21aae1d commit 03768f4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
34 changes: 34 additions & 0 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,40 @@ func Test_Issue964(t *testing.T) {
}
}

type reqSameType struct {
Parent *reqSameType `json:"parent"`
Children []reqSameType `json:"children"`
Foo1 reqSameType2 `json:"foo1"`
A string `json:"a"`
}

type reqSameType2 struct {
Foo1 *reqSameType `json:"foo1"`
}

func TestBind_Issue1015(t *testing.T) {
req := newMockRequest().
SetJSONContentType().
SetBody([]byte(`{"parent":{"parent":{}, "children":[{},{}], "foo1":{"foo1":{}}}, "children":[{},{}], "a":"asd"}`))

var result reqSameType

err := DefaultBinder().Bind(req.Req, &result, nil)
if err != nil {
t.Error(err)
}
assert.NotNil(t, result.Parent)
assert.NotNil(t, result.Parent.Parent)
assert.Nil(t, result.Parent.Parent.Parent)
assert.NotNil(t, result.Parent.Children)
assert.DeepEqual(t, 2, len(result.Parent.Children))
assert.NotNil(t, result.Parent.Foo1.Foo1)
assert.DeepEqual(t, "", result.Parent.A)
assert.DeepEqual(t, 2, len(result.Children))
assert.Nil(t, result.Foo1.Foo1)
assert.DeepEqual(t, "asd", result.A)
}

func Benchmark_Binding(b *testing.B) {
type Req struct {
Version string `path:"v"`
Expand Down
49 changes: 37 additions & 12 deletions pkg/app/server/binding/internal/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder
continue
}

dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag, config)
dec, needValidate2, err := getFieldDecoder(parentInfos{[]reflect.Type{el}, []int{}, ""}, el.Field(i), i, byTag, config)
if err != nil {
return nil, false, err
}
Expand All @@ -103,7 +103,13 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder
}, needValidate, nil
}

func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
type parentInfos struct {
Types []reflect.Type
Indexes []int
JSONName string
}

func getFieldDecoder(pInfo parentInfos, field reflect.StructField, index int, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
for field.Type.Kind() == reflect.Ptr {
field.Type = field.Type.Elem()
}
Expand All @@ -116,7 +122,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare
}

// JSONName is like 'a.b.c' for 'required validate'
fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, config)
fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, pInfo.JSONName, config)
if len(fieldTagInfos) == 0 && !config.DisableDefaultTag {
fieldTagInfos = getDefaultFieldTags(field)
}
Expand All @@ -126,19 +132,19 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare

// customized type decoder has the highest priority
if customizedFunc, exist := config.TypeUnmarshalFuncs[field.Type]; exist {
dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc, config)
dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, customizedFunc, config)
return dec, needValidate, err
}

// slice/array field decoder
if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array {
dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx, config)
dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, config)
return dec, needValidate, err
}

// map filed decoder
if field.Type.Kind() == reflect.Map {
dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config)
dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, pInfo.Indexes, config)
return dec, needValidate, err
}

Expand All @@ -149,11 +155,11 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare
// todo: more built-in common struct binding, ex. time...
switch el {
case reflect.TypeOf(multipart.FileHeader{}): // file binding
dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx, config)
dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, pInfo.Indexes, config)
return dec, needValidate, err
}
if !config.DisableStructFieldResolve { // decode struct type separately
structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx, config)
structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, config)
if err != nil {
return nil, needValidate, err
}
Expand All @@ -162,17 +168,26 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare
}
}

// prevent infinite recursion when struct field with the same name as a struct
if hasSameType(pInfo.Types, el) {
return decoders, needValidate, nil
}

pIdx := pInfo.Indexes
for i := 0; i < el.NumField(); i++ {
if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous {
// ignore unexported field
continue
}
var idxes []int
if len(parentIdx) > 0 {
idxes = append(idxes, parentIdx...)
if len(pInfo.Indexes) > 0 {
idxes = append(idxes, pIdx...)
}
idxes = append(idxes, index)
dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag, config)
pInfo.Indexes = idxes
pInfo.Types = append(pInfo.Types, el)
pInfo.JSONName = newParentJSONName
dec, needValidate2, err := getFieldDecoder(pInfo, el.Field(i), i, byTag, config)
needValidate = needValidate || needValidate2
if err != nil {
return nil, false, err
Expand All @@ -186,6 +201,16 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare
}

// base type decoder
dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config)
dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, pInfo.Indexes, config)
return dec, needValidate, err
}

// hasSameType determine if the same type is present in the parent-child relationship
func hasSameType(pts []reflect.Type, ft reflect.Type) bool {
for _, pt := range pts {
if reflect.DeepEqual(getElemType(pt), getElemType(ft)) {
return true
}
}
return false
}

0 comments on commit 03768f4

Please sign in to comment.