tlsn_core/
merkle.rs

1//! Merkle tree types.
2
3use serde::{Deserialize, Serialize};
4use utils::iter::DuplicateCheck;
5
6use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
7
8/// Errors that can occur during operations with Merkle tree and Merkle proof.
9#[derive(Debug, thiserror::Error)]
10#[error("merkle error: {0}")]
11pub(crate) struct MerkleError(String);
12
13impl MerkleError {
14    fn new(msg: impl Into<String>) -> Self {
15        Self(msg.into())
16    }
17}
18
19#[derive(Clone, Serialize, Deserialize)]
20pub(crate) struct MerkleProof {
21    alg: HashAlgId,
22    leaf_count: usize,
23    proof: rs_merkle::MerkleProof<Hash>,
24}
25
26opaque_debug::implement!(MerkleProof);
27
28impl MerkleProof {
29    /// Checks if the counts of indices, hashes, and leaves are valid for the
30    /// provided root.
31    ///
32    /// # Panics
33    ///
34    /// - If the length of `leaf_indices` and `leaf_hashes` does not match.
35    /// - If `leaf_indices` contains duplicates.
36    pub(crate) fn verify(
37        &self,
38        hasher: &dyn HashAlgorithm,
39        root: &TypedHash,
40        leaves: impl IntoIterator<Item = (usize, Hash)>,
41    ) -> Result<(), MerkleError> {
42        let mut leaves = leaves.into_iter().collect::<Vec<_>>();
43
44        // Sort by index
45        leaves.sort_by_key(|(idx, _)| *idx);
46
47        let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
48
49        if indices.iter().contains_dups() {
50            return Err(MerkleError::new("duplicate leaf indices provided"));
51        }
52
53        if !self.proof.verify(
54            &RsMerkleHasher(hasher),
55            root.value,
56            &indices,
57            &leaves,
58            self.leaf_count,
59        ) {
60            return Err(MerkleError::new("invalid merkle proof"));
61        }
62
63        Ok(())
64    }
65
66    /// Returns the leaf count of the Merkle tree associated with the proof.
67    pub(crate) fn leaf_count(&self) -> usize {
68        self.leaf_count
69    }
70}
71
72#[derive(Clone)]
73struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
74
75impl rs_merkle::Hasher for RsMerkleHasher<'_> {
76    type Hash = Hash;
77
78    fn hash(&self, data: &[u8]) -> Hash {
79        self.0.hash(data)
80    }
81}
82
83#[derive(Clone, Serialize, Deserialize)]
84pub(crate) struct MerkleTree {
85    alg: HashAlgId,
86    tree: rs_merkle::MerkleTree<Hash>,
87}
88
89impl MerkleTree {
90    pub(crate) fn new(alg: HashAlgId) -> Self {
91        Self {
92            alg,
93            tree: Default::default(),
94        }
95    }
96
97    pub(crate) fn algorithm(&self) -> HashAlgId {
98        self.alg
99    }
100
101    pub(crate) fn root(&self) -> TypedHash {
102        TypedHash {
103            alg: self.alg,
104            value: self.tree.root().expect("tree should not be empty"),
105        }
106    }
107
108    /// Inserts leaves into the tree.
109    ///
110    /// # Panics
111    ///
112    /// - If the provided hasher is not the same as the one used to create the
113    ///   tree.
114    pub(crate) fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
115        assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
116
117        self.tree.append(&mut leaves);
118        self.tree.commit(&RsMerkleHasher(hasher))
119    }
120
121    /// Returns a Merkle proof for the provided indices.
122    ///
123    /// # Panics
124    ///
125    /// - If the provided indices are not unique and sorted.
126    /// - If the provided indices are out of bounds.
127    pub(crate) fn proof(&self, indices: &[usize]) -> MerkleProof {
128        assert!(
129            indices.windows(2).all(|w| w[0] < w[1]),
130            "indices must be unique and sorted"
131        );
132
133        assert!(
134            *indices.last().unwrap() < self.tree.leaves_len(),
135            "one or more provided indices are out of bounds"
136        );
137
138        MerkleProof {
139            alg: self.alg,
140            leaf_count: self.tree.leaves_len(),
141            proof: self.tree.proof(indices),
142        }
143    }
144}
145
146#[cfg(test)]
147mod test {
148    use crate::hash::{impl_domain_separator, Blake3, HashAlgorithmExt, Keccak256, Sha256};
149
150    use super::*;
151    use rstest::*;
152
153    #[derive(Serialize)]
154    struct T(u64);
155
156    impl_domain_separator!(T);
157
158    fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
159        leaves
160            .into_iter()
161            .map(|x| hasher.hash_canonical(&x))
162            .collect()
163    }
164
165    fn choose_leaves(
166        indices: impl IntoIterator<Item = usize>,
167        leaves: &[Hash],
168    ) -> Vec<(usize, Hash)> {
169        indices.into_iter().map(|i| (i, leaves[i])).collect()
170    }
171
172    // Expect Merkle proof verification to succeed.
173    #[rstest]
174    #[case::sha2(Sha256::default())]
175    #[case::blake3(Blake3::default())]
176    #[case::keccak(Keccak256::default())]
177    fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
178        let mut tree = MerkleTree::new(hasher.id());
179
180        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
181
182        tree.insert(&hasher, leaves.clone());
183
184        let proof = tree.proof(&[2, 3, 4]);
185
186        assert!(proof
187            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
188            .is_ok());
189    }
190
191    #[rstest]
192    #[case::sha2(Sha256::default())]
193    #[case::blake3(Blake3::default())]
194    #[case::keccak(Keccak256::default())]
195    fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
196        let mut tree = MerkleTree::new(hasher.id());
197
198        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
199
200        tree.insert(&hasher, leaves.clone());
201
202        let proof = tree.proof(&[2, 3, 4]);
203
204        let mut choices = choose_leaves([2, 3, 4], &leaves);
205
206        choices[1].1 = leaves[0];
207
208        // Fail because the leaf is wrong.
209        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
210    }
211
212    #[rstest]
213    #[case::sha2(Sha256::default())]
214    #[case::blake3(Blake3::default())]
215    #[case::keccak(Keccak256::default())]
216    #[should_panic]
217    fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
218        let mut tree = MerkleTree::new(hasher.id());
219
220        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
221
222        tree.insert(&hasher, leaves.clone());
223
224        _ = tree.proof(&[2, 4, 3]);
225    }
226
227    #[rstest]
228    #[case::sha2(Sha256::default())]
229    #[case::blake3(Blake3::default())]
230    #[case::keccak(Keccak256::default())]
231    #[should_panic]
232    fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
233        let mut tree = MerkleTree::new(hasher.id());
234
235        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
236
237        tree.insert(&hasher, leaves.clone());
238
239        _ = tree.proof(&[2, 3, 4, 6]);
240    }
241
242    #[rstest]
243    #[case::sha2(Sha256::default())]
244    #[case::blake3(Blake3::default())]
245    #[case::keccak(Keccak256::default())]
246    #[should_panic]
247    fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
248        let mut tree = MerkleTree::new(hasher.id());
249
250        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
251
252        tree.insert(&hasher, leaves.clone());
253
254        _ = tree.proof(&[2, 2, 3]);
255    }
256
257    #[rstest]
258    #[case::sha2(Sha256::default())]
259    #[case::blake3(Blake3::default())]
260    #[case::keccak(Keccak256::default())]
261    fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
262        let mut tree = MerkleTree::new(hasher.id());
263
264        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
265
266        tree.insert(&hasher, leaves.clone());
267
268        let proof = tree.proof(&[2, 3, 4]);
269
270        assert!(proof
271            .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
272            .is_err());
273    }
274
275    #[rstest]
276    #[case::sha2(Sha256::default())]
277    #[case::blake3(Blake3::default())]
278    #[case::keccak(Keccak256::default())]
279    fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
280        let mut tree = MerkleTree::new(hasher.id());
281
282        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
283
284        tree.insert(&hasher, leaves.clone());
285
286        let mut proof1 = tree.proof(&[2, 3, 4]);
287        let mut proof2 = proof1.clone();
288
289        // Increment the `leaf_count` field.
290        proof1.leaf_count += 1;
291        // Decrement the `leaf_count` field.
292        proof2.leaf_count -= 1;
293
294        // Fail because leaf count is wrong.
295        assert!(proof1
296            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
297            .is_err());
298
299        assert!(proof2
300            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
301            .is_err());
302    }
303
304    #[rstest]
305    #[case::sha2(Sha256::default())]
306    #[case::blake3(Blake3::default())]
307    #[case::keccak(Keccak256::default())]
308    fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
309        let mut tree = MerkleTree::new(hasher.id());
310
311        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
312
313        tree.insert(&hasher, leaves.clone());
314
315        let proof = tree.proof(&[2, 3, 4]);
316
317        let mut choices = choose_leaves([2, 3, 4], &leaves);
318        choices[1].0 = 1;
319
320        // Fail because leaf index is wrong.
321        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
322    }
323
324    #[rstest]
325    #[case::sha2(Sha256::default())]
326    #[case::blake3(Blake3::default())]
327    #[case::keccak(Keccak256::default())]
328    fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
329        let mut tree = MerkleTree::new(hasher.id());
330
331        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
332
333        tree.insert(&hasher, leaves.clone());
334
335        let proof = tree.proof(&[2, 3, 4]);
336
337        // Trying to verify less leaves than what was included in the proof.
338        assert!(proof
339            .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
340            .is_err());
341    }
342}