Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(merkle_tree): initial batch membership proof implementation #740

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions merkle_tree/src/batch_proof.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
use crate::{MerkleTreeError, MerkleTreeProof, NodeValue};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use tagged::Tagged;

/// A compact batch proof for multiple elements in a Merkle tree.
/// Optimizes proof size through:
/// 1. Identifying and storing shared nodes between proofs
/// 2. Storing only differing nodes for individual proofs
/// 3. Efficient node reuse during verification
///
/// # Example
/// ```ignore
/// let proofs = vec![proof1, proof2, proof3];
/// let batch_proof = CompactBatchProof::new(proofs)?;
///
/// // Verify all proofs
/// tree.batch_verify(positions, elements, &batch_proof)?;
/// ```
#[derive(Clone, Debug, Hash, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
#[tagged("COMPACT_BATCH_PROOF")]
pub struct CompactBatchProof<T: NodeValue> {
/// Shared nodes at each tree level that are common across multiple proofs.
/// Vector index corresponds to the tree level.
shared_nodes: Vec<Vec<T>>,

/// Individual nodes for each proof.
/// Contains only nodes that differ from shared nodes.
/// Structure: [proof][level][nodes at level]
individual_proofs: Vec<Vec<Vec<T>>>,

/// Height of the Merkle tree
height: usize,
}

impl<T: NodeValue> CompactBatchProof<T> {
/// Creates a new compact batch proof from a set of individual proofs.
///
/// # Arguments
/// * `proofs` - Vector of individual MerkleTreeProof instances
///
/// # Errors
/// * `ParametersError` - if input proofs are empty or have inconsistent heights
///
/// # Example
/// ```ignore
/// let proofs = vec![proof1, proof2, proof3];
/// let batch_proof = CompactBatchProof::new(proofs)?;
/// ```
pub fn new(proofs: Vec<MerkleTreeProof<T>>) -> Result<Self, MerkleTreeError> {
// Check for empty vector
if proofs.is_empty() {
return Err(MerkleTreeError::ParametersError(
"Empty proofs vector".to_string(),
));
}

let height = proofs[0].height();

// Check height consistency
if proofs.iter().any(|p| p.height() != height) {
return Err(MerkleTreeError::ParametersError(
"Inconsistent proof heights".to_string(),
));
}

let mut shared = vec![vec![]; height];
let mut individual = Vec::with_capacity(proofs.len());

// Convert proofs to the required format
let proof_values: Vec<Vec<Vec<T>>> = proofs.iter()
.map(|p| p.path_values().to_vec())
.collect();

// Find shared nodes at each level
for level in 0..height {
let mut level_nodes = proof_values[0][level].clone();

// Compare nodes from all proofs at current level
for proof in proof_values.iter().skip(1) {
for (idx, node) in proof[level].iter().enumerate() {
if level_nodes[idx] != *node {
level_nodes[idx] = T::default();
}
}
}
shared[level] = level_nodes;
}

// Create individual proofs with only differing nodes
for proof in proof_values {
let mut indiv = Vec::with_capacity(height);
for (level, nodes) in proof.iter().enumerate() {
let mut level_nodes = vec![T::default(); nodes.len()];
for (idx, node) in nodes.iter().enumerate() {
if *node != shared[level][idx] {
level_nodes[idx] = *node;
}
}
indiv.push(level_nodes);
}
individual.push(indiv);
}

Ok(Self {
shared_nodes: shared,
individual_proofs: individual,
height,
})
}

/// Returns shared nodes common to all proofs
pub fn get_shared_nodes(&self) -> &[Vec<T>] {
&self.shared_nodes
}

/// Returns individual parts of the proofs
pub fn get_individual_proofs(&self) -> &[Vec<Vec<T>>] {
&self.individual_proofs
}

/// Returns the height of the Merkle tree
pub fn height(&self) -> usize {
self.height
}

/// Reconstructs a complete proof for a specific index
///
/// # Arguments
/// * `proof_idx` - Index of the proof in the batch
///
/// # Errors
/// * `ParametersError` - if index is out of bounds
pub fn get_proof(&self, proof_idx: usize) -> Result<MerkleTreeProof<T>, MerkleTreeError> {
if proof_idx >= self.individual_proofs.len() {
return Err(MerkleTreeError::ParametersError(
"Proof index out of bounds".to_string(),
));
}

let mut proof_path = Vec::with_capacity(self.height);

for level in 0..self.height {
let mut level_nodes = self.shared_nodes[level].clone();
let indiv_nodes = &self.individual_proofs[proof_idx][level];

for (idx, node) in indiv_nodes.iter().enumerate() {
if !node.is_empty() {
level_nodes[idx] = *node;
}
}
proof_path.push(level_nodes);
}

Ok(MerkleTreeProof(proof_path))
}
}

#[cfg(test)]
mod tests {
use super::*;
use ark_ed25519::Fr;

fn create_test_proof(height: usize, value: u8) -> MerkleTreeProof<Fr> {
let nodes = (0..height)
.map(|_| vec![Fr::from(value as u64); 2])
.collect();
MerkleTreeProof(nodes)
}

#[test]
fn test_new_batch_proof() {
let proofs = vec![
create_test_proof(3, 1),
create_test_proof(3, 2),
create_test_proof(3, 1),
];

let batch_proof = CompactBatchProof::new(proofs).unwrap();
assert_eq!(batch_proof.height(), 3);
assert_eq!(batch_proof.get_shared_nodes().len(), 3);
assert_eq!(batch_proof.get_individual_proofs().len(), 3);
}

#[test]
fn test_empty_proofs() {
let proofs: Vec<MerkleTreeProof<Fr>> = vec![];
assert!(CompactBatchProof::new(proofs).is_err());
}

#[test]
fn test_inconsistent_heights() {
let proofs = vec![
create_test_proof(3, 1),
create_test_proof(2, 2),
];
assert!(CompactBatchProof::new(proofs).is_err());
}

#[test]
fn test_get_proof() {
let proofs = vec![
create_test_proof(3, 1),
create_test_proof(3, 2),
create_test_proof(3, 1),
];

let batch_proof = CompactBatchProof::new(proofs.clone()).unwrap();

// Test that we can recover original proofs
let recovered = batch_proof.get_proof(0).unwrap();
assert_eq!(recovered, proofs[0]);

let recovered = batch_proof.get_proof(1).unwrap();
assert_eq!(recovered, proofs[1]);

// Test error on invalid index
assert!(batch_proof.get_proof(5).is_err());
}

#[test]
fn test_compression_efficiency() {
let proofs = vec![
create_test_proof(3, 1),
create_test_proof(3, 1), // Identical proof
create_test_proof(3, 2),
];

let batch_proof = CompactBatchProof::new(proofs).unwrap();

// Check that identical proofs don't duplicate nodes
let individual = batch_proof.get_individual_proofs();
assert!(individual[0].iter().all(|level| level.iter().all(|node| node.is_empty())));
assert!(individual[1].iter().all(|level| level.iter().all(|node| node.is_empty())));
}
}
12 changes: 8 additions & 4 deletions merkle_tree/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub mod macros;
pub mod universal_merkle_tree;

pub(crate) mod internal;
pub mod merkle_proof;
pub mod batch_proof;

pub mod prelude;
pub use crate::errors::MerkleTreeError;
Expand Down Expand Up @@ -130,7 +132,7 @@ where
fn digest_leaf(pos: &I, elem: &E) -> Result<T, MerkleTreeError>;
}

/// An trait for Merkle tree index type.
/// A trait for Merkle tree index type.
pub trait ToTraversalPath<const ARITY: usize> {
/// Convert the given index to a vector of branch indices given tree height
/// and ARITY.
Expand Down Expand Up @@ -170,8 +172,8 @@ pub trait MerkleProof<T: NodeValue>:
}

/// Basic functionalities for a merkle tree implementation. Abstracted as an
/// accumulator for fixed-length array. Supports generate membership proof at a
/// given position and verify a membership proof.
/// accumulator for fixed-length array. Supports generating membership proof at
/// a given position and verify a membership proof.
pub trait MerkleTreeScheme: Sized {
/// Merkle tree element type
type Element: Element;
Expand Down Expand Up @@ -231,7 +233,7 @@ pub trait MerkleTreeScheme: Sized {
// ) -> Result<(), MerkleTreeError>;

/// Return an iterator that iterates through all element that are not
/// forgetton
/// forgotten
fn iter(&self) -> MerkleTreeIter<Self::Element, Self::Index, Self::NodeValue>;
}

Expand Down Expand Up @@ -439,3 +441,5 @@ pub trait PersistentUniversalMerkleTreeScheme: UniversalMerkleTreeScheme {
where
F: FnOnce(Option<&Self::Element>) -> Option<Self::Element>;
}

pub use batch_proof::CompactBatchProof;
61 changes: 59 additions & 2 deletions merkle_tree/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ macro_rules! impl_merkle_tree_scheme {
type Index = I;
type NodeValue = T;
type MembershipProof = MerkleTreeProof<T>;
// TODO(Chengyu): implement batch membership proof
type BatchMembershipProof = ();
type BatchMembershipProof = CompactBatchProof<T>;
type Commitment = T;

const ARITY: usize = ARITY;
Expand Down Expand Up @@ -90,6 +89,64 @@ macro_rules! impl_merkle_tree_scheme {
fn iter(&self) -> MerkleTreeIter<E, I, T> {
MerkleTreeIter::new(&self.root)
}

fn batch_lookup(
&self,
positions: impl IntoIterator<Item = impl Borrow<Self::Index>>,
) -> Vec<LookupResult<&Self::Element, Self::MembershipProof, ()>> {
positions
.into_iter()
.map(|pos| self.lookup(pos))
.collect()
}

fn batch_verify(
commitment: impl Borrow<Self::Commitment>,
positions: impl IntoIterator<Item = impl Borrow<Self::Index>>,
elements: impl IntoIterator<Item = impl Borrow<Self::Element>>,
proof: impl Borrow<Self::BatchMembershipProof>,
) -> Result<Vec<VerificationResult>, MerkleTreeError> {
let commitment = commitment.borrow();
let proof = proof.borrow();
let mut results = Vec::new();

// Get shared path information
let shared_nodes = proof.get_shared_nodes();
let individual_proofs = proof.get_individual_proofs();

// Verify each element using shared path information
for ((pos, elem), indiv_proof) in positions
.into_iter()
.zip(elements.into_iter())
.zip(individual_proofs.iter())
{
let mut proof_path = Vec::new();
// Combine shared nodes with individual proof nodes
for (level, shared) in shared_nodes.iter().enumerate() {
let mut level_nodes = shared[level].clone();
if let Some(indiv_nodes) = indiv_proof.get(level) {
// Merge individual nodes with shared nodes
for (idx, node) in indiv_nodes.iter().enumerate() {
if !node.is_empty() {
level_nodes[idx] = node.clone();
}
}
}
proof_path.push(level_nodes);
}

// Verify individual proof
let result = Self::verify(
commitment,
pos,
elem,
&MerkleTreeProof(proof_path),
)?;
results.push(result);
}

Ok(results)
}
}

impl<'a, E, H, I, const ARITY: usize, T> IntoIterator for &'a $name<E, H, I, ARITY, T>
Expand Down
Loading