Skip to content

Commit

Permalink
Make writeMu a channel based mutex
Browse files Browse the repository at this point in the history
Will prevent deadlock if a writer is used after close.
  • Loading branch information
nhooyr committed Feb 21, 2020
1 parent 1200707 commit 2e0dd1c
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"sync"
"time"

"github.com/klauspost/compress/flate"
Expand Down Expand Up @@ -71,7 +70,7 @@ type msgWriterState struct {
c *Conn

mu *mu
writeMu sync.Mutex
writeMu *mu

ctx context.Context
opcode opcode
Expand All @@ -83,8 +82,9 @@ type msgWriterState struct {

func newMsgWriterState(c *Conn) *msgWriterState {
mw := &msgWriterState{
c: c,
mu: newMu(c),
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
Expand Down Expand Up @@ -155,12 +155,15 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
mw.writeMu.Lock()
defer mw.writeMu.Unlock()
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()

defer func() {
err = fmt.Errorf("failed to write: %w", err)
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.c.close(err)
}
}()
Expand Down Expand Up @@ -198,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
func (mw *msgWriterState) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")

mw.writeMu.Lock()
defer mw.writeMu.Unlock()
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()

_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
Expand All @@ -219,7 +225,7 @@ func (mw *msgWriterState) close() {
putBufioWriter(mw.c.bw)
}

mw.writeMu.Lock()
mw.writeMu.forceLock()
mw.dict.close()
}

Expand Down Expand Up @@ -250,7 +256,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco

defer func() {
if err != nil {
c.close(fmt.Errorf("failed to write frame: %w", err))
err = fmt.Errorf("failed to write frame: %w", err)
c.close(err)
}
}()

Expand Down

0 comments on commit 2e0dd1c

Please sign in to comment.