1use std::collections::HashSet;
4
5use serde::{Deserialize, Serialize};
6
7use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
8
9#[derive(Debug, thiserror::Error)]
11#[error("merkle error: {0}")]
12pub struct MerkleError(String);
13
14impl MerkleError {
15 fn new(msg: impl Into<String>) -> Self {
16 Self(msg.into())
17 }
18}
19
20#[derive(Clone, Serialize, Deserialize)]
22pub struct MerkleProof {
23 alg: HashAlgId,
24 leaf_count: usize,
25 proof: rs_merkle::MerkleProof<Hash>,
26}
27
28opaque_debug::implement!(MerkleProof);
29
30impl MerkleProof {
31 pub fn verify(
39 &self,
40 hasher: &dyn HashAlgorithm,
41 root: &TypedHash,
42 leaves: impl IntoIterator<Item = (usize, Hash)>,
43 ) -> Result<(), MerkleError> {
44 let mut leaves = leaves.into_iter().collect::<Vec<_>>();
45
46 leaves.sort_by_key(|(idx, _)| *idx);
48
49 let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
50
51 if indices.len() != indices.iter().collect::<HashSet<_>>().len() {
52 return Err(MerkleError::new("duplicate leaf indices provided"));
53 }
54
55 if !self.proof.verify(
56 &RsMerkleHasher(hasher),
57 root.value,
58 &indices,
59 &leaves,
60 self.leaf_count,
61 ) {
62 return Err(MerkleError::new("invalid merkle proof"));
63 }
64
65 Ok(())
66 }
67}
68
69#[derive(Clone)]
70struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
71
72impl rs_merkle::Hasher for RsMerkleHasher<'_> {
73 type Hash = Hash;
74
75 fn hash(&self, data: &[u8]) -> Hash {
76 self.0.hash(data)
77 }
78}
79
80#[derive(Clone, Serialize, Deserialize)]
82pub struct MerkleTree {
83 alg: HashAlgId,
84 tree: rs_merkle::MerkleTree<Hash>,
85}
86
87impl MerkleTree {
88 pub fn new(alg: HashAlgId) -> Self {
90 Self {
91 alg,
92 tree: Default::default(),
93 }
94 }
95
96 pub fn algorithm(&self) -> HashAlgId {
98 self.alg
99 }
100
101 pub fn root(&self) -> TypedHash {
103 TypedHash {
104 alg: self.alg,
105 value: self.tree.root().expect("tree should not be empty"),
106 }
107 }
108
109 pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
116 assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
117
118 self.tree.append(&mut leaves);
119 self.tree.commit(&RsMerkleHasher(hasher))
120 }
121
122 pub fn proof(&self, indices: &[usize]) -> MerkleProof {
129 assert!(
130 indices.windows(2).all(|w| w[0] < w[1]),
131 "indices must be unique and sorted"
132 );
133
134 assert!(
135 *indices.last().unwrap() < self.tree.leaves_len(),
136 "one or more provided indices are out of bounds"
137 );
138
139 MerkleProof {
140 alg: self.alg,
141 leaf_count: self.tree.leaves_len(),
142 proof: self.tree.proof(indices),
143 }
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use crate::hash::{Blake3, Keccak256, Sha256};
150
151 use super::*;
152 use rstest::*;
153
154 #[derive(Serialize)]
155 struct T(u64);
156
157 fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
158 leaves
159 .into_iter()
160 .map(|x| hasher.hash(&x.0.to_be_bytes()))
161 .collect()
162 }
163
164 fn choose_leaves(
165 indices: impl IntoIterator<Item = usize>,
166 leaves: &[Hash],
167 ) -> Vec<(usize, Hash)> {
168 indices.into_iter().map(|i| (i, leaves[i])).collect()
169 }
170
171 #[rstest]
173 #[case::sha2(Sha256::default())]
174 #[case::blake3(Blake3::default())]
175 #[case::keccak(Keccak256::default())]
176 fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
177 let mut tree = MerkleTree::new(hasher.id());
178
179 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
180
181 tree.insert(&hasher, leaves.clone());
182
183 let proof = tree.proof(&[2, 3, 4]);
184
185 assert!(proof
186 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
187 .is_ok());
188 }
189
190 #[rstest]
191 #[case::sha2(Sha256::default())]
192 #[case::blake3(Blake3::default())]
193 #[case::keccak(Keccak256::default())]
194 fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
195 let mut tree = MerkleTree::new(hasher.id());
196
197 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
198
199 tree.insert(&hasher, leaves.clone());
200
201 let proof = tree.proof(&[2, 3, 4]);
202
203 let mut choices = choose_leaves([2, 3, 4], &leaves);
204
205 choices[1].1 = leaves[0];
206
207 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
209 }
210
211 #[rstest]
212 #[case::sha2(Sha256::default())]
213 #[case::blake3(Blake3::default())]
214 #[case::keccak(Keccak256::default())]
215 #[should_panic]
216 fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
217 let mut tree = MerkleTree::new(hasher.id());
218
219 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
220
221 tree.insert(&hasher, leaves.clone());
222
223 _ = tree.proof(&[2, 4, 3]);
224 }
225
226 #[rstest]
227 #[case::sha2(Sha256::default())]
228 #[case::blake3(Blake3::default())]
229 #[case::keccak(Keccak256::default())]
230 #[should_panic]
231 fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
232 let mut tree = MerkleTree::new(hasher.id());
233
234 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
235
236 tree.insert(&hasher, leaves.clone());
237
238 _ = tree.proof(&[2, 3, 4, 6]);
239 }
240
241 #[rstest]
242 #[case::sha2(Sha256::default())]
243 #[case::blake3(Blake3::default())]
244 #[case::keccak(Keccak256::default())]
245 #[should_panic]
246 fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
247 let mut tree = MerkleTree::new(hasher.id());
248
249 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
250
251 tree.insert(&hasher, leaves.clone());
252
253 _ = tree.proof(&[2, 2, 3]);
254 }
255
256 #[rstest]
257 #[case::sha2(Sha256::default())]
258 #[case::blake3(Blake3::default())]
259 #[case::keccak(Keccak256::default())]
260 fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
261 let mut tree = MerkleTree::new(hasher.id());
262
263 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
264
265 tree.insert(&hasher, leaves.clone());
266
267 let proof = tree.proof(&[2, 3, 4]);
268
269 assert!(proof
270 .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
271 .is_err());
272 }
273
274 #[rstest]
275 #[case::sha2(Sha256::default())]
276 #[case::blake3(Blake3::default())]
277 #[case::keccak(Keccak256::default())]
278 fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
279 let mut tree = MerkleTree::new(hasher.id());
280
281 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
282
283 tree.insert(&hasher, leaves.clone());
284
285 let mut proof1 = tree.proof(&[2, 3, 4]);
286 let mut proof2 = proof1.clone();
287
288 proof1.leaf_count += 1;
290 proof2.leaf_count -= 1;
292
293 assert!(proof1
295 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
296 .is_err());
297
298 assert!(proof2
299 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
300 .is_err());
301 }
302
303 #[rstest]
304 #[case::sha2(Sha256::default())]
305 #[case::blake3(Blake3::default())]
306 #[case::keccak(Keccak256::default())]
307 fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
308 let mut tree = MerkleTree::new(hasher.id());
309
310 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
311
312 tree.insert(&hasher, leaves.clone());
313
314 let proof = tree.proof(&[2, 3, 4]);
315
316 let mut choices = choose_leaves([2, 3, 4], &leaves);
317 choices[1].0 = 1;
318
319 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
321 }
322
323 #[rstest]
324 #[case::sha2(Sha256::default())]
325 #[case::blake3(Blake3::default())]
326 #[case::keccak(Keccak256::default())]
327 fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
328 let mut tree = MerkleTree::new(hasher.id());
329
330 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
331
332 tree.insert(&hasher, leaves.clone());
333
334 let proof = tree.proof(&[2, 3, 4]);
335
336 assert!(proof
338 .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
339 .is_err());
340 }
341}