1use serde::{Deserialize, Serialize};
4use utils::iter::DuplicateCheck;
5
6use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
7
8#[derive(Debug, thiserror::Error)]
10#[error("merkle error: {0}")]
11pub struct MerkleError(String);
12
13impl MerkleError {
14 fn new(msg: impl Into<String>) -> Self {
15 Self(msg.into())
16 }
17}
18
19#[derive(Clone, Serialize, Deserialize)]
21pub struct MerkleProof {
22 alg: HashAlgId,
23 leaf_count: usize,
24 proof: rs_merkle::MerkleProof<Hash>,
25}
26
27opaque_debug::implement!(MerkleProof);
28
29impl MerkleProof {
30 pub fn verify(
38 &self,
39 hasher: &dyn HashAlgorithm,
40 root: &TypedHash,
41 leaves: impl IntoIterator<Item = (usize, Hash)>,
42 ) -> Result<(), MerkleError> {
43 let mut leaves = leaves.into_iter().collect::<Vec<_>>();
44
45 leaves.sort_by_key(|(idx, _)| *idx);
47
48 let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
49
50 if indices.iter().contains_dups() {
51 return Err(MerkleError::new("duplicate leaf indices provided"));
52 }
53
54 if !self.proof.verify(
55 &RsMerkleHasher(hasher),
56 root.value,
57 &indices,
58 &leaves,
59 self.leaf_count,
60 ) {
61 return Err(MerkleError::new("invalid merkle proof"));
62 }
63
64 Ok(())
65 }
66}
67
68#[derive(Clone)]
69struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
70
71impl rs_merkle::Hasher for RsMerkleHasher<'_> {
72 type Hash = Hash;
73
74 fn hash(&self, data: &[u8]) -> Hash {
75 self.0.hash(data)
76 }
77}
78
79#[derive(Clone, Serialize, Deserialize)]
81pub struct MerkleTree {
82 alg: HashAlgId,
83 tree: rs_merkle::MerkleTree<Hash>,
84}
85
86impl MerkleTree {
87 pub fn new(alg: HashAlgId) -> Self {
89 Self {
90 alg,
91 tree: Default::default(),
92 }
93 }
94
95 pub fn algorithm(&self) -> HashAlgId {
97 self.alg
98 }
99
100 pub fn root(&self) -> TypedHash {
102 TypedHash {
103 alg: self.alg,
104 value: self.tree.root().expect("tree should not be empty"),
105 }
106 }
107
108 pub fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
115 assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
116
117 self.tree.append(&mut leaves);
118 self.tree.commit(&RsMerkleHasher(hasher))
119 }
120
121 pub fn proof(&self, indices: &[usize]) -> MerkleProof {
128 assert!(
129 indices.windows(2).all(|w| w[0] < w[1]),
130 "indices must be unique and sorted"
131 );
132
133 assert!(
134 *indices.last().unwrap() < self.tree.leaves_len(),
135 "one or more provided indices are out of bounds"
136 );
137
138 MerkleProof {
139 alg: self.alg,
140 leaf_count: self.tree.leaves_len(),
141 proof: self.tree.proof(indices),
142 }
143 }
144}
145
146#[cfg(test)]
147mod test {
148 use crate::hash::{Blake3, Keccak256, Sha256};
149
150 use super::*;
151 use rstest::*;
152
153 #[derive(Serialize)]
154 struct T(u64);
155
156 fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
157 leaves
158 .into_iter()
159 .map(|x| hasher.hash(&x.0.to_be_bytes()))
160 .collect()
161 }
162
163 fn choose_leaves(
164 indices: impl IntoIterator<Item = usize>,
165 leaves: &[Hash],
166 ) -> Vec<(usize, Hash)> {
167 indices.into_iter().map(|i| (i, leaves[i])).collect()
168 }
169
170 #[rstest]
172 #[case::sha2(Sha256::default())]
173 #[case::blake3(Blake3::default())]
174 #[case::keccak(Keccak256::default())]
175 fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
176 let mut tree = MerkleTree::new(hasher.id());
177
178 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
179
180 tree.insert(&hasher, leaves.clone());
181
182 let proof = tree.proof(&[2, 3, 4]);
183
184 assert!(proof
185 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
186 .is_ok());
187 }
188
189 #[rstest]
190 #[case::sha2(Sha256::default())]
191 #[case::blake3(Blake3::default())]
192 #[case::keccak(Keccak256::default())]
193 fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
194 let mut tree = MerkleTree::new(hasher.id());
195
196 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
197
198 tree.insert(&hasher, leaves.clone());
199
200 let proof = tree.proof(&[2, 3, 4]);
201
202 let mut choices = choose_leaves([2, 3, 4], &leaves);
203
204 choices[1].1 = leaves[0];
205
206 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
208 }
209
210 #[rstest]
211 #[case::sha2(Sha256::default())]
212 #[case::blake3(Blake3::default())]
213 #[case::keccak(Keccak256::default())]
214 #[should_panic]
215 fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
216 let mut tree = MerkleTree::new(hasher.id());
217
218 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
219
220 tree.insert(&hasher, leaves.clone());
221
222 _ = tree.proof(&[2, 4, 3]);
223 }
224
225 #[rstest]
226 #[case::sha2(Sha256::default())]
227 #[case::blake3(Blake3::default())]
228 #[case::keccak(Keccak256::default())]
229 #[should_panic]
230 fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
231 let mut tree = MerkleTree::new(hasher.id());
232
233 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
234
235 tree.insert(&hasher, leaves.clone());
236
237 _ = tree.proof(&[2, 3, 4, 6]);
238 }
239
240 #[rstest]
241 #[case::sha2(Sha256::default())]
242 #[case::blake3(Blake3::default())]
243 #[case::keccak(Keccak256::default())]
244 #[should_panic]
245 fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
246 let mut tree = MerkleTree::new(hasher.id());
247
248 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
249
250 tree.insert(&hasher, leaves.clone());
251
252 _ = tree.proof(&[2, 2, 3]);
253 }
254
255 #[rstest]
256 #[case::sha2(Sha256::default())]
257 #[case::blake3(Blake3::default())]
258 #[case::keccak(Keccak256::default())]
259 fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
260 let mut tree = MerkleTree::new(hasher.id());
261
262 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
263
264 tree.insert(&hasher, leaves.clone());
265
266 let proof = tree.proof(&[2, 3, 4]);
267
268 assert!(proof
269 .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
270 .is_err());
271 }
272
273 #[rstest]
274 #[case::sha2(Sha256::default())]
275 #[case::blake3(Blake3::default())]
276 #[case::keccak(Keccak256::default())]
277 fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
278 let mut tree = MerkleTree::new(hasher.id());
279
280 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
281
282 tree.insert(&hasher, leaves.clone());
283
284 let mut proof1 = tree.proof(&[2, 3, 4]);
285 let mut proof2 = proof1.clone();
286
287 proof1.leaf_count += 1;
289 proof2.leaf_count -= 1;
291
292 assert!(proof1
294 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
295 .is_err());
296
297 assert!(proof2
298 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
299 .is_err());
300 }
301
302 #[rstest]
303 #[case::sha2(Sha256::default())]
304 #[case::blake3(Blake3::default())]
305 #[case::keccak(Keccak256::default())]
306 fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
307 let mut tree = MerkleTree::new(hasher.id());
308
309 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
310
311 tree.insert(&hasher, leaves.clone());
312
313 let proof = tree.proof(&[2, 3, 4]);
314
315 let mut choices = choose_leaves([2, 3, 4], &leaves);
316 choices[1].0 = 1;
317
318 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
320 }
321
322 #[rstest]
323 #[case::sha2(Sha256::default())]
324 #[case::blake3(Blake3::default())]
325 #[case::keccak(Keccak256::default())]
326 fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
327 let mut tree = MerkleTree::new(hasher.id());
328
329 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
330
331 tree.insert(&hasher, leaves.clone());
332
333 let proof = tree.proof(&[2, 3, 4]);
334
335 assert!(proof
337 .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
338 .is_err());
339 }
340}