diff --git a/rskj-core/src/main/java/co/rsk/trie/TrieDTO.java b/rskj-core/src/main/java/co/rsk/trie/TrieDTO.java index 64bfa97ba60..d7ca5fc3fdb 100644 --- a/rskj-core/src/main/java/co/rsk/trie/TrieDTO.java +++ b/rskj-core/src/main/java/co/rsk/trie/TrieDTO.java @@ -1,6 +1,6 @@ /* * This file is part of RskJ - * Copyright (C) 2017 RSK Labs Ltd. + * Copyright (C) 2023 RSK Labs Ltd. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by @@ -77,9 +77,6 @@ public class TrieDTO { private TrieDTO rightNode; private byte[] hash; - public TrieDTO() { - } - public static TrieDTO decodeFromMessage(byte[] src, TrieStore ds) { return decodeFromMessage(src, ds, false, null); } @@ -106,33 +103,10 @@ public static TrieDTO decodeFromMessage(byte[] src, TrieStore ds, boolean preloa result.path = pathTuple != null ? pathTuple.getValue() : null; //(*optional) 3.left - if present & !embedded => hash - - if (result.leftNodePresent && result.leftNodeEmbedded) { - result.leftNode = TrieDTO.decodeFromMessage(readChildEmbedded(srcWrap, decodeUint8(), Uint8.BYTES), ds, false, hash); - result.left = result.leftNode.getEncoded(); - result.leftHash = null; - encoder.write(encodeUint24(result.left.length)); - encoder.write(result.left); - } else if (result.leftNodePresent) { - byte[] valueHash = new byte[Keccak256Helper.DEFAULT_SIZE_BYTES]; - srcWrap.get(valueHash); - result.left = preloadChildren ? valueHash : null; - result.leftHash = valueHash; - } + handleLeft(result, srcWrap, encoder, ds, preloadChildren, hash); //(*optional) 3.right - if present & !embedded => hash - if (result.rightNodePresent && result.rightNodeEmbedded) { - result.rightNode = TrieDTO.decodeFromMessage(readChildEmbedded(srcWrap, decodeUint8(), Uint8.BYTES), ds, false, hash); - result.right = result.rightNode.getEncoded(); - result.rightHash = null; - encoder.write(encodeUint24(result.right.length)); - encoder.write(result.right); - } else if (result.rightNodePresent) { - byte[] valueHash = new byte[Keccak256Helper.DEFAULT_SIZE_BYTES]; - srcWrap.get(valueHash); - result.right = preloadChildren ? valueHash : null; - result.rightHash = valueHash; - } + handleRight(result, srcWrap, encoder, ds, preloadChildren, hash); result.childrenSize = new VarInt(0); if (result.leftNodePresent || result.rightNodePresent) { @@ -140,22 +114,7 @@ public static TrieDTO decodeFromMessage(byte[] src, TrieStore ds, boolean preloa result.childrenSize = readVarInt(srcWrap, encoder); } - if (result.hasLongVal) { - byte[] valueHashBytes = new byte[Keccak256Helper.DEFAULT_SIZE_BYTES]; - srcWrap.get(valueHashBytes); - byte[] lvalueBytes = new byte[Uint24.BYTES]; - srcWrap.get(lvalueBytes); - byte[] value = ds.retrieveValue(valueHashBytes); - encoder.write(value); - result.value = value; - } else { - int remaining = srcWrap.remaining(); - byte[] value = new byte[remaining]; - srcWrap.get(value); - //(*optional) 5.value - if !longValue => value - encoder.write(value); - result.value = value; - } + handleValue(result, srcWrap, encoder, ds); if (srcWrap.hasRemaining()) { throw new IllegalArgumentException("The srcWrap had more data than expected"); @@ -169,6 +128,55 @@ public static TrieDTO decodeFromMessage(byte[] src, TrieStore ds, boolean preloa return result; } + private static void handleLeft(TrieDTO result, ByteBuffer srcWrap, ByteArrayOutputStream encoder, TrieStore ds, boolean preloadChildren, byte[] hash) throws IOException { + if (result.leftNodePresent && result.leftNodeEmbedded) { + result.leftNode = TrieDTO.decodeFromMessage(readChildEmbedded(srcWrap, decodeUint8(), Uint8.BYTES), ds, false, hash); + result.left = result.leftNode.getEncoded(); + result.leftHash = null; + encoder.write(encodeUint24(result.left.length)); + encoder.write(result.left); + } else if (result.leftNodePresent) { + byte[] valueHash = new byte[Keccak256Helper.DEFAULT_SIZE_BYTES]; + srcWrap.get(valueHash); + result.left = preloadChildren ? valueHash : null; + result.leftHash = valueHash; + } + } + private static void handleRight(TrieDTO result, ByteBuffer srcWrap, ByteArrayOutputStream encoder, TrieStore ds, boolean preloadChildren, byte[] hash) throws IOException { + if (result.rightNodePresent && result.rightNodeEmbedded) { + result.rightNode = TrieDTO.decodeFromMessage(readChildEmbedded(srcWrap, decodeUint8(), Uint8.BYTES), ds, false, hash); + result.right = result.rightNode.getEncoded(); + result.rightHash = null; + encoder.write(encodeUint24(result.right.length)); + encoder.write(result.right); + } else if (result.rightNodePresent) { + byte[] valueHash = new byte[Keccak256Helper.DEFAULT_SIZE_BYTES]; + srcWrap.get(valueHash); + result.right = preloadChildren ? valueHash : null; + result.rightHash = valueHash; + } + + } + + private static void handleValue(TrieDTO result, ByteBuffer srcWrap, ByteArrayOutputStream encoder, TrieStore ds) throws IOException { + if (result.hasLongVal) { + byte[] valueHashBytes = new byte[Keccak256Helper.DEFAULT_SIZE_BYTES]; + srcWrap.get(valueHashBytes); + byte[] lvalueBytes = new byte[Uint24.BYTES]; + srcWrap.get(lvalueBytes); + byte[] value = ds.retrieveValue(valueHashBytes); + encoder.write(value); + result.value = value; + } else { + int remaining = srcWrap.remaining(); + byte[] value = new byte[remaining]; + srcWrap.get(value); + //(*optional) 5.value - if !longValue => value + encoder.write(value); + result.value = value; + } + + } public static TrieDTO decodeFromSync(byte[] src) { TrieDTO result = new TrieDTO(); try { @@ -215,7 +223,7 @@ public static TrieDTO decodeFromSync(byte[] src) { return result; } - private static byte[] readChildEmbedded(ByteBuffer srcWrap, Function decode, int uintBytes) throws IOException { + private static byte[] readChildEmbedded(ByteBuffer srcWrap, Function decode, int uintBytes) { byte[] lengthBytes = new byte[uintBytes]; srcWrap.get(lengthBytes); byte[] serializedNode = decode.apply(lengthBytes); @@ -238,9 +246,9 @@ public VarInt getChildrenSize() { public long getSize() { long externalValueLength = this.hasLongVal ? this.value.length : 0L; - final long left = getLeftSize(); - final long right = getRightSize(); - return externalValueLength + this.source.length + left + right; + final long leftSize = getLeftSize(); + final long rightSize = getRightSize(); + return externalValueLength + this.source.length + leftSize + rightSize; } public long getLeftSize() { @@ -299,10 +307,6 @@ public byte[] getValue() { } public boolean isTerminal() { - // old impl: -// return (!this.leftNodePresent && !this.rightNodePresent) || -// !((this.leftNodePresent && !this.leftNodeEmbedded) || -// (this.rightNodePresent && !this.rightNodeEmbedded)); if (!this.leftNodePresent && !this.rightNodePresent) { return true; } @@ -336,6 +340,7 @@ public TrieDTO getLeftNode() { public TrieDTO getRightNode() { return rightNode; } + public void setLeftHash(byte[] hash) { this.leftHash = hash; } @@ -343,6 +348,7 @@ public void setLeftHash(byte[] hash) { public void setRightHash(byte[] hash) { this.rightHash = hash; } + public boolean isLeftNodePresent() { return leftNodePresent; } @@ -362,6 +368,7 @@ public boolean isRightNodeEmbedded() { public boolean isSharedPrefixPresent() { return sharedPrefixPresent; } + public byte[] getPath() { return this.path; } @@ -373,6 +380,7 @@ public Integer getPathLength() { public boolean isHasLongVal() { return hasLongVal; } + @Override public String toString() { return "Node{" + HexUtils.toJsonHex(this.path) + "}:" + this.childrenSize.value; @@ -432,6 +440,24 @@ public byte[] toMessage() { if (this.sharedPrefixPresent) { SharedPathSerializer.serializeBytes(buffer, this.pathLength, this.path); } + + toMessageHandleLeftNode(buffer); + toMessageHandleRightNode(buffer); + + if (leftNodePresent || rightNodePresent) { + buffer.put(childrenSize.encode()); + } + if (hasLongVal) { + byte[] valueHash = new Keccak256(Keccak256Helper.keccak256(getValue())).getBytes(); + buffer.put(valueHash); + buffer.put(encodeUint24(value.length)); + } else if (this.getValue().length > 0) { + buffer.put(this.getValue()); + } + return buffer.array(); + } + + private void toMessageHandleLeftNode(ByteBuffer buffer){ if (leftNodePresent) { if (leftNodeEmbedded) { buffer.put(encodeUint8(this.left.length)); @@ -440,6 +466,9 @@ public byte[] toMessage() { buffer.put(this.leftHash); } } + } + + private void toMessageHandleRightNode(ByteBuffer buffer){ if (rightNodePresent) { if (rightNodeEmbedded) { buffer.put(encodeUint8(this.right.length)); @@ -448,19 +477,7 @@ public byte[] toMessage() { buffer.put(this.rightHash); } } - if (leftNodePresent || rightNodePresent) { - buffer.put(childrenSize.encode()); - } - if (hasLongVal) { - byte[] valueHash = new Keccak256(Keccak256Helper.keccak256(getValue())).getBytes(); - buffer.put(valueHash); - buffer.put(encodeUint24(value.length)); - } else if (this.getValue().length > 0) { - buffer.put(this.getValue()); - } - return buffer.array(); } - public int serializedLength(boolean isPresent, boolean isEmbeddable, byte[] value) { if (isPresent) { if (isEmbeddable) { @@ -474,8 +491,12 @@ public int serializedLength(boolean isPresent, boolean isEmbeddable, byte[] valu @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o){ + return true; + } + if (o == null || getClass() != o.getClass()){ + return false; + } TrieDTO trieDTO = (TrieDTO) o; return hasLongVal == trieDTO.hasLongVal && sharedPrefixPresent == trieDTO.sharedPrefixPresent diff --git a/rskj-core/src/main/java/co/rsk/trie/TrieDTOInOrderIterator.java b/rskj-core/src/main/java/co/rsk/trie/TrieDTOInOrderIterator.java index 98529880016..92113216a3d 100644 --- a/rskj-core/src/main/java/co/rsk/trie/TrieDTOInOrderIterator.java +++ b/rskj-core/src/main/java/co/rsk/trie/TrieDTOInOrderIterator.java @@ -53,7 +53,7 @@ private TrieDTO findByChildrenSize(long offset, TrieDTO nodeDTO, Deque // TODO poner los nodos padres intermedios en el stack, tenemos que serializarlos para poder validar el chunk completo. if (!nodeDTO.isTerminal()) { - if (nodeDTO.isLeftNodePresent() && !nodeDTO.isLeftNodeEmbedded()) { + if (isLeftNotEmbedded(nodeDTO)){ TrieDTO left = getNode(nodeDTO.getLeftHash()); if (left == null) { @@ -99,6 +99,10 @@ private TrieDTO findByChildrenSize(long offset, TrieDTO nodeDTO, Deque } } + private boolean isLeftNotEmbedded(TrieDTO nodeDTO){ + return nodeDTO.isLeftNodePresent() && !nodeDTO.isLeftNodeEmbedded(); + } + private TrieDTO pushAndReturn(TrieDTO nodeDTO, Deque visiting, long offset) { this.from -= offset; visiting.push(nodeDTO); diff --git a/rskj-core/src/test/java/co/rsk/trie/TrieDTOTest.java b/rskj-core/src/test/java/co/rsk/trie/TrieDTOTest.java new file mode 100644 index 00000000000..d8969693488 --- /dev/null +++ b/rskj-core/src/test/java/co/rsk/trie/TrieDTOTest.java @@ -0,0 +1,181 @@ +/* + * This file is part of RskJ + * Copyright (C) 2023 RSK Labs Ltd. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package co.rsk.trie; + +import co.rsk.crypto.Keccak256; +import org.ethereum.TestUtils; +import org.ethereum.datasource.HashMapDB; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Optional; + +import static org.bouncycastle.util.encoders.Hex.decode; +import static org.junit.jupiter.api.Assertions.*; + +@ExtendWith(MockitoExtension.class) +class TrieDTOTest { + + + private HashMapDB map; + private TrieStore trieStore; + + @BeforeEach + void setUp() { + this.map = new HashMapDB(); + this.trieStore = new TrieStoreImpl(map); + } + + @Test + void testDecodeDto() { + Trie trie = new Trie(trieStore) + .put("foo", "bar".getBytes()); + Keccak256 hash = trie.getHash(); + trieStore.save(trie); + + Optional optTrieDTO = trieStore.retrieveDTO(hash.getBytes()); + + assertTrue(optTrieDTO.isPresent()); + TrieDTO trieDTO = optTrieDTO.get(); + assertArrayEquals(trie.getValue(), trieDTO.getValue()); + String trieDtoDescription = trieDTO.toDescription(); + assertNotNull(trieDtoDescription); + assertNotNull(trieDTO.toString()); + } + + @Test + void testDecodeFromMessage() { + Trie trie = new Trie(trieStore).put("foo", "bar".getBytes()); + trieStore.save(trie); + byte[] message = trie.toMessage(); + TrieDTO decodedTrieDTO = TrieDTO.decodeFromMessage(message, trieStore); + + assertArrayEquals(trie.getValue(), decodedTrieDTO.getValue()); + } + + @Test + void testGetSideHash() { + Trie trie = buildTestTrie(); + TrieDTO trieDTO = TrieDTO.decodeFromMessage(trie.toMessage(), trieStore, true, null); + byte[] leftHash = trieDTO.getLeftHash(); + byte[] rightHash = trieDTO.getRightHash(); + assertEquals(trie.getLeft().getHash().get(), new Keccak256(leftHash)); + assertEquals(trie.getRight().getHash().get(), new Keccak256(rightHash)); + } + + + @Test + void testMessageDecoding() { + Trie trie = new Trie(trieStore).put("foo", "bar".getBytes()); + Keccak256 hash = trie.getHash(); + trieStore.save(trie); + byte[] message = trie.toMessage(); + TrieDTO decodedTrieDTO = TrieDTO.decodeFromMessage(message, trieStore); + TrieDTO retrievedDto = trieStore.retrieveDTO(hash.getBytes()).get(); + + assertEquals(decodedTrieDTO, retrievedDto); + assertNotEquals(decodedTrieDTO.hashCode(), retrievedDto.hashCode()); + } + + @Test + void testMessageEncoding() { + Trie trie = new Trie(trieStore) + .put("foo", "bar".getBytes()) + .put("abc", "bc".getBytes()) + .put("def", "ef".getBytes()); + Keccak256 hash = trie.getHash(); + trieStore.save(trie); + TrieDTO retrievedDto = trieStore.retrieveDTO(hash.getBytes()).get(); + byte[] message = retrievedDto.toMessage(); + TrieDTO decodedTrieDTO = TrieDTO.decodeFromMessage(message, trieStore); + assertEquals(retrievedDto, decodedTrieDTO); + } + + @Test + void retrieveWithEmbedded() { + Trie trie = new Trie(trieStore) + .put("bar", "foo".getBytes()) + .put("foo", "bar".getBytes()); + + trieStore.save(trie); + TrieDTO trieDTO = TrieDTO.decodeFromMessage(trie.toMessage(), trieStore, true, null); + assertNotNull(trieDTO); + assertTrue(trieDTO.isTerminal()); + assertTrue(trieDTO.isLeftNodeEmbedded()); + assertTrue(trieDTO.isRightNodeEmbedded()); + + assertTrue(trieDTO.isLeftNodePresent()); + assertTrue(trieDTO.isRightNodePresent()); + + TrieDTO rightNode = trieDTO.getRightNode(); + assertArrayEquals("bar".getBytes(), rightNode.getValue()); + TrieDTO leftNode = trieDTO.getLeftNode(); + assertArrayEquals("foo".getBytes(), leftNode.getValue()); + assertTrue(trieDTO.isSharedPrefixPresent()); + } + + @Test + void testBasicSetters() { + Trie trie = new Trie(trieStore) + .put("foo", "bar".getBytes()); + + TrieDTO trieDTO = TrieDTO.decodeFromMessage(trie.toMessage(), trieStore); + byte[] lBytes = TestUtils.generateBytes("left", 32); + trieDTO.setLeft(lBytes); + assertArrayEquals(lBytes, trieDTO.getLeft()); + byte[] rBytes = TestUtils.generateBytes("right", 32); + trieDTO.setRight(rBytes); + assertArrayEquals(rBytes, trieDTO.getRight()); + + + + trie = buildTestTrie(); + trieDTO = TrieDTO.decodeFromMessage(trie.toMessage(), trieStore, true, null); + byte[] leftHash = TestUtils.generateBytes("leftHash", 32); + trieDTO.setLeftHash(leftHash); + assertEquals(leftHash, trieDTO.getLeftHash()); + byte[] rightHash = TestUtils.generateBytes("rightHash", 32); + trieDTO.setRightHash(rightHash); + assertEquals(rightHash, trieDTO.getRightHash()); + } + + @Test + void testLongValue() { + Trie trie = new Trie(trieStore) + .put("foo", TrieValueTest.makeValue(200)); + Keccak256 hash = trie.getHash(); + trieStore.save(trie); + + TrieDTO trieDTO = trieStore.retrieveDTO(hash.getBytes()).get(); + assertTrue(trieDTO.isHasLongVal()); + } + + + private Trie buildTestTrie() { + Trie trie = new Trie(); + trie = trie.put(decode("0a"), new byte[]{0x06}); + trie = trie.put(decode("0a00"), new byte[]{0x02}); + trie = trie.put(decode("0a80"), new byte[]{0x07}); + trie = trie.put(decode("0a0000"), new byte[]{0x01}); + trie = trie.put(decode("0a8 080"), new byte[]{0x08}); + return trie; + } +} \ No newline at end of file