tlsn_core/
hash.rs

1//! Hash types.
2
3use std::{collections::HashMap, fmt::Display};
4
5use rand::{distr::StandardUniform, prelude::Distribution};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8/// Maximum length of a hash value.
9const MAX_LEN: usize = 64;
10
11/// An error for [`HashProvider`].
12#[derive(Debug, thiserror::Error)]
13#[error("unknown hash algorithm id: {}", self.0)]
14pub struct HashProviderError(HashAlgId);
15
16/// Hash provider.
17pub struct HashProvider {
18    algs: HashMap<HashAlgId, Box<dyn HashAlgorithm + Send + Sync>>,
19}
20
21impl Default for HashProvider {
22    fn default() -> Self {
23        let mut algs: HashMap<_, Box<dyn HashAlgorithm + Send + Sync>> = HashMap::new();
24
25        algs.insert(HashAlgId::SHA256, Box::new(Sha256::default()));
26        algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default()));
27        algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default()));
28
29        Self { algs }
30    }
31}
32
33impl HashProvider {
34    /// Sets a hash algorithm.
35    ///
36    /// This can be used to add or override implementations of hash algorithms.
37    pub fn set_algorithm(
38        &mut self,
39        id: HashAlgId,
40        algorithm: Box<dyn HashAlgorithm + Send + Sync>,
41    ) {
42        self.algs.insert(id, algorithm);
43    }
44
45    /// Returns the hash algorithm with the given identifier, or an error if the
46    /// hash algorithm does not exist.
47    pub fn get(
48        &self,
49        id: &HashAlgId,
50    ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> {
51        self.algs
52            .get(id)
53            .map(|alg| &**alg)
54            .ok_or(HashProviderError(*id))
55    }
56}
57
58/// A hash algorithm identifier.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct HashAlgId(u8);
61
62impl HashAlgId {
63    /// SHA-256 hash algorithm.
64    pub const SHA256: Self = Self(1);
65    /// BLAKE3 hash algorithm.
66    pub const BLAKE3: Self = Self(2);
67    /// Keccak-256 hash algorithm.
68    pub const KECCAK256: Self = Self(3);
69
70    /// Creates a new hash algorithm identifier.
71    ///
72    /// # Panics
73    ///
74    /// Panics if the identifier is in the reserved range 0-127.
75    ///
76    /// # Arguments
77    ///
78    /// * id - Unique identifier for the hash algorithm.
79    pub const fn new(id: u8) -> Self {
80        assert!(id >= 128, "hash algorithm id range 0-127 is reserved");
81
82        Self(id)
83    }
84
85    /// Returns the id as a `u8`.
86    pub const fn as_u8(&self) -> u8 {
87        self.0
88    }
89}
90
91impl Display for HashAlgId {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "{:02x}", self.0)
94    }
95}
96
97/// A typed hash value.
98#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub struct TypedHash {
100    /// The algorithm of the hash.
101    pub alg: HashAlgId,
102    /// The hash value.
103    pub value: Hash,
104}
105
106/// A hash value.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub struct Hash {
109    // To avoid heap allocation, we use a fixed-size array.
110    // 64 bytes should be sufficient for most hash algorithms.
111    value: [u8; MAX_LEN],
112    len: usize,
113}
114
115impl Default for Hash {
116    fn default() -> Self {
117        Self {
118            value: [0u8; MAX_LEN],
119            len: 0,
120        }
121    }
122}
123
124impl Serialize for Hash {
125    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126    where
127        S: Serializer,
128    {
129        serializer.collect_seq(&self.value[..self.len])
130    }
131}
132
133impl<'de> Deserialize<'de> for Hash {
134    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135    where
136        D: Deserializer<'de>,
137    {
138        use core::marker::PhantomData;
139        use serde::de::{Error, SeqAccess, Visitor};
140
141        struct HashVisitor<'de>(PhantomData<&'de ()>);
142
143        impl<'de> Visitor<'de> for HashVisitor<'de> {
144            type Value = Hash;
145
146            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
147                write!(formatter, "an array at most 64 bytes long")
148            }
149
150            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
151            where
152                A: SeqAccess<'de>,
153            {
154                let mut value = [0; MAX_LEN];
155                let mut len = 0;
156
157                while let Some(byte) = seq.next_element()? {
158                    if len >= MAX_LEN {
159                        return Err(A::Error::invalid_length(len, &self));
160                    }
161
162                    value[len] = byte;
163                    len += 1;
164                }
165
166                Ok(Hash { value, len })
167            }
168        }
169
170        deserializer.deserialize_seq(HashVisitor(PhantomData))
171    }
172}
173
174impl Hash {
175    /// Creates a new hash value.
176    ///
177    /// # Panics
178    ///
179    /// Panics if the length of the value is greater than 64 bytes.
180    fn new(value: &[u8]) -> Self {
181        assert!(
182            value.len() <= MAX_LEN,
183            "hash value must be at most 64 bytes"
184        );
185
186        let mut bytes = [0; MAX_LEN];
187        bytes[..value.len()].copy_from_slice(value);
188
189        Self {
190            value: bytes,
191            len: value.len(),
192        }
193    }
194}
195
196impl rs_merkle::Hash for Hash {
197    const SIZE: usize = MAX_LEN;
198}
199
200impl TryFrom<Vec<u8>> for Hash {
201    type Error = &'static str;
202
203    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
204        if value.len() > MAX_LEN {
205            return Err("hash value must be at most 64 bytes");
206        }
207
208        let mut bytes = [0; MAX_LEN];
209        bytes[..value.len()].copy_from_slice(&value);
210
211        Ok(Self {
212            value: bytes,
213            len: value.len(),
214        })
215    }
216}
217
218impl From<Hash> for Vec<u8> {
219    fn from(value: Hash) -> Self {
220        value.value[..value.len].to_vec()
221    }
222}
223
224/// A hashing algorithm.
225pub trait HashAlgorithm {
226    /// Returns the hash algorithm identifier.
227    fn id(&self) -> HashAlgId;
228
229    /// Computes the hash of the provided data.
230    fn hash(&self, data: &[u8]) -> Hash;
231
232    /// Computes the hash of the provided data with a prefix.
233    fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash;
234}
235
236/// A hash blinder.
237#[derive(Clone, Serialize, Deserialize)]
238pub struct Blinder([u8; 16]);
239
240opaque_debug::implement!(Blinder);
241
242impl Blinder {
243    /// Returns the blinder as a byte slice.
244    pub fn as_bytes(&self) -> &[u8] {
245        &self.0
246    }
247}
248
249impl Distribution<Blinder> for StandardUniform {
250    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Blinder {
251        let mut blinder = [0; 16];
252        rng.fill(&mut blinder);
253        Blinder(blinder)
254    }
255}
256
257/// A blinded pre-image of a hash.
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct Blinded<T> {
260    data: T,
261    blinder: Blinder,
262}
263
264impl<T> Blinded<T> {
265    /// Creates a new blinded pre-image.
266    pub fn new(data: T) -> Self {
267        Self {
268            data,
269            blinder: rand::random(),
270        }
271    }
272
273    /// Returns the data.
274    pub fn data(&self) -> &T {
275        &self.data
276    }
277}
278
279mod sha2 {
280    use ::sha2::Digest;
281
282    /// SHA-256 hash algorithm.
283    #[derive(Default, Clone)]
284    pub struct Sha256 {}
285
286    impl super::HashAlgorithm for Sha256 {
287        fn id(&self) -> super::HashAlgId {
288            super::HashAlgId::SHA256
289        }
290
291        fn hash(&self, data: &[u8]) -> super::Hash {
292            let mut hasher = ::sha2::Sha256::default();
293            hasher.update(data);
294            super::Hash::new(hasher.finalize().as_slice())
295        }
296
297        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
298            let mut hasher = ::sha2::Sha256::default();
299            hasher.update(prefix);
300            hasher.update(data);
301            super::Hash::new(hasher.finalize().as_slice())
302        }
303    }
304}
305
306pub use sha2::Sha256;
307
308mod blake3 {
309
310    /// BLAKE3 hash algorithm.
311    #[derive(Default, Clone)]
312    pub struct Blake3 {}
313
314    impl super::HashAlgorithm for Blake3 {
315        fn id(&self) -> super::HashAlgId {
316            super::HashAlgId::BLAKE3
317        }
318
319        fn hash(&self, data: &[u8]) -> super::Hash {
320            super::Hash::new(::blake3::hash(data).as_bytes())
321        }
322
323        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
324            let mut hasher = ::blake3::Hasher::new();
325            hasher.update(prefix);
326            hasher.update(data);
327            super::Hash::new(hasher.finalize().as_bytes())
328        }
329    }
330}
331
332pub use blake3::Blake3;
333
334mod keccak {
335    use tiny_keccak::Hasher;
336
337    /// Keccak-256 hash algorithm.
338    #[derive(Default, Clone)]
339    pub struct Keccak256 {}
340
341    impl super::HashAlgorithm for Keccak256 {
342        fn id(&self) -> super::HashAlgId {
343            super::HashAlgId::KECCAK256
344        }
345
346        fn hash(&self, data: &[u8]) -> super::Hash {
347            let mut hasher = tiny_keccak::Keccak::v256();
348            hasher.update(data);
349            let mut output = vec![0; 32];
350            hasher.finalize(&mut output);
351            super::Hash::new(&output)
352        }
353
354        fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
355            let mut hasher = tiny_keccak::Keccak::v256();
356            hasher.update(prefix);
357            hasher.update(data);
358            let mut output = vec![0; 32];
359            hasher.finalize(&mut output);
360            super::Hash::new(&output)
361        }
362    }
363}
364
365pub use keccak::Keccak256;