1use std::{collections::HashSet, fmt};
4
5use rangeset::ToRangeSet;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 hash::HashAlgId,
10 transcript::{
11 encoding::{EncodingCommitment, EncodingTree},
12 hash::{PlaintextHash, PlaintextHashSecret},
13 Direction, Idx, Transcript,
14 },
15};
16
17pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[non_exhaustive]
29pub enum TranscriptCommitmentKind {
30 Encoding,
32 Hash {
34 alg: HashAlgId,
36 },
37}
38
39impl fmt::Display for TranscriptCommitmentKind {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Encoding => f.write_str("encoding"),
43 Self::Hash { alg } => write!(f, "hash ({alg})"),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[non_exhaustive]
51pub enum TranscriptCommitment {
52 Encoding(EncodingCommitment),
54 Hash(PlaintextHash),
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60#[non_exhaustive]
61pub enum TranscriptSecret {
62 Encoding(EncodingTree),
64 Hash(PlaintextHashSecret),
66}
67
68#[derive(Debug, Clone)]
70pub struct TranscriptCommitConfig {
71 encoding_hash_alg: HashAlgId,
72 has_encoding: bool,
73 has_hash: bool,
74 commits: Vec<((Direction, Idx), TranscriptCommitmentKind)>,
75}
76
77impl TranscriptCommitConfig {
78 pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder {
80 TranscriptCommitConfigBuilder::new(transcript)
81 }
82
83 pub fn encoding_hash_alg(&self) -> &HashAlgId {
85 &self.encoding_hash_alg
86 }
87
88 pub fn has_encoding(&self) -> bool {
90 self.has_encoding
91 }
92
93 pub fn has_hash(&self) -> bool {
95 self.has_hash
96 }
97
98 pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, Idx)> {
100 self.commits.iter().filter_map(|(idx, kind)| match kind {
101 TranscriptCommitmentKind::Encoding => Some(idx),
102 _ => None,
103 })
104 }
105
106 pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, Idx), &HashAlgId)> {
108 self.commits.iter().filter_map(|(idx, kind)| match kind {
109 TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
110 _ => None,
111 })
112 }
113
114 pub fn to_request(&self) -> TranscriptCommitRequest {
116 TranscriptCommitRequest {
117 encoding: self.has_encoding,
118 hash: self
119 .iter_hash()
120 .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
121 .collect(),
122 }
123 }
124}
125
126#[derive(Debug)]
131pub struct TranscriptCommitConfigBuilder<'a> {
132 transcript: &'a Transcript,
133 encoding_hash_alg: HashAlgId,
134 has_encoding: bool,
135 has_hash: bool,
136 default_kind: TranscriptCommitmentKind,
137 commits: HashSet<((Direction, Idx), TranscriptCommitmentKind)>,
138}
139
140impl<'a> TranscriptCommitConfigBuilder<'a> {
141 pub fn new(transcript: &'a Transcript) -> Self {
143 Self {
144 transcript,
145 encoding_hash_alg: HashAlgId::BLAKE3,
146 has_encoding: false,
147 has_hash: false,
148 default_kind: TranscriptCommitmentKind::Encoding,
149 commits: HashSet::default(),
150 }
151 }
152
153 pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
155 self.encoding_hash_alg = alg;
156 self
157 }
158
159 pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
161 self.default_kind = default_kind;
162 self
163 }
164
165 pub fn commit_with_kind(
173 &mut self,
174 ranges: &dyn ToRangeSet<usize>,
175 direction: Direction,
176 kind: TranscriptCommitmentKind,
177 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
178 let idx = Idx::new(ranges.to_range_set());
179
180 if idx.end() > self.transcript.len_of_direction(direction) {
181 return Err(TranscriptCommitConfigBuilderError::new(
182 ErrorKind::Index,
183 format!(
184 "range is out of bounds of the transcript ({}): {} > {}",
185 direction,
186 idx.end(),
187 self.transcript.len_of_direction(direction)
188 ),
189 ));
190 }
191
192 match kind {
193 TranscriptCommitmentKind::Encoding => self.has_encoding = true,
194 TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
195 }
196
197 self.commits.insert(((direction, idx), kind));
198
199 Ok(self)
200 }
201
202 pub fn commit(
209 &mut self,
210 ranges: &dyn ToRangeSet<usize>,
211 direction: Direction,
212 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
213 self.commit_with_kind(ranges, direction, self.default_kind)
214 }
215
216 pub fn commit_sent(
222 &mut self,
223 ranges: &dyn ToRangeSet<usize>,
224 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
225 self.commit(ranges, Direction::Sent)
226 }
227
228 pub fn commit_recv(
234 &mut self,
235 ranges: &dyn ToRangeSet<usize>,
236 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
237 self.commit(ranges, Direction::Received)
238 }
239
240 pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
242 Ok(TranscriptCommitConfig {
243 encoding_hash_alg: self.encoding_hash_alg,
244 has_encoding: self.has_encoding,
245 has_hash: self.has_hash,
246 commits: Vec::from_iter(self.commits),
247 })
248 }
249}
250
251#[derive(Debug, thiserror::Error)]
253pub struct TranscriptCommitConfigBuilderError {
254 kind: ErrorKind,
255 source: Option<Box<dyn std::error::Error + Send + Sync>>,
256}
257
258impl TranscriptCommitConfigBuilderError {
259 fn new<E>(kind: ErrorKind, source: E) -> Self
260 where
261 E: Into<Box<dyn std::error::Error + Send + Sync>>,
262 {
263 Self {
264 kind,
265 source: Some(source.into()),
266 }
267 }
268}
269
270#[derive(Debug)]
271enum ErrorKind {
272 Index,
273}
274
275impl fmt::Display for TranscriptCommitConfigBuilderError {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 match self.kind {
278 ErrorKind::Index => f.write_str("index error")?,
279 }
280
281 if let Some(source) = &self.source {
282 write!(f, " caused by: {source}")?;
283 }
284
285 Ok(())
286 }
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
291pub struct TranscriptCommitRequest {
292 encoding: bool,
293 hash: Vec<(Direction, Idx, HashAlgId)>,
294}
295
296impl TranscriptCommitRequest {
297 pub fn encoding(&self) -> bool {
299 self.encoding
300 }
301
302 pub fn has_hash(&self) -> bool {
304 !self.hash.is_empty()
305 }
306
307 pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, Idx, HashAlgId)> {
309 self.hash.iter()
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_range_out_of_bounds() {
319 let transcript = Transcript::new(
320 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
321 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
322 );
323 let mut builder = TranscriptCommitConfigBuilder::new(&transcript);
324
325 assert!(builder.commit_sent(&(10..15)).is_err());
326 assert!(builder.commit_recv(&(10..15)).is_err());
327 }
328}