1use std::{collections::HashSet, fmt};
4
5use rangeset::ToRangeSet;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 hash::{impl_domain_separator, HashAlgId},
10 transcript::{
11 encoding::{EncodingCommitment, EncodingTree},
12 hash::{PlaintextHash, PlaintextHashSecret},
13 Direction, Idx, Transcript,
14 },
15};
16
17pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[non_exhaustive]
29pub enum TranscriptCommitmentKind {
30 Encoding,
32 Hash {
34 alg: HashAlgId,
36 },
37}
38
39impl fmt::Display for TranscriptCommitmentKind {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Encoding => f.write_str("encoding"),
43 Self::Hash { alg } => write!(f, "hash ({alg})"),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[non_exhaustive]
51pub enum TranscriptCommitment {
52 Encoding(EncodingCommitment),
54 Hash(PlaintextHash),
56}
57
58impl_domain_separator!(TranscriptCommitment);
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62#[non_exhaustive]
63pub enum TranscriptSecret {
64 Encoding(EncodingTree),
66 Hash(PlaintextHashSecret),
68}
69
70impl_domain_separator!(TranscriptSecret);
71
72#[derive(Debug, Clone)]
74pub struct TranscriptCommitConfig {
75 encoding_hash_alg: HashAlgId,
76 has_encoding: bool,
77 has_hash: bool,
78 commits: Vec<((Direction, Idx), TranscriptCommitmentKind)>,
79}
80
81impl TranscriptCommitConfig {
82 pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder {
84 TranscriptCommitConfigBuilder::new(transcript)
85 }
86
87 pub fn encoding_hash_alg(&self) -> &HashAlgId {
89 &self.encoding_hash_alg
90 }
91
92 pub fn has_encoding(&self) -> bool {
94 self.has_encoding
95 }
96
97 pub fn has_hash(&self) -> bool {
99 self.has_hash
100 }
101
102 pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, Idx)> {
104 self.commits.iter().filter_map(|(idx, kind)| match kind {
105 TranscriptCommitmentKind::Encoding => Some(idx),
106 _ => None,
107 })
108 }
109
110 pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, Idx), &HashAlgId)> {
112 self.commits.iter().filter_map(|(idx, kind)| match kind {
113 TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
114 _ => None,
115 })
116 }
117
118 pub fn to_request(&self) -> TranscriptCommitRequest {
120 TranscriptCommitRequest {
121 encoding: self.has_encoding,
122 hash: self
123 .iter_hash()
124 .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
125 .collect(),
126 }
127 }
128}
129
130#[derive(Debug)]
135pub struct TranscriptCommitConfigBuilder<'a> {
136 transcript: &'a Transcript,
137 encoding_hash_alg: HashAlgId,
138 has_encoding: bool,
139 has_hash: bool,
140 default_kind: TranscriptCommitmentKind,
141 commits: HashSet<((Direction, Idx), TranscriptCommitmentKind)>,
142}
143
144impl<'a> TranscriptCommitConfigBuilder<'a> {
145 pub fn new(transcript: &'a Transcript) -> Self {
147 Self {
148 transcript,
149 encoding_hash_alg: HashAlgId::BLAKE3,
150 has_encoding: false,
151 has_hash: false,
152 default_kind: TranscriptCommitmentKind::Encoding,
153 commits: HashSet::default(),
154 }
155 }
156
157 pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
159 self.encoding_hash_alg = alg;
160 self
161 }
162
163 pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
165 self.default_kind = default_kind;
166 self
167 }
168
169 pub fn commit_with_kind(
177 &mut self,
178 ranges: &dyn ToRangeSet<usize>,
179 direction: Direction,
180 kind: TranscriptCommitmentKind,
181 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
182 let idx = Idx::new(ranges.to_range_set());
183
184 if idx.end() > self.transcript.len_of_direction(direction) {
185 return Err(TranscriptCommitConfigBuilderError::new(
186 ErrorKind::Index,
187 format!(
188 "range is out of bounds of the transcript ({}): {} > {}",
189 direction,
190 idx.end(),
191 self.transcript.len_of_direction(direction)
192 ),
193 ));
194 }
195
196 match kind {
197 TranscriptCommitmentKind::Encoding => self.has_encoding = true,
198 TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
199 }
200
201 self.commits.insert(((direction, idx), kind));
202
203 Ok(self)
204 }
205
206 pub fn commit(
213 &mut self,
214 ranges: &dyn ToRangeSet<usize>,
215 direction: Direction,
216 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
217 self.commit_with_kind(ranges, direction, self.default_kind)
218 }
219
220 pub fn commit_sent(
226 &mut self,
227 ranges: &dyn ToRangeSet<usize>,
228 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
229 self.commit(ranges, Direction::Sent)
230 }
231
232 pub fn commit_recv(
238 &mut self,
239 ranges: &dyn ToRangeSet<usize>,
240 ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
241 self.commit(ranges, Direction::Received)
242 }
243
244 pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
246 Ok(TranscriptCommitConfig {
247 encoding_hash_alg: self.encoding_hash_alg,
248 has_encoding: self.has_encoding,
249 has_hash: self.has_hash,
250 commits: Vec::from_iter(self.commits),
251 })
252 }
253}
254
255#[derive(Debug, thiserror::Error)]
257pub struct TranscriptCommitConfigBuilderError {
258 kind: ErrorKind,
259 source: Option<Box<dyn std::error::Error + Send + Sync>>,
260}
261
262impl TranscriptCommitConfigBuilderError {
263 fn new<E>(kind: ErrorKind, source: E) -> Self
264 where
265 E: Into<Box<dyn std::error::Error + Send + Sync>>,
266 {
267 Self {
268 kind,
269 source: Some(source.into()),
270 }
271 }
272}
273
274#[derive(Debug)]
275enum ErrorKind {
276 Index,
277}
278
279impl fmt::Display for TranscriptCommitConfigBuilderError {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 match self.kind {
282 ErrorKind::Index => f.write_str("index error")?,
283 }
284
285 if let Some(source) = &self.source {
286 write!(f, " caused by: {}", source)?;
287 }
288
289 Ok(())
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct TranscriptCommitRequest {
296 encoding: bool,
297 hash: Vec<(Direction, Idx, HashAlgId)>,
298}
299
300impl TranscriptCommitRequest {
301 pub fn encoding(&self) -> bool {
303 self.encoding
304 }
305
306 pub fn has_hash(&self) -> bool {
308 !self.hash.is_empty()
309 }
310
311 pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, Idx, HashAlgId)> {
313 self.hash.iter()
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_range_out_of_bounds() {
323 let transcript = Transcript::new(
324 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
325 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
326 );
327 let mut builder = TranscriptCommitConfigBuilder::new(&transcript);
328
329 assert!(builder.commit_sent(&(10..15)).is_err());
330 assert!(builder.commit_recv(&(10..15)).is_err());
331 }
332}