From 2e0dd1c74967c77af530fb33e0762344d044033c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 20 Feb 2020 19:03:09 -0500 Subject: [PATCH] Make writeMu a channel based mutex Will prevent deadlock if a writer is used after close. --- write.go | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/write.go b/write.go index d94486e2..baa5e6e2 100644 --- a/write.go +++ b/write.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "sync" "time" "github.com/klauspost/compress/flate" @@ -71,7 +70,7 @@ type msgWriterState struct { c *Conn mu *mu - writeMu sync.Mutex + writeMu *mu ctx context.Context opcode opcode @@ -83,8 +82,9 @@ type msgWriterState struct { func newMsgWriterState(c *Conn) *msgWriterState { mw := &msgWriterState{ - c: c, - mu: newMu(c), + c: c, + mu: newMu(c), + writeMu: newMu(c), } return mw } @@ -155,12 +155,15 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { - mw.writeMu.Lock() - defer mw.writeMu.Unlock() + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return 0, fmt.Errorf("failed to write: %w", err) + } + defer mw.writeMu.unlock() defer func() { - err = fmt.Errorf("failed to write: %w", err) if err != nil { + err = fmt.Errorf("failed to write: %w", err) mw.c.close(err) } }() @@ -198,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) { func (mw *msgWriterState) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") - mw.writeMu.Lock() - defer mw.writeMu.Unlock() + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return err + } + defer mw.writeMu.unlock() _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { @@ -219,7 +225,7 @@ func (mw *msgWriterState) close() { putBufioWriter(mw.c.bw) } - mw.writeMu.Lock() + mw.writeMu.forceLock() mw.dict.close() } @@ -250,7 +256,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco defer func() { if err != nil { - c.close(fmt.Errorf("failed to write frame: %w", err)) + err = fmt.Errorf("failed to write frame: %w", err) + c.close(err) } }()