diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index 7c1429554e..c4c1910276 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -538,7 +538,7 @@ func (c *Codec) decodeWhoareyou(head *Header, headerData []byte) (Packet, error) func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData, msgData []byte) (n *enode.Node, p Packet, err error) { node, auth, session, err := c.decodeHandshake(fromAddr, head) if err != nil { - if auth != nil { + if auth != nil && auth.isHandshakeAuthDataValid() { c.sc.deleteHandshake(auth.h.SrcID, fromAddr) } return nil, nil, err @@ -547,7 +547,9 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData // Decrypt the message using the new session keys. msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, session.readKey) if err != nil { - c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + if auth != nil && auth.isHandshakeAuthDataValid() { + c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + } return node, msg, err } @@ -559,39 +561,39 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData } func (c *Codec) decodeHandshake(fromAddr string, head *Header) (n *enode.Node, auth *handshakeAuthData, s *session, err error) { - tempAuth := &handshakeAuthData{} - if *tempAuth, err = c.decodeHandshakeAuthData(head); err != nil { + var tempAuth handshakeAuthData + if tempAuth, err = c.decodeHandshakeAuthData(head); err != nil { return nil, nil, nil, err } + auth = &tempAuth // Verify against our last WHOAREYOU. - challenge := c.sc.getHandshake(tempAuth.h.SrcID, fromAddr) + challenge := c.sc.getHandshake(auth.h.SrcID, fromAddr) if challenge == nil { - return nil, nil, nil, errUnexpectedHandshake + return nil, auth, nil, errUnexpectedHandshake } // Get node record. - n, err = c.decodeHandshakeRecord(challenge.Node, tempAuth.h.SrcID, tempAuth.record) + n, err = c.decodeHandshakeRecord(challenge.Node, auth.h.SrcID, auth.record) if err != nil { - return nil, nil, nil, err + return nil, auth, nil, err } // Verify ID nonce signature. - sig := tempAuth.signature + sig := auth.signature cdata := challenge.ChallengeData - - err = verifyIDSignature(c.sha256, sig, n, cdata, tempAuth.pubkey, c.localnode.ID()) + err = verifyIDSignature(c.sha256, sig, n, cdata, auth.pubkey, c.localnode.ID()) if err != nil { - return nil, nil, nil, err + return nil, auth, nil, err } // Verify ephemeral key is on curve - ephkey, err := DecodePubkey(c.privkey.Curve, tempAuth.pubkey) + ephkey, err := DecodePubkey(c.privkey.Curve, auth.pubkey) if err != nil { - return nil, nil, nil, errInvalidAuthKey + return nil, auth, nil, errInvalidAuthKey } // Derive session keys. - session := deriveKeys(sha256.New, c.privkey, ephkey, tempAuth.h.SrcID, c.localnode.ID(), cdata) + session := deriveKeys(sha256.New, c.privkey, ephkey, auth.h.SrcID, c.localnode.ID(), cdata) session = session.keysFlipped() - return n, tempAuth, session, nil + return n, auth, session, nil } // decodeHandshakeAuthData reads the authdata section of a handshake packet. @@ -732,3 +734,9 @@ func bytesCopy(r *bytes.Buffer) []byte { return b } + +// isValid checks if handshakeAuthData is valid +func (auth *handshakeAuthData) isHandshakeAuthDataValid() bool { + // Conditions for the auth to be considered valid + return auth != nil && len(auth.signature) > 0 && len(auth.pubkey) > 0 && auth.h.SrcID != (enode.ID{}) +}