tlsn_core/transcript/encoding/
tree.rs

1use std::collections::HashMap;
2
3use bimap::BiMap;
4use rangeset::{RangeSet, UnionMut};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
9    merkle::MerkleTree,
10    transcript::{
11        encoding::{
12            proof::{EncodingProof, Opening},
13            EncodingProvider,
14        },
15        Direction,
16    },
17};
18
19/// Encoding tree builder error.
20#[derive(Debug, thiserror::Error)]
21pub enum EncodingTreeError {
22    /// Index is out of bounds of the transcript.
23    #[error("index is out of bounds of the transcript")]
24    OutOfBounds {
25        /// The index.
26        index: RangeSet<usize>,
27        /// The transcript length.
28        transcript_length: usize,
29    },
30    /// Encoding provider is missing an encoding for an index.
31    #[error("encoding provider is missing an encoding for an index")]
32    MissingEncoding {
33        /// The index which is missing.
34        index: RangeSet<usize>,
35    },
36    /// Index is missing from the tree.
37    #[error("index is missing from the tree")]
38    MissingLeaf {
39        /// The index which is missing.
40        index: RangeSet<usize>,
41    },
42}
43
44/// A merkle tree of transcript encodings.
45#[derive(Clone, Serialize, Deserialize)]
46pub struct EncodingTree {
47    /// Merkle tree of the commitments.
48    tree: MerkleTree,
49    /// Nonces used to blind the hashes.
50    blinders: Vec<Blinder>,
51    /// Mapping between the index of a leaf and the transcript index it
52    /// corresponds to.
53    idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
54    /// Union of all transcript indices in the sent direction.
55    sent_idx: RangeSet<usize>,
56    /// Union of all transcript indices in the received direction.
57    received_idx: RangeSet<usize>,
58}
59
60opaque_debug::implement!(EncodingTree);
61
62impl EncodingTree {
63    /// Creates a new encoding tree.
64    ///
65    /// # Arguments
66    ///
67    /// * `hasher` - The hash algorithm to use.
68    /// * `idxs` - The subsequence indices to commit to.
69    /// * `provider` - The encoding provider.
70    pub fn new<'idx>(
71        hasher: &dyn HashAlgorithm,
72        idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
73        provider: &dyn EncodingProvider,
74    ) -> Result<Self, EncodingTreeError> {
75        let mut this = Self {
76            tree: MerkleTree::new(hasher.id()),
77            blinders: Vec::new(),
78            idxs: BiMap::new(),
79            sent_idx: RangeSet::default(),
80            received_idx: RangeSet::default(),
81        };
82
83        let mut leaves = Vec::new();
84        let mut encoding = Vec::new();
85        for dir_idx in idxs {
86            let direction = dir_idx.0;
87            let idx = &dir_idx.1;
88
89            // Ignore empty indices.
90            if idx.is_empty() {
91                continue;
92            }
93
94            if this.idxs.contains_right(dir_idx) {
95                // The subsequence is already in the tree.
96                continue;
97            }
98
99            let blinder: Blinder = rand::random();
100
101            encoding.clear();
102            for range in idx.iter_ranges() {
103                provider
104                    .provide_encoding(direction, range, &mut encoding)
105                    .map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
106            }
107            encoding.extend_from_slice(blinder.as_bytes());
108
109            let leaf = hasher.hash(&encoding);
110
111            leaves.push(leaf);
112            this.blinders.push(blinder);
113            this.idxs.insert(this.idxs.len(), dir_idx.clone());
114            match direction {
115                Direction::Sent => this.sent_idx.union_mut(idx),
116                Direction::Received => this.received_idx.union_mut(idx),
117            }
118        }
119
120        this.tree.insert(hasher, leaves);
121
122        Ok(this)
123    }
124
125    /// Returns the root of the tree.
126    pub fn root(&self) -> TypedHash {
127        self.tree.root()
128    }
129
130    /// Returns the hash algorithm of the tree.
131    pub fn algorithm(&self) -> HashAlgId {
132        self.tree.algorithm()
133    }
134
135    /// Generates a proof for the given indices.
136    ///
137    /// # Arguments
138    ///
139    /// * `idxs` - The transcript indices to prove.
140    pub fn proof<'idx>(
141        &self,
142        idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
143    ) -> Result<EncodingProof, EncodingTreeError> {
144        let mut openings = HashMap::new();
145        for dir_idx in idxs {
146            let direction = dir_idx.0;
147            let idx = &dir_idx.1;
148
149            let leaf_idx = *self
150                .idxs
151                .get_by_right(dir_idx)
152                .ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
153            let blinder = self.blinders[leaf_idx].clone();
154
155            openings.insert(
156                leaf_idx,
157                Opening {
158                    direction,
159                    idx: idx.clone(),
160                    blinder,
161                },
162            );
163        }
164
165        let mut indices = openings.keys().copied().collect::<Vec<_>>();
166        indices.sort();
167
168        Ok(EncodingProof {
169            inclusion_proof: self.tree.proof(&indices),
170            openings,
171        })
172    }
173
174    /// Returns whether the tree contains the given transcript index.
175    pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
176        self.idxs.contains_right(idx)
177    }
178
179    pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
180        match direction {
181            Direction::Sent => &self.sent_idx,
182            Direction::Received => &self.received_idx,
183        }
184    }
185
186    /// Returns the committed transcript indices.
187    pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
188        self.idxs.right_values()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::{
196        fixtures::{encoder_secret, encoding_provider},
197        hash::{Blake3, HashProvider},
198        transcript::{encoding::EncodingCommitment, Transcript},
199    };
200    use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
201
202    fn new_tree<'seq>(
203        transcript: &Transcript,
204        idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
205    ) -> Result<EncodingTree, EncodingTreeError> {
206        let provider = encoding_provider(transcript.sent(), transcript.received());
207
208        EncodingTree::new(&Blake3::default(), idxs, &provider)
209    }
210
211    #[test]
212    fn test_encoding_tree() {
213        let transcript = Transcript::new(POST_JSON, OK_JSON);
214
215        let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
216        let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
217
218        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
219
220        assert!(tree.contains(&idx_0));
221        assert!(tree.contains(&idx_1));
222
223        let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
224
225        let commitment = EncodingCommitment { root: tree.root() };
226
227        let (auth_sent, auth_recv) = proof
228            .verify_with_provider(
229                &HashProvider::default(),
230                &encoder_secret(),
231                &commitment,
232                transcript.sent(),
233                transcript.received(),
234            )
235            .unwrap();
236
237        assert_eq!(auth_sent, idx_0.1);
238        assert_eq!(auth_recv, idx_1.1);
239    }
240
241    #[test]
242    fn test_encoding_tree_multiple_ranges() {
243        let transcript = Transcript::new(POST_JSON, OK_JSON);
244
245        let idx_0 = (Direction::Sent, RangeSet::from(0..1));
246        let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
247        let idx_2 = (Direction::Received, RangeSet::from(0..1));
248        let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
249
250        let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
251
252        assert!(tree.contains(&idx_0));
253        assert!(tree.contains(&idx_1));
254        assert!(tree.contains(&idx_2));
255        assert!(tree.contains(&idx_3));
256
257        let proof = tree
258            .proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
259            .unwrap();
260
261        let commitment = EncodingCommitment { root: tree.root() };
262
263        let (auth_sent, auth_recv) = proof
264            .verify_with_provider(
265                &HashProvider::default(),
266                &encoder_secret(),
267                &commitment,
268                transcript.sent(),
269                transcript.received(),
270            )
271            .unwrap();
272
273        let mut expected_auth_sent = RangeSet::default();
274        expected_auth_sent.union_mut(&idx_0.1);
275        expected_auth_sent.union_mut(&idx_1.1);
276
277        let mut expected_auth_recv = RangeSet::default();
278        expected_auth_recv.union_mut(&idx_2.1);
279        expected_auth_recv.union_mut(&idx_3.1);
280
281        assert_eq!(auth_sent, expected_auth_sent);
282        assert_eq!(auth_recv, expected_auth_recv);
283    }
284
285    #[test]
286    fn test_encoding_tree_proof_missing_leaf() {
287        let transcript = Transcript::new(POST_JSON, OK_JSON);
288
289        let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
290        let idx_1 = (Direction::Received, RangeSet::from(0..4));
291        let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
292
293        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
294
295        let result = tree
296            .proof([&idx_0, &idx_1, &idx_2].into_iter())
297            .unwrap_err();
298        assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
299    }
300
301    #[test]
302    fn test_encoding_tree_out_of_bounds() {
303        let transcript = Transcript::new(POST_JSON, OK_JSON);
304
305        let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
306        let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
307
308        let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
309        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
310
311        let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
312        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
313    }
314
315    #[test]
316    fn test_encoding_tree_missing_encoding() {
317        let provider = encoding_provider(&[], &[]);
318
319        let result = EncodingTree::new(
320            &Blake3::default(),
321            [(Direction::Sent, RangeSet::from(0..8))].iter(),
322            &provider,
323        )
324        .unwrap_err();
325        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
326    }
327}