tlsn_core/transcript/encoding/
tree.rs1use std::collections::HashMap;
2
3use bimap::BiMap;
4use rangeset::{RangeSet, UnionMut};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
9 merkle::MerkleTree,
10 transcript::{
11 encoding::{
12 proof::{EncodingProof, Opening},
13 EncodingProvider,
14 },
15 Direction,
16 },
17};
18
19#[derive(Debug, thiserror::Error)]
21pub enum EncodingTreeError {
22 #[error("index is out of bounds of the transcript")]
24 OutOfBounds {
25 index: RangeSet<usize>,
27 transcript_length: usize,
29 },
30 #[error("encoding provider is missing an encoding for an index")]
32 MissingEncoding {
33 index: RangeSet<usize>,
35 },
36 #[error("index is missing from the tree")]
38 MissingLeaf {
39 index: RangeSet<usize>,
41 },
42}
43
44#[derive(Clone, Serialize, Deserialize)]
46pub struct EncodingTree {
47 tree: MerkleTree,
49 blinders: Vec<Blinder>,
51 idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
54 sent_idx: RangeSet<usize>,
56 received_idx: RangeSet<usize>,
58}
59
60opaque_debug::implement!(EncodingTree);
61
62impl EncodingTree {
63 pub fn new<'idx>(
71 hasher: &dyn HashAlgorithm,
72 idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
73 provider: &dyn EncodingProvider,
74 ) -> Result<Self, EncodingTreeError> {
75 let mut this = Self {
76 tree: MerkleTree::new(hasher.id()),
77 blinders: Vec::new(),
78 idxs: BiMap::new(),
79 sent_idx: RangeSet::default(),
80 received_idx: RangeSet::default(),
81 };
82
83 let mut leaves = Vec::new();
84 let mut encoding = Vec::new();
85 for dir_idx in idxs {
86 let direction = dir_idx.0;
87 let idx = &dir_idx.1;
88
89 if idx.is_empty() {
91 continue;
92 }
93
94 if this.idxs.contains_right(dir_idx) {
95 continue;
97 }
98
99 let blinder: Blinder = rand::random();
100
101 encoding.clear();
102 for range in idx.iter_ranges() {
103 provider
104 .provide_encoding(direction, range, &mut encoding)
105 .map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
106 }
107 encoding.extend_from_slice(blinder.as_bytes());
108
109 let leaf = hasher.hash(&encoding);
110
111 leaves.push(leaf);
112 this.blinders.push(blinder);
113 this.idxs.insert(this.idxs.len(), dir_idx.clone());
114 match direction {
115 Direction::Sent => this.sent_idx.union_mut(idx),
116 Direction::Received => this.received_idx.union_mut(idx),
117 }
118 }
119
120 this.tree.insert(hasher, leaves);
121
122 Ok(this)
123 }
124
125 pub fn root(&self) -> TypedHash {
127 self.tree.root()
128 }
129
130 pub fn algorithm(&self) -> HashAlgId {
132 self.tree.algorithm()
133 }
134
135 pub fn proof<'idx>(
141 &self,
142 idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
143 ) -> Result<EncodingProof, EncodingTreeError> {
144 let mut openings = HashMap::new();
145 for dir_idx in idxs {
146 let direction = dir_idx.0;
147 let idx = &dir_idx.1;
148
149 let leaf_idx = *self
150 .idxs
151 .get_by_right(dir_idx)
152 .ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
153 let blinder = self.blinders[leaf_idx].clone();
154
155 openings.insert(
156 leaf_idx,
157 Opening {
158 direction,
159 idx: idx.clone(),
160 blinder,
161 },
162 );
163 }
164
165 let mut indices = openings.keys().copied().collect::<Vec<_>>();
166 indices.sort();
167
168 Ok(EncodingProof {
169 inclusion_proof: self.tree.proof(&indices),
170 openings,
171 })
172 }
173
174 pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
176 self.idxs.contains_right(idx)
177 }
178
179 pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
180 match direction {
181 Direction::Sent => &self.sent_idx,
182 Direction::Received => &self.received_idx,
183 }
184 }
185
186 pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
188 self.idxs.right_values()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::{
196 fixtures::{encoder_secret, encoding_provider},
197 hash::{Blake3, HashProvider},
198 transcript::{encoding::EncodingCommitment, Transcript},
199 };
200 use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
201
202 fn new_tree<'seq>(
203 transcript: &Transcript,
204 idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
205 ) -> Result<EncodingTree, EncodingTreeError> {
206 let provider = encoding_provider(transcript.sent(), transcript.received());
207
208 EncodingTree::new(&Blake3::default(), idxs, &provider)
209 }
210
211 #[test]
212 fn test_encoding_tree() {
213 let transcript = Transcript::new(POST_JSON, OK_JSON);
214
215 let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
216 let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
217
218 let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
219
220 assert!(tree.contains(&idx_0));
221 assert!(tree.contains(&idx_1));
222
223 let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
224
225 let commitment = EncodingCommitment {
226 root: tree.root(),
227 secret: encoder_secret(),
228 };
229
230 let (auth_sent, auth_recv) = proof
231 .verify_with_provider(
232 &HashProvider::default(),
233 &commitment,
234 transcript.sent(),
235 transcript.received(),
236 )
237 .unwrap();
238
239 assert_eq!(auth_sent, idx_0.1);
240 assert_eq!(auth_recv, idx_1.1);
241 }
242
243 #[test]
244 fn test_encoding_tree_multiple_ranges() {
245 let transcript = Transcript::new(POST_JSON, OK_JSON);
246
247 let idx_0 = (Direction::Sent, RangeSet::from(0..1));
248 let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
249 let idx_2 = (Direction::Received, RangeSet::from(0..1));
250 let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
251
252 let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
253
254 assert!(tree.contains(&idx_0));
255 assert!(tree.contains(&idx_1));
256 assert!(tree.contains(&idx_2));
257 assert!(tree.contains(&idx_3));
258
259 let proof = tree
260 .proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
261 .unwrap();
262
263 let commitment = EncodingCommitment {
264 root: tree.root(),
265 secret: encoder_secret(),
266 };
267
268 let (auth_sent, auth_recv) = proof
269 .verify_with_provider(
270 &HashProvider::default(),
271 &commitment,
272 transcript.sent(),
273 transcript.received(),
274 )
275 .unwrap();
276
277 let mut expected_auth_sent = RangeSet::default();
278 expected_auth_sent.union_mut(&idx_0.1);
279 expected_auth_sent.union_mut(&idx_1.1);
280
281 let mut expected_auth_recv = RangeSet::default();
282 expected_auth_recv.union_mut(&idx_2.1);
283 expected_auth_recv.union_mut(&idx_3.1);
284
285 assert_eq!(auth_sent, expected_auth_sent);
286 assert_eq!(auth_recv, expected_auth_recv);
287 }
288
289 #[test]
290 fn test_encoding_tree_proof_missing_leaf() {
291 let transcript = Transcript::new(POST_JSON, OK_JSON);
292
293 let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
294 let idx_1 = (Direction::Received, RangeSet::from(0..4));
295 let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
296
297 let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
298
299 let result = tree
300 .proof([&idx_0, &idx_1, &idx_2].into_iter())
301 .unwrap_err();
302 assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
303 }
304
305 #[test]
306 fn test_encoding_tree_out_of_bounds() {
307 let transcript = Transcript::new(POST_JSON, OK_JSON);
308
309 let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
310 let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
311
312 let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
313 assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
314
315 let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
316 assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
317 }
318
319 #[test]
320 fn test_encoding_tree_missing_encoding() {
321 let provider = encoding_provider(&[], &[]);
322
323 let result = EncodingTree::new(
324 &Blake3::default(),
325 [(Direction::Sent, RangeSet::from(0..8))].iter(),
326 &provider,
327 )
328 .unwrap_err();
329 assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
330 }
331}