tlsn_core/transcript/encoding/
tree.rs

1use std::collections::HashMap;
2
3use bimap::BiMap;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
8    merkle::MerkleTree,
9    transcript::{
10        encoding::{
11            proof::{EncodingProof, Opening},
12            EncodingProvider,
13        },
14        Direction, Idx,
15    },
16};
17
18/// Encoding tree builder error.
19#[derive(Debug, thiserror::Error)]
20pub enum EncodingTreeError {
21    /// Index is out of bounds of the transcript.
22    #[error("index is out of bounds of the transcript")]
23    OutOfBounds {
24        /// The index.
25        index: Idx,
26        /// The transcript length.
27        transcript_length: usize,
28    },
29    /// Encoding provider is missing an encoding for an index.
30    #[error("encoding provider is missing an encoding for an index")]
31    MissingEncoding {
32        /// The index which is missing.
33        index: Idx,
34    },
35    /// Index is missing from the tree.
36    #[error("index is missing from the tree")]
37    MissingLeaf {
38        /// The index which is missing.
39        index: Idx,
40    },
41}
42
43/// A merkle tree of transcript encodings.
44#[derive(Clone, Serialize, Deserialize)]
45pub struct EncodingTree {
46    /// Merkle tree of the commitments.
47    tree: MerkleTree,
48    /// Nonces used to blind the hashes.
49    blinders: Vec<Blinder>,
50    /// Mapping between the index of a leaf and the transcript index it
51    /// corresponds to.
52    idxs: BiMap<usize, (Direction, Idx)>,
53    /// Union of all transcript indices in the sent direction.
54    sent_idx: Idx,
55    /// Union of all transcript indices in the received direction.
56    received_idx: Idx,
57}
58
59opaque_debug::implement!(EncodingTree);
60
61impl EncodingTree {
62    /// Creates a new encoding tree.
63    ///
64    /// # Arguments
65    ///
66    /// * `hasher` - The hash algorithm to use.
67    /// * `idxs` - The subsequence indices to commit to.
68    /// * `provider` - The encoding provider.
69    pub fn new<'idx>(
70        hasher: &dyn HashAlgorithm,
71        idxs: impl IntoIterator<Item = &'idx (Direction, Idx)>,
72        provider: &dyn EncodingProvider,
73    ) -> Result<Self, EncodingTreeError> {
74        let mut this = Self {
75            tree: MerkleTree::new(hasher.id()),
76            blinders: Vec::new(),
77            idxs: BiMap::new(),
78            sent_idx: Idx::empty(),
79            received_idx: Idx::empty(),
80        };
81
82        let mut leaves = Vec::new();
83        let mut encoding = Vec::new();
84        for dir_idx in idxs {
85            let direction = dir_idx.0;
86            let idx = &dir_idx.1;
87
88            // Ignore empty indices.
89            if idx.is_empty() {
90                continue;
91            }
92
93            if this.idxs.contains_right(dir_idx) {
94                // The subsequence is already in the tree.
95                continue;
96            }
97
98            let blinder: Blinder = rand::random();
99
100            encoding.clear();
101            for range in idx.iter_ranges() {
102                provider
103                    .provide_encoding(direction, range, &mut encoding)
104                    .map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
105            }
106            encoding.extend_from_slice(blinder.as_bytes());
107
108            let leaf = hasher.hash(&encoding);
109
110            leaves.push(leaf);
111            this.blinders.push(blinder);
112            this.idxs.insert(this.idxs.len(), dir_idx.clone());
113            match direction {
114                Direction::Sent => this.sent_idx.union_mut(idx),
115                Direction::Received => this.received_idx.union_mut(idx),
116            }
117        }
118
119        this.tree.insert(hasher, leaves);
120
121        Ok(this)
122    }
123
124    /// Returns the root of the tree.
125    pub fn root(&self) -> TypedHash {
126        self.tree.root()
127    }
128
129    /// Returns the hash algorithm of the tree.
130    pub fn algorithm(&self) -> HashAlgId {
131        self.tree.algorithm()
132    }
133
134    /// Generates a proof for the given indices.
135    ///
136    /// # Arguments
137    ///
138    /// * `idxs` - The transcript indices to prove.
139    pub fn proof<'idx>(
140        &self,
141        idxs: impl Iterator<Item = &'idx (Direction, Idx)>,
142    ) -> Result<EncodingProof, EncodingTreeError> {
143        let mut openings = HashMap::new();
144        for dir_idx in idxs {
145            let direction = dir_idx.0;
146            let idx = &dir_idx.1;
147
148            let leaf_idx = *self
149                .idxs
150                .get_by_right(dir_idx)
151                .ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
152            let blinder = self.blinders[leaf_idx].clone();
153
154            openings.insert(
155                leaf_idx,
156                Opening {
157                    direction,
158                    idx: idx.clone(),
159                    blinder,
160                },
161            );
162        }
163
164        let mut indices = openings.keys().copied().collect::<Vec<_>>();
165        indices.sort();
166
167        Ok(EncodingProof {
168            inclusion_proof: self.tree.proof(&indices),
169            openings,
170        })
171    }
172
173    /// Returns whether the tree contains the given transcript index.
174    pub fn contains(&self, idx: &(Direction, Idx)) -> bool {
175        self.idxs.contains_right(idx)
176    }
177
178    pub(crate) fn idx(&self, direction: Direction) -> &Idx {
179        match direction {
180            Direction::Sent => &self.sent_idx,
181            Direction::Received => &self.received_idx,
182        }
183    }
184
185    /// Returns the committed transcript indices.
186    pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, Idx)> {
187        self.idxs.right_values()
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::{
195        fixtures::{encoder_secret, encoding_provider},
196        hash::{Blake3, HashProvider},
197        transcript::{encoding::EncodingCommitment, Transcript},
198    };
199    use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
200
201    fn new_tree<'seq>(
202        transcript: &Transcript,
203        idxs: impl Iterator<Item = &'seq (Direction, Idx)>,
204    ) -> Result<EncodingTree, EncodingTreeError> {
205        let provider = encoding_provider(transcript.sent(), transcript.received());
206
207        EncodingTree::new(&Blake3::default(), idxs, &provider)
208    }
209
210    #[test]
211    fn test_encoding_tree() {
212        let transcript = Transcript::new(POST_JSON, OK_JSON);
213
214        let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
215        let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len()));
216
217        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
218
219        assert!(tree.contains(&idx_0));
220        assert!(tree.contains(&idx_1));
221
222        let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
223
224        let commitment = EncodingCommitment {
225            root: tree.root(),
226            secret: encoder_secret(),
227        };
228
229        let (auth_sent, auth_recv) = proof
230            .verify_with_provider(
231                &HashProvider::default(),
232                &commitment,
233                transcript.sent(),
234                transcript.received(),
235            )
236            .unwrap();
237
238        assert_eq!(auth_sent, idx_0.1);
239        assert_eq!(auth_recv, idx_1.1);
240    }
241
242    #[test]
243    fn test_encoding_tree_multiple_ranges() {
244        let transcript = Transcript::new(POST_JSON, OK_JSON);
245
246        let idx_0 = (Direction::Sent, Idx::new(0..1));
247        let idx_1 = (Direction::Sent, Idx::new(1..POST_JSON.len()));
248        let idx_2 = (Direction::Received, Idx::new(0..1));
249        let idx_3 = (Direction::Received, Idx::new(1..OK_JSON.len()));
250
251        let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
252
253        assert!(tree.contains(&idx_0));
254        assert!(tree.contains(&idx_1));
255        assert!(tree.contains(&idx_2));
256        assert!(tree.contains(&idx_3));
257
258        let proof = tree
259            .proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
260            .unwrap();
261
262        let commitment = EncodingCommitment {
263            root: tree.root(),
264            secret: encoder_secret(),
265        };
266
267        let (auth_sent, auth_recv) = proof
268            .verify_with_provider(
269                &HashProvider::default(),
270                &commitment,
271                transcript.sent(),
272                transcript.received(),
273            )
274            .unwrap();
275
276        let mut expected_auth_sent = Idx::default();
277        expected_auth_sent.union_mut(&idx_0.1);
278        expected_auth_sent.union_mut(&idx_1.1);
279
280        let mut expected_auth_recv = Idx::default();
281        expected_auth_recv.union_mut(&idx_2.1);
282        expected_auth_recv.union_mut(&idx_3.1);
283
284        assert_eq!(auth_sent, expected_auth_sent);
285        assert_eq!(auth_recv, expected_auth_recv);
286    }
287
288    #[test]
289    fn test_encoding_tree_proof_missing_leaf() {
290        let transcript = Transcript::new(POST_JSON, OK_JSON);
291
292        let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
293        let idx_1 = (Direction::Received, Idx::new(0..4));
294        let idx_2 = (Direction::Received, Idx::new(4..OK_JSON.len()));
295
296        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
297
298        let result = tree
299            .proof([&idx_0, &idx_1, &idx_2].into_iter())
300            .unwrap_err();
301        assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
302    }
303
304    #[test]
305    fn test_encoding_tree_out_of_bounds() {
306        let transcript = Transcript::new(POST_JSON, OK_JSON);
307
308        let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len() + 1));
309        let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len() + 1));
310
311        let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
312        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
313
314        let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
315        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
316    }
317
318    #[test]
319    fn test_encoding_tree_missing_encoding() {
320        let provider = encoding_provider(&[], &[]);
321
322        let result = EncodingTree::new(
323            &Blake3::default(),
324            [(Direction::Sent, Idx::new(0..8))].iter(),
325            &provider,
326        )
327        .unwrap_err();
328        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
329    }
330}