Skip to main content

tlsn_core/
merkle.rs

1//! Merkle tree types.
2
3use serde::{Deserialize, Serialize};
4use utils::iter::DuplicateCheck;
5
6use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
7
8/// Merkle tree error.
9#[derive(Debug, thiserror::Error)]
10#[error("merkle error: {0}")]
11pub struct MerkleError(String);
12
13impl MerkleError {
14    fn new(msg: impl Into<String>) -> Self {
15        Self(msg.into())
16    }
17}
18
19/// Merkle proof.
20#[derive(Clone, Serialize, Deserialize)]
21pub struct MerkleProof {
22    alg: HashAlgId,
23    leaf_count: usize,
24    proof: rs_merkle::MerkleProof<Hash>,
25}
26
27opaque_debug::implement!(MerkleProof);
28
29impl MerkleProof {
30    /// Checks if the counts of indices, hashes, and leaves are valid for the
31    /// provided root.
32    ///
33    /// # Panics
34    ///
35    /// - If the length of `leaf_indices` and `leaf_hashes` does not match.
36    /// - If `leaf_indices` contains duplicates.
37    pub fn verify(
38        &self,
39        hasher: &dyn HashAlgorithm,
40        root: &TypedHash,
41        leaves: impl IntoIterator<Item = (usize, Hash)>,
42    ) -> Result<(), MerkleError> {
43        let mut leaves = leaves.into_iter().collect::<Vec<_>>();
44
45        // Sort by index
46        leaves.sort_by_key(|(idx, _)| *idx);
47
48        let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
49
50        if indices.iter().contains_dups() {
51            return Err(MerkleError::new("duplicate leaf indices provided"));
52        }
53
54        if !self.proof.verify(
55            &RsMerkleHasher(hasher),
56            root.value,
57            &indices,
58            &leaves,
59            self.leaf_count,
60        ) {
61            return Err(MerkleError::new("invalid merkle proof"));
62        }
63
64        Ok(())
65    }
66}
67
68#[derive(Clone)]
69struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
70
71impl rs_merkle::Hasher for RsMerkleHasher<'_> {
72    type Hash = Hash;
73
74    fn hash(&self, data: &[u8]) -> Hash {
75        self.0.hash(data)
76    }
77}
78
79/// Merkle tree.
80#[derive(Clone, Serialize, Deserialize)]
81pub struct MerkleTree {
82    alg: HashAlgId,
83    tree: rs_merkle::MerkleTree<Hash>,
84}
85
86impl MerkleTree {
87    /// Creates a new Merkle tree.
88    pub fn new(alg: HashAlgId) -> Self {
89        Self {
90            alg,
91            tree: Default::default(),
92        }
93    }
94
95    /// Returns the hash algorithm used to create the tree.
96    pub fn algorithm(&self) -> HashAlgId {
97        self.alg
98    }
99
100    /// Returns the root of the tree.
101    pub fn root(&self) -> TypedHash {
102        TypedHash {
103            alg: self.alg,
104            value: self.tree.root().expect("tree should not be empty"),
105        }
106    }
107
108    /// Inserts leaves into the tree.
109    ///
110    /// # Panics
111    ///
112    /// - If the provided hasher is not the same as the one used to create the
113    ///   tree.
114    pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
115        assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
116
117        self.tree.append(&mut leaves);
118        self.tree.commit(&RsMerkleHasher(hasher))
119    }
120
121    /// Returns a Merkle proof for the provided indices.
122    ///
123    /// # Panics
124    ///
125    /// - If the provided indices are not unique and sorted.
126    /// - If the provided indices are out of bounds.
127    pub fn proof(&self, indices: &[usize]) -> MerkleProof {
128        assert!(
129            indices.windows(2).all(|w| w[0] < w[1]),
130            "indices must be unique and sorted"
131        );
132
133        assert!(
134            *indices.last().unwrap() < self.tree.leaves_len(),
135            "one or more provided indices are out of bounds"
136        );
137
138        MerkleProof {
139            alg: self.alg,
140            leaf_count: self.tree.leaves_len(),
141            proof: self.tree.proof(indices),
142        }
143    }
144}
145
146#[cfg(test)]
147mod test {
148    use crate::hash::{Blake3, Keccak256, Sha256};
149
150    use super::*;
151    use rstest::*;
152
153    #[derive(Serialize)]
154    struct T(u64);
155
156    fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
157        leaves
158            .into_iter()
159            .map(|x| hasher.hash(&x.0.to_be_bytes()))
160            .collect()
161    }
162
163    fn choose_leaves(
164        indices: impl IntoIterator<Item = usize>,
165        leaves: &[Hash],
166    ) -> Vec<(usize, Hash)> {
167        indices.into_iter().map(|i| (i, leaves[i])).collect()
168    }
169
170    // Expect Merkle proof verification to succeed.
171    #[rstest]
172    #[case::sha2(Sha256::default())]
173    #[case::blake3(Blake3::default())]
174    #[case::keccak(Keccak256::default())]
175    fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
176        let mut tree = MerkleTree::new(hasher.id());
177
178        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
179
180        tree.insert(&hasher, leaves.clone());
181
182        let proof = tree.proof(&[2, 3, 4]);
183
184        assert!(proof
185            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
186            .is_ok());
187    }
188
189    #[rstest]
190    #[case::sha2(Sha256::default())]
191    #[case::blake3(Blake3::default())]
192    #[case::keccak(Keccak256::default())]
193    fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
194        let mut tree = MerkleTree::new(hasher.id());
195
196        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
197
198        tree.insert(&hasher, leaves.clone());
199
200        let proof = tree.proof(&[2, 3, 4]);
201
202        let mut choices = choose_leaves([2, 3, 4], &leaves);
203
204        choices[1].1 = leaves[0];
205
206        // Fail because the leaf is wrong.
207        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
208    }
209
210    #[rstest]
211    #[case::sha2(Sha256::default())]
212    #[case::blake3(Blake3::default())]
213    #[case::keccak(Keccak256::default())]
214    #[should_panic]
215    fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
216        let mut tree = MerkleTree::new(hasher.id());
217
218        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
219
220        tree.insert(&hasher, leaves.clone());
221
222        _ = tree.proof(&[2, 4, 3]);
223    }
224
225    #[rstest]
226    #[case::sha2(Sha256::default())]
227    #[case::blake3(Blake3::default())]
228    #[case::keccak(Keccak256::default())]
229    #[should_panic]
230    fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
231        let mut tree = MerkleTree::new(hasher.id());
232
233        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
234
235        tree.insert(&hasher, leaves.clone());
236
237        _ = tree.proof(&[2, 3, 4, 6]);
238    }
239
240    #[rstest]
241    #[case::sha2(Sha256::default())]
242    #[case::blake3(Blake3::default())]
243    #[case::keccak(Keccak256::default())]
244    #[should_panic]
245    fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
246        let mut tree = MerkleTree::new(hasher.id());
247
248        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
249
250        tree.insert(&hasher, leaves.clone());
251
252        _ = tree.proof(&[2, 2, 3]);
253    }
254
255    #[rstest]
256    #[case::sha2(Sha256::default())]
257    #[case::blake3(Blake3::default())]
258    #[case::keccak(Keccak256::default())]
259    fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
260        let mut tree = MerkleTree::new(hasher.id());
261
262        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
263
264        tree.insert(&hasher, leaves.clone());
265
266        let proof = tree.proof(&[2, 3, 4]);
267
268        assert!(proof
269            .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
270            .is_err());
271    }
272
273    #[rstest]
274    #[case::sha2(Sha256::default())]
275    #[case::blake3(Blake3::default())]
276    #[case::keccak(Keccak256::default())]
277    fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
278        let mut tree = MerkleTree::new(hasher.id());
279
280        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
281
282        tree.insert(&hasher, leaves.clone());
283
284        let mut proof1 = tree.proof(&[2, 3, 4]);
285        let mut proof2 = proof1.clone();
286
287        // Increment the `leaf_count` field.
288        proof1.leaf_count += 1;
289        // Decrement the `leaf_count` field.
290        proof2.leaf_count -= 1;
291
292        // Fail because leaf count is wrong.
293        assert!(proof1
294            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
295            .is_err());
296
297        assert!(proof2
298            .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
299            .is_err());
300    }
301
302    #[rstest]
303    #[case::sha2(Sha256::default())]
304    #[case::blake3(Blake3::default())]
305    #[case::keccak(Keccak256::default())]
306    fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
307        let mut tree = MerkleTree::new(hasher.id());
308
309        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
310
311        tree.insert(&hasher, leaves.clone());
312
313        let proof = tree.proof(&[2, 3, 4]);
314
315        let mut choices = choose_leaves([2, 3, 4], &leaves);
316        choices[1].0 = 1;
317
318        // Fail because leaf index is wrong.
319        assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
320    }
321
322    #[rstest]
323    #[case::sha2(Sha256::default())]
324    #[case::blake3(Blake3::default())]
325    #[case::keccak(Keccak256::default())]
326    fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
327        let mut tree = MerkleTree::new(hasher.id());
328
329        let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
330
331        tree.insert(&hasher, leaves.clone());
332
333        let proof = tree.proof(&[2, 3, 4]);
334
335        // Trying to verify less leaves than what was included in the proof.
336        assert!(proof
337            .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
338            .is_err());
339    }
340}