Skip to content

Commit

Permalink
Ensure connection is closed at all error points
Browse files Browse the repository at this point in the history
Closes #191
  • Loading branch information
nhooyr committed Feb 20, 2020
1 parent 43c4dc0 commit 1200707
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
26 changes: 12 additions & 14 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
defer c.readMu.unlock()

if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}

h, err := c.readLoop(ctx)
Expand Down Expand Up @@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
}

func (mr *msgReader) Read(p []byte) (n int, err error) {
defer func() {
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
err = io.EOF
}
if errors.Is(err, io.EOF) {
err = io.EOF
mr.putFlateReader()
return
}
errd.Wrap(&err, "failed to read")
}()

err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()

Expand All @@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
}
return n, err
}

Expand Down
19 changes: 15 additions & 4 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,16 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write")

mw.writeMu.Lock()
defer mw.writeMu.Unlock()

defer func() {
err = fmt.Errorf("failed to write: %w", err)
if err != nil {
mw.c.close(err)
}
}()

if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
Expand Down Expand Up @@ -230,8 +235,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
}

// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
err := c.writeFrameMu.lock(ctx)
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
Expand All @@ -243,6 +248,12 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
case c.writeTimeout <- ctx:
}

defer func() {
if err != nil {
c.close(fmt.Errorf("failed to write frame: %w", err))
}
}()

c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
Expand Down

0 comments on commit 1200707

Please sign in to comment.