1use std::{collections::HashSet, fmt};
4
5use rangeset::{ToRangeSet, UnionMut};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 hash::HashAlgId,
10 transcript::{
11 encoding::{EncodingCommitment, EncodingTree},
12 hash::{PlaintextHash, PlaintextHashSecret},
13 Direction, RangeSet, Transcript,
14 },
15};
16
17pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[non_exhaustive]
29pub enum TranscriptCommitmentKind {
30 Encoding,
32 Hash {
34 alg: HashAlgId,
36 },
37}
38
39impl fmt::Display for TranscriptCommitmentKind {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Encoding => f.write_str("encoding"),
43 Self::Hash { alg } => write!(f, "hash ({alg})"),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[non_exhaustive]
51pub enum TranscriptCommitment {
52 Encoding(EncodingCommitment),
54 Hash(PlaintextHash),
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60#[non_exhaustive]
61pub enum TranscriptSecret {
62 Encoding(EncodingTree),
64 Hash(PlaintextHashSecret),
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TranscriptCommitConfig {
71 encoding_hash_alg: HashAlgId,
72 has_encoding: bool,
73 has_hash: bool,
74 commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
75}
76
77impl TranscriptCommitConfig {
78 pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder<'_> {
80 TranscriptCommitConfigBuilder::new(transcript)
81 }
82
83 pub fn encoding_hash_alg(&self) -> &HashAlgId {
85 &self.encoding_hash_alg
86 }
87
88 pub fn has_encoding(&self) -> bool {
90 self.has_encoding
91 }
92
93 pub fn has_hash(&self) -> bool {
95 self.has_hash
96 }
97
98 pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
100 self.commits.iter().filter_map(|(idx, kind)| match kind {
101 TranscriptCommitmentKind::Encoding => Some(idx),
102 _ => None,
103 })
104 }
105
106 pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, RangeSet<usize>), &HashAlgId)> {
108 self.commits.iter().filter_map(|(idx, kind)| match kind {
109 TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
110 _ => None,
111 })
112 }
113
114 pub fn to_request(&self) -> TranscriptCommitRequest {
116 TranscriptCommitRequest {
117 encoding: self.has_encoding.then(|| {
118 let mut sent = RangeSet::default();
119 let mut recv = RangeSet::default();
120
121 for (dir, idx) in self.iter_encoding() {
122 match dir {
123 Direction::Sent => sent.union_mut(idx),
124 Direction::Received => recv.union_mut(idx),
125 }
126 }
127
128 (sent, recv)
129 }),
130 hash: self
131 .iter_hash()
132 .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
133 .collect(),
134 }
135 }
136}
137
138#[derive(Debug)]
143pub struct TranscriptCommitConfigBuilder<'a> {
144 transcript: &'a Transcript,
145 encoding_hash_alg: HashAlgId,
146 has_encoding: bool,
147 has_hash: bool,
148 default_kind: TranscriptCommitmentKind,
149 commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
150}
151
152impl<'a> TranscriptCommitConfigBuilder<'a> {
153 pub fn new(transcript: &'a Transcript) -> Self {
155 Self {
156 transcript,
157 encoding_hash_alg: HashAlgId::BLAKE3,
158 has_encoding: false,
159 has_hash: false,
160 default_kind: TranscriptCommitmentKind::Encoding,
161 commits: HashSet::default(),
162 }
163 }
164
165 pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
167 self.encoding_hash_alg = alg;
168 self
169 }
170
171 pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
173 self.default_kind = default_kind;
174 self
175 }
176
177 pub fn commit_with_kind(
185 &mut self,
186 ranges: &dyn ToRangeSet<usize>,
187 direction: Direction,
188 kind: TranscriptCommitmentKind,
189 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
190 let idx = ranges.to_range_set();
191
192 if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
193 return Err(TranscriptCommitConfigBuilderError::new(
194 ErrorKind::Index,
195 format!(
196 "range is out of bounds of the transcript ({}): {} > {}",
197 direction,
198 idx.end().unwrap_or(0),
199 self.transcript.len_of_direction(direction)
200 ),
201 ));
202 }
203
204 match kind {
205 TranscriptCommitmentKind::Encoding => self.has_encoding = true,
206 TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
207 }
208
209 self.commits.insert(((direction, idx), kind));
210
211 Ok(self)
212 }
213
214 pub fn commit(
221 &mut self,
222 ranges: &dyn ToRangeSet<usize>,
223 direction: Direction,
224 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
225 self.commit_with_kind(ranges, direction, self.default_kind)
226 }
227
228 pub fn commit_sent(
234 &mut self,
235 ranges: &dyn ToRangeSet<usize>,
236 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
237 self.commit(ranges, Direction::Sent)
238 }
239
240 pub fn commit_recv(
246 &mut self,
247 ranges: &dyn ToRangeSet<usize>,
248 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
249 self.commit(ranges, Direction::Received)
250 }
251
252 pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
254 Ok(TranscriptCommitConfig {
255 encoding_hash_alg: self.encoding_hash_alg,
256 has_encoding: self.has_encoding,
257 has_hash: self.has_hash,
258 commits: Vec::from_iter(self.commits),
259 })
260 }
261}
262
263#[derive(Debug, thiserror::Error)]
265pub struct TranscriptCommitConfigBuilderError {
266 kind: ErrorKind,
267 source: Option<Box<dyn std::error::Error + Send + Sync>>,
268}
269
270impl TranscriptCommitConfigBuilderError {
271 fn new<E>(kind: ErrorKind, source: E) -> Self
272 where
273 E: Into<Box<dyn std::error::Error + Send + Sync>>,
274 {
275 Self {
276 kind,
277 source: Some(source.into()),
278 }
279 }
280}
281
282#[derive(Debug)]
283enum ErrorKind {
284 Index,
285}
286
287impl fmt::Display for TranscriptCommitConfigBuilderError {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self.kind {
290 ErrorKind::Index => f.write_str("index error")?,
291 }
292
293 if let Some(source) = &self.source {
294 write!(f, " caused by: {source}")?;
295 }
296
297 Ok(())
298 }
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct TranscriptCommitRequest {
304 encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
305 hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
306}
307
308impl TranscriptCommitRequest {
309 pub fn has_encoding(&self) -> bool {
311 self.encoding.is_some()
312 }
313
314 pub fn has_hash(&self) -> bool {
316 !self.hash.is_empty()
317 }
318
319 pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
321 self.hash.iter()
322 }
323
324 pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
326 self.encoding.as_ref()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_range_out_of_bounds() {
336 let transcript = Transcript::new(
337 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
338 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
339 );
340 let mut builder = TranscriptCommitConfigBuilder::new(&transcript);
341
342 assert!(builder.commit_sent(&(10..15)).is_err());
343 assert!(builder.commit_recv(&(10..15)).is_err());
344 }
345}