tlsn_core/
hash.rs

1//! Hash types.
2
3use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8use crate::serialize::CanonicalSerialize;
9
10pub(crate) const DEFAULT_SUPPORTED_HASH_ALGS: &[HashAlgId] =
11    &[HashAlgId::SHA256, HashAlgId::BLAKE3, HashAlgId::KECCAK256];
12
13/// Maximum length of a hash value.
14const MAX_LEN: usize = 64;
15
16/// An error for [`HashProvider`].
17#[derive(Debug, thiserror::Error)]
18#[error("unknown hash algorithm id: {}", self.0)]
19pub struct HashProviderError(HashAlgId);
20
21/// Hash provider.
22pub struct HashProvider {
23    algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
24}
25
26impl Default for HashProvider {
27    fn default() -> Self {
28        let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
29
30        algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
31        algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
32        algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
33
34        Self { algs }
35    }
36}
37
38impl HashProvider {
39    /// Sets a hash algorithm.
40    ///
41    /// This can be used to add or override implementations of hash algorithms.
42    pub fn set_algorithm(
43        &mut self,
44        id: HashAlgId,
45        algorithm: Box<dyn HashAlgorithm + Send + Sync>,
46    ) {
47        self.algs.insert(id, algorithm);
48    }
49
50    /// Returns the hash algorithm with the given identifier, or an error if the
51    /// hash algorithm does not exist.
52    pub fn get(
53        &self,
54        id: &HashAlgId,
55    ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
56        self.algs
57            .get(id)
58            .map(|alg| &**alg)
59            .ok_or(HashProviderError(*id))
60    }
61}
62
63/// A hash algorithm identifier.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct HashAlgId(u8);
66
67impl HashAlgId {
68    /// SHA-256 hash algorithm.
69    pub const SHA256: Self = Self(1);
70    /// BLAKE3 hash algorithm.
71    pub const BLAKE3: Self = Self(2);
72    /// Keccak-256 hash algorithm.
73    pub const KECCAK256: Self = Self(3);
74
75    /// Creates a new hash algorithm identifier.
76    ///
77    /// # Panics
78    ///
79    /// Panics if the identifier is in the reserved range 0-127.
80    ///
81    /// # Arguments
82    ///
83    /// * id - Unique identifier for the hash algorithm.
84    pub const fn new(id: u8) -> Self {
85        assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
86
87        Self(id)
88    }
89
90    /// Returns the id as a `u8`.
91    pub const fn as_u8(&self) -> u8 {
92        self.0
93    }
94}
95
96impl Display for HashAlgId {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{:02x}", self.0)
99    }
100}
101
102/// A typed hash value.
103#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104pub struct TypedHash {
105    /// The algorithm of the hash.
106    pub alg: HashAlgId,
107    /// The hash value.
108    pub value: Hash,
109}
110
111/// A hash value.
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub struct Hash {
114    // To avoid heap allocation, we use a fixed-size array.
115    // 64 bytes should be sufficient for most hash algorithms.
116    value: [u8; MAX_LEN],
117    len: usize,
118}
119
120impl Default for Hash {
121    fn default() -> Self {
122        Self {
123            value: [0u8; MAX_LEN],
124            len: 0,
125        }
126    }
127}
128
129impl Serialize for Hash {
130    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
131    where
132        S: Serializer,
133    {
134        serializer.collect_seq(&self.value[..self.len])
135    }
136}
137
138impl<'de> Deserialize<'de> for Hash {
139    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140    where
141        D: Deserializer<'de>,
142    {
143        use core::marker::PhantomData;
144        use serde::de::{Error, SeqAccess, Visitor};
145
146        struct HashVisitor<'de>(PhantomData<&'de ()>);
147
148        impl<'de> Visitor<'de> for HashVisitor<'de> {
149            type Value = Hash;
150
151            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
152                write!(formatter, "an array at most 64 bytes long")
153            }
154
155            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
156            where
157                A: SeqAccess<'de>,
158            {
159                let mut value = [0; MAX_LEN];
160                let mut len = 0;
161
162                while let Some(byte) = seq.next_element()? {
163                    if len >= MAX_LEN {
164                        return Err(A::Error::invalid_length(len, &self));
165                    }
166
167                    value[len] = byte;
168                    len += 1;
169                }
170
171                Ok(Hash { value, len })
172            }
173        }
174
175        deserializer.deserialize_seq(HashVisitor(PhantomData))
176    }
177}
178
179impl Hash {
180    /// Creates a new hash value.
181    ///
182    /// # Panics
183    ///
184    /// Panics if the length of the value is greater than 64 bytes.
185    fn new(value: &[u8]) -> Self {
186        assert!(
187            value.len() <= MAX_LEN,
188            "hash value must be at most 64 bytes"
189        );
190
191        let mut bytes = [0; MAX_LEN];
192        bytes[..value.len()].copy_from_slice(value);
193
194        Self {
195            value: bytes,
196            len: value.len(),
197        }
198    }
199}
200
201impl rs_merkle::Hash for Hash {
202    const SIZE: usize = MAX_LEN;
203}
204
205impl TryFrom<Vec<u8>> for Hash {
206    type Error = &'static str;
207
208    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
209        if value.len() > MAX_LEN {
210            return Err("hash value must be at most 64 bytes");
211        }
212
213        let mut bytes = [0; MAX_LEN];
214        bytes[..value.len()].copy_from_slice(&value);
215
216        Ok(Self {
217            value: bytes,
218            len: value.len(),
219        })
220    }
221}
222
223impl From<Hash> for Vec<u8> {
224    fn from(value: Hash) -> Self {
225        value.value[..value.len].to_vec()
226    }
227}
228
229/// A hashing algorithm.
230pub trait HashAlgorithm {
231    /// Returns the hash algorithm identifier.
232    fn id(&self) -> HashAlgId;
233
234    /// Computes the hash of the provided data.
235    fn hash(&self, data: &[u8]) -> Hash;
236
237    /// Computes the hash of the provided data with a prefix.
238    fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
239}
240
241pub(crate) trait HashAlgorithmExt: HashAlgorithm {
242    #[allow(dead_code)]
243    fn hash_canonical<T: CanonicalSerialize>(&self, data: &T) -> Hash {
244        self.hash(&data.serialize())
245    }
246
247    fn hash_separated<T: DomainSeparator + CanonicalSerialize>(&self, data: &T) -> Hash {
248        self.hash_prefixed(data.domain(), &data.serialize())
249    }
250}
251
252impl<T: HashAlgorithm + ?Sized> HashAlgorithmExt for T {}
253
254/// A hash blinder.
255#[derive(Clone, Serialize, Deserialize)]
256pub(crate) struct Blinder([u8; 16]);
257
258opaque_debug::implement!(Blinder);
259
260impl Blinder {
261    pub(crate) fn as_bytes(&self) -> &[u8] {
262        &self.0
263    }
264}
265
266impl Distribution<Blinder> for StandardUniform {
267    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
268        let mut blinder = [0; 16];
269        rng.fill(&mut blinder);
270        Blinder(blinder)
271    }
272}
273
274/// A blinded pre-image of a hash.
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub(crate) struct Blinded<T> {
277    data: T,
278    blinder: Blinder,
279}
280
281impl<T> Blinded<T> {
282    /// Creates a new blinded pre-image.
283    pub(crate) fn new(data: T) -> Self {
284        Self {
285            data,
286            blinder: rand::random(),
287        }
288    }
289
290    pub(crate) fn data(&self) -> &T {
291        &self.data
292    }
293}
294
295/// A type with a domain separator which is used during hashing to mitigate type
296/// confusion attacks.
297pub(crate) trait DomainSeparator {
298    /// Returns the domain separator for the type.
299    fn domain(&self) -> &[u8];
300}
301
302macro_rules! impl_domain_separator {
303    ($type:ty) => {
304        impl $crate::hash::DomainSeparator for $type {
305            fn domain(&self) -> &[u8] {
306                use std::sync::LazyLock;
307
308                // Computes a 16 byte hash of the type's name to use as a domain separator.
309                static DOMAIN: LazyLock<[u8; 16]> = LazyLock::new(|| {
310                    let domain: [u8; 32] = blake3::hash(stringify!($type).as_bytes()).into();
311                    domain[..16].try_into().unwrap()
312                });
313
314                &*DOMAIN
315            }
316        }
317    };
318}
319
320pub(crate) use impl_domain_separator;
321
322mod sha2 {
323    use ::sha2::Digest;
324
325    /// SHA-256 hash algorithm.
326    #[derive(Default, Clone)]
327    pub struct Sha256 {}
328
329    impl super::HashAlgorithm for Sha256 {
330        fn id(&self) -> super::HashAlgId {
331            super::HashAlgId::SHA256
332        }
333
334        fn hash(&self, data: &[u8]) -> super::Hash {
335            let mut hasher = ::sha2::Sha256::default();
336            hasher.update(data);
337            super::Hash::new(hasher.finalize().as_slice())
338        }
339
340        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
341            let mut hasher = ::sha2::Sha256::default();
342            hasher.update(prefix);
343            hasher.update(data);
344            super::Hash::new(hasher.finalize().as_slice())
345        }
346    }
347}
348
349pub use sha2::Sha256;
350
351mod blake3 {
352
353    /// BLAKE3 hash algorithm.
354    #[derive(Default, Clone)]
355    pub struct Blake3 {}
356
357    impl super::HashAlgorithm for Blake3 {
358        fn id(&self) -> super::HashAlgId {
359            super::HashAlgId::BLAKE3
360        }
361
362        fn hash(&self, data: &[u8]) -> super::Hash {
363            super::Hash::new(::blake3::hash(data).as_bytes())
364        }
365
366        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
367            let mut hasher = ::blake3::Hasher::new();
368            hasher.update(prefix);
369            hasher.update(data);
370            super::Hash::new(hasher.finalize().as_bytes())
371        }
372    }
373}
374
375pub use blake3::Blake3;
376
377mod keccak {
378    use tiny_keccak::Hasher;
379
380    /// Keccak-256 hash algorithm.
381    #[derive(Default, Clone)]
382    pub struct Keccak256 {}
383
384    impl super::HashAlgorithm for Keccak256 {
385        fn id(&self) -> super::HashAlgId {
386            super::HashAlgId::KECCAK256
387        }
388
389        fn hash(&self, data: &[u8]) -> super::Hash {
390            let mut hasher = tiny_keccak::Keccak::v256();
391            hasher.update(data);
392            let mut output = vec![0; 32];
393            hasher.finalize(&mut output);
394            super::Hash::new(&output)
395        }
396
397        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
398            let mut hasher = tiny_keccak::Keccak::v256();
399            hasher.update(prefix);
400            hasher.update(data);
401            let mut output = vec![0; 32];
402            hasher.finalize(&mut output);
403            super::Hash::new(&output)
404        }
405    }
406}
407
408pub use keccak::Keccak256;