1use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8use crate::serialize::CanonicalSerialize;
9
10pub(crate) const DEFAULT_SUPPORTED_HASH_ALGS: &[HashAlgId] =
11 &[HashAlgId::SHA256, HashAlgId::BLAKE3, HashAlgId::KECCAK256];
12
13const MAX_LEN: usize = 64;
15
16#[derive(Debug, thiserror::Error)]
18#[error("unknown hash algorithm id: {}", self.0)]
19pub struct HashProviderError(HashAlgId);
20
21pub struct HashProvider {
23 algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
24}
25
26impl Default for HashProvider {
27 fn default() -> Self {
28 let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
29
30 algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
31 algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
32 algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
33
34 Self { algs }
35 }
36}
37
38impl HashProvider {
39 pub fn set_algorithm(
43 &mut self,
44 id: HashAlgId,
45 algorithm: Box<dyn HashAlgorithm + Send + Sync>,
46 ) {
47 self.algs.insert(id, algorithm);
48 }
49
50 pub fn get(
53 &self,
54 id: &HashAlgId,
55 ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
56 self.algs
57 .get(id)
58 .map(|alg| &**alg)
59 .ok_or(HashProviderError(*id))
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct HashAlgId(u8);
66
67impl HashAlgId {
68 pub const SHA256: Self = Self(1);
70 pub const BLAKE3: Self = Self(2);
72 pub const KECCAK256: Self = Self(3);
74
75 pub const fn new(id: u8) -> Self {
85 assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
86
87 Self(id)
88 }
89
90 pub const fn as_u8(&self) -> u8 {
92 self.0
93 }
94}
95
96impl Display for HashAlgId {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 write!(f, "{:02x}", self.0)
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
104pub struct TypedHash {
105 pub alg: HashAlgId,
107 pub value: Hash,
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
113pub struct Hash {
114 value: [u8; MAX_LEN],
117 len: usize,
118}
119
120impl Default for Hash {
121 fn default() -> Self {
122 Self {
123 value: [0u8; MAX_LEN],
124 len: 0,
125 }
126 }
127}
128
129impl Serialize for Hash {
130 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
131 where
132 S: Serializer,
133 {
134 serializer.collect_seq(&self.value[..self.len])
135 }
136}
137
138impl<'de> Deserialize<'de> for Hash {
139 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140 where
141 D: Deserializer<'de>,
142 {
143 use core::marker::PhantomData;
144 use serde::de::{Error, SeqAccess, Visitor};
145
146 struct HashVisitor<'de>(PhantomData<&'de ()>);
147
148 impl<'de> Visitor<'de> for HashVisitor<'de> {
149 type Value = Hash;
150
151 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
152 write!(formatter, "an array at most 64 bytes long")
153 }
154
155 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
156 where
157 A: SeqAccess<'de>,
158 {
159 let mut value = [0; MAX_LEN];
160 let mut len = 0;
161
162 while let Some(byte) = seq.next_element()? {
163 if len >= MAX_LEN {
164 return Err(A::Error::invalid_length(len, &self));
165 }
166
167 value[len] = byte;
168 len += 1;
169 }
170
171 Ok(Hash { value, len })
172 }
173 }
174
175 deserializer.deserialize_seq(HashVisitor(PhantomData))
176 }
177}
178
179impl Hash {
180 fn new(value: &[u8]) -> Self {
186 assert!(
187 value.len() <= MAX_LEN,
188 "hash value must be at most 64 bytes"
189 );
190
191 let mut bytes = [0; MAX_LEN];
192 bytes[..value.len()].copy_from_slice(value);
193
194 Self {
195 value: bytes,
196 len: value.len(),
197 }
198 }
199}
200
201impl rs_merkle::Hash for Hash {
202 const SIZE: usize = MAX_LEN;
203}
204
205impl TryFrom<Vec<u8>> for Hash {
206 type Error = &'static str;
207
208 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
209 if value.len() > MAX_LEN {
210 return Err("hash value must be at most 64 bytes");
211 }
212
213 let mut bytes = [0; MAX_LEN];
214 bytes[..value.len()].copy_from_slice(&value);
215
216 Ok(Self {
217 value: bytes,
218 len: value.len(),
219 })
220 }
221}
222
223impl From<Hash> for Vec<u8> {
224 fn from(value: Hash) -> Self {
225 value.value[..value.len].to_vec()
226 }
227}
228
229pub trait HashAlgorithm {
231 fn id(&self) -> HashAlgId;
233
234 fn hash(&self, data: &[u8]) -> Hash;
236
237 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
239}
240
241pub(crate) trait HashAlgorithmExt: HashAlgorithm {
242 #[allow(dead_code)]
243 fn hash_canonical<T: CanonicalSerialize>(&self, data: &T) -> Hash {
244 self.hash(&data.serialize())
245 }
246
247 fn hash_separated<T: DomainSeparator + CanonicalSerialize>(&self, data: &T) -> Hash {
248 self.hash_prefixed(data.domain(), &data.serialize())
249 }
250}
251
252impl<T: HashAlgorithm + ?Sized> HashAlgorithmExt for T {}
253
254#[derive(Clone, Serialize, Deserialize)]
256pub struct Blinder([u8; 16]);
257
258opaque_debug::implement!(Blinder);
259
260impl Blinder {
261 pub fn as_bytes(&self) -> &[u8] {
263 &self.0
264 }
265}
266
267impl Distribution<Blinder> for StandardUniform {
268 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
269 let mut blinder = [0; 16];
270 rng.fill(&mut blinder);
271 Blinder(blinder)
272 }
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub(crate) struct Blinded<T> {
278 data: T,
279 blinder: Blinder,
280}
281
282impl<T> Blinded<T> {
283 pub(crate) fn new(data: T) -> Self {
285 Self {
286 data,
287 blinder: rand::random(),
288 }
289 }
290
291 pub(crate) fn data(&self) -> &T {
292 &self.data
293 }
294}
295
296pub(crate) trait DomainSeparator {
299 fn domain(&self) -> &[u8];
301}
302
303macro_rules! impl_domain_separator {
304 ($type:ty) => {
305 impl $crate::hash::DomainSeparator for $type {
306 fn domain(&self) -> &[u8] {
307 use std::sync::LazyLock;
308
309 static DOMAIN: LazyLock<[u8; 16]> = LazyLock::new(|| {
311 let domain: [u8; 32] = blake3::hash(stringify!($type).as_bytes()).into();
312 domain[..16].try_into().unwrap()
313 });
314
315 &*DOMAIN
316 }
317 }
318 };
319}
320
321pub(crate) use impl_domain_separator;
322
323mod sha2 {
324 use ::sha2::Digest;
325
326 #[derive(Default, Clone)]
328 pub struct Sha256 {}
329
330 impl super::HashAlgorithm for Sha256 {
331 fn id(&self) -> super::HashAlgId {
332 super::HashAlgId::SHA256
333 }
334
335 fn hash(&self, data: &[u8]) -> super::Hash {
336 let mut hasher = ::sha2::Sha256::default();
337 hasher.update(data);
338 super::Hash::new(hasher.finalize().as_slice())
339 }
340
341 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
342 let mut hasher = ::sha2::Sha256::default();
343 hasher.update(prefix);
344 hasher.update(data);
345 super::Hash::new(hasher.finalize().as_slice())
346 }
347 }
348}
349
350pub use sha2::Sha256;
351
352mod blake3 {
353
354 #[derive(Default, Clone)]
356 pub struct Blake3 {}
357
358 impl super::HashAlgorithm for Blake3 {
359 fn id(&self) -> super::HashAlgId {
360 super::HashAlgId::BLAKE3
361 }
362
363 fn hash(&self, data: &[u8]) -> super::Hash {
364 super::Hash::new(::blake3::hash(data).as_bytes())
365 }
366
367 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
368 let mut hasher = ::blake3::Hasher::new();
369 hasher.update(prefix);
370 hasher.update(data);
371 super::Hash::new(hasher.finalize().as_bytes())
372 }
373 }
374}
375
376pub use blake3::Blake3;
377
378mod keccak {
379 use tiny_keccak::Hasher;
380
381 #[derive(Default, Clone)]
383 pub struct Keccak256 {}
384
385 impl super::HashAlgorithm for Keccak256 {
386 fn id(&self) -> super::HashAlgId {
387 super::HashAlgId::KECCAK256
388 }
389
390 fn hash(&self, data: &[u8]) -> super::Hash {
391 let mut hasher = tiny_keccak::Keccak::v256();
392 hasher.update(data);
393 let mut output = vec![0; 32];
394 hasher.finalize(&mut output);
395 super::Hash::new(&output)
396 }
397
398 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
399 let mut hasher = tiny_keccak::Keccak::v256();
400 hasher.update(prefix);
401 hasher.update(data);
402 let mut output = vec![0; 32];
403 hasher.finalize(&mut output);
404 super::Hash::new(&output)
405 }
406 }
407}
408
409pub use keccak::Keccak256;