1use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8use crate::serialize::CanonicalSerialize;
9
10pub(crate) const DEFAULT_SUPPORTED_HASH_ALGS: &[HashAlgId] =
11 &[HashAlgId::SHA256, HashAlgId::BLAKE3, HashAlgId::KECCAK256];
12
13const MAX_LEN: usize = 64;
15
16#[derive(Debug, thiserror::Error)]
18#[error("unknown hash algorithm id: {}", self.0)]
19pub struct HashProviderError(HashAlgId);
20
21pub struct HashProvider {
23 algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
24}
25
26impl Default for HashProvider {
27 fn default() -> Self {
28 let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
29
30 algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
31 algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
32 algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
33
34 Self { algs }
35 }
36}
37
38impl HashProvider {
39 pub fn set_algorithm(
43 &mut self,
44 id: HashAlgId,
45 algorithm: Box<dyn HashAlgorithm + Send + Sync>,
46 ) {
47 self.algs.insert(id, algorithm);
48 }
49
50 pub fn get(
53 &self,
54 id: &HashAlgId,
55 ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
56 self.algs
57 .get(id)
58 .map(|alg| &**alg)
59 .ok_or(HashProviderError(*id))
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct HashAlgId(u8);
66
67impl HashAlgId {
68 pub const SHA256: Self = Self(1);
70 pub const BLAKE3: Self = Self(2);
72 pub const KECCAK256: Self = Self(3);
74
75 pub const fn new(id: u8) -> Self {
85 assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
86
87 Self(id)
88 }
89
90 pub const fn as_u8(&self) -> u8 {
92 self.0
93 }
94}
95
96impl Display for HashAlgId {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 write!(f, "{:02x}", self.0)
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104pub struct TypedHash {
105 pub alg: HashAlgId,
107 pub value: Hash,
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub struct Hash {
114 value: [u8; MAX_LEN],
117 len: usize,
118}
119
120impl Default for Hash {
121 fn default() -> Self {
122 Self {
123 value: [0u8; MAX_LEN],
124 len: 0,
125 }
126 }
127}
128
129impl Serialize for Hash {
130 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
131 where
132 S: Serializer,
133 {
134 serializer.collect_seq(&self.value[..self.len])
135 }
136}
137
138impl<'de> Deserialize<'de> for Hash {
139 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140 where
141 D: Deserializer<'de>,
142 {
143 use core::marker::PhantomData;
144 use serde::de::{Error, SeqAccess, Visitor};
145
146 struct HashVisitor<'de>(PhantomData<&'de ()>);
147
148 impl<'de> Visitor<'de> for HashVisitor<'de> {
149 type Value = Hash;
150
151 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
152 write!(formatter, "an array at most 64 bytes long")
153 }
154
155 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
156 where
157 A: SeqAccess<'de>,
158 {
159 let mut value = [0; MAX_LEN];
160 let mut len = 0;
161
162 while let Some(byte) = seq.next_element()? {
163 if len >= MAX_LEN {
164 return Err(A::Error::invalid_length(len, &self));
165 }
166
167 value[len] = byte;
168 len += 1;
169 }
170
171 Ok(Hash { value, len })
172 }
173 }
174
175 deserializer.deserialize_seq(HashVisitor(PhantomData))
176 }
177}
178
179impl Hash {
180 fn new(value: &[u8]) -> Self {
186 assert!(
187 value.len() <= MAX_LEN,
188 "hash value must be at most 64 bytes"
189 );
190
191 let mut bytes = [0; MAX_LEN];
192 bytes[..value.len()].copy_from_slice(value);
193
194 Self {
195 value: bytes,
196 len: value.len(),
197 }
198 }
199}
200
201impl rs_merkle::Hash for Hash {
202 const SIZE: usize = MAX_LEN;
203}
204
205impl TryFrom<Vec<u8>> for Hash {
206 type Error = &'static str;
207
208 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
209 if value.len() > MAX_LEN {
210 return Err("hash value must be at most 64 bytes");
211 }
212
213 let mut bytes = [0; MAX_LEN];
214 bytes[..value.len()].copy_from_slice(&value);
215
216 Ok(Self {
217 value: bytes,
218 len: value.len(),
219 })
220 }
221}
222
223impl From<Hash> for Vec<u8> {
224 fn from(value: Hash) -> Self {
225 value.value[..value.len].to_vec()
226 }
227}
228
229pub trait HashAlgorithm {
231 fn id(&self) -> HashAlgId;
233
234 fn hash(&self, data: &[u8]) -> Hash;
236
237 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
239}
240
241pub(crate) trait HashAlgorithmExt: HashAlgorithm {
242 #[allow(dead_code)]
243 fn hash_canonical<T: CanonicalSerialize>(&self, data: &T) -> Hash {
244 self.hash(&data.serialize())
245 }
246
247 fn hash_separated<T: DomainSeparator + CanonicalSerialize>(&self, data: &T) -> Hash {
248 self.hash_prefixed(data.domain(), &data.serialize())
249 }
250}
251
252impl<T: HashAlgorithm + ?Sized> HashAlgorithmExt for T {}
253
254#[derive(Clone, Serialize, Deserialize)]
256pub(crate) struct Blinder([u8; 16]);
257
258opaque_debug::implement!(Blinder);
259
260impl Blinder {
261 pub(crate) fn as_bytes(&self) -> &[u8] {
262 &self.0
263 }
264}
265
266impl Distribution<Blinder> for StandardUniform {
267 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
268 let mut blinder = [0; 16];
269 rng.fill(&mut blinder);
270 Blinder(blinder)
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
276pub(crate) struct Blinded<T> {
277 data: T,
278 blinder: Blinder,
279}
280
281impl<T> Blinded<T> {
282 pub(crate) fn new(data: T) -> Self {
284 Self {
285 data,
286 blinder: rand::random(),
287 }
288 }
289
290 pub(crate) fn data(&self) -> &T {
291 &self.data
292 }
293}
294
295pub(crate) trait DomainSeparator {
298 fn domain(&self) -> &[u8];
300}
301
302macro_rules! impl_domain_separator {
303 ($type:ty) => {
304 impl $crate::hash::DomainSeparator for $type {
305 fn domain(&self) -> &[u8] {
306 use std::sync::LazyLock;
307
308 static DOMAIN: LazyLock<[u8; 16]> = LazyLock::new(|| {
310 let domain: [u8; 32] = blake3::hash(stringify!($type).as_bytes()).into();
311 domain[..16].try_into().unwrap()
312 });
313
314 &*DOMAIN
315 }
316 }
317 };
318}
319
320pub(crate) use impl_domain_separator;
321
322mod sha2 {
323 use ::sha2::Digest;
324
325 #[derive(Default, Clone)]
327 pub struct Sha256 {}
328
329 impl super::HashAlgorithm for Sha256 {
330 fn id(&self) -> super::HashAlgId {
331 super::HashAlgId::SHA256
332 }
333
334 fn hash(&self, data: &[u8]) -> super::Hash {
335 let mut hasher = ::sha2::Sha256::default();
336 hasher.update(data);
337 super::Hash::new(hasher.finalize().as_slice())
338 }
339
340 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
341 let mut hasher = ::sha2::Sha256::default();
342 hasher.update(prefix);
343 hasher.update(data);
344 super::Hash::new(hasher.finalize().as_slice())
345 }
346 }
347}
348
349pub use sha2::Sha256;
350
351mod blake3 {
352
353 #[derive(Default, Clone)]
355 pub struct Blake3 {}
356
357 impl super::HashAlgorithm for Blake3 {
358 fn id(&self) -> super::HashAlgId {
359 super::HashAlgId::BLAKE3
360 }
361
362 fn hash(&self, data: &[u8]) -> super::Hash {
363 super::Hash::new(::blake3::hash(data).as_bytes())
364 }
365
366 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
367 let mut hasher = ::blake3::Hasher::new();
368 hasher.update(prefix);
369 hasher.update(data);
370 super::Hash::new(hasher.finalize().as_bytes())
371 }
372 }
373}
374
375pub use blake3::Blake3;
376
377mod keccak {
378 use tiny_keccak::Hasher;
379
380 #[derive(Default, Clone)]
382 pub struct Keccak256 {}
383
384 impl super::HashAlgorithm for Keccak256 {
385 fn id(&self) -> super::HashAlgId {
386 super::HashAlgId::KECCAK256
387 }
388
389 fn hash(&self, data: &[u8]) -> super::Hash {
390 let mut hasher = tiny_keccak::Keccak::v256();
391 hasher.update(data);
392 let mut output = vec![0; 32];
393 hasher.finalize(&mut output);
394 super::Hash::new(&output)
395 }
396
397 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
398 let mut hasher = tiny_keccak::Keccak::v256();
399 hasher.update(prefix);
400 hasher.update(data);
401 let mut output = vec![0; 32];
402 hasher.finalize(&mut output);
403 super::Hash::new(&output)
404 }
405 }
406}
407
408pub use keccak::Keccak256;