tlsn_core/
hash.rs

1//! Hash types.
2
3use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8/// Maximum length of a hash value.
9const MAX_LEN: usize = 64;
10
11/// An error for [`HashProvider`].
12#[derive(Debug, thiserror::Error)]
13#[error("unknown hash algorithm id: {}", self.0)]
14pub struct HashProviderError(HashAlgId);
15
16/// Hash provider.
17pub struct HashProvider {
18    algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
19}
20
21impl Default for HashProvider {
22    fn default() -> Self {
23        let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
24
25        algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
26        algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
27        algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
28
29        Self { algs }
30    }
31}
32
33impl HashProvider {
34    /// Sets a hash algorithm.
35    ///
36    /// This can be used to add or override implementations of hash algorithms.
37    pub fn set_algorithm(
38        &mut self,
39        id: HashAlgId,
40        algorithm: Box<dyn HashAlgorithm + Send + Sync>,
41    ) {
42        self.algs.insert(id, algorithm);
43    }
44
45    /// Returns the hash algorithm with the given identifier, or an error if the
46    /// hash algorithm does not exist.
47    pub fn get(
48        &self,
49        id: &HashAlgId,
50    ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
51        self.algs
52            .get(id)
53            .map(|alg| &**alg)
54            .ok_or(HashProviderError(*id))
55    }
56}
57
58/// A hash algorithm identifier.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct HashAlgId(u8);
61
62impl HashAlgId {
63    /// SHA-256 hash algorithm.
64    pub const SHA256: Self = Self(1);
65    /// BLAKE3 hash algorithm.
66    pub const BLAKE3: Self = Self(2);
67    /// Keccak-256 hash algorithm.
68    pub const KECCAK256: Self = Self(3);
69
70    /// Creates a new hash algorithm identifier.
71    ///
72    /// # Panics
73    ///
74    /// Panics if the identifier is in the reserved range 0-127.
75    ///
76    /// # Arguments
77    ///
78    /// * id - Unique identifier for the hash algorithm.
79    pub const fn new(id: u8) -> Self {
80        assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
81
82        Self(id)
83    }
84
85    /// Returns the id as a `u8`.
86    pub const fn as_u8(&self) -> u8 {
87        self.0
88    }
89}
90
91impl Display for HashAlgId {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "{:02x}", self.0)
94    }
95}
96
97/// A typed hash value.
98#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub struct TypedHash {
100    /// The algorithm of the hash.
101    pub alg: HashAlgId,
102    /// The hash value.
103    pub value: Hash,
104}
105
106/// A hash value.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub struct Hash {
109    // To avoid heap allocation, we use a fixed-size array.
110    // 64 bytes should be sufficient for most hash algorithms.
111    value: [u8; MAX_LEN],
112    len: usize,
113}
114
115impl Default for Hash {
116    fn default() -> Self {
117        Self {
118            value: [0u8; MAX_LEN],
119            len: 0,
120        }
121    }
122}
123
124impl Serialize for Hash {
125    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126    where
127        S: Serializer,
128    {
129        serializer.collect_seq(&self.value[..self.len])
130    }
131}
132
133impl<'de> Deserialize<'de> for Hash {
134    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135    where
136        D: Deserializer<'de>,
137    {
138        use core::marker::PhantomData;
139        use serde::de::{Error, SeqAccess, Visitor};
140
141        struct HashVisitor<'de>(PhantomData<&'de ()>);
142
143        impl<'de> Visitor<'de> for HashVisitor<'de> {
144            type Value = Hash;
145
146            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
147                write!(formatter, "an array at most 64 bytes long")
148            }
149
150            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
151            where
152                A: SeqAccess<'de>,
153            {
154                let mut value = [0; MAX_LEN];
155                let mut len = 0;
156
157                while let Some(byte) = seq.next_element()? {
158                    if len >= MAX_LEN {
159                        return Err(A::Error::invalid_length(len, &self));
160                    }
161
162                    value[len] = byte;
163                    len += 1;
164                }
165
166                Ok(Hash { value, len })
167            }
168        }
169
170        deserializer.deserialize_seq(HashVisitor(PhantomData))
171    }
172}
173
174impl Hash {
175    /// Creates a new hash value.
176    ///
177    /// # Panics
178    ///
179    /// Panics if the length of the value is greater than 64 bytes.
180    fn new(value: &[u8]) -> Self {
181        assert!(
182            value.len() <= MAX_LEN,
183            "hash value must be at most 64 bytes"
184        );
185
186        let mut bytes = [0; MAX_LEN];
187        bytes[..value.len()].copy_from_slice(value);
188
189        Self {
190            value: bytes,
191            len: value.len(),
192        }
193    }
194
195    /// Returns a byte slice of the hash value.
196    pub fn as_bytes(&self) -> &[u8] {
197        &self.value[..self.len]
198    }
199}
200
201impl rs_merkle::Hash for Hash {
202    const SIZE: usize = MAX_LEN;
203}
204
205impl TryFrom<Vec<u8>> for Hash {
206    type Error = &'static str;
207
208    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
209        if value.len() > MAX_LEN {
210            return Err("hash value must be at most 64 bytes");
211        }
212
213        let mut bytes = [0; MAX_LEN];
214        bytes[..value.len()].copy_from_slice(&value);
215
216        Ok(Self {
217            value: bytes,
218            len: value.len(),
219        })
220    }
221}
222
223impl From<Hash> for Vec<u8> {
224    fn from(value: Hash) -> Self {
225        value.value[..value.len].to_vec()
226    }
227}
228
229/// A hashing algorithm.
230pub trait HashAlgorithm {
231    /// Returns the hash algorithm identifier.
232    fn id(&self) -> HashAlgId;
233
234    /// Computes the hash of the provided data.
235    fn hash(&self, data: &[u8]) -> Hash;
236
237    /// Computes the hash of the provided data with a prefix.
238    fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
239}
240
241/// A hash blinder.
242#[derive(Clone, Serialize, Deserialize)]
243pub struct Blinder([u8; 16]);
244
245opaque_debug::implement!(Blinder);
246
247impl Blinder {
248    /// Returns the blinder as a byte slice.
249    pub fn as_bytes(&self) -> &[u8] {
250        &self.0
251    }
252}
253
254impl Distribution<Blinder> for StandardUniform {
255    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
256        let mut blinder = [0; 16];
257        rng.fill(&mut blinder);
258        Blinder(blinder)
259    }
260}
261
262/// A blinded pre-image of a hash.
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct Blinded<T> {
265    data: T,
266    blinder: Blinder,
267}
268
269impl<T> Blinded<T> {
270    /// Creates a new blinded pre-image.
271    pub fn new(data: T) -> Self {
272        Self {
273            data,
274            blinder: rand::random(),
275        }
276    }
277
278    /// Returns the data.
279    pub fn data(&self) -> &T {
280        &self.data
281    }
282}
283
284mod sha2 {
285    use ::sha2::Digest;
286
287    /// SHA-256 hash algorithm.
288    #[derive(Default, Clone)]
289    pub struct Sha256 {}
290
291    impl super::HashAlgorithm for Sha256 {
292        fn id(&self) -> super::HashAlgId {
293            super::HashAlgId::SHA256
294        }
295
296        fn hash(&self, data: &[u8]) -> super::Hash {
297            let mut hasher = ::sha2::Sha256::default();
298            hasher.update(data);
299            super::Hash::new(hasher.finalize().as_slice())
300        }
301
302        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
303            let mut hasher = ::sha2::Sha256::default();
304            hasher.update(prefix);
305            hasher.update(data);
306            super::Hash::new(hasher.finalize().as_slice())
307        }
308    }
309}
310
311pub use sha2::Sha256;
312
313mod blake3 {
314
315    /// BLAKE3 hash algorithm.
316    #[derive(Default, Clone)]
317    pub struct Blake3 {}
318
319    impl super::HashAlgorithm for Blake3 {
320        fn id(&self) -> super::HashAlgId {
321            super::HashAlgId::BLAKE3
322        }
323
324        fn hash(&self, data: &[u8]) -> super::Hash {
325            super::Hash::new(::blake3::hash(data).as_bytes())
326        }
327
328        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
329            let mut hasher = ::blake3::Hasher::new();
330            hasher.update(prefix);
331            hasher.update(data);
332            super::Hash::new(hasher.finalize().as_bytes())
333        }
334    }
335}
336
337pub use blake3::Blake3;
338
339mod keccak {
340    use tiny_keccak::Hasher;
341
342    /// Keccak-256 hash algorithm.
343    #[derive(Default, Clone)]
344    pub struct Keccak256 {}
345
346    impl super::HashAlgorithm for Keccak256 {
347        fn id(&self) -> super::HashAlgId {
348            super::HashAlgId::KECCAK256
349        }
350
351        fn hash(&self, data: &[u8]) -> super::Hash {
352            let mut hasher = tiny_keccak::Keccak::v256();
353            hasher.update(data);
354            let mut output = vec![0; 32];
355            hasher.finalize(&mut output);
356            super::Hash::new(&output)
357        }
358
359        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
360            let mut hasher = tiny_keccak::Keccak::v256();
361            hasher.update(prefix);
362            hasher.update(data);
363            let mut output = vec![0; 32];
364            hasher.finalize(&mut output);
365            super::Hash::new(&output)
366        }
367    }
368}
369
370pub use keccak::Keccak256;