1use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8const MAX_LEN: usize = 64;
10
11#[derive(Debug, thiserror::Error)]
13#[error("unknown hash algorithm id: {}", self.0)]
14pub struct HashProviderError(HashAlgId);
15
16pub struct HashProvider {
18 algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
19}
20
21impl Default for HashProvider {
22 fn default() -> Self {
23 let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
24
25 algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
26 algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
27 algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
28
29 Self { algs }
30 }
31}
32
33impl HashProvider {
34 pub fn set_algorithm(
38 &mut self,
39 id: HashAlgId,
40 algorithm: Box<dyn HashAlgorithm + Send + Sync>,
41 ) {
42 self.algs.insert(id, algorithm);
43 }
44
45 pub fn get(
48 &self,
49 id: &HashAlgId,
50 ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
51 self.algs
52 .get(id)
53 .map(|alg| &**alg)
54 .ok_or(HashProviderError(*id))
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct HashAlgId(u8);
61
62impl HashAlgId {
63 pub const SHA256: Self = Self(1);
65 pub const BLAKE3: Self = Self(2);
67 pub const KECCAK256: Self = Self(3);
69
70 pub const fn new(id: u8) -> Self {
80 assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
81
82 Self(id)
83 }
84
85 pub const fn as_u8(&self) -> u8 {
87 self.0
88 }
89}
90
91impl Display for HashAlgId {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{:02x}", self.0)
94 }
95}
96
97#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub struct TypedHash {
100 pub alg: HashAlgId,
102 pub value: Hash,
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub struct Hash {
109 value: [u8; MAX_LEN],
112 len: usize,
113}
114
115impl Default for Hash {
116 fn default() -> Self {
117 Self {
118 value: [0u8; MAX_LEN],
119 len: 0,
120 }
121 }
122}
123
124impl Serialize for Hash {
125 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126 where
127 S: Serializer,
128 {
129 serializer.collect_seq(&self.value[..self.len])
130 }
131}
132
133impl<'de> Deserialize<'de> for Hash {
134 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135 where
136 D: Deserializer<'de>,
137 {
138 use core::marker::PhantomData;
139 use serde::de::{Error, SeqAccess, Visitor};
140
141 struct HashVisitor<'de>(PhantomData<&'de ()>);
142
143 impl<'de> Visitor<'de> for HashVisitor<'de> {
144 type Value = Hash;
145
146 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
147 write!(formatter, "an array at most 64 bytes long")
148 }
149
150 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
151 where
152 A: SeqAccess<'de>,
153 {
154 let mut value = [0; MAX_LEN];
155 let mut len = 0;
156
157 while let Some(byte) = seq.next_element()? {
158 if len >= MAX_LEN {
159 return Err(A::Error::invalid_length(len, &self));
160 }
161
162 value[len] = byte;
163 len += 1;
164 }
165
166 Ok(Hash { value, len })
167 }
168 }
169
170 deserializer.deserialize_seq(HashVisitor(PhantomData))
171 }
172}
173
174impl Hash {
175 fn new(value: &[u8]) -> Self {
181 assert!(
182 value.len() <= MAX_LEN,
183 "hash value must be at most 64 bytes"
184 );
185
186 let mut bytes = [0; MAX_LEN];
187 bytes[..value.len()].copy_from_slice(value);
188
189 Self {
190 value: bytes,
191 len: value.len(),
192 }
193 }
194}
195
196impl rs_merkle::Hash for Hash {
197 const SIZE: usize = MAX_LEN;
198}
199
200impl TryFrom<Vec<u8>> for Hash {
201 type Error = &'static str;
202
203 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
204 if value.len() > MAX_LEN {
205 return Err("hash value must be at most 64 bytes");
206 }
207
208 let mut bytes = [0; MAX_LEN];
209 bytes[..value.len()].copy_from_slice(&value);
210
211 Ok(Self {
212 value: bytes,
213 len: value.len(),
214 })
215 }
216}
217
218impl From<Hash> for Vec<u8> {
219 fn from(value: Hash) -> Self {
220 value.value[..value.len].to_vec()
221 }
222}
223
224pub trait HashAlgorithm {
226 fn id(&self) -> HashAlgId;
228
229 fn hash(&self, data: &[u8]) -> Hash;
231
232 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
234}
235
236#[derive(Clone, Serialize, Deserialize)]
238pub struct Blinder([u8; 16]);
239
240opaque_debug::implement!(Blinder);
241
242impl Blinder {
243 pub fn as_bytes(&self) -> &[u8] {
245 &self.0
246 }
247}
248
249impl Distribution<Blinder> for StandardUniform {
250 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
251 let mut blinder = [0; 16];
252 rng.fill(&mut blinder);
253 Blinder(blinder)
254 }
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct Blinded<T> {
260 data: T,
261 blinder: Blinder,
262}
263
264impl<T> Blinded<T> {
265 pub fn new(data: T) -> Self {
267 Self {
268 data,
269 blinder: rand::random(),
270 }
271 }
272
273 pub fn data(&self) -> &T {
275 &self.data
276 }
277}
278
279mod sha2 {
280 use ::sha2::Digest;
281
282 #[derive(Default, Clone)]
284 pub struct Sha256 {}
285
286 impl super::HashAlgorithm for Sha256 {
287 fn id(&self) -> super::HashAlgId {
288 super::HashAlgId::SHA256
289 }
290
291 fn hash(&self, data: &[u8]) -> super::Hash {
292 let mut hasher = ::sha2::Sha256::default();
293 hasher.update(data);
294 super::Hash::new(hasher.finalize().as_slice())
295 }
296
297 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
298 let mut hasher = ::sha2::Sha256::default();
299 hasher.update(prefix);
300 hasher.update(data);
301 super::Hash::new(hasher.finalize().as_slice())
302 }
303 }
304}
305
306pub use sha2::Sha256;
307
308mod blake3 {
309
310 #[derive(Default, Clone)]
312 pub struct Blake3 {}
313
314 impl super::HashAlgorithm for Blake3 {
315 fn id(&self) -> super::HashAlgId {
316 super::HashAlgId::BLAKE3
317 }
318
319 fn hash(&self, data: &[u8]) -> super::Hash {
320 super::Hash::new(::blake3::hash(data).as_bytes())
321 }
322
323 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
324 let mut hasher = ::blake3::Hasher::new();
325 hasher.update(prefix);
326 hasher.update(data);
327 super::Hash::new(hasher.finalize().as_bytes())
328 }
329 }
330}
331
332pub use blake3::Blake3;
333
334mod keccak {
335 use tiny_keccak::Hasher;
336
337 #[derive(Default, Clone)]
339 pub struct Keccak256 {}
340
341 impl super::HashAlgorithm for Keccak256 {
342 fn id(&self) -> super::HashAlgId {
343 super::HashAlgId::KECCAK256
344 }
345
346 fn hash(&self, data: &[u8]) -> super::Hash {
347 let mut hasher = tiny_keccak::Keccak::v256();
348 hasher.update(data);
349 let mut output = vec![0; 32];
350 hasher.finalize(&mut output);
351 super::Hash::new(&output)
352 }
353
354 fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
355 let mut hasher = tiny_keccak::Keccak::v256();
356 hasher.update(prefix);
357 hasher.update(data);
358 let mut output = vec![0; 32];
359 hasher.finalize(&mut output);
360 super::Hash::new(&output)
361 }
362 }
363}
364
365pub use keccak::Keccak256;