tlsn_core/transcript/encoding/
tree.rs

1use std::collections::HashMap;
2
3use bimap::BiMap;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    connection::TranscriptLength,
8    hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
9    merkle::MerkleTree,
10    transcript::{
11        encoding::{
12            proof::{EncodingProof, Opening},
13            EncodingProvider,
14        },
15        Direction, Idx,
16    },
17};
18
19/// Encoding tree builder error.
20#[derive(Debug, thiserror::Error)]
21pub enum EncodingTreeError {
22    /// Index is out of bounds of the transcript.
23    #[error("index is out of bounds of the transcript")]
24    OutOfBounds {
25        /// The index.
26        index: Idx,
27        /// The transcript length.
28        transcript_length: usize,
29    },
30    /// Encoding provider is missing an encoding for an index.
31    #[error("encoding provider is missing an encoding for an index")]
32    MissingEncoding {
33        /// The index which is missing.
34        index: Idx,
35    },
36    /// Index is missing from the tree.
37    #[error("index is missing from the tree")]
38    MissingLeaf {
39        /// The index which is missing.
40        index: Idx,
41    },
42}
43
44/// A merkle tree of transcript encodings.
45#[derive(Clone, Serialize, Deserialize)]
46pub struct EncodingTree {
47    /// Merkle tree of the commitments.
48    tree: MerkleTree,
49    /// Nonces used to blind the hashes.
50    blinders: Vec<Blinder>,
51    /// Mapping between the index of a leaf and the transcript index it
52    /// corresponds to.
53    idxs: BiMap<usize, (Direction, Idx)>,
54    /// Union of all transcript indices in the sent direction.
55    sent_idx: Idx,
56    /// Union of all transcript indices in the received direction.
57    received_idx: Idx,
58}
59
60opaque_debug::implement!(EncodingTree);
61
62impl EncodingTree {
63    /// Creates a new encoding tree.
64    ///
65    /// # Arguments
66    ///
67    /// * `hasher` - The hash algorithm to use.
68    /// * `idxs` - The subsequence indices to commit to.
69    /// * `provider` - The encoding provider.
70    /// * `transcript_length` - The length of the transcript.
71    pub fn new<'idx>(
72        hasher: &dyn HashAlgorithm,
73        idxs: impl IntoIterator<Item = &'idx (Direction, Idx)>,
74        provider: &dyn EncodingProvider,
75        transcript_length: &TranscriptLength,
76    ) -> Result<Self, EncodingTreeError> {
77        let mut this = Self {
78            tree: MerkleTree::new(hasher.id()),
79            blinders: Vec::new(),
80            idxs: BiMap::new(),
81            sent_idx: Idx::empty(),
82            received_idx: Idx::empty(),
83        };
84
85        let mut leaves = Vec::new();
86        let mut encoding = Vec::new();
87        for dir_idx in idxs {
88            let direction = dir_idx.0;
89            let idx = &dir_idx.1;
90
91            // Ignore empty indices.
92            if idx.is_empty() {
93                continue;
94            }
95
96            let len = match direction {
97                Direction::Sent => transcript_length.sent as usize,
98                Direction::Received => transcript_length.received as usize,
99            };
100
101            if idx.end() > len {
102                return Err(EncodingTreeError::OutOfBounds {
103                    index: idx.clone(),
104                    transcript_length: len,
105                });
106            }
107
108            if this.idxs.contains_right(dir_idx) {
109                // The subsequence is already in the tree.
110                continue;
111            }
112
113            let blinder: Blinder = rand::random();
114
115            encoding.clear();
116            for range in idx.iter_ranges() {
117                provider
118                    .provide_encoding(direction, range, &mut encoding)
119                    .map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
120            }
121            encoding.extend_from_slice(blinder.as_bytes());
122
123            let leaf = hasher.hash(&encoding);
124
125            leaves.push(leaf);
126            this.blinders.push(blinder);
127            this.idxs.insert(this.idxs.len(), dir_idx.clone());
128            match direction {
129                Direction::Sent => this.sent_idx.union_mut(idx),
130                Direction::Received => this.received_idx.union_mut(idx),
131            }
132        }
133
134        this.tree.insert(hasher, leaves);
135
136        Ok(this)
137    }
138
139    /// Returns the root of the tree.
140    pub fn root(&self) -> TypedHash {
141        self.tree.root()
142    }
143
144    /// Returns the hash algorithm of the tree.
145    pub fn algorithm(&self) -> HashAlgId {
146        self.tree.algorithm()
147    }
148
149    /// Generates a proof for the given indices.
150    ///
151    /// # Arguments
152    ///
153    /// * `idxs` - The transcript indices to prove.
154    pub fn proof<'idx>(
155        &self,
156        idxs: impl Iterator<Item = &'idx (Direction, Idx)>,
157    ) -> Result<EncodingProof, EncodingTreeError> {
158        let mut openings = HashMap::new();
159        for dir_idx in idxs {
160            let direction = dir_idx.0;
161            let idx = &dir_idx.1;
162
163            let leaf_idx = *self
164                .idxs
165                .get_by_right(dir_idx)
166                .ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
167            let blinder = self.blinders[leaf_idx].clone();
168
169            openings.insert(
170                leaf_idx,
171                Opening {
172                    direction,
173                    idx: idx.clone(),
174                    blinder,
175                },
176            );
177        }
178
179        let mut indices = openings.keys().copied().collect::<Vec<_>>();
180        indices.sort();
181
182        Ok(EncodingProof {
183            inclusion_proof: self.tree.proof(&indices),
184            openings,
185        })
186    }
187
188    /// Returns whether the tree contains the given transcript index.
189    pub fn contains(&self, idx: &(Direction, Idx)) -> bool {
190        self.idxs.contains_right(idx)
191    }
192
193    pub(crate) fn idx(&self, direction: Direction) -> &Idx {
194        match direction {
195            Direction::Sent => &self.sent_idx,
196            Direction::Received => &self.received_idx,
197        }
198    }
199
200    /// Returns the committed transcript indices.
201    pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, Idx)> {
202        self.idxs.right_values()
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::{
210        fixtures::{encoder_secret, encoding_provider},
211        hash::Blake3,
212        transcript::{encoding::EncodingCommitment, Transcript},
213        CryptoProvider,
214    };
215    use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
216
217    fn new_tree<'seq>(
218        transcript: &Transcript,
219        idxs: impl Iterator<Item = &'seq (Direction, Idx)>,
220    ) -> Result<EncodingTree, EncodingTreeError> {
221        let provider = encoding_provider(transcript.sent(), transcript.received());
222        let transcript_length = TranscriptLength {
223            sent: transcript.sent().len() as u32,
224            received: transcript.received().len() as u32,
225        };
226        EncodingTree::new(&Blake3::default(), idxs, &provider, &transcript_length)
227    }
228
229    #[test]
230    fn test_encoding_tree() {
231        let transcript = Transcript::new(POST_JSON, OK_JSON);
232
233        let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
234        let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len()));
235
236        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
237
238        assert!(tree.contains(&idx_0));
239        assert!(tree.contains(&idx_1));
240
241        let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
242
243        let commitment = EncodingCommitment {
244            root: tree.root(),
245            secret: encoder_secret(),
246        };
247
248        let (auth_sent, auth_recv) = proof
249            .verify_with_provider(
250                &CryptoProvider::default(),
251                &commitment,
252                transcript.sent(),
253                transcript.received(),
254            )
255            .unwrap();
256
257        assert_eq!(auth_sent, idx_0.1);
258        assert_eq!(auth_recv, idx_1.1);
259    }
260
261    #[test]
262    fn test_encoding_tree_multiple_ranges() {
263        let transcript = Transcript::new(POST_JSON, OK_JSON);
264
265        let idx_0 = (Direction::Sent, Idx::new(0..1));
266        let idx_1 = (Direction::Sent, Idx::new(1..POST_JSON.len()));
267        let idx_2 = (Direction::Received, Idx::new(0..1));
268        let idx_3 = (Direction::Received, Idx::new(1..OK_JSON.len()));
269
270        let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
271
272        assert!(tree.contains(&idx_0));
273        assert!(tree.contains(&idx_1));
274        assert!(tree.contains(&idx_2));
275        assert!(tree.contains(&idx_3));
276
277        let proof = tree
278            .proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
279            .unwrap();
280
281        let commitment = EncodingCommitment {
282            root: tree.root(),
283            secret: encoder_secret(),
284        };
285
286        let (auth_sent, auth_recv) = proof
287            .verify_with_provider(
288                &CryptoProvider::default(),
289                &commitment,
290                transcript.sent(),
291                transcript.received(),
292            )
293            .unwrap();
294
295        let mut expected_auth_sent = Idx::default();
296        expected_auth_sent.union_mut(&idx_0.1);
297        expected_auth_sent.union_mut(&idx_1.1);
298
299        let mut expected_auth_recv = Idx::default();
300        expected_auth_recv.union_mut(&idx_2.1);
301        expected_auth_recv.union_mut(&idx_3.1);
302
303        assert_eq!(auth_sent, expected_auth_sent);
304        assert_eq!(auth_recv, expected_auth_recv);
305    }
306
307    #[test]
308    fn test_encoding_tree_proof_missing_leaf() {
309        let transcript = Transcript::new(POST_JSON, OK_JSON);
310
311        let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
312        let idx_1 = (Direction::Received, Idx::new(0..4));
313        let idx_2 = (Direction::Received, Idx::new(4..OK_JSON.len()));
314
315        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
316
317        let result = tree
318            .proof([&idx_0, &idx_1, &idx_2].into_iter())
319            .unwrap_err();
320        assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
321    }
322
323    #[test]
324    fn test_encoding_tree_out_of_bounds() {
325        let transcript = Transcript::new(POST_JSON, OK_JSON);
326
327        let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len() + 1));
328        let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len() + 1));
329
330        let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
331        assert!(matches!(result, EncodingTreeError::OutOfBounds { .. }));
332
333        let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
334        assert!(matches!(result, EncodingTreeError::OutOfBounds { .. }));
335    }
336
337    #[test]
338    fn test_encoding_tree_missing_encoding() {
339        let provider = encoding_provider(&[], &[]);
340        let transcript_length = TranscriptLength {
341            sent: 8,
342            received: 8,
343        };
344
345        let result = EncodingTree::new(
346            &Blake3::default(),
347            [(Direction::Sent, Idx::new(0..8))].iter(),
348            &provider,
349            &transcript_length,
350        )
351        .unwrap_err();
352        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
353    }
354}