tlsn_core/
merkle.rs

1//! Merkle tree types.
2
3use serde::{Deserialize, Serialize};
4use utils::iter::DuplicateCheck;
5
6use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
7
8/// Merkle tree error.
9#[derive(Debug, thiserror::Error)]
10#[error("merkle error: {0}")]
11pub struct MerkleError(String);
12
13impl MerkleError {
14    fn new(msg: impl Into<String>) -> Self {
15        Self(msg.into())
16    }
17}
18
19/// Merkle proof.
20#[derive(Clone, Serialize, Deserialize)]
21pub struct MerkleProof {
22    alg: HashAlgId,
23    leaf_count: usize,
24    proof: rs_merkle::MerkleProof<Hash>,
25}
26
27opaque_debug::implement!(MerkleProof);
28
29impl MerkleProof {
30    /// Checks if the counts of indices, hashes, and leaves are valid for the
31    /// provided root.
32    ///
33    /// # Panics
34    ///
35    /// - If the length of `leaf_indices` and `leaf_hashes` does not match.
36    /// - If `leaf_indices` contains duplicates.
37    pub fn verify(
38        &self,
39        hasher: &dyn HashAlgorithm,
40        root: &TypedHash,
41        leaves: impl IntoIterator<Item = (usize, Hash)>,
42    ) -> Result<(), MerkleError> {
43        let mut leaves = leaves.into_iter().collect::<Vec<_>>();
44
45        // Sort by index
46        leaves.sort_by_key(|(idx, _)| *idx);
47
48        let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
49
50        if indices.iter().contains_dups() {
51            return Err(MerkleError::new("duplicate leaf indices provided"));
52        }
53
54        if !self.proof.verify(
55            &RsMerkleHasher(hasher),
56            root.value,
57            &indices,
58            &leaves,
59            self.leaf_count,
60        ) {
61            return Err(MerkleError::new("invalid merkle proof"));
62        }
63
64        Ok(())
65    }
66
67    /// Returns the leaf count of the Merkle tree associated with the proof.
68    pub(crate) fn leaf_count(&self) -> usize {
69        self.leaf_count
70    }
71}
72
73#[derive(Clone)]
74struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
75
76impl rs_merkle::Hasher for RsMerkleHasher<'_> {
77    type Hash = Hash;
78
79    fn hash(&self, data: &[u8]) -> Hash {
80        self.0.hash(data)
81    }
82}
83
84/// Merkle tree.
85#[derive(Clone, Serialize, Deserialize)]
86pub struct MerkleTree {
87    alg: HashAlgId,
88    tree: rs_merkle::MerkleTree<Hash>,
89}
90
91impl MerkleTree {
92    /// Creates a new Merkle tree.
93    pub fn new(alg: HashAlgId) -> Self {
94        Self {
95            alg,
96            tree: Default::default(),
97        }
98    }
99
100    /// Returns the hash algorithm used to create the tree.
101    pub fn algorithm(&self) -> HashAlgId {
102        self.alg
103    }
104
105    /// Returns the root of the tree.
106    pub fn root(&self) -> TypedHash {
107        TypedHash {
108            alg: self.alg,
109            value: self.tree.root().expect("tree should not be empty"),
110        }
111    }
112
113    /// Inserts leaves into the tree.
114    ///
115    /// # Panics
116    ///
117    /// - If the provided hasher is not the same as the one used to create the
118    ///   tree.
119    pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
120        assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
121
122        self.tree.append(&mut leaves);
123        self.tree.commit(&RsMerkleHasher(hasher))
124    }
125
126    /// Returns a Merkle proof for the provided indices.
127    ///
128    /// # Panics
129    ///
130    /// - If the provided indices are not unique and sorted.
131    /// - If the provided indices are out of bounds.
132    pub fn proof(&self, indices: &[usize]) -> MerkleProof {
133        assert!(
134            indices.windows(2).all(|w| w[0] < w[1]),
135            "indices must be unique and sorted"
136        );
137
138        assert!(
139            *indices.last().unwrap() < self.tree.leaves_len(),
140            "one or more provided indices are out of bounds"
141        );
142
143        MerkleProof {
144            alg: self.alg,
145            leaf_count: self.tree.leaves_len(),
146            proof: self.tree.proof(indices),
147        }
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use crate::hash::{Blake3, Keccak256, Sha256};
154
155    use super::*;
156    use rstest::*;
157
158    #[derive(Serialize)]
159    struct T(u64);
160
161    fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
162        leaves
163            .into_iter()
164            .map(|x| hasher.hash(&x.0.to_be_bytes()))
165            .collect()
166    }
167
168    fn choose_leaves(
169        indices: impl IntoIterator<Item = usize>,
170        leaves: &[Hash],
171    ) -> Vec<(usize, Hash)> {
172        indices.into_iter().map(|i| (i, leaves[i])).collect()
173    }
174
175    // Expect Merkle proof verification to succeed.
176    #[rstest]
177    #[case::sha2(Sha256::default())]
178    #[case::blake3(Blake3::default())]
179    #[case::keccak(Keccak256::default())]
180    fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
181        let mut tree = MerkleTree::new(hasher.id());
182
183        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
184
185        tree.insert(&hasher, leaves.clone());
186
187        let proof = tree.proof(&[2, 3, 4]);
188
189        assert!(proof
190            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
191            .is_ok());
192    }
193
194    #[rstest]
195    #[case::sha2(Sha256::default())]
196    #[case::blake3(Blake3::default())]
197    #[case::keccak(Keccak256::default())]
198    fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
199        let mut tree = MerkleTree::new(hasher.id());
200
201        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
202
203        tree.insert(&hasher, leaves.clone());
204
205        let proof = tree.proof(&[2, 3, 4]);
206
207        let mut choices = choose_leaves([2, 3, 4], &leaves);
208
209        choices[1].1 = leaves[0];
210
211        // Fail because the leaf is wrong.
212        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
213    }
214
215    #[rstest]
216    #[case::sha2(Sha256::default())]
217    #[case::blake3(Blake3::default())]
218    #[case::keccak(Keccak256::default())]
219    #[should_panic]
220    fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
221        let mut tree = MerkleTree::new(hasher.id());
222
223        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
224
225        tree.insert(&hasher, leaves.clone());
226
227        _ = tree.proof(&[2, 4, 3]);
228    }
229
230    #[rstest]
231    #[case::sha2(Sha256::default())]
232    #[case::blake3(Blake3::default())]
233    #[case::keccak(Keccak256::default())]
234    #[should_panic]
235    fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
236        let mut tree = MerkleTree::new(hasher.id());
237
238        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
239
240        tree.insert(&hasher, leaves.clone());
241
242        _ = tree.proof(&[2, 3, 4, 6]);
243    }
244
245    #[rstest]
246    #[case::sha2(Sha256::default())]
247    #[case::blake3(Blake3::default())]
248    #[case::keccak(Keccak256::default())]
249    #[should_panic]
250    fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
251        let mut tree = MerkleTree::new(hasher.id());
252
253        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
254
255        tree.insert(&hasher, leaves.clone());
256
257        _ = tree.proof(&[2, 2, 3]);
258    }
259
260    #[rstest]
261    #[case::sha2(Sha256::default())]
262    #[case::blake3(Blake3::default())]
263    #[case::keccak(Keccak256::default())]
264    fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
265        let mut tree = MerkleTree::new(hasher.id());
266
267        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
268
269        tree.insert(&hasher, leaves.clone());
270
271        let proof = tree.proof(&[2, 3, 4]);
272
273        assert!(proof
274            .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
275            .is_err());
276    }
277
278    #[rstest]
279    #[case::sha2(Sha256::default())]
280    #[case::blake3(Blake3::default())]
281    #[case::keccak(Keccak256::default())]
282    fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
283        let mut tree = MerkleTree::new(hasher.id());
284
285        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
286
287        tree.insert(&hasher, leaves.clone());
288
289        let mut proof1 = tree.proof(&[2, 3, 4]);
290        let mut proof2 = proof1.clone();
291
292        // Increment the `leaf_count` field.
293        proof1.leaf_count += 1;
294        // Decrement the `leaf_count` field.
295        proof2.leaf_count -= 1;
296
297        // Fail because leaf count is wrong.
298        assert!(proof1
299            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
300            .is_err());
301
302        assert!(proof2
303            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
304            .is_err());
305    }
306
307    #[rstest]
308    #[case::sha2(Sha256::default())]
309    #[case::blake3(Blake3::default())]
310    #[case::keccak(Keccak256::default())]
311    fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
312        let mut tree = MerkleTree::new(hasher.id());
313
314        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
315
316        tree.insert(&hasher, leaves.clone());
317
318        let proof = tree.proof(&[2, 3, 4]);
319
320        let mut choices = choose_leaves([2, 3, 4], &leaves);
321        choices[1].0 = 1;
322
323        // Fail because leaf index is wrong.
324        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
325    }
326
327    #[rstest]
328    #[case::sha2(Sha256::default())]
329    #[case::blake3(Blake3::default())]
330    #[case::keccak(Keccak256::default())]
331    fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
332        let mut tree = MerkleTree::new(hasher.id());
333
334        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
335
336        tree.insert(&hasher, leaves.clone());
337
338        let proof = tree.proof(&[2, 3, 4]);
339
340        // Trying to verify less leaves than what was included in the proof.
341        assert!(proof
342            .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
343            .is_err());
344    }
345}