1use serde::{Deserialize, Serialize};
4use utils::iter::DuplicateCheck;
5
6use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash};
7
8#[derive(Debug, thiserror::Error)]
10#[error("merkle error: {0}")]
11pub(crate) struct MerkleError(String);
12
13impl MerkleError {
14 fn new(msg: impl Into<String>) -> Self {
15 Self(msg.into())
16 }
17}
18
19#[derive(Clone, Serialize, Deserialize)]
20pub(crate) struct MerkleProof {
21 alg: HashAlgId,
22 leaf_count: usize,
23 proof: rs_merkle::MerkleProof<Hash>,
24}
25
26opaque_debug::implement!(MerkleProof);
27
28impl MerkleProof {
29 pub(crate) fn verify(
37 &self,
38 hasher: &dyn HashAlgorithm,
39 root: &TypedHash,
40 leaves: impl IntoIterator<Item = (usize, Hash)>,
41 ) -> Result<(), MerkleError> {
42 let mut leaves = leaves.into_iter().collect::<Vec<_>>();
43
44 leaves.sort_by_key(|(idx, _)| *idx);
46
47 let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip();
48
49 if indices.iter().contains_dups() {
50 return Err(MerkleError::new("duplicate leaf indices provided"));
51 }
52
53 if !self.proof.verify(
54 &RsMerkleHasher(hasher),
55 root.value,
56 &indices,
57 &leaves,
58 self.leaf_count,
59 ) {
60 return Err(MerkleError::new("invalid merkle proof"));
61 }
62
63 Ok(())
64 }
65
66 pub(crate) fn leaf_count(&self) -> usize {
68 self.leaf_count
69 }
70}
71
72#[derive(Clone)]
73struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm);
74
75impl rs_merkle::Hasher for RsMerkleHasher<'_> {
76 type Hash = Hash;
77
78 fn hash(&self, data: &[u8]) -> Hash {
79 self.0.hash(data)
80 }
81}
82
83#[derive(Clone, Serialize, Deserialize)]
84pub(crate) struct MerkleTree {
85 alg: HashAlgId,
86 tree: rs_merkle::MerkleTree<Hash>,
87}
88
89impl MerkleTree {
90 pub(crate) fn new(alg: HashAlgId) -> Self {
91 Self {
92 alg,
93 tree: Default::default(),
94 }
95 }
96
97 pub(crate) fn algorithm(&self) -> HashAlgId {
98 self.alg
99 }
100
101 pub(crate) fn root(&self) -> TypedHash {
102 TypedHash {
103 alg: self.alg,
104 value: self.tree.root().expect("tree should not be empty"),
105 }
106 }
107
108 pub(crate) fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec<Hash>) {
115 assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch");
116
117 self.tree.append(&mut leaves);
118 self.tree.commit(&RsMerkleHasher(hasher))
119 }
120
121 pub(crate) fn proof(&self, indices: &[usize]) -> MerkleProof {
128 assert!(
129 indices.windows(2).all(|w| w[0] < w[1]),
130 "indices must be unique and sorted"
131 );
132
133 assert!(
134 *indices.last().unwrap() < self.tree.leaves_len(),
135 "one or more provided indices are out of bounds"
136 );
137
138 MerkleProof {
139 alg: self.alg,
140 leaf_count: self.tree.leaves_len(),
141 proof: self.tree.proof(indices),
142 }
143 }
144}
145
146#[cfg(test)]
147mod test {
148 use crate::hash::{impl_domain_separator, Blake3, HashAlgorithmExt, Keccak256, Sha256};
149
150 use super::*;
151 use rstest::*;
152
153 #[derive(Serialize)]
154 struct T(u64);
155
156 impl_domain_separator!(T);
157
158 fn leaves<H: HashAlgorithm>(hasher: &H, leaves: impl IntoIterator<Item = T>) -> Vec<Hash> {
159 leaves
160 .into_iter()
161 .map(|x| hasher.hash_canonical(&x))
162 .collect()
163 }
164
165 fn choose_leaves(
166 indices: impl IntoIterator<Item = usize>,
167 leaves: &[Hash],
168 ) -> Vec<(usize, Hash)> {
169 indices.into_iter().map(|i| (i, leaves[i])).collect()
170 }
171
172 #[rstest]
174 #[case::sha2(Sha256::default())]
175 #[case::blake3(Blake3::default())]
176 #[case::keccak(Keccak256::default())]
177 fn test_verify_success<H: HashAlgorithm>(#[case] hasher: H) {
178 let mut tree = MerkleTree::new(hasher.id());
179
180 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
181
182 tree.insert(&hasher, leaves.clone());
183
184 let proof = tree.proof(&[2, 3, 4]);
185
186 assert!(proof
187 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
188 .is_ok());
189 }
190
191 #[rstest]
192 #[case::sha2(Sha256::default())]
193 #[case::blake3(Blake3::default())]
194 #[case::keccak(Keccak256::default())]
195 fn test_verify_fail_wrong_leaf<H: HashAlgorithm>(#[case] hasher: H) {
196 let mut tree = MerkleTree::new(hasher.id());
197
198 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
199
200 tree.insert(&hasher, leaves.clone());
201
202 let proof = tree.proof(&[2, 3, 4]);
203
204 let mut choices = choose_leaves([2, 3, 4], &leaves);
205
206 choices[1].1 = leaves[0];
207
208 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
210 }
211
212 #[rstest]
213 #[case::sha2(Sha256::default())]
214 #[case::blake3(Blake3::default())]
215 #[case::keccak(Keccak256::default())]
216 #[should_panic]
217 fn test_proof_fail_length_unsorted<H: HashAlgorithm>(#[case] hasher: H) {
218 let mut tree = MerkleTree::new(hasher.id());
219
220 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
221
222 tree.insert(&hasher, leaves.clone());
223
224 _ = tree.proof(&[2, 4, 3]);
225 }
226
227 #[rstest]
228 #[case::sha2(Sha256::default())]
229 #[case::blake3(Blake3::default())]
230 #[case::keccak(Keccak256::default())]
231 #[should_panic]
232 fn test_proof_fail_index_out_of_bounds<H: HashAlgorithm>(#[case] hasher: H) {
233 let mut tree = MerkleTree::new(hasher.id());
234
235 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
236
237 tree.insert(&hasher, leaves.clone());
238
239 _ = tree.proof(&[2, 3, 4, 6]);
240 }
241
242 #[rstest]
243 #[case::sha2(Sha256::default())]
244 #[case::blake3(Blake3::default())]
245 #[case::keccak(Keccak256::default())]
246 #[should_panic]
247 fn test_proof_fail_length_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
248 let mut tree = MerkleTree::new(hasher.id());
249
250 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
251
252 tree.insert(&hasher, leaves.clone());
253
254 _ = tree.proof(&[2, 2, 3]);
255 }
256
257 #[rstest]
258 #[case::sha2(Sha256::default())]
259 #[case::blake3(Blake3::default())]
260 #[case::keccak(Keccak256::default())]
261 fn test_verify_fail_duplicates<H: HashAlgorithm>(#[case] hasher: H) {
262 let mut tree = MerkleTree::new(hasher.id());
263
264 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
265
266 tree.insert(&hasher, leaves.clone());
267
268 let proof = tree.proof(&[2, 3, 4]);
269
270 assert!(proof
271 .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves))
272 .is_err());
273 }
274
275 #[rstest]
276 #[case::sha2(Sha256::default())]
277 #[case::blake3(Blake3::default())]
278 #[case::keccak(Keccak256::default())]
279 fn test_verify_fail_incorrect_leaf_count<H: HashAlgorithm>(#[case] hasher: H) {
280 let mut tree = MerkleTree::new(hasher.id());
281
282 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
283
284 tree.insert(&hasher, leaves.clone());
285
286 let mut proof1 = tree.proof(&[2, 3, 4]);
287 let mut proof2 = proof1.clone();
288
289 proof1.leaf_count += 1;
291 proof2.leaf_count -= 1;
293
294 assert!(proof1
296 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
297 .is_err());
298
299 assert!(proof2
300 .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves))
301 .is_err());
302 }
303
304 #[rstest]
305 #[case::sha2(Sha256::default())]
306 #[case::blake3(Blake3::default())]
307 #[case::keccak(Keccak256::default())]
308 fn test_verify_fail_incorrect_indices<H: HashAlgorithm>(#[case] hasher: H) {
309 let mut tree = MerkleTree::new(hasher.id());
310
311 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
312
313 tree.insert(&hasher, leaves.clone());
314
315 let proof = tree.proof(&[2, 3, 4]);
316
317 let mut choices = choose_leaves([2, 3, 4], &leaves);
318 choices[1].0 = 1;
319
320 assert!(proof.verify(&hasher, &tree.root(), choices).is_err());
322 }
323
324 #[rstest]
325 #[case::sha2(Sha256::default())]
326 #[case::blake3(Blake3::default())]
327 #[case::keccak(Keccak256::default())]
328 fn test_verify_fail_fewer_indices<H: HashAlgorithm>(#[case] hasher: H) {
329 let mut tree = MerkleTree::new(hasher.id());
330
331 let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]);
332
333 tree.insert(&hasher, leaves.clone());
334
335 let proof = tree.proof(&[2, 3, 4]);
336
337 assert!(proof
339 .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves))
340 .is_err());
341 }
342}