1use std::collections::HashSet;
4
5use serde::{Deserialize, Serialize};
6
7use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
8
9#[derive(Debug, thiserror::Error)]
11#[error("merkle error: {0}")]
12pub struct MerkleError(String);
13
14impl MerkleError {
15 fn new(msg: impl Into<String>) -> Self {
16 Self(msg.into())
17 }
18}
19
20#[derive(Clone, Serialize, Deserialize)]
22pub struct MerkleProof {
23 alg: HashAlgId,
24 leaf_count: usize,
25 proof: rs_merkle::MerkleProof<Hash>,
26}
27
28opaque_debug::implement!(MerkleProof);
29
30impl MerkleProof {
31 pub fn verify(
39 &self,
40 hasher: &dyn HashAlgorithm,
41 root: &TypedHash,
42 leaves: impl IntoIterator<Item = (usize, Hash)>,
43 ) -> Result<(), MerkleError> {
44 let mut leaves = leaves.into_iter().collect::<Vec<_>>();
45
46 leaves.sort_by_key(|(idx, _)| *idx);
48
49 let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
50
51 if indices.len() != indices.iter().collect::<HashSet<_>>().len() {
52 return Err(MerkleError::new("duplicate leaf indices provided"));
53 }
54
55 if !self.proof.verify(
56 &RsMerkleHasher(hasher),
57 root.value,
58 &indices,
59 &leaves,
60 self.leaf_count,
61 ) {
62 return Err(MerkleError::new("invalid merkle proof"));
63 }
64
65 Ok(())
66 }
67}
68
69#[derive(Clone)]
70struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
71
72impl rs_merkle::Hasher for RsMerkleHasher<'_> {
73 type Hash = Hash;
74
75 fn hash(&self, data: &[u8]) -> Hash {
76 self.0.hash(data)
77 }
78}
79
80#[derive(Clone, Serialize, Deserialize)]
82pub struct MerkleTree {
83 alg: HashAlgId,
84 tree: rs_merkle::MerkleTree<Hash>,
85}
86
87impl MerkleTree {
88 pub fn new(alg: HashAlgId) -> Self {
90 Self {
91 alg,
92 tree: Default::default(),
93 }
94 }
95
96 pub fn algorithm(&self) -> HashAlgId {
98 self.alg
99 }
100
101 pub fn root(&self) -> TypedHash {
103 TypedHash {
104 alg: self.alg,
105 value: self.tree.root().expect("tree should not be empty"),
106 }
107 }
108
109 pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
116 assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
117
118 self.tree.append(&mut leaves);
119 self.tree.commit(&RsMerkleHasher(hasher))
120 }
121
122 pub fn proof(&self, indices: &[usize]) -> MerkleProof {
129 assert!(
130 indices.windows(2).all(|w| w[0] < w[1]),
131 "indices must be unique and sorted"
132 );
133
134 assert!(
135 *indices.last().unwrap() < self.tree.leaves_len(),
136 "one or more provided indices are out of bounds"
137 );
138
139 MerkleProof {
140 alg: self.alg,
141 leaf_count: self.tree.leaves_len(),
142 proof: self.tree.proof(indices),
143 }
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use crate::hash::{Blake3, Keccak256, Sha256};
150
151 use super::*;
152 use rstest::*;
153
154 #[derive(Serialize)]
155 struct T(u64);
156
157 fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
158 leaves
159 .into_iter()
160 .map(|x| hasher.hash(&x.0.to_be_bytes()))
161 .collect()
162 }
163
164 fn choose_leaves(
165 indices: impl IntoIterator<Item = usize>,
166 leaves: &[Hash],
167 ) -> Vec<(usize, Hash)> {
168 indices.into_iter().map(|i| (i, leaves[i])).collect()
169 }
170
171 #[rstest]
173 #[case::sha2(Sha256::default())]
174 #[case::blake3(Blake3::default())]
175 #[case::keccak(Keccak256::default())]
176 fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
177 let mut tree = MerkleTree::new(hasher.id());
178
179 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
180
181 tree.insert(&hasher, leaves.clone());
182
183 let proof = tree.proof(&[2, 3, 4]);
184
185 assert!(
186 proof
187 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
188 .is_ok()
189 );
190 }
191
192 #[rstest]
193 #[case::sha2(Sha256::default())]
194 #[case::blake3(Blake3::default())]
195 #[case::keccak(Keccak256::default())]
196 fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
197 let mut tree = MerkleTree::new(hasher.id());
198
199 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
200
201 tree.insert(&hasher, leaves.clone());
202
203 let proof = tree.proof(&[2, 3, 4]);
204
205 let mut choices = choose_leaves([2, 3, 4], &leaves);
206
207 choices[1].1 = leaves[0];
208
209 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
211 }
212
213 #[rstest]
214 #[case::sha2(Sha256::default())]
215 #[case::blake3(Blake3::default())]
216 #[case::keccak(Keccak256::default())]
217 #[should_panic]
218 fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
219 let mut tree = MerkleTree::new(hasher.id());
220
221 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
222
223 tree.insert(&hasher, leaves.clone());
224
225 _ = tree.proof(&[2, 4, 3]);
226 }
227
228 #[rstest]
229 #[case::sha2(Sha256::default())]
230 #[case::blake3(Blake3::default())]
231 #[case::keccak(Keccak256::default())]
232 #[should_panic]
233 fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
234 let mut tree = MerkleTree::new(hasher.id());
235
236 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
237
238 tree.insert(&hasher, leaves.clone());
239
240 _ = tree.proof(&[2, 3, 4, 6]);
241 }
242
243 #[rstest]
244 #[case::sha2(Sha256::default())]
245 #[case::blake3(Blake3::default())]
246 #[case::keccak(Keccak256::default())]
247 #[should_panic]
248 fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
249 let mut tree = MerkleTree::new(hasher.id());
250
251 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
252
253 tree.insert(&hasher, leaves.clone());
254
255 _ = tree.proof(&[2, 2, 3]);
256 }
257
258 #[rstest]
259 #[case::sha2(Sha256::default())]
260 #[case::blake3(Blake3::default())]
261 #[case::keccak(Keccak256::default())]
262 fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
263 let mut tree = MerkleTree::new(hasher.id());
264
265 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
266
267 tree.insert(&hasher, leaves.clone());
268
269 let proof = tree.proof(&[2, 3, 4]);
270
271 assert!(
272 proof
273 .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
274 .is_err()
275 );
276 }
277
278 #[rstest]
279 #[case::sha2(Sha256::default())]
280 #[case::blake3(Blake3::default())]
281 #[case::keccak(Keccak256::default())]
282 fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
283 let mut tree = MerkleTree::new(hasher.id());
284
285 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
286
287 tree.insert(&hasher, leaves.clone());
288
289 let mut proof1 = tree.proof(&[2, 3, 4]);
290 let mut proof2 = proof1.clone();
291
292 proof1.leaf_count += 1;
294 proof2.leaf_count -= 1;
296
297 assert!(
299 proof1
300 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
301 .is_err()
302 );
303
304 assert!(
305 proof2
306 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
307 .is_err()
308 );
309 }
310
311 #[rstest]
312 #[case::sha2(Sha256::default())]
313 #[case::blake3(Blake3::default())]
314 #[case::keccak(Keccak256::default())]
315 fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
316 let mut tree = MerkleTree::new(hasher.id());
317
318 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
319
320 tree.insert(&hasher, leaves.clone());
321
322 let proof = tree.proof(&[2, 3, 4]);
323
324 let mut choices = choose_leaves([2, 3, 4], &leaves);
325 choices[1].0 = 1;
326
327 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
329 }
330
331 #[rstest]
332 #[case::sha2(Sha256::default())]
333 #[case::blake3(Blake3::default())]
334 #[case::keccak(Keccak256::default())]
335 fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
336 let mut tree = MerkleTree::new(hasher.id());
337
338 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
339
340 tree.insert(&hasher, leaves.clone());
341
342 let proof = tree.proof(&[2, 3, 4]);
343
344 assert!(
346 proof
347 .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
348 .is_err()
349 );
350 }
351}