1use serde::{Deserialize, Serialize};
4use utils::iter::DuplicateCheck;
5
6use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
7
8#[derive(Debug, thiserror::Error)]
10#[error("merkle error: {0}")]
11pub struct MerkleError(String);
12
13impl MerkleError {
14 fn new(msg: impl Into<String>) -> Self {
15 Self(msg.into())
16 }
17}
18
19#[derive(Clone, Serialize, Deserialize)]
21pub struct MerkleProof {
22 alg: HashAlgId,
23 leaf_count: usize,
24 proof: rs_merkle::MerkleProof<Hash>,
25}
26
27opaque_debug::implement!(MerkleProof);
28
29impl MerkleProof {
30 pub fn verify(
38 &self,
39 hasher: &dyn HashAlgorithm,
40 root: &TypedHash,
41 leaves: impl IntoIterator<Item = (usize, Hash)>,
42 ) -> Result<(), MerkleError> {
43 let mut leaves = leaves.into_iter().collect::<Vec<_>>();
44
45 leaves.sort_by_key(|(idx, _)| *idx);
47
48 let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
49
50 if indices.iter().contains_dups() {
51 return Err(MerkleError::new("duplicate leaf indices provided"));
52 }
53
54 if !self.proof.verify(
55 &RsMerkleHasher(hasher),
56 root.value,
57 &indices,
58 &leaves,
59 self.leaf_count,
60 ) {
61 return Err(MerkleError::new("invalid merkle proof"));
62 }
63
64 Ok(())
65 }
66
67 pub(crate) fn leaf_count(&self) -> usize {
69 self.leaf_count
70 }
71}
72
73#[derive(Clone)]
74struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
75
76impl rs_merkle::Hasher for RsMerkleHasher<'_> {
77 type Hash = Hash;
78
79 fn hash(&self, data: &[u8]) -> Hash {
80 self.0.hash(data)
81 }
82}
83
84#[derive(Clone, Serialize, Deserialize)]
86pub struct MerkleTree {
87 alg: HashAlgId,
88 tree: rs_merkle::MerkleTree<Hash>,
89}
90
91impl MerkleTree {
92 pub fn new(alg: HashAlgId) -> Self {
94 Self {
95 alg,
96 tree: Default::default(),
97 }
98 }
99
100 pub fn algorithm(&self) -> HashAlgId {
102 self.alg
103 }
104
105 pub fn root(&self) -> TypedHash {
107 TypedHash {
108 alg: self.alg,
109 value: self.tree.root().expect("tree should not be empty"),
110 }
111 }
112
113 pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
120 assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
121
122 self.tree.append(&mut leaves);
123 self.tree.commit(&RsMerkleHasher(hasher))
124 }
125
126 pub fn proof(&self, indices: &[usize]) -> MerkleProof {
133 assert!(
134 indices.windows(2).all(|w| w[0] < w[1]),
135 "indices must be unique and sorted"
136 );
137
138 assert!(
139 *indices.last().unwrap() < self.tree.leaves_len(),
140 "one or more provided indices are out of bounds"
141 );
142
143 MerkleProof {
144 alg: self.alg,
145 leaf_count: self.tree.leaves_len(),
146 proof: self.tree.proof(indices),
147 }
148 }
149}
150
151#[cfg(test)]
152mod test {
153 use crate::hash::{Blake3, Keccak256, Sha256};
154
155 use super::*;
156 use rstest::*;
157
158 #[derive(Serialize)]
159 struct T(u64);
160
161 fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
162 leaves
163 .into_iter()
164 .map(|x| hasher.hash(&x.0.to_be_bytes()))
165 .collect()
166 }
167
168 fn choose_leaves(
169 indices: impl IntoIterator<Item = usize>,
170 leaves: &[Hash],
171 ) -> Vec<(usize, Hash)> {
172 indices.into_iter().map(|i| (i, leaves[i])).collect()
173 }
174
175 #[rstest]
177 #[case::sha2(Sha256::default())]
178 #[case::blake3(Blake3::default())]
179 #[case::keccak(Keccak256::default())]
180 fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
181 let mut tree = MerkleTree::new(hasher.id());
182
183 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
184
185 tree.insert(&hasher, leaves.clone());
186
187 let proof = tree.proof(&[2, 3, 4]);
188
189 assert!(proof
190 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
191 .is_ok());
192 }
193
194 #[rstest]
195 #[case::sha2(Sha256::default())]
196 #[case::blake3(Blake3::default())]
197 #[case::keccak(Keccak256::default())]
198 fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
199 let mut tree = MerkleTree::new(hasher.id());
200
201 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
202
203 tree.insert(&hasher, leaves.clone());
204
205 let proof = tree.proof(&[2, 3, 4]);
206
207 let mut choices = choose_leaves([2, 3, 4], &leaves);
208
209 choices[1].1 = leaves[0];
210
211 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
213 }
214
215 #[rstest]
216 #[case::sha2(Sha256::default())]
217 #[case::blake3(Blake3::default())]
218 #[case::keccak(Keccak256::default())]
219 #[should_panic]
220 fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
221 let mut tree = MerkleTree::new(hasher.id());
222
223 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
224
225 tree.insert(&hasher, leaves.clone());
226
227 _ = tree.proof(&[2, 4, 3]);
228 }
229
230 #[rstest]
231 #[case::sha2(Sha256::default())]
232 #[case::blake3(Blake3::default())]
233 #[case::keccak(Keccak256::default())]
234 #[should_panic]
235 fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
236 let mut tree = MerkleTree::new(hasher.id());
237
238 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
239
240 tree.insert(&hasher, leaves.clone());
241
242 _ = tree.proof(&[2, 3, 4, 6]);
243 }
244
245 #[rstest]
246 #[case::sha2(Sha256::default())]
247 #[case::blake3(Blake3::default())]
248 #[case::keccak(Keccak256::default())]
249 #[should_panic]
250 fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
251 let mut tree = MerkleTree::new(hasher.id());
252
253 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
254
255 tree.insert(&hasher, leaves.clone());
256
257 _ = tree.proof(&[2, 2, 3]);
258 }
259
260 #[rstest]
261 #[case::sha2(Sha256::default())]
262 #[case::blake3(Blake3::default())]
263 #[case::keccak(Keccak256::default())]
264 fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
265 let mut tree = MerkleTree::new(hasher.id());
266
267 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
268
269 tree.insert(&hasher, leaves.clone());
270
271 let proof = tree.proof(&[2, 3, 4]);
272
273 assert!(proof
274 .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
275 .is_err());
276 }
277
278 #[rstest]
279 #[case::sha2(Sha256::default())]
280 #[case::blake3(Blake3::default())]
281 #[case::keccak(Keccak256::default())]
282 fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
283 let mut tree = MerkleTree::new(hasher.id());
284
285 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
286
287 tree.insert(&hasher, leaves.clone());
288
289 let mut proof1 = tree.proof(&[2, 3, 4]);
290 let mut proof2 = proof1.clone();
291
292 proof1.leaf_count += 1;
294 proof2.leaf_count -= 1;
296
297 assert!(proof1
299 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
300 .is_err());
301
302 assert!(proof2
303 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
304 .is_err());
305 }
306
307 #[rstest]
308 #[case::sha2(Sha256::default())]
309 #[case::blake3(Blake3::default())]
310 #[case::keccak(Keccak256::default())]
311 fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
312 let mut tree = MerkleTree::new(hasher.id());
313
314 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
315
316 tree.insert(&hasher, leaves.clone());
317
318 let proof = tree.proof(&[2, 3, 4]);
319
320 let mut choices = choose_leaves([2, 3, 4], &leaves);
321 choices[1].0 = 1;
322
323 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
325 }
326
327 #[rstest]
328 #[case::sha2(Sha256::default())]
329 #[case::blake3(Blake3::default())]
330 #[case::keccak(Keccak256::default())]
331 fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
332 let mut tree = MerkleTree::new(hasher.id());
333
334 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
335
336 tree.insert(&hasher, leaves.clone());
337
338 let proof = tree.proof(&[2, 3, 4]);
339
340 assert!(proof
342 .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
343 .is_err());
344 }
345}