Skip to main content

tlsn_core/
merkle.rs

1//! Merkle tree types.
2
3use std::collections::HashSet;
4
5use serde::{Deserialize, Serialize};
6
7use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
8
9/// Merkle tree error.
10#[derive(Debug, thiserror::Error)]
11#[error("merkle error: {0}")]
12pub struct MerkleError(String);
13
14impl MerkleError {
15    fn new(msg: impl Into<String>) -> Self {
16        Self(msg.into())
17    }
18}
19
20/// Merkle proof.
21#[derive(Clone, Serialize, Deserialize)]
22pub struct MerkleProof {
23    alg: HashAlgId,
24    leaf_count: usize,
25    proof: rs_merkle::MerkleProof<Hash>,
26}
27
28opaque_debug::implement!(MerkleProof);
29
30impl MerkleProof {
31    /// Checks if the counts of indices, hashes, and leaves are valid for the
32    /// provided root.
33    ///
34    /// # Panics
35    ///
36    /// - If the length of `leaf_indices` and `leaf_hashes` does not match.
37    /// - If `leaf_indices` contains duplicates.
38    pub fn verify(
39        &self,
40        hasher: &dyn HashAlgorithm,
41        root: &TypedHash,
42        leaves: impl IntoIterator<Item = (usize, Hash)>,
43    ) -> Result<(), MerkleError> {
44        let mut leaves = leaves.into_iter().collect::<Vec<_>>();
45
46        // Sort by index
47        leaves.sort_by_key(|(idx, _)| *idx);
48
49        let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
50
51        if indices.len() != indices.iter().collect::<HashSet<_>>().len() {
52            return Err(MerkleError::new("duplicate leaf indices provided"));
53        }
54
55        if !self.proof.verify(
56            &RsMerkleHasher(hasher),
57            root.value,
58            &indices,
59            &leaves,
60            self.leaf_count,
61        ) {
62            return Err(MerkleError::new("invalid merkle proof"));
63        }
64
65        Ok(())
66    }
67}
68
69#[derive(Clone)]
70struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
71
72impl rs_merkle::Hasher for RsMerkleHasher<'_> {
73    type Hash = Hash;
74
75    fn hash(&self, data: &[u8]) -> Hash {
76        self.0.hash(data)
77    }
78}
79
80/// Merkle tree.
81#[derive(Clone, Serialize, Deserialize)]
82pub struct MerkleTree {
83    alg: HashAlgId,
84    tree: rs_merkle::MerkleTree<Hash>,
85}
86
87impl MerkleTree {
88    /// Creates a new Merkle tree.
89    pub fn new(alg: HashAlgId) -> Self {
90        Self {
91            alg,
92            tree: Default::default(),
93        }
94    }
95
96    /// Returns the hash algorithm used to create the tree.
97    pub fn algorithm(&self) -> HashAlgId {
98        self.alg
99    }
100
101    /// Returns the root of the tree.
102    pub fn root(&self) -> TypedHash {
103        TypedHash {
104            alg: self.alg,
105            value: self.tree.root().expect("tree should not be empty"),
106        }
107    }
108
109    /// Inserts leaves into the tree.
110    ///
111    /// # Panics
112    ///
113    /// - If the provided hasher is not the same as the one used to create the
114    ///   tree.
115    pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
116        assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
117
118        self.tree.append(&mut leaves);
119        self.tree.commit(&RsMerkleHasher(hasher))
120    }
121
122    /// Returns a Merkle proof for the provided indices.
123    ///
124    /// # Panics
125    ///
126    /// - If the provided indices are not unique and sorted.
127    /// - If the provided indices are out of bounds.
128    pub fn proof(&self, indices: &[usize]) -> MerkleProof {
129        assert!(
130            indices.windows(2).all(|w| w[0] < w[1]),
131            "indices must be unique and sorted"
132        );
133
134        assert!(
135            *indices.last().unwrap() < self.tree.leaves_len(),
136            "one or more provided indices are out of bounds"
137        );
138
139        MerkleProof {
140            alg: self.alg,
141            leaf_count: self.tree.leaves_len(),
142            proof: self.tree.proof(indices),
143        }
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use crate::hash::{Blake3, Keccak256, Sha256};
150
151    use super::*;
152    use rstest::*;
153
154    #[derive(Serialize)]
155    struct T(u64);
156
157    fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
158        leaves
159            .into_iter()
160            .map(|x| hasher.hash(&x.0.to_be_bytes()))
161            .collect()
162    }
163
164    fn choose_leaves(
165        indices: impl IntoIterator<Item = usize>,
166        leaves: &[Hash],
167    ) -> Vec<(usize, Hash)> {
168        indices.into_iter().map(|i| (i, leaves[i])).collect()
169    }
170
171    // Expect Merkle proof verification to succeed.
172    #[rstest]
173    #[case::sha2(Sha256::default())]
174    #[case::blake3(Blake3::default())]
175    #[case::keccak(Keccak256::default())]
176    fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
177        let mut tree = MerkleTree::new(hasher.id());
178
179        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
180
181        tree.insert(&hasher, leaves.clone());
182
183        let proof = tree.proof(&[2, 3, 4]);
184
185        assert!(proof
186            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
187            .is_ok());
188    }
189
190    #[rstest]
191    #[case::sha2(Sha256::default())]
192    #[case::blake3(Blake3::default())]
193    #[case::keccak(Keccak256::default())]
194    fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
195        let mut tree = MerkleTree::new(hasher.id());
196
197        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
198
199        tree.insert(&hasher, leaves.clone());
200
201        let proof = tree.proof(&[2, 3, 4]);
202
203        let mut choices = choose_leaves([2, 3, 4], &leaves);
204
205        choices[1].1 = leaves[0];
206
207        // Fail because the leaf is wrong.
208        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
209    }
210
211    #[rstest]
212    #[case::sha2(Sha256::default())]
213    #[case::blake3(Blake3::default())]
214    #[case::keccak(Keccak256::default())]
215    #[should_panic]
216    fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
217        let mut tree = MerkleTree::new(hasher.id());
218
219        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
220
221        tree.insert(&hasher, leaves.clone());
222
223        _ = tree.proof(&[2, 4, 3]);
224    }
225
226    #[rstest]
227    #[case::sha2(Sha256::default())]
228    #[case::blake3(Blake3::default())]
229    #[case::keccak(Keccak256::default())]
230    #[should_panic]
231    fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
232        let mut tree = MerkleTree::new(hasher.id());
233
234        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
235
236        tree.insert(&hasher, leaves.clone());
237
238        _ = tree.proof(&[2, 3, 4, 6]);
239    }
240
241    #[rstest]
242    #[case::sha2(Sha256::default())]
243    #[case::blake3(Blake3::default())]
244    #[case::keccak(Keccak256::default())]
245    #[should_panic]
246    fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
247        let mut tree = MerkleTree::new(hasher.id());
248
249        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
250
251        tree.insert(&hasher, leaves.clone());
252
253        _ = tree.proof(&[2, 2, 3]);
254    }
255
256    #[rstest]
257    #[case::sha2(Sha256::default())]
258    #[case::blake3(Blake3::default())]
259    #[case::keccak(Keccak256::default())]
260    fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
261        let mut tree = MerkleTree::new(hasher.id());
262
263        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
264
265        tree.insert(&hasher, leaves.clone());
266
267        let proof = tree.proof(&[2, 3, 4]);
268
269        assert!(proof
270            .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
271            .is_err());
272    }
273
274    #[rstest]
275    #[case::sha2(Sha256::default())]
276    #[case::blake3(Blake3::default())]
277    #[case::keccak(Keccak256::default())]
278    fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
279        let mut tree = MerkleTree::new(hasher.id());
280
281        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
282
283        tree.insert(&hasher, leaves.clone());
284
285        let mut proof1 = tree.proof(&[2, 3, 4]);
286        let mut proof2 = proof1.clone();
287
288        // Increment the `leaf_count` field.
289        proof1.leaf_count += 1;
290        // Decrement the `leaf_count` field.
291        proof2.leaf_count -= 1;
292
293        // Fail because leaf count is wrong.
294        assert!(proof1
295            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
296            .is_err());
297
298        assert!(proof2
299            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
300            .is_err());
301    }
302
303    #[rstest]
304    #[case::sha2(Sha256::default())]
305    #[case::blake3(Blake3::default())]
306    #[case::keccak(Keccak256::default())]
307    fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
308        let mut tree = MerkleTree::new(hasher.id());
309
310        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
311
312        tree.insert(&hasher, leaves.clone());
313
314        let proof = tree.proof(&[2, 3, 4]);
315
316        let mut choices = choose_leaves([2, 3, 4], &leaves);
317        choices[1].0 = 1;
318
319        // Fail because leaf index is wrong.
320        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
321    }
322
323    #[rstest]
324    #[case::sha2(Sha256::default())]
325    #[case::blake3(Blake3::default())]
326    #[case::keccak(Keccak256::default())]
327    fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
328        let mut tree = MerkleTree::new(hasher.id());
329
330        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
331
332        tree.insert(&hasher, leaves.clone());
333
334        let proof = tree.proof(&[2, 3, 4]);
335
336        // Trying to verify less leaves than what was included in the proof.
337        assert!(proof
338            .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
339            .is_err());
340    }
341}