Skip to content

Commit

Permalink
[fastws] fix reading large frames
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsannm committed May 27, 2024
1 parent a9e59fa commit 1603117
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 144 deletions.
2 changes: 1 addition & 1 deletion example/ex-01-rpc/bench.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
tcpkali --ws -c 1000 -m '{"hdr":{"cmd":"echoRequest"},"payload":{"randomID": 1234}}' -r 10 127.0.0.1:80 -T 30
tcpkali --ws -c 100 -m '{"hdr":{"cmd":"echoRequest"},"payload":{"randomID": 1234}}' -r 10 127.0.0.1:80 -T 30
3 changes: 3 additions & 0 deletions std/gateways/fastws/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/clubpay/ronykit/kit"
"github.com/clubpay/ronykit/kit/common"
"github.com/clubpay/ronykit/kit/errors"
"github.com/gobwas/ws"
"github.com/panjf2000/gnet/v2"
)

Expand All @@ -26,6 +27,7 @@ type bundle struct {
routes map[string]*routeData
rpcInFactory kit.IncomingRPCFactory
rpcOutFactory kit.OutgoingRPCFactory
writeMode ws.OpCode
}

var _ kit.Gateway = (*bundle)(nil)
Expand All @@ -37,6 +39,7 @@ func New(opts ...Option) (kit.Gateway, error) {
rpcInFactory: common.SimpleIncomingJSONRPC,
rpcOutFactory: common.SimpleOutgoingJSONRPC,
l: common.NewNopLogger(),
writeMode: ws.OpText,
}
b.eh = newGateway(b)

Expand Down
249 changes: 192 additions & 57 deletions std/gateways/fastws/conn.go
Original file line number Diff line number Diff line change
@@ -1,103 +1,238 @@
package fastws

import (
"bytes"
"io"

"github.com/clubpay/ronykit/kit"
"github.com/clubpay/ronykit/kit/errors"
"github.com/clubpay/ronykit/kit/utils"
"github.com/clubpay/ronykit/kit/utils/buf"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/panjf2000/gnet/v2"
"github.com/panjf2000/gnet/v2/pkg/buffer/ring"
)

const ringBufInitialSize = 4 << 10

type wsConn struct {
utils.SpinLock

id uint64
kv map[string]string
c gnet.Conn
r *wsutil.Reader
w *wsutil.Writer
handshakeDone bool
rpcOutFactory kit.OutgoingRPCFactory
msgs []wsutil.Message
clientIP string

// websocketCodec
handshakeDone bool
readBuff *ring.Buffer
msgBuff *buf.Bytes
currHead *ws.Header
w *wsutil.Writer
}

var _ kit.Conn = (*wsConn)(nil)

func newWebsocketConn(
id uint64, c gnet.Conn,
rpcOutFactory kit.OutgoingRPCFactory,
writeMode ws.OpCode, l kit.Logger,
) *wsConn {
wsc := &wsConn{
w: wsutil.NewWriter(c, ws.StateServerSide, ws.OpText),
w: wsutil.NewWriter(c, ws.StateServerSide, writeMode),
id: id,
kv: map[string]string{},
c: c,
readBuff: ring.New(ringBufInitialSize),
msgBuff: buf.GetCap(ringBufInitialSize),
rpcOutFactory: rpcOutFactory,
}

wsc.r = &wsutil.Reader{
Source: c,
State: ws.StateServerSide,
CheckUTF8: true,
OnIntermediate: func(hdr ws.Header, src io.Reader) error {
if hdr.OpCode.IsControl() {
return wsutil.ControlHandler{
Src: wsc.r,
Dst: wsc.c,
State: wsc.r.State,
DisableSrcCiphering: true,
}.Handle(hdr)
}
if addr := c.RemoteAddr(); addr != nil {
wsc.clientIP = addr.String()
}

bts, err := io.ReadAll(src)
if err != nil {
return err
}
wsc.msgs = append(wsc.msgs, wsutil.Message{OpCode: hdr.OpCode, Payload: bts})
return wsc
}

func (wsc *wsConn) readBuffer(c gnet.Conn) error {
buff, err := c.Next(c.InboundBuffered())
if err != nil {
return err
}

_, err = wsc.readBuff.Write(buff)

return err
}

func (wsc *wsConn) isUpgraded() bool {
return wsc.handshakeDone
}

func (wsc *wsConn) upgrade(c gnet.Conn) error {
sp := acquireSwitchProtocol()
if _, err := sp.Upgrade(c); err != nil {
return err
}
releaseSwitchProtocol(sp)

wsc.handshakeDone = true

return nil
}

func (wsc *wsConn) nextHeader() error {
if wsc.currHead != nil {
return nil
}

if wsc.readBuff.Buffered() < ws.MinHeaderSize {
return nil
}

// we can read the header for sure
if wsc.readBuff.Buffered() >= ws.MaxHeaderSize {
head, err := ws.ReadHeader(wsc.readBuff)
if err != nil {
return err
}

wsc.currHead = &head

return nil
}

// we need to check if there is header in the buffer
tmp := bytes.NewReader(wsc.readBuff.Bytes())
preLen := tmp.Len()
head, err := ws.ReadHeader(tmp)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return nil
}

return err
}

skipN := preLen - tmp.Len()
_, _ = wsc.readBuff.Discard(skipN)

wsc.currHead = &head

return nil
}

func (wsc *wsConn) handleControlMessage(c gnet.Conn) error {
buff := buf.GetLen(int(wsc.currHead.Length))
defer buff.Release()

if wsc.currHead.Length > 0 {
_, err := wsc.readBuff.Read(*buff.Bytes())
if err != nil {
return err
}
}

err := wsutil.HandleClientControlMessage(
c,
wsutil.Message{
OpCode: wsc.currHead.OpCode,
Payload: *buff.Bytes(),
},
)
if err != nil {
return err
}

return wsc
return nil
}

func (c *wsConn) Close() {
_ = c.c.Close()
func (wsc *wsConn) executeMessages(c gnet.Conn, d kit.GatewayDelegate) error {
for {
err := wsc.nextHeader()
if err != nil {
return err
}

if wsc.currHead == nil {
return nil
}

// if it is a control message then let's handle it
if wsc.currHead.Fin && wsc.currHead.OpCode.IsControl() {
err = wsc.handleControlMessage(c)
if err != nil {
return err
}

wsc.currHead = nil

continue
}

dataLen := int(wsc.currHead.Length)
if dataLen > 0 {
if dataLen > wsc.readBuff.Buffered() {
return nil
}

tmpBuff := buf.GetLen(8192)
stPos := wsc.msgBuff.Len()
written, err := io.CopyBuffer(wsc.msgBuff, io.LimitReader(wsc.readBuff, wsc.currHead.Length), *tmpBuff.Bytes())
tmpBuff.Release()
if err != nil {
return err
}
if written < wsc.currHead.Length && err == nil {
// src stopped early; must have been EOF.
return io.EOF
}

endPos := wsc.msgBuff.Len()
ws.Cipher(utils.PtrVal(wsc.msgBuff.Bytes())[stPos:endPos], wsc.currHead.Mask, 0)
}

if wsc.currHead.Fin {
go wsc.execMessage(d, wsc.msgBuff)
wsc.msgBuff = buf.GetCap(wsc.msgBuff.Cap())
}

// reset the current head
wsc.currHead = nil
}
}

func (c *wsConn) ConnID() uint64 {
return c.id
func (wsc *wsConn) execMessage(d kit.GatewayDelegate, msgBuff *buf.Bytes) {

d.OnMessage(wsc, *msgBuff.Bytes())
msgBuff.Release()
}

func (c *wsConn) ClientIP() string {
addr := c.c.RemoteAddr()
if addr == nil {
return ""
}
func (wsc *wsConn) ConnID() uint64 {
return wsc.id
}

return addr.String()
func (wsc *wsConn) ClientIP() string {
return wsc.clientIP
}

func (c *wsConn) Write(data []byte) (int, error) {
c.Lock()
defer c.Unlock()
func (wsc *wsConn) Write(data []byte) (int, error) {
wsc.Lock()
defer wsc.Unlock()

n, err := c.w.Write(data)
n, err := wsc.w.Write(data)
if err != nil {
return n, err
}

err = c.w.Flush()
err = wsc.w.Flush()

return n, err
}

func (c *wsConn) WriteEnvelope(e *kit.Envelope) error {
outC := c.rpcOutFactory()
func (wsc *wsConn) WriteEnvelope(e *kit.Envelope) error {
outC := wsc.rpcOutFactory()
outC.InjectMessage(e.GetMsg())
outC.SetID(e.GetID())
e.WalkHdr(func(key string, val string) bool {
Expand All @@ -111,37 +246,37 @@ func (c *wsConn) WriteEnvelope(e *kit.Envelope) error {
return errors.Wrap(kit.ErrEncodeOutgoingMessageFailed, err)
}

_, err = c.Write(data)
_, err = wsc.Write(data)
outC.Release()

return err
}

func (c *wsConn) Stream() bool {
func (wsc *wsConn) Stream() bool {
return true
}

func (c *wsConn) Walk(f func(key string, val string) bool) {
c.Lock()
defer c.Unlock()
func (wsc *wsConn) Walk(f func(key string, val string) bool) {
wsc.Lock()
defer wsc.Unlock()

for k, v := range c.kv {
for k, v := range wsc.kv {
if !f(k, v) {
return
}
}
}

func (c *wsConn) Get(key string) string {
c.Lock()
v := c.kv[key]
c.Unlock()
func (wsc *wsConn) Get(key string) string {
wsc.Lock()
v := wsc.kv[key]
wsc.Unlock()

return v
}

func (c *wsConn) Set(key string, val string) {
c.Lock()
c.kv[key] = val
c.Unlock()
func (wsc *wsConn) Set(key string, val string) {
wsc.Lock()
wsc.kv[key] = val
wsc.Unlock()
}
Loading

0 comments on commit 1603117

Please sign in to comment.