1use std::{collections::HashSet, fmt};
4
5use rangeset::ToRangeSet;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 hash::HashAlgId,
10 transcript::{Direction, Idx, Transcript},
11};
12
13pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24#[non_exhaustive]
25pub enum TranscriptCommitmentKind {
26 Encoding,
28 Hash {
30 alg: HashAlgId,
32 },
33}
34
35impl fmt::Display for TranscriptCommitmentKind {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 Self::Encoding => f.write_str("encoding"),
39 Self::Hash { alg } => write!(f, "hash ({alg})"),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct TranscriptCommitConfig {
47 encoding_hash_alg: HashAlgId,
48 commits: Vec<((Direction, Idx), TranscriptCommitmentKind)>,
49}
50
51impl TranscriptCommitConfig {
52 pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder {
54 TranscriptCommitConfigBuilder::new(transcript)
55 }
56
57 pub fn encoding_hash_alg(&self) -> &HashAlgId {
59 &self.encoding_hash_alg
60 }
61
62 pub fn has_encoding(&self) -> bool {
64 self.commits
65 .iter()
66 .any(|(_, kind)| matches!(kind, TranscriptCommitmentKind::Encoding))
67 }
68
69 pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, Idx)> {
71 self.commits.iter().filter_map(|(idx, kind)| match kind {
72 TranscriptCommitmentKind::Encoding => Some(idx),
73 _ => None,
74 })
75 }
76
77 pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, Idx), &HashAlgId)> {
79 self.commits.iter().filter_map(|(idx, kind)| match kind {
80 TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
81 _ => None,
82 })
83 }
84}
85
86#[derive(Debug)]
91pub struct TranscriptCommitConfigBuilder<'a> {
92 transcript: &'a Transcript,
93 encoding_hash_alg: HashAlgId,
94 default_kind: TranscriptCommitmentKind,
95 commits: HashSet<((Direction, Idx), TranscriptCommitmentKind)>,
96}
97
98impl<'a> TranscriptCommitConfigBuilder<'a> {
99 pub fn new(transcript: &'a Transcript) -> Self {
101 Self {
102 transcript,
103 encoding_hash_alg: HashAlgId::BLAKE3,
104 default_kind: TranscriptCommitmentKind::Encoding,
105 commits: HashSet::default(),
106 }
107 }
108
109 pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
111 self.encoding_hash_alg = alg;
112 self
113 }
114
115 pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
117 self.default_kind = default_kind;
118 self
119 }
120
121 pub fn commit_with_kind(
129 &mut self,
130 ranges: &dyn ToRangeSet<usize>,
131 direction: Direction,
132 kind: TranscriptCommitmentKind,
133 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
134 let idx = Idx::new(ranges.to_range_set());
135
136 if idx.end() > self.transcript.len_of_direction(direction) {
137 return Err(TranscriptCommitConfigBuilderError::new(
138 ErrorKind::Index,
139 format!(
140 "range is out of bounds of the transcript ({}): {} > {}",
141 direction,
142 idx.end(),
143 self.transcript.len_of_direction(direction)
144 ),
145 ));
146 }
147
148 self.commits.insert(((direction, idx), kind));
149
150 Ok(self)
151 }
152
153 pub fn commit(
160 &mut self,
161 ranges: &dyn ToRangeSet<usize>,
162 direction: Direction,
163 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
164 self.commit_with_kind(ranges, direction, self.default_kind)
165 }
166
167 pub fn commit_sent(
173 &mut self,
174 ranges: &dyn ToRangeSet<usize>,
175 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
176 self.commit(ranges, Direction::Sent)
177 }
178
179 pub fn commit_recv(
185 &mut self,
186 ranges: &dyn ToRangeSet<usize>,
187 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
188 self.commit(ranges, Direction::Received)
189 }
190
191 pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
193 Ok(TranscriptCommitConfig {
194 encoding_hash_alg: self.encoding_hash_alg,
195 commits: Vec::from_iter(self.commits),
196 })
197 }
198}
199
200#[derive(Debug, thiserror::Error)]
202pub struct TranscriptCommitConfigBuilderError {
203 kind: ErrorKind,
204 source: Option<Box<dyn std::error::Error + Send + Sync>>,
205}
206
207impl TranscriptCommitConfigBuilderError {
208 fn new<E>(kind: ErrorKind, source: E) -> Self
209 where
210 E: Into<Box<dyn std::error::Error + Send + Sync>>,
211 {
212 Self {
213 kind,
214 source: Some(source.into()),
215 }
216 }
217}
218
219#[derive(Debug)]
220enum ErrorKind {
221 Index,
222}
223
224impl fmt::Display for TranscriptCommitConfigBuilderError {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self.kind {
227 ErrorKind::Index => f.write_str("index error")?,
228 }
229
230 if let Some(source) = &self.source {
231 write!(f, " caused by: {}", source)?;
232 }
233
234 Ok(())
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_range_out_of_bounds() {
244 let transcript = Transcript::new(
245 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
246 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
247 );
248 let mut builder = TranscriptCommitConfigBuilder::new(&transcript);
249
250 assert!(builder.commit_sent(&(10..15)).is_err());
251 assert!(builder.commit_recv(&(10..15)).is_err());
252 }
253}