From 0924825a43650818868635f5f9484ff62ab278df Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 18 Jan 2019 16:02:20 -0800 Subject: [PATCH] wire: cache the non-witness serialization of MsgTx to memoize part of TxHash In this commit, we add a new field to the `MsgTx` struct: `cachedSeralizedNoWitness`. As we decode the main transaction, we use an `io.TeeReader` to copy over the non-witness bytes into this new field. As a result, we can fully cache all tx serialization when computing the TxHash. This has been shown to show up on profiles during IBD. Caching this value allows us to optimize TxHash calculation across the entire daemon as a whole. --- wire/message_test.go | 24 ++++++++++++++ wire/msgtx.go | 74 +++++++++++++++++++++++++++++++++----------- wire/msgtx_test.go | 53 +++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 18 deletions(-) diff --git a/wire/message_test.go b/wire/message_test.go index 7ba2e0639f..8b0187e4cd 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -137,6 +137,18 @@ func TestMessage(t *testing.T) { spew.Sdump(msg)) continue } + + // Blank out the cached encoding for transactions to ensure the + // deep equality check doesn't fail. + if tx, ok := msg.(*MsgTx); ok { + tx.cachedSeralizedNoWitness = nil + } + if block, ok := msg.(*MsgBlock); ok { + for _, tx := range block.Transactions { + tx.cachedSeralizedNoWitness = nil + } + } + if !reflect.DeepEqual(msg, test.out) { t.Errorf("ReadMessage #%d\n got: %v want: %v", i, spew.Sdump(msg), spew.Sdump(test.out)) @@ -170,6 +182,18 @@ func TestMessage(t *testing.T) { spew.Sdump(msg)) continue } + + // Blank out the cached encoding for transactions to ensure the + // deep equality check doesn't fail. + if tx, ok := msg.(*MsgTx); ok { + tx.cachedSeralizedNoWitness = nil + } + if block, ok := msg.(*MsgBlock); ok { + for _, tx := range block.Transactions { + tx.cachedSeralizedNoWitness = nil + } + } + if !reflect.DeepEqual(msg, test.out) { t.Errorf("ReadMessage #%d\n got: %v want: %v", i, spew.Sdump(msg), spew.Sdump(test.out)) diff --git a/wire/msgtx.go b/wire/msgtx.go index 7705504cc8..21e191d37f 100644 --- a/wire/msgtx.go +++ b/wire/msgtx.go @@ -343,6 +343,12 @@ type MsgTx struct { TxIn []*TxIn TxOut []*TxOut LockTime uint32 + + // cachedSeralizedNoWitness is a cached version of the serialization of + // this transaction without witness data. When we decode a transaction, + // we'll write out the non-witness bytes to this so we can quickly + // calculate the TxHash later if needed. + cachedSeralizedNoWitness []byte } // AddTxIn adds a transaction input to the message. @@ -357,13 +363,19 @@ func (msg *MsgTx) AddTxOut(to *TxOut) { // TxHash generates the Hash for the transaction. func (msg *MsgTx) TxHash() chainhash.Hash { - // Encode the transaction and calculate double sha256 on the result. - // Ignore the error returns since the only way the encode could fail - // is being out of memory or due to nil pointers, both of which would - // cause a run-time panic. - buf := bytes.NewBuffer(make([]byte, 0, msg.SerializeSizeStripped())) - _ = msg.SerializeNoWitness(buf) - return chainhash.DoubleHashH(buf.Bytes()) + if msg.cachedSeralizedNoWitness == nil { + // Encode the transaction and calculate double sha256 on the + // result. Ignore the error returns since the only way the + // encode could fail is being out of memory or due to nil + // pointers, both of which would cause a run-time panic. + strippedSize := msg.SerializeSizeStripped() + buf := bytes.NewBuffer(make([]byte, 0, strippedSize)) + _ = msg.SerializeNoWitness(buf) + + msg.cachedSeralizedNoWitness = buf.Bytes() + } + + return chainhash.DoubleHashH(msg.cachedSeralizedNoWitness) } // WitnessHash generates the hash of the transaction serialized according to @@ -461,7 +473,14 @@ func (msg *MsgTx) Copy() *MsgTx { // See Deserialize for decoding transactions stored to disk, such as in a // database, as opposed to decoding transactions from the wire. func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error { - version, err := binarySerializer.Uint32(r, littleEndian) + // We'll use a tee reader in order to incrementally cache the raw + // non-witness serialization of this transaction. We'll then later + // cache this value as it allow to compute the TxHash more quickly, as + // we don't need to re-serialize the entire transaction. + var rawTxBuf bytes.Buffer + rawTxTeeReader := io.TeeReader(r, &rawTxBuf) + + version, err := binarySerializer.Uint32(rawTxTeeReader, littleEndian) if err != nil { return err } @@ -472,12 +491,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error return err } - // A count of zero (meaning no TxIn's to the uninitiated) means that the - // value is a TxFlagMarker, and hence indicates the presence of a flag. - var flag [1]TxFlag + // A count of zero (meaning no TxIn's to the uninitiated) indicates + // this is a transaction with witness data. Notice that we don't use + // the rawTxTeeReader here, as these are segwit specific bytes. + var ( + flag [1]byte + hasWitneess bool + ) if count == TxFlagMarker && enc == WitnessEncoding { - // The count varint was in fact the flag marker byte. Next, we need to - // read the flag value, which is a single byte. + // Next, we need to read the flag, which is a single byte. if _, err = io.ReadFull(r, flag[:]); err != nil { return err } @@ -495,6 +517,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error if err != nil { return err } + + hasWitneess = true + } + + // Write out the actual number of inputs as this won't be the very byte + // series after the versino of segwit transactions. + if WriteVarInt(&rawTxBuf, pver, count); err != nil { + str := fmt.Sprintf("unable to write txin count: %v", err) + return messageError("MsgTx.BtcDecode", str) } // Prevent more input transactions than could possibly fit into a @@ -545,7 +576,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error // and needs to be returned to the pool on error. ti := &txIns[i] msg.TxIn[i] = ti - err = readTxIn(r, pver, msg.Version, ti) + err = readTxIn(rawTxTeeReader, pver, msg.Version, ti) if err != nil { returnScriptBuffers() return err @@ -553,7 +584,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error totalScriptSize += uint64(len(ti.SignatureScript)) } - count, err = ReadVarInt(r, pver) + count, err = ReadVarInt(rawTxTeeReader, pver) if err != nil { returnScriptBuffers() return err @@ -578,7 +609,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error // and needs to be returned to the pool on error. to := &txOuts[i] msg.TxOut[i] = to - err = ReadTxOut(r, pver, msg.Version, to) + err = ReadTxOut(rawTxTeeReader, pver, msg.Version, to) if err != nil { returnScriptBuffers() return err @@ -588,7 +619,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error // If the transaction's flag byte isn't 0x00 at this point, then one or // more of its inputs has accompanying witness data. - if flag[0] != 0 && enc == WitnessEncoding { + if hasWitneess && enc == WitnessEncoding { for _, txin := range msg.TxIn { // For each input, the witness is encoded as a stack // with one or more items. Therefore, we first read a @@ -626,7 +657,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error } } - msg.LockTime, err = binarySerializer.Uint32(r, littleEndian) + msg.LockTime, err = binarySerializer.Uint32( + rawTxTeeReader, littleEndian, + ) if err != nil { returnScriptBuffers() return err @@ -700,6 +733,11 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error scriptPool.Return(pkScript) } + // Now that we've decoded the entire transaction without any issues, + // we'll cache the non-witness serialization so we can more quickly + // calculate the TxHash in the future. + msg.cachedSeralizedNoWitness = rawTxBuf.Bytes() + return nil } diff --git a/wire/msgtx_test.go b/wire/msgtx_test.go index 5ec753b62d..77a6986453 100644 --- a/wire/msgtx_test.go +++ b/wire/msgtx_test.go @@ -181,6 +181,13 @@ func TestTxHash(t *testing.T) { t.Errorf("TxHash: wrong hash - got %v, want %v", spew.Sprint(txHash), spew.Sprint(wantHash)) } + + // Compute it again to ensure any cached elements, are valid. + txHash = msgTx.TxHash() + if !txHash.IsEqual(wantHash) { + t.Errorf("TxHash: wrong hash - got %v, want %v", + spew.Sprint(txHash), spew.Sprint(wantHash)) + } } // TestTxSha tests the ability to generate the wtxid, and txid of a transaction @@ -258,6 +265,18 @@ func TestWTxSha(t *testing.T) { t.Errorf("WTxSha: wrong hash - got %v, want %v", spew.Sprint(wtxid), spew.Sprint(wantHashWTxid)) } + + // Compute the values again to ensure any cached elements are valid. + txid = msgTx.TxHash() + if !txid.IsEqual(wantHashTxid) { + t.Errorf("TxSha: wrong hash - got %v, want %v", + spew.Sprint(txid), spew.Sprint(wantHashTxid)) + } + wtxid = msgTx.WitnessHash() + if !wtxid.IsEqual(wantHashWTxid) { + t.Errorf("WTxSha: wrong hash - got %v, want %v", + spew.Sprint(wtxid), spew.Sprint(wantHashWTxid)) + } } // TestTxWire tests the MsgTx wire encode and decode for various numbers @@ -393,6 +412,23 @@ func TestTxWire(t *testing.T) { t.Errorf("BtcDecode #%d error %v", i, err) continue } + + // If this is the base encoding, then ensure that the cached + // serialization properly matches the raw encoding. + if test.enc == BaseEncoding { + if !bytes.Equal( + test.buf, msg.cachedSeralizedNoWitness, + ) { + t.Errorf("BtcdDecode #%d: cached encoding "+ + "is wrong, expected %x got %x", i, + test.buf, + msg.cachedSeralizedNoWitness) + continue + } + } + + msg.cachedSeralizedNoWitness = nil + if !reflect.DeepEqual(&msg, test.out) { t.Errorf("BtcDecode #%d\n got: %s want: %s", i, spew.Sdump(&msg), spew.Sdump(test.out)) @@ -539,6 +575,23 @@ func TestTxSerialize(t *testing.T) { t.Errorf("Deserialize #%d error %v", i, err) continue } + + // Ensure that the raw non-witness encoding matches the cached + // non-witness encoding bytes. + var b bytes.Buffer + if err := tx.SerializeNoWitness(&b); err != nil { + t.Errorf("Deserialize #%d: unable to encode: %v", i, err) + } + if !bytes.Equal(b.Bytes(), tx.cachedSeralizedNoWitness) { + t.Errorf("Deserialize #%d: cached encoding "+ + "is wrong, expected %x got %x", i, + b.Bytes(), + tx.cachedSeralizedNoWitness) + continue + } + + tx.cachedSeralizedNoWitness = nil + if !reflect.DeepEqual(&tx, test.out) { t.Errorf("Deserialize #%d\n got: %s want: %s", i, spew.Sdump(&tx), spew.Sdump(test.out))