Skip to content

Commit

Permalink
test: improve hertz/pkg/routeut unit test coverage (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
HzTTT authored Nov 26, 2023
1 parent 9c3f0b7 commit 89e9721
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/app/server/binding/internal/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder
}, needValidate, nil
}

func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
for field.Type.Kind() == reflect.Ptr {
field.Type = field.Type.Elem()
}
Expand Down
175 changes: 175 additions & 0 deletions pkg/route/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server/binding"
"github.com/cloudwego/hertz/pkg/app/server/registry"
"github.com/cloudwego/hertz/pkg/common/config"
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
Expand All @@ -63,6 +64,7 @@ import (
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/cloudwego/hertz/pkg/protocol/suite"
"github.com/cloudwego/hertz/pkg/route/param"
)

Expand Down Expand Up @@ -854,3 +856,176 @@ func TestCustomValidator(t *testing.T) {
})
performRequest(e, "GET", "/validate?a=2")
}

var errTestDeregsitry = fmt.Errorf("test deregsitry error")

type mockDeregsitryErr struct{}

var _ registry.Registry = &mockDeregsitryErr{}

func (e mockDeregsitryErr) Register(*registry.Info) error {
return nil
}

func (e mockDeregsitryErr) Deregister(*registry.Info) error {
return errTestDeregsitry
}

func TestEngineShutdown(t *testing.T) {
defaultTransporter = standard.NewTransporter
mockCtxCallback := func(ctx context.Context) {}
// Test case 1: serve not running error
engine := NewEngine(config.NewOptions(nil))
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
err := engine.Shutdown(ctx1)
assert.DeepEqual(t, errStatusNotRunning, err)

// Test case 2: serve successfully running and shutdown
engine = NewEngine(config.NewOptions(nil))
engine.OnShutdown = []CtxCallback{mockCtxCallback}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)

ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
err = engine.Shutdown(ctx2)
assert.Nil(t, err)
assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status))

// Test case 3: serve successfully running and shutdown with deregistry error
engine = NewEngine(config.NewOptions(nil))
engine.OnShutdown = []CtxCallback{mockCtxCallback}
engine.options.Registry = &mockDeregsitryErr{}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)

ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second)
defer cancel3()
err = engine.Shutdown(ctx3)
assert.DeepEqual(t, errTestDeregsitry, err)
assert.DeepEqual(t, statusShutdown, atomic.LoadUint32(&engine.status))
}

type mockStreamer struct{}

type mockProtocolServer struct{}

func (s *mockStreamer) Serve(c context.Context, conn network.StreamConn) error {
return nil
}

func (s *mockProtocolServer) Serve(c context.Context, conn network.Conn) error {
return nil
}

type mockStreamConn struct {
network.StreamConn
version string
}

var _ network.StreamConn = &mockStreamConn{}

func (m *mockStreamConn) GetVersion() uint32 {
return network.Version1
}

func TestEngineServeStream(t *testing.T) {
engine := &Engine{
options: &config.Options{
ALPN: true,
TLS: &tls.Config{},
},
protocolStreamServers: map[string]protocol.StreamServer{
suite.HTTP3: &mockStreamer{},
},
}

// Test ALPN path
conn := &mockStreamConn{version: suite.HTTP3}
err := engine.ServeStream(context.Background(), conn)
assert.Nil(t, err)

// Test default path
engine.options.ALPN = false
conn = &mockStreamConn{}
err = engine.ServeStream(context.Background(), conn)
assert.Nil(t, err)

// Test unsupported protocol
engine.protocolStreamServers = map[string]protocol.StreamServer{}
conn = &mockStreamConn{}
err = engine.ServeStream(context.Background(), conn)
assert.DeepEqual(t, errs.ErrNotSupportProtocol, err)
}

func TestEngineServe(t *testing.T) {
engine := NewEngine(config.NewOptions(nil))
engine.protocolServers[suite.HTTP1] = &mockProtocolServer{}
engine.protocolServers[suite.HTTP2] = &mockProtocolServer{}

// test H2C path
ctx := context.Background()
conn := mock.NewConn("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
engine.options.H2C = true
err := engine.Serve(ctx, conn)
assert.Nil(t, err)

// test ALPN path
ctx = context.Background()
conn = mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
engine.options.H2C = false
engine.options.ALPN = true
engine.options.TLS = &tls.Config{}
err = engine.Serve(ctx, conn)
assert.Nil(t, err)

// test HTTP1 path
engine.options.ALPN = false
err = engine.Serve(ctx, conn)
assert.Nil(t, err)
}

func TestOndata(t *testing.T) {
ctx := context.Background()
engine := NewEngine(config.NewOptions(nil))

// test stream conn
streamConn := &mockStreamConn{version: suite.HTTP3}
engine.protocolStreamServers[suite.HTTP3] = &mockStreamer{}
err := engine.onData(ctx, streamConn)
assert.Nil(t, err)

// test conn
conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
engine.protocolServers[suite.HTTP1] = &mockProtocolServer{}
err = engine.onData(ctx, conn)
assert.Nil(t, err)
}

func TestAcquireHijackConn(t *testing.T) {
engine := &Engine{
NoHijackConnPool: false,
}
// test conn pool
conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
hijackConn := engine.acquireHijackConn(conn)
assert.NotNil(t, hijackConn)
assert.NotNil(t, hijackConn.Conn)
assert.DeepEqual(t, engine, hijackConn.e)
assert.DeepEqual(t, conn, hijackConn.Conn)

// test no conn pool
engine.NoHijackConnPool = true
hijackConn = engine.acquireHijackConn(conn)
assert.NotNil(t, hijackConn)
assert.NotNil(t, hijackConn.Conn)
assert.DeepEqual(t, engine, hijackConn.e)
assert.DeepEqual(t, conn, hijackConn.Conn)
}

0 comments on commit 89e9721

Please sign in to comment.