tlsn_core/
hash.rs

1//! Hash types.
2
3use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8use crate::serialize::CanonicalSerialize;
9
10pub(crate) const DEFAULT_SUPPORTED_HASH_ALGS: &[HashAlgId] =
11    &[HashAlgId::SHA256, HashAlgId::BLAKE3, HashAlgId::KECCAK256];
12
13/// Maximum length of a hash value.
14const MAX_LEN: usize = 64;
15
16/// An error for [`HashProvider`].
17#[derive(Debug, thiserror::Error)]
18#[error("unknown hash algorithm id: {}", self.0)]
19pub struct HashProviderError(HashAlgId);
20
21/// Hash provider.
22pub struct HashProvider {
23    algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
24}
25
26impl Default for HashProvider {
27    fn default() -> Self {
28        let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
29
30        algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
31        algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
32        algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
33
34        Self { algs }
35    }
36}
37
38impl HashProvider {
39    /// Sets a hash algorithm.
40    ///
41    /// This can be used to add or override implementations of hash algorithms.
42    pub fn set_algorithm(
43        &mut self,
44        id: HashAlgId,
45        algorithm: Box<dyn HashAlgorithm + Send + Sync>,
46    ) {
47        self.algs.insert(id, algorithm);
48    }
49
50    /// Returns the hash algorithm with the given identifier, or an error if the
51    /// hash algorithm does not exist.
52    pub fn get(
53        &self,
54        id: &HashAlgId,
55    ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
56        self.algs
57            .get(id)
58            .map(|alg| &**alg)
59            .ok_or(HashProviderError(*id))
60    }
61}
62
63/// A hash algorithm identifier.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct HashAlgId(u8);
66
67impl HashAlgId {
68    /// SHA-256 hash algorithm.
69    pub const SHA256: Self = Self(1);
70    /// BLAKE3 hash algorithm.
71    pub const BLAKE3: Self = Self(2);
72    /// Keccak-256 hash algorithm.
73    pub const KECCAK256: Self = Self(3);
74
75    /// Creates a new hash algorithm identifier.
76    ///
77    /// # Panics
78    ///
79    /// Panics if the identifier is in the reserved range 0-127.
80    ///
81    /// # Arguments
82    ///
83    /// * id - Unique identifier for the hash algorithm.
84    pub const fn new(id: u8) -> Self {
85        assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
86
87        Self(id)
88    }
89
90    /// Returns the id as a `u8`.
91    pub const fn as_u8(&self) -> u8 {
92        self.0
93    }
94}
95
96impl Display for HashAlgId {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{:02x}", self.0)
99    }
100}
101
102/// A typed hash value.
103#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
104pub struct TypedHash {
105    /// The algorithm of the hash.
106    pub alg: HashAlgId,
107    /// The hash value.
108    pub value: Hash,
109}
110
111/// A hash value.
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
113pub struct Hash {
114    // To avoid heap allocation, we use a fixed-size array.
115    // 64 bytes should be sufficient for most hash algorithms.
116    value: [u8; MAX_LEN],
117    len: usize,
118}
119
120impl Default for Hash {
121    fn default() -> Self {
122        Self {
123            value: [0u8; MAX_LEN],
124            len: 0,
125        }
126    }
127}
128
129impl Serialize for Hash {
130    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
131    where
132        S: Serializer,
133    {
134        serializer.collect_seq(&self.value[..self.len])
135    }
136}
137
138impl<'de> Deserialize<'de> for Hash {
139    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140    where
141        D: Deserializer<'de>,
142    {
143        use core::marker::PhantomData;
144        use serde::de::{Error, SeqAccess, Visitor};
145
146        struct HashVisitor<'de>(PhantomData<&'de ()>);
147
148        impl<'de> Visitor<'de> for HashVisitor<'de> {
149            type Value = Hash;
150
151            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
152                write!(formatter, "an array at most 64 bytes long")
153            }
154
155            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
156            where
157                A: SeqAccess<'de>,
158            {
159                let mut value = [0; MAX_LEN];
160                let mut len = 0;
161
162                while let Some(byte) = seq.next_element()? {
163                    if len >= MAX_LEN {
164                        return Err(A::Error::invalid_length(len, &self));
165                    }
166
167                    value[len] = byte;
168                    len += 1;
169                }
170
171                Ok(Hash { value, len })
172            }
173        }
174
175        deserializer.deserialize_seq(HashVisitor(PhantomData))
176    }
177}
178
179impl Hash {
180    /// Creates a new hash value.
181    ///
182    /// # Panics
183    ///
184    /// Panics if the length of the value is greater than 64 bytes.
185    fn new(value: &[u8]) -> Self {
186        assert!(
187            value.len() <= MAX_LEN,
188            "hash value must be at most 64 bytes"
189        );
190
191        let mut bytes = [0; MAX_LEN];
192        bytes[..value.len()].copy_from_slice(value);
193
194        Self {
195            value: bytes,
196            len: value.len(),
197        }
198    }
199}
200
201impl rs_merkle::Hash for Hash {
202    const SIZE: usize = MAX_LEN;
203}
204
205impl TryFrom<Vec<u8>> for Hash {
206    type Error = &'static str;
207
208    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
209        if value.len() > MAX_LEN {
210            return Err("hash value must be at most 64 bytes");
211        }
212
213        let mut bytes = [0; MAX_LEN];
214        bytes[..value.len()].copy_from_slice(&value);
215
216        Ok(Self {
217            value: bytes,
218            len: value.len(),
219        })
220    }
221}
222
223impl From<Hash> for Vec<u8> {
224    fn from(value: Hash) -> Self {
225        value.value[..value.len].to_vec()
226    }
227}
228
229/// A hashing algorithm.
230pub trait HashAlgorithm {
231    /// Returns the hash algorithm identifier.
232    fn id(&self) -> HashAlgId;
233
234    /// Computes the hash of the provided data.
235    fn hash(&self, data: &[u8]) -> Hash;
236
237    /// Computes the hash of the provided data with a prefix.
238    fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
239}
240
241pub(crate) trait HashAlgorithmExt: HashAlgorithm {
242    #[allow(dead_code)]
243    fn hash_canonical<T: CanonicalSerialize>(&self, data: &T) -> Hash {
244        self.hash(&data.serialize())
245    }
246
247    fn hash_separated<T: DomainSeparator + CanonicalSerialize>(&self, data: &T) -> Hash {
248        self.hash_prefixed(data.domain(), &data.serialize())
249    }
250}
251
252impl<T: HashAlgorithm + ?Sized> HashAlgorithmExt for T {}
253
254/// A hash blinder.
255#[derive(Clone, Serialize, Deserialize)]
256pub struct Blinder([u8; 16]);
257
258opaque_debug::implement!(Blinder);
259
260impl Blinder {
261    /// Returns the blinder as a byte slice.
262    pub fn as_bytes(&self) -> &[u8] {
263        &self.0
264    }
265}
266
267impl Distribution<Blinder> for StandardUniform {
268    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
269        let mut blinder = [0; 16];
270        rng.fill(&mut blinder);
271        Blinder(blinder)
272    }
273}
274
275/// A blinded pre-image of a hash.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub(crate) struct Blinded<T> {
278    data: T,
279    blinder: Blinder,
280}
281
282impl<T> Blinded<T> {
283    /// Creates a new blinded pre-image.
284    pub(crate) fn new(data: T) -> Self {
285        Self {
286            data,
287            blinder: rand::random(),
288        }
289    }
290
291    pub(crate) fn data(&self) -> &T {
292        &self.data
293    }
294}
295
296/// A type with a domain separator which is used during hashing to mitigate type
297/// confusion attacks.
298pub(crate) trait DomainSeparator {
299    /// Returns the domain separator for the type.
300    fn domain(&self) -> &[u8];
301}
302
303macro_rules! impl_domain_separator {
304    ($type:ty) => {
305        impl $crate::hash::DomainSeparator for $type {
306            fn domain(&self) -> &[u8] {
307                use std::sync::LazyLock;
308
309                // Computes a 16 byte hash of the type's name to use as a domain separator.
310                static DOMAIN: LazyLock<[u8; 16]> = LazyLock::new(|| {
311                    let domain: [u8; 32] = blake3::hash(stringify!($type).as_bytes()).into();
312                    domain[..16].try_into().unwrap()
313                });
314
315                &*DOMAIN
316            }
317        }
318    };
319}
320
321pub(crate) use impl_domain_separator;
322
323mod sha2 {
324    use ::sha2::Digest;
325
326    /// SHA-256 hash algorithm.
327    #[derive(Default, Clone)]
328    pub struct Sha256 {}
329
330    impl super::HashAlgorithm for Sha256 {
331        fn id(&self) -> super::HashAlgId {
332            super::HashAlgId::SHA256
333        }
334
335        fn hash(&self, data: &[u8]) -> super::Hash {
336            let mut hasher = ::sha2::Sha256::default();
337            hasher.update(data);
338            super::Hash::new(hasher.finalize().as_slice())
339        }
340
341        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
342            let mut hasher = ::sha2::Sha256::default();
343            hasher.update(prefix);
344            hasher.update(data);
345            super::Hash::new(hasher.finalize().as_slice())
346        }
347    }
348}
349
350pub use sha2::Sha256;
351
352mod blake3 {
353
354    /// BLAKE3 hash algorithm.
355    #[derive(Default, Clone)]
356    pub struct Blake3 {}
357
358    impl super::HashAlgorithm for Blake3 {
359        fn id(&self) -> super::HashAlgId {
360            super::HashAlgId::BLAKE3
361        }
362
363        fn hash(&self, data: &[u8]) -> super::Hash {
364            super::Hash::new(::blake3::hash(data).as_bytes())
365        }
366
367        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
368            let mut hasher = ::blake3::Hasher::new();
369            hasher.update(prefix);
370            hasher.update(data);
371            super::Hash::new(hasher.finalize().as_bytes())
372        }
373    }
374}
375
376pub use blake3::Blake3;
377
378mod keccak {
379    use tiny_keccak::Hasher;
380
381    /// Keccak-256 hash algorithm.
382    #[derive(Default, Clone)]
383    pub struct Keccak256 {}
384
385    impl super::HashAlgorithm for Keccak256 {
386        fn id(&self) -> super::HashAlgId {
387            super::HashAlgId::KECCAK256
388        }
389
390        fn hash(&self, data: &[u8]) -> super::Hash {
391            let mut hasher = tiny_keccak::Keccak::v256();
392            hasher.update(data);
393            let mut output = vec![0; 32];
394            hasher.finalize(&mut output);
395            super::Hash::new(&output)
396        }
397
398        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
399            let mut hasher = tiny_keccak::Keccak::v256();
400            hasher.update(prefix);
401            hasher.update(data);
402            let mut output = vec![0; 32];
403            hasher.finalize(&mut output);
404            super::Hash::new(&output)
405        }
406    }
407}
408
409pub use keccak::Keccak256;