1use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8const MAX_LEN: usize = 64;
10
11#[derive(Debug, thiserror::Error)]
13#[error("unknown hash algorithm id: {}", self.0)]
14pub struct HashProviderError(HashAlgId);
15
16pub struct HashProvider {
18 algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
19}
20
21impl Default for HashProvider {
22 fn default() -> Self {
23 let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
24
25 algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
26 algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
27 algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
28
29 Self { algs }
30 }
31}
32
33impl HashProvider {
34 pub fn set_algorithm(
38 &mut self,
39 id: HashAlgId,
40 algorithm: Box<dyn HashAlgorithm + Send + Sync>,
41 ) {
42 self.algs.insert(id, algorithm);
43 }
44
45 pub fn get(
48 &self,
49 id: &HashAlgId,
50 ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
51 self.algs
52 .get(id)
53 .map(|alg| &**alg)
54 .ok_or(HashProviderError(*id))
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct HashAlgId(u8);
61
62impl HashAlgId {
63 pub const SHA256: Self = Self(1);
65 pub const BLAKE3: Self = Self(2);
67 pub const KECCAK256: Self = Self(3);
69
70 pub const fn new(id: u8) -> Self {
80 assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
81
82 Self(id)
83 }
84
85 pub const fn as_u8(&self) -> u8 {
87 self.0
88 }
89}
90
91impl Display for HashAlgId {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{:02x}", self.0)
94 }
95}
96
97#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub struct TypedHash {
100 pub alg: HashAlgId,
102 pub value: Hash,
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub struct Hash {
109 value: [u8; MAX_LEN],
112 len: usize,
113}
114
115impl Default for Hash {
116 fn default() -> Self {
117 Self {
118 value: [0u8; MAX_LEN],
119 len: 0,
120 }
121 }
122}
123
124impl Serialize for Hash {
125 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126 where
127 S: Serializer,
128 {
129 serializer.collect_seq(&self.value[..self.len])
130 }
131}
132
133impl<'de> Deserialize<'de> for Hash {
134 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135 where
136 D: Deserializer<'de>,
137 {
138 use core::marker::PhantomData;
139 use serde::de::{Error, SeqAccess, Visitor};
140
141 struct HashVisitor<'de>(PhantomData<&'de ()>);
142
143 impl<'de> Visitor<'de> for HashVisitor<'de> {
144 type Value = Hash;
145
146 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
147 write!(formatter, "an array at most 64 bytes long")
148 }
149
150 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
151 where
152 A: SeqAccess<'de>,
153 {
154 let mut value = [0; MAX_LEN];
155 let mut len = 0;
156
157 while let Some(byte) = seq.next_element()? {
158 if len >= MAX_LEN {
159 return Err(A::Error::invalid_length(len, &self));
160 }
161
162 value[len] = byte;
163 len += 1;
164 }
165
166 Ok(Hash { value, len })
167 }
168 }
169
170 deserializer.deserialize_seq(HashVisitor(PhantomData))
171 }
172}
173
174impl Hash {
175 fn new(value: &[u8]) -> Self {
181 assert!(
182 value.len() <= MAX_LEN,
183 "hash value must be at most 64 bytes"
184 );
185
186 let mut bytes = [0; MAX_LEN];
187 bytes[..value.len()].copy_from_slice(value);
188
189 Self {
190 value: bytes,
191 len: value.len(),
192 }
193 }
194
195 pub fn as_bytes(&self) -> &[u8] {
197 &self.value[..self.len]
198 }
199}
200
201impl rs_merkle::Hash for Hash {
202 const SIZE: usize = MAX_LEN;
203}
204
205impl TryFrom<Vec<u8>> for Hash {
206 type Error = &'static str;
207
208 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
209 if value.len() > MAX_LEN {
210 return Err("hash value must be at most 64 bytes");
211 }
212
213 let mut bytes = [0; MAX_LEN];
214 bytes[..value.len()].copy_from_slice(&value);
215
216 Ok(Self {
217 value: bytes,
218 len: value.len(),
219 })
220 }
221}
222
223impl From<Hash> for Vec<u8> {
224 fn from(value: Hash) -> Self {
225 value.value[..value.len].to_vec()
226 }
227}
228
229pub trait HashAlgorithm {
231 fn id(&self) -> HashAlgId;
233
234 fn hash(&self, data: &[u8]) -> Hash;
236
237 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
239}
240
241#[derive(Clone, Serialize, Deserialize)]
243pub struct Blinder([u8; 16]);
244
245opaque_debug::implement!(Blinder);
246
247impl Blinder {
248 pub fn as_bytes(&self) -> &[u8] {
250 &self.0
251 }
252}
253
254impl Distribution<Blinder> for StandardUniform {
255 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
256 let mut blinder = [0; 16];
257 rng.fill(&mut blinder);
258 Blinder(blinder)
259 }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct Blinded<T> {
265 data: T,
266 blinder: Blinder,
267}
268
269impl<T> Blinded<T> {
270 pub fn new(data: T) -> Self {
272 Self {
273 data,
274 blinder: rand::random(),
275 }
276 }
277
278 pub fn data(&self) -> &T {
280 &self.data
281 }
282}
283
284mod sha2 {
285 use ::sha2::Digest;
286
287 #[derive(Default, Clone)]
289 pub struct Sha256 {}
290
291 impl super::HashAlgorithm for Sha256 {
292 fn id(&self) -> super::HashAlgId {
293 super::HashAlgId::SHA256
294 }
295
296 fn hash(&self, data: &[u8]) -> super::Hash {
297 let mut hasher = ::sha2::Sha256::default();
298 hasher.update(data);
299 super::Hash::new(hasher.finalize().as_slice())
300 }
301
302 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
303 let mut hasher = ::sha2::Sha256::default();
304 hasher.update(prefix);
305 hasher.update(data);
306 super::Hash::new(hasher.finalize().as_slice())
307 }
308 }
309}
310
311pub use sha2::Sha256;
312
313mod blake3 {
314
315 #[derive(Default, Clone)]
317 pub struct Blake3 {}
318
319 impl super::HashAlgorithm for Blake3 {
320 fn id(&self) -> super::HashAlgId {
321 super::HashAlgId::BLAKE3
322 }
323
324 fn hash(&self, data: &[u8]) -> super::Hash {
325 super::Hash::new(::blake3::hash(data).as_bytes())
326 }
327
328 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
329 let mut hasher = ::blake3::Hasher::new();
330 hasher.update(prefix);
331 hasher.update(data);
332 super::Hash::new(hasher.finalize().as_bytes())
333 }
334 }
335}
336
337pub use blake3::Blake3;
338
339mod keccak {
340 use tiny_keccak::Hasher;
341
342 #[derive(Default, Clone)]
344 pub struct Keccak256 {}
345
346 impl super::HashAlgorithm for Keccak256 {
347 fn id(&self) -> super::HashAlgId {
348 super::HashAlgId::KECCAK256
349 }
350
351 fn hash(&self, data: &[u8]) -> super::Hash {
352 let mut hasher = tiny_keccak::Keccak::v256();
353 hasher.update(data);
354 let mut output = vec![0; 32];
355 hasher.finalize(&mut output);
356 super::Hash::new(&output)
357 }
358
359 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
360 let mut hasher = tiny_keccak::Keccak::v256();
361 hasher.update(prefix);
362 hasher.update(data);
363 let mut output = vec![0; 32];
364 hasher.finalize(&mut output);
365 super::Hash::new(&output)
366 }
367 }
368}
369
370pub use keccak::Keccak256;