From 0a225ca34bcc4fc4cf6e7f713e747b260ef2db89 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 21 Nov 2024 16:35:06 +0900 Subject: [PATCH] remove some use of errBadConnNoWrite --- buffer.go | 5 +++++ connection.go | 10 +++++++--- packets.go | 44 ++++++++++++++------------------------------ 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/buffer.go b/buffer.go index 0774c5c8c..2dc898724 100644 --- a/buffer.go +++ b/buffer.go @@ -43,6 +43,11 @@ func newBuffer(nc net.Conn) buffer { } } +// busy retruns true if the buffer contains some read data. +func (b *buffer) busy() bool { + return b.length > 0 +} + // flip replaces the active buffer with the background buffer // this is a delayed flip that simply increases the buffer counter; // the actual flip will be performed the next time we call `buffer.fill` diff --git a/connection.go b/connection.go index ef6fc9e40..95f05c92f 100644 --- a/connection.go +++ b/connection.go @@ -121,10 +121,14 @@ func (mc *mysqlConn) Close() (err error) { if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } + mc.close() + return +} +// close closes the network connection and cleare results without sending COM_QUIT. +func (mc *mysqlConn) close() { mc.cleanup() mc.clearResult() - return } // Closes the network connection and unsets internal variables. Do not call this @@ -637,7 +641,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.Load() { + if mc.closed.Load() || mc.buf.busy() { return driver.ErrBadConn } @@ -671,7 +675,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.Load() + return !mc.closed.Load() && !mc.buf.busy() } var _ driver.SessionResetter = &mysqlConn{} diff --git a/packets.go b/packets.go index a2e7ef95c..4695fb81a 100644 --- a/packets.go +++ b/packets.go @@ -32,11 +32,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet header data, err := mc.buf.readNext(4) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } mc.log(err) - mc.Close() return nil, ErrInvalidConn } @@ -45,7 +45,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { - mc.Close() + mc.close() if data[3] > mc.sequence { return nil, ErrPktSyncMul } @@ -59,7 +59,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // there was no previous packet if prevData == nil { mc.log(ErrMalformPkt) - mc.Close() + mc.close() return nil, ErrInvalidConn } @@ -69,11 +69,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } mc.log(err) - mc.Close() return nil, ErrInvalidConn } @@ -125,10 +125,10 @@ func (mc *mysqlConn) writePacket(data []byte) error { n, err := mc.netConn.Write(data[:4+size]) if err != nil { + mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { return cerr } - mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet mc.log(err) @@ -162,11 +162,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { data, err = mc.readPacket() if err != nil { - // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since - // in connection initialization we don't risk retrying non-idempotent actions. - if err == ErrInvalidConn { - return nil, "", driver.ErrBadConn - } return } @@ -312,9 +307,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Calculate packet length and get buffer with that size data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // ClientFlags [32 bit] @@ -404,9 +398,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) data, err := mc.buf.takeBuffer(pktLen) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add the auth data [EOF] @@ -424,9 +417,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -443,9 +434,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -464,9 +453,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -1007,9 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In this case the len(data) == cap(data) which is used to optimise the flow below. } if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // command [1 byte] @@ -1207,8 +1192,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.log(err) - return errBadConnNoWrite + return err } }