From 2fe59d4b6a03904196be2bd1697d7bbe5d4cca42 Mon Sep 17 00:00:00 2001 From: marcello33 Date: Fri, 6 Dec 2024 10:48:02 +0100 Subject: [PATCH] fix: test --- p2p/discover/v5wire/encoding.go | 40 ++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index 7c1429554e..c4c1910276 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -538,7 +538,7 @@ func (c *Codec) decodeWhoareyou(head *Header, headerData []byte) (Packet, error) func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData, msgData []byte) (n *enode.Node, p Packet, err error) { node, auth, session, err := c.decodeHandshake(fromAddr, head) if err != nil { - if auth != nil { + if auth != nil && auth.isHandshakeAuthDataValid() { c.sc.deleteHandshake(auth.h.SrcID, fromAddr) } return nil, nil, err @@ -547,7 +547,9 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData // Decrypt the message using the new session keys. msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, session.readKey) if err != nil { - c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + if auth != nil && auth.isHandshakeAuthDataValid() { + c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + } return node, msg, err } @@ -559,39 +561,39 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData } func (c *Codec) decodeHandshake(fromAddr string, head *Header) (n *enode.Node, auth *handshakeAuthData, s *session, err error) { - tempAuth := &handshakeAuthData{} - if *tempAuth, err = c.decodeHandshakeAuthData(head); err != nil { + var tempAuth handshakeAuthData + if tempAuth, err = c.decodeHandshakeAuthData(head); err != nil { return nil, nil, nil, err } + auth = &tempAuth // Verify against our last WHOAREYOU. - challenge := c.sc.getHandshake(tempAuth.h.SrcID, fromAddr) + challenge := c.sc.getHandshake(auth.h.SrcID, fromAddr) if challenge == nil { - return nil, nil, nil, errUnexpectedHandshake + return nil, auth, nil, errUnexpectedHandshake } // Get node record. - n, err = c.decodeHandshakeRecord(challenge.Node, tempAuth.h.SrcID, tempAuth.record) + n, err = c.decodeHandshakeRecord(challenge.Node, auth.h.SrcID, auth.record) if err != nil { - return nil, nil, nil, err + return nil, auth, nil, err } // Verify ID nonce signature. - sig := tempAuth.signature + sig := auth.signature cdata := challenge.ChallengeData - - err = verifyIDSignature(c.sha256, sig, n, cdata, tempAuth.pubkey, c.localnode.ID()) + err = verifyIDSignature(c.sha256, sig, n, cdata, auth.pubkey, c.localnode.ID()) if err != nil { - return nil, nil, nil, err + return nil, auth, nil, err } // Verify ephemeral key is on curve - ephkey, err := DecodePubkey(c.privkey.Curve, tempAuth.pubkey) + ephkey, err := DecodePubkey(c.privkey.Curve, auth.pubkey) if err != nil { - return nil, nil, nil, errInvalidAuthKey + return nil, auth, nil, errInvalidAuthKey } // Derive session keys. - session := deriveKeys(sha256.New, c.privkey, ephkey, tempAuth.h.SrcID, c.localnode.ID(), cdata) + session := deriveKeys(sha256.New, c.privkey, ephkey, auth.h.SrcID, c.localnode.ID(), cdata) session = session.keysFlipped() - return n, tempAuth, session, nil + return n, auth, session, nil } // decodeHandshakeAuthData reads the authdata section of a handshake packet. @@ -732,3 +734,9 @@ func bytesCopy(r *bytes.Buffer) []byte { return b } + +// isValid checks if handshakeAuthData is valid +func (auth *handshakeAuthData) isHandshakeAuthDataValid() bool { + // Conditions for the auth to be considered valid + return auth != nil && len(auth.signature) > 0 && len(auth.pubkey) > 0 && auth.h.SrcID != (enode.ID{}) +}