tlsn_core/transcript/encoding/
tree.rs

1use std::collections::HashMap;
2
3use bimap::BiMap;
4use rangeset::{RangeSet, UnionMut};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
9    merkle::MerkleTree,
10    transcript::{
11        encoding::{
12            proof::{EncodingProof, Opening},
13            EncodingProvider,
14        },
15        Direction,
16    },
17};
18
19/// Encoding tree builder error.
20#[derive(Debug, thiserror::Error)]
21pub enum EncodingTreeError {
22    /// Index is out of bounds of the transcript.
23    #[error("index is out of bounds of the transcript")]
24    OutOfBounds {
25        /// The index.
26        index: RangeSet<usize>,
27        /// The transcript length.
28        transcript_length: usize,
29    },
30    /// Encoding provider is missing an encoding for an index.
31    #[error("encoding provider is missing an encoding for an index")]
32    MissingEncoding {
33        /// The index which is missing.
34        index: RangeSet<usize>,
35    },
36    /// Index is missing from the tree.
37    #[error("index is missing from the tree")]
38    MissingLeaf {
39        /// The index which is missing.
40        index: RangeSet<usize>,
41    },
42}
43
44/// A merkle tree of transcript encodings.
45#[derive(Clone, Serialize, Deserialize)]
46pub struct EncodingTree {
47    /// Merkle tree of the commitments.
48    tree: MerkleTree,
49    /// Nonces used to blind the hashes.
50    blinders: Vec<Blinder>,
51    /// Mapping between the index of a leaf and the transcript index it
52    /// corresponds to.
53    idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
54    /// Union of all transcript indices in the sent direction.
55    sent_idx: RangeSet<usize>,
56    /// Union of all transcript indices in the received direction.
57    received_idx: RangeSet<usize>,
58}
59
60opaque_debug::implement!(EncodingTree);
61
62impl EncodingTree {
63    /// Creates a new encoding tree.
64    ///
65    /// # Arguments
66    ///
67    /// * `hasher` - The hash algorithm to use.
68    /// * `idxs` - The subsequence indices to commit to.
69    /// * `provider` - The encoding provider.
70    pub fn new<'idx>(
71        hasher: &dyn HashAlgorithm,
72        idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
73        provider: &dyn EncodingProvider,
74    ) -> Result<Self, EncodingTreeError> {
75        let mut this = Self {
76            tree: MerkleTree::new(hasher.id()),
77            blinders: Vec::new(),
78            idxs: BiMap::new(),
79            sent_idx: RangeSet::default(),
80            received_idx: RangeSet::default(),
81        };
82
83        let mut leaves = Vec::new();
84        let mut encoding = Vec::new();
85        for dir_idx in idxs {
86            let direction = dir_idx.0;
87            let idx = &dir_idx.1;
88
89            // Ignore empty indices.
90            if idx.is_empty() {
91                continue;
92            }
93
94            if this.idxs.contains_right(dir_idx) {
95                // The subsequence is already in the tree.
96                continue;
97            }
98
99            let blinder: Blinder = rand::random();
100
101            encoding.clear();
102            for range in idx.iter_ranges() {
103                provider
104                    .provide_encoding(direction, range, &mut encoding)
105                    .map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
106            }
107            encoding.extend_from_slice(blinder.as_bytes());
108
109            let leaf = hasher.hash(&encoding);
110
111            leaves.push(leaf);
112            this.blinders.push(blinder);
113            this.idxs.insert(this.idxs.len(), dir_idx.clone());
114            match direction {
115                Direction::Sent => this.sent_idx.union_mut(idx),
116                Direction::Received => this.received_idx.union_mut(idx),
117            }
118        }
119
120        this.tree.insert(hasher, leaves);
121
122        Ok(this)
123    }
124
125    /// Returns the root of the tree.
126    pub fn root(&self) -> TypedHash {
127        self.tree.root()
128    }
129
130    /// Returns the hash algorithm of the tree.
131    pub fn algorithm(&self) -> HashAlgId {
132        self.tree.algorithm()
133    }
134
135    /// Generates a proof for the given indices.
136    ///
137    /// # Arguments
138    ///
139    /// * `idxs` - The transcript indices to prove.
140    pub fn proof<'idx>(
141        &self,
142        idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
143    ) -> Result<EncodingProof, EncodingTreeError> {
144        let mut openings = HashMap::new();
145        for dir_idx in idxs {
146            let direction = dir_idx.0;
147            let idx = &dir_idx.1;
148
149            let leaf_idx = *self
150                .idxs
151                .get_by_right(dir_idx)
152                .ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
153            let blinder = self.blinders[leaf_idx].clone();
154
155            openings.insert(
156                leaf_idx,
157                Opening {
158                    direction,
159                    idx: idx.clone(),
160                    blinder,
161                },
162            );
163        }
164
165        let mut indices = openings.keys().copied().collect::<Vec<_>>();
166        indices.sort();
167
168        Ok(EncodingProof {
169            inclusion_proof: self.tree.proof(&indices),
170            openings,
171        })
172    }
173
174    /// Returns whether the tree contains the given transcript index.
175    pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
176        self.idxs.contains_right(idx)
177    }
178
179    pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
180        match direction {
181            Direction::Sent => &self.sent_idx,
182            Direction::Received => &self.received_idx,
183        }
184    }
185
186    /// Returns the committed transcript indices.
187    pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
188        self.idxs.right_values()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::{
196        fixtures::{encoder_secret, encoding_provider},
197        hash::{Blake3, HashProvider},
198        transcript::{encoding::EncodingCommitment, Transcript},
199    };
200    use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
201
202    fn new_tree<'seq>(
203        transcript: &Transcript,
204        idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
205    ) -> Result<EncodingTree, EncodingTreeError> {
206        let provider = encoding_provider(transcript.sent(), transcript.received());
207
208        EncodingTree::new(&Blake3::default(), idxs, &provider)
209    }
210
211    #[test]
212    fn test_encoding_tree() {
213        let transcript = Transcript::new(POST_JSON, OK_JSON);
214
215        let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
216        let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
217
218        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
219
220        assert!(tree.contains(&idx_0));
221        assert!(tree.contains(&idx_1));
222
223        let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
224
225        let commitment = EncodingCommitment {
226            root: tree.root(),
227            secret: encoder_secret(),
228        };
229
230        let (auth_sent, auth_recv) = proof
231            .verify_with_provider(
232                &HashProvider::default(),
233                &commitment,
234                transcript.sent(),
235                transcript.received(),
236            )
237            .unwrap();
238
239        assert_eq!(auth_sent, idx_0.1);
240        assert_eq!(auth_recv, idx_1.1);
241    }
242
243    #[test]
244    fn test_encoding_tree_multiple_ranges() {
245        let transcript = Transcript::new(POST_JSON, OK_JSON);
246
247        let idx_0 = (Direction::Sent, RangeSet::from(0..1));
248        let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
249        let idx_2 = (Direction::Received, RangeSet::from(0..1));
250        let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
251
252        let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
253
254        assert!(tree.contains(&idx_0));
255        assert!(tree.contains(&idx_1));
256        assert!(tree.contains(&idx_2));
257        assert!(tree.contains(&idx_3));
258
259        let proof = tree
260            .proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
261            .unwrap();
262
263        let commitment = EncodingCommitment {
264            root: tree.root(),
265            secret: encoder_secret(),
266        };
267
268        let (auth_sent, auth_recv) = proof
269            .verify_with_provider(
270                &HashProvider::default(),
271                &commitment,
272                transcript.sent(),
273                transcript.received(),
274            )
275            .unwrap();
276
277        let mut expected_auth_sent = RangeSet::default();
278        expected_auth_sent.union_mut(&idx_0.1);
279        expected_auth_sent.union_mut(&idx_1.1);
280
281        let mut expected_auth_recv = RangeSet::default();
282        expected_auth_recv.union_mut(&idx_2.1);
283        expected_auth_recv.union_mut(&idx_3.1);
284
285        assert_eq!(auth_sent, expected_auth_sent);
286        assert_eq!(auth_recv, expected_auth_recv);
287    }
288
289    #[test]
290    fn test_encoding_tree_proof_missing_leaf() {
291        let transcript = Transcript::new(POST_JSON, OK_JSON);
292
293        let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
294        let idx_1 = (Direction::Received, RangeSet::from(0..4));
295        let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
296
297        let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
298
299        let result = tree
300            .proof([&idx_0, &idx_1, &idx_2].into_iter())
301            .unwrap_err();
302        assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
303    }
304
305    #[test]
306    fn test_encoding_tree_out_of_bounds() {
307        let transcript = Transcript::new(POST_JSON, OK_JSON);
308
309        let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
310        let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
311
312        let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
313        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
314
315        let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
316        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
317    }
318
319    #[test]
320    fn test_encoding_tree_missing_encoding() {
321        let provider = encoding_provider(&[], &[]);
322
323        let result = EncodingTree::new(
324            &Blake3::default(),
325            [(Direction::Sent, RangeSet::from(0..8))].iter(),
326            &provider,
327        )
328        .unwrap_err();
329        assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
330    }
331}