diff --git a/example/ex-01-rpc/bench.sh b/example/ex-01-rpc/bench.sh index 5d9d0a51..31dd646f 100644 --- a/example/ex-01-rpc/bench.sh +++ b/example/ex-01-rpc/bench.sh @@ -1 +1 @@ -tcpkali --ws -c 1000 -m '{"hdr":{"cmd":"echoRequest"},"payload":{"randomID": 1234}}' -r 10 127.0.0.1:80 -T 30 +tcpkali --ws -c 100 -m '{"hdr":{"cmd":"echoRequest"},"payload":{"randomID": 1234}}' -r 10 127.0.0.1:80 -T 30 diff --git a/std/gateways/fastws/bundle.go b/std/gateways/fastws/bundle.go index 3ec58a79..6470d123 100644 --- a/std/gateways/fastws/bundle.go +++ b/std/gateways/fastws/bundle.go @@ -7,6 +7,7 @@ import ( "github.com/clubpay/ronykit/kit" "github.com/clubpay/ronykit/kit/common" "github.com/clubpay/ronykit/kit/errors" + "github.com/gobwas/ws" "github.com/panjf2000/gnet/v2" ) @@ -26,6 +27,7 @@ type bundle struct { routes map[string]*routeData rpcInFactory kit.IncomingRPCFactory rpcOutFactory kit.OutgoingRPCFactory + writeMode ws.OpCode } var _ kit.Gateway = (*bundle)(nil) @@ -37,6 +39,7 @@ func New(opts ...Option) (kit.Gateway, error) { rpcInFactory: common.SimpleIncomingJSONRPC, rpcOutFactory: common.SimpleOutgoingJSONRPC, l: common.NewNopLogger(), + writeMode: ws.OpText, } b.eh = newGateway(b) diff --git a/std/gateways/fastws/conn.go b/std/gateways/fastws/conn.go index ebe18b0b..7ca6be49 100644 --- a/std/gateways/fastws/conn.go +++ b/std/gateways/fastws/conn.go @@ -1,27 +1,35 @@ package fastws import ( + "bytes" "io" "github.com/clubpay/ronykit/kit" "github.com/clubpay/ronykit/kit/errors" "github.com/clubpay/ronykit/kit/utils" + "github.com/clubpay/ronykit/kit/utils/buf" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/panjf2000/gnet/v2" + "github.com/panjf2000/gnet/v2/pkg/buffer/ring" ) +const ringBufInitialSize = 4 << 10 + type wsConn struct { utils.SpinLock id uint64 kv map[string]string - c gnet.Conn - r *wsutil.Reader - w *wsutil.Writer - handshakeDone bool rpcOutFactory kit.OutgoingRPCFactory - msgs []wsutil.Message + clientIP string + + // websocketCodec + handshakeDone bool + readBuff *ring.Buffer + msgBuff *buf.Bytes + currHead *ws.Header + w *wsutil.Writer } var _ kit.Conn = (*wsConn)(nil) @@ -29,75 +37,202 @@ var _ kit.Conn = (*wsConn)(nil) func newWebsocketConn( id uint64, c gnet.Conn, rpcOutFactory kit.OutgoingRPCFactory, + writeMode ws.OpCode, l kit.Logger, ) *wsConn { wsc := &wsConn{ - w: wsutil.NewWriter(c, ws.StateServerSide, ws.OpText), + w: wsutil.NewWriter(c, ws.StateServerSide, writeMode), id: id, kv: map[string]string{}, - c: c, + readBuff: ring.New(ringBufInitialSize), + msgBuff: buf.GetCap(ringBufInitialSize), rpcOutFactory: rpcOutFactory, } - wsc.r = &wsutil.Reader{ - Source: c, - State: ws.StateServerSide, - CheckUTF8: true, - OnIntermediate: func(hdr ws.Header, src io.Reader) error { - if hdr.OpCode.IsControl() { - return wsutil.ControlHandler{ - Src: wsc.r, - Dst: wsc.c, - State: wsc.r.State, - DisableSrcCiphering: true, - }.Handle(hdr) - } + if addr := c.RemoteAddr(); addr != nil { + wsc.clientIP = addr.String() + } - bts, err := io.ReadAll(src) - if err != nil { - return err - } - wsc.msgs = append(wsc.msgs, wsutil.Message{OpCode: hdr.OpCode, Payload: bts}) + return wsc +} + +func (wsc *wsConn) readBuffer(c gnet.Conn) error { + buff, err := c.Next(c.InboundBuffered()) + if err != nil { + return err + } + + _, err = wsc.readBuff.Write(buff) + + return err +} +func (wsc *wsConn) isUpgraded() bool { + return wsc.handshakeDone +} + +func (wsc *wsConn) upgrade(c gnet.Conn) error { + sp := acquireSwitchProtocol() + if _, err := sp.Upgrade(c); err != nil { + return err + } + releaseSwitchProtocol(sp) + + wsc.handshakeDone = true + + return nil +} + +func (wsc *wsConn) nextHeader() error { + if wsc.currHead != nil { + return nil + } + + if wsc.readBuff.Buffered() < ws.MinHeaderSize { + return nil + } + + // we can read the header for sure + if wsc.readBuff.Buffered() >= ws.MaxHeaderSize { + head, err := ws.ReadHeader(wsc.readBuff) + if err != nil { + return err + } + + wsc.currHead = &head + + return nil + } + + // we need to check if there is header in the buffer + tmp := bytes.NewReader(wsc.readBuff.Bytes()) + preLen := tmp.Len() + head, err := ws.ReadHeader(tmp) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { return nil + } + + return err + } + + skipN := preLen - tmp.Len() + _, _ = wsc.readBuff.Discard(skipN) + + wsc.currHead = &head + + return nil +} + +func (wsc *wsConn) handleControlMessage(c gnet.Conn) error { + buff := buf.GetLen(int(wsc.currHead.Length)) + defer buff.Release() + + if wsc.currHead.Length > 0 { + _, err := wsc.readBuff.Read(*buff.Bytes()) + if err != nil { + return err + } + } + + err := wsutil.HandleClientControlMessage( + c, + wsutil.Message{ + OpCode: wsc.currHead.OpCode, + Payload: *buff.Bytes(), }, + ) + if err != nil { + return err } - return wsc + return nil } -func (c *wsConn) Close() { - _ = c.c.Close() +func (wsc *wsConn) executeMessages(c gnet.Conn, d kit.GatewayDelegate) error { + for { + err := wsc.nextHeader() + if err != nil { + return err + } + + if wsc.currHead == nil { + return nil + } + + // if it is a control message then let's handle it + if wsc.currHead.Fin && wsc.currHead.OpCode.IsControl() { + err = wsc.handleControlMessage(c) + if err != nil { + return err + } + + wsc.currHead = nil + + continue + } + + dataLen := int(wsc.currHead.Length) + if dataLen > 0 { + if dataLen > wsc.readBuff.Buffered() { + return nil + } + + tmpBuff := buf.GetLen(8192) + stPos := wsc.msgBuff.Len() + written, err := io.CopyBuffer(wsc.msgBuff, io.LimitReader(wsc.readBuff, wsc.currHead.Length), *tmpBuff.Bytes()) + tmpBuff.Release() + if err != nil { + return err + } + if written < wsc.currHead.Length && err == nil { + // src stopped early; must have been EOF. + return io.EOF + } + + endPos := wsc.msgBuff.Len() + ws.Cipher(utils.PtrVal(wsc.msgBuff.Bytes())[stPos:endPos], wsc.currHead.Mask, 0) + } + + if wsc.currHead.Fin { + go wsc.execMessage(d, wsc.msgBuff) + wsc.msgBuff = buf.GetCap(wsc.msgBuff.Cap()) + } + + // reset the current head + wsc.currHead = nil + } } -func (c *wsConn) ConnID() uint64 { - return c.id +func (wsc *wsConn) execMessage(d kit.GatewayDelegate, msgBuff *buf.Bytes) { + + d.OnMessage(wsc, *msgBuff.Bytes()) + msgBuff.Release() } -func (c *wsConn) ClientIP() string { - addr := c.c.RemoteAddr() - if addr == nil { - return "" - } +func (wsc *wsConn) ConnID() uint64 { + return wsc.id +} - return addr.String() +func (wsc *wsConn) ClientIP() string { + return wsc.clientIP } -func (c *wsConn) Write(data []byte) (int, error) { - c.Lock() - defer c.Unlock() +func (wsc *wsConn) Write(data []byte) (int, error) { + wsc.Lock() + defer wsc.Unlock() - n, err := c.w.Write(data) + n, err := wsc.w.Write(data) if err != nil { return n, err } - err = c.w.Flush() + err = wsc.w.Flush() return n, err } -func (c *wsConn) WriteEnvelope(e *kit.Envelope) error { - outC := c.rpcOutFactory() +func (wsc *wsConn) WriteEnvelope(e *kit.Envelope) error { + outC := wsc.rpcOutFactory() outC.InjectMessage(e.GetMsg()) outC.SetID(e.GetID()) e.WalkHdr(func(key string, val string) bool { @@ -111,37 +246,37 @@ func (c *wsConn) WriteEnvelope(e *kit.Envelope) error { return errors.Wrap(kit.ErrEncodeOutgoingMessageFailed, err) } - _, err = c.Write(data) + _, err = wsc.Write(data) outC.Release() return err } -func (c *wsConn) Stream() bool { +func (wsc *wsConn) Stream() bool { return true } -func (c *wsConn) Walk(f func(key string, val string) bool) { - c.Lock() - defer c.Unlock() +func (wsc *wsConn) Walk(f func(key string, val string) bool) { + wsc.Lock() + defer wsc.Unlock() - for k, v := range c.kv { + for k, v := range wsc.kv { if !f(k, v) { return } } } -func (c *wsConn) Get(key string) string { - c.Lock() - v := c.kv[key] - c.Unlock() +func (wsc *wsConn) Get(key string) string { + wsc.Lock() + v := wsc.kv[key] + wsc.Unlock() return v } -func (c *wsConn) Set(key string, val string) { - c.Lock() - c.kv[key] = val - c.Unlock() +func (wsc *wsConn) Set(key string, val string) { + wsc.Lock() + wsc.kv[key] = val + wsc.Unlock() } diff --git a/std/gateways/fastws/gateway.go b/std/gateways/fastws/gateway.go index 4dd04920..76b05e70 100644 --- a/std/gateways/fastws/gateway.go +++ b/std/gateways/fastws/gateway.go @@ -2,16 +2,13 @@ package fastws import ( "bytes" - builtinErr "errors" "io" "net/http" "sync" "sync/atomic" "time" - "github.com/clubpay/ronykit/kit" "github.com/clubpay/ronykit/kit/utils" - "github.com/clubpay/ronykit/kit/utils/buf" "github.com/gobwas/ws" "github.com/panjf2000/gnet/v2" ) @@ -33,11 +30,6 @@ func newGateway(b *bundle) *gateway { return gw } -func (gw *gateway) reactFunc(wsc kit.Conn, payload *buf.Bytes, n int) { - gw.b.d.OnMessage(wsc, (*payload.Bytes())[:n]) - payload.Release() -} - func (gw *gateway) getConnWrap(conn gnet.Conn) *wsConn { connID, ok := conn.Context().(uint64) if !ok { @@ -57,7 +49,13 @@ func (gw *gateway) OnBoot(_ gnet.Engine) (action gnet.Action) { func (gw *gateway) OnShutdown(_ gnet.Engine) {} func (gw *gateway) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { - wsc := newWebsocketConn(atomic.AddUint64(&gw.nextID, 1), c, gw.b.rpcOutFactory) + wsc := newWebsocketConn( + atomic.AddUint64(&gw.nextID, 1), + c, + gw.b.rpcOutFactory, + gw.b.writeMode, + gw.b.l, + ) c.SetContext(wsc.id) gw.Lock() @@ -92,91 +90,31 @@ func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } - if !wsc.handshakeDone { - sp := acquireSwitchProtocol() - _, err := sp.Upgrade(wsc.c) + if !wsc.isUpgraded() { + err := wsc.upgrade(c) if err != nil { - wsc.Close() gw.b.l.Debugf("faild to upgrade websocket connID(%d): %v", utils.TryCast[uint64](c.Context()), err) return gnet.Close } - releaseSwitchProtocol(sp) - - wsc.handshakeDone = true return gnet.None } - var ( - err error - hdr ws.Header - ) - - for { - hdr, err = wsc.r.NextFrame() - if err != nil { - if builtinErr.Is(err, io.EOF) { - return gnet.None - } - gw.b.l.Debugf("failed to read next frame of connID(%d): %v", utils.TryCast[int64](c.Context()), err) - - return gnet.Close - } - - if hdr.OpCode.IsControl() { - if err = wsc.r.OnIntermediate(hdr, wsc.r); err != nil { - gw.b.l.Debugf( - "failed to handle control message of connID(%d), opCode(%d): %v", - utils.TryCast[int64](c.Context()), hdr.OpCode, err, - ) - - return gnet.Close - } - - if err = wsc.r.Discard(); err != nil { - gw.b.l.Debugf( - "failed to discard on control message connID(%d): %v", - utils.TryCast[int64](c.Context()), err, - ) - - return gnet.Close - } - - return gnet.None - } - - if hdr.OpCode&(ws.OpText|ws.OpBinary) != hdr.OpCode { - if err = wsc.r.Discard(); err != nil { - return gnet.Close - } - - continue - } + err := wsc.readBuffer(c) + if err != nil { + gw.b.l.Debugf("faild to read buffer websocket connID(%d): %v", utils.TryCast[uint64](c.Context()), err) - break + return gnet.Close } - var pBuff *buf.Bytes - if hdr.Fin { - // No more frames will be read. Use fixed sized buffer to read payload. - // It is not possible to receive io.EOF here because Reader does not - // return EOF if frame payload was successfully fetched. - pBuff = buf.GetLen(int(hdr.Length)) - _, err = io.ReadFull(wsc.r, *pBuff.Bytes()) - } else { - // create a default buffer cap, since we don't know the exact size of payload - pBuff = buf.GetCap(8192) - buff := bytes.NewBuffer(*pBuff.Bytes()) - _, err = buff.ReadFrom(wsc.r) - pBuff.SetBytes(utils.ValPtr(buff.Bytes())) - } + err = wsc.executeMessages(c, gw.b.d) if err != nil { + gw.b.l.Debugf("failed to execute message connID(%d): %v", utils.TryCast[uint64](c.Context()), err) + return gnet.Close } - go gw.reactFunc(wsc, pBuff, pBuff.Len()) - return gnet.None } diff --git a/std/gateways/fastws/options.go b/std/gateways/fastws/options.go index 6ece946a..702f55ff 100644 --- a/std/gateways/fastws/options.go +++ b/std/gateways/fastws/options.go @@ -1,6 +1,9 @@ package fastws -import "github.com/clubpay/ronykit/kit" +import ( + "github.com/clubpay/ronykit/kit" + "github.com/gobwas/ws" +) type Option func(b *bundle) @@ -31,3 +34,18 @@ func WithCustomRPC(in kit.IncomingRPCFactory, out kit.OutgoingRPCFactory) Option b.rpcOutFactory = out } } + +// WithWebsocketTextMode sets the write mode of the websocket to text opCode +// This is the default behavior +func WithWebsocketTextMode() Option { + return func(b *bundle) { + b.writeMode = ws.OpText + } +} + +// WithWebsocketBinaryMode sets the write mode of the websocket to binary opCode +func WithWebsocketBinaryMode() Option { + return func(b *bundle) { + b.writeMode = ws.OpBinary + } +} diff --git a/testenv/fastws_test.go b/testenv/fastws_test.go index 07426795..1c565b03 100644 --- a/testenv/fastws_test.go +++ b/testenv/fastws_test.go @@ -24,8 +24,8 @@ func TestFastWS(t *testing.T) { Convey("Kit with FastWS", t, func(c C) { testCases := map[string]func(t *testing.T, opt fx.Option) func(c C){ "Edge Server With Huge Websocket Payload": fastwsWithHugePayload, - //"Edge Server With Ping and Small Payload": fastwsWithPingAndSmallPayload, - //"Edge Server With Ping Only": fastwsWithPingOnly, + "Edge Server With Ping and Small Payload": fastwsWithPingAndSmallPayload, + "Edge Server With Ping Only": fastwsWithPingOnly, } for title, fn := range testCases { Convey(title, @@ -100,7 +100,7 @@ func fastwsWithHugePayload(t *testing.T, opt fx.Option) func(c C) { c.So(wsCtx.Connect(ctx, "/"), ShouldBeNil) for i := 0; i < 10; i++ { - req := &services.EchoRequest{Input: utils.RandomID(8192)} + req := &services.EchoRequest{Input: utils.RandomID(10000)} res := &services.EchoResponse{} err := wsCtx.BinaryMessage( ctx, "echo", req, res, @@ -109,9 +109,7 @@ func fastwsWithHugePayload(t *testing.T, opt fx.Option) func(c C) { c.So(msg.(*services.EchoResponse).Output, ShouldEqual, req.Input) //nolint:forcetypeassert }, ) - _ = err - //_, _ = c.Println("Error: ", err) - // c.So(err, ShouldBeNil) + c.So(err, ShouldBeNil) time.Sleep(time.Second * 2) } } @@ -141,7 +139,7 @@ func fastwsWithPingAndSmallPayload(t *testing.T, opt fx.Option) func(c C) { c.So(wsCtx.Connect(ctx, "/"), ShouldBeNil) for i := 0; i < 10; i++ { - req := &services.EchoRequest{Input: utils.RandomID(1024)} + req := &services.EchoRequest{Input: utils.RandomID(32)} res := &services.EchoResponse{} err := wsCtx.BinaryMessage( ctx, "echo", req, res,