1use std::error::Error;
2
3use rand::{rng, Rng};
4
5use crate::{
6 attestation::{
7 Attestation, AttestationConfig, Body, Extension, FieldId, Header, ServerCertCommitment,
8 VERSION,
9 },
10 connection::{ConnectionInfo, ServerEphemKey},
11 hash::HashAlgId,
12 request::Request,
13 serialize::CanonicalSerialize,
14 signing::SignatureAlgId,
15 transcript::TranscriptCommitment,
16 CryptoProvider,
17};
18
19#[derive(Debug)]
21pub struct Accept {}
22
23#[derive(Debug)]
24pub struct Sign {
25 signature_alg: SignatureAlgId,
26 hash_alg: HashAlgId,
27 connection_info: Option<ConnectionInfo>,
28 server_ephemeral_key: Option<ServerEphemKey>,
29 cert_commitment: ServerCertCommitment,
30 extensions: Vec<Extension>,
31 transcript_commitments: Vec<TranscriptCommitment>,
32}
33
34#[derive(Debug)]
36pub struct AttestationBuilder<'a, T = Accept> {
37 config: &'a AttestationConfig,
38 state: T,
39}
40
41impl<'a> AttestationBuilder<'a, Accept> {
42 pub fn new(config: &'a AttestationConfig) -> Self {
44 Self {
45 config,
46 state: Accept {},
47 }
48 }
49
50 pub fn accept_request(
52 self,
53 request: Request,
54 ) -> Result<AttestationBuilder<'a, Sign>, AttestationBuilderError> {
55 let config = self.config;
56
57 let Request {
58 signature_alg,
59 hash_alg,
60 server_cert_commitment: cert_commitment,
61 extensions,
62 } = request;
63
64 if !config.supported_signature_algs().contains(&signature_alg) {
65 return Err(AttestationBuilderError::new(
66 ErrorKind::Request,
67 format!("unsupported signature algorithm: {signature_alg}"),
68 ));
69 }
70
71 if !config.supported_hash_algs().contains(&hash_alg) {
72 return Err(AttestationBuilderError::new(
73 ErrorKind::Request,
74 format!("unsupported hash algorithm: {hash_alg}"),
75 ));
76 }
77
78 if let Some(validator) = config.extension_validator() {
79 validator(&extensions)
80 .map_err(|err| AttestationBuilderError::new(ErrorKind::Extension, err))?;
81 }
82
83 Ok(AttestationBuilder {
84 config: self.config,
85 state: Sign {
86 signature_alg,
87 hash_alg,
88 connection_info: None,
89 server_ephemeral_key: None,
90 cert_commitment,
91 transcript_commitments: Vec::new(),
92 extensions,
93 },
94 })
95 }
96}
97
98impl AttestationBuilder<'_, Sign> {
99 pub fn connection_info(&mut self, connection_info: ConnectionInfo) -> &mut Self {
101 self.state.connection_info = Some(connection_info);
102 self
103 }
104
105 pub fn server_ephemeral_key(&mut self, key: ServerEphemKey) -> &mut Self {
107 self.state.server_ephemeral_key = Some(key);
108 self
109 }
110
111 pub fn extension(&mut self, extension: Extension) -> &mut Self {
113 self.state.extensions.push(extension);
114 self
115 }
116
117 pub fn transcript_commitments(
119 &mut self,
120 transcript_commitments: Vec<TranscriptCommitment>,
121 ) -> &mut Self {
122 self.state.transcript_commitments = transcript_commitments;
123 self
124 }
125
126 pub fn build(self, provider: &CryptoProvider) -> Result<Attestation, AttestationBuilderError> {
128 let Sign {
129 signature_alg,
130 hash_alg,
131 connection_info,
132 server_ephemeral_key,
133 cert_commitment,
134 extensions,
135 transcript_commitments,
136 } = self.state;
137
138 let hasher = provider.hash.get(&hash_alg).map_err(|_| {
139 AttestationBuilderError::new(
140 ErrorKind::Config,
141 format!("accepted hash algorithm {hash_alg} but it's missing in the provider"),
142 )
143 })?;
144 let signer = provider.signer.get(&signature_alg).map_err(|_| {
145 AttestationBuilderError::new(
146 ErrorKind::Config,
147 format!(
148 "accepted signature algorithm {signature_alg} but it's missing in the provider"
149 ),
150 )
151 })?;
152
153 let mut field_id = FieldId::default();
154
155 let body = Body {
156 verifying_key: field_id.next(signer.verifying_key()),
157 connection_info: field_id.next(connection_info.ok_or_else(|| {
158 AttestationBuilderError::new(ErrorKind::Field, "connection info was not set")
159 })?),
160 server_ephemeral_key: field_id.next(server_ephemeral_key.ok_or_else(|| {
161 AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set")
162 })?),
163 cert_commitment: field_id.next(cert_commitment),
164 extensions: extensions
165 .into_iter()
166 .map(|extension| field_id.next(extension))
167 .collect(),
168 transcript_commitments: transcript_commitments
169 .into_iter()
170 .map(|commitment| field_id.next(commitment))
171 .collect(),
172 };
173
174 let header = Header {
175 id: rng().random(),
176 version: VERSION,
177 root: body.root(hasher),
178 };
179
180 let signature = signer
181 .sign(&CanonicalSerialize::serialize(&header))
182 .map_err(|err| AttestationBuilderError::new(ErrorKind::Signature, err))?;
183
184 Ok(Attestation {
185 signature,
186 header,
187 body,
188 })
189 }
190}
191
192#[derive(Debug, thiserror::Error)]
194pub struct AttestationBuilderError {
195 kind: ErrorKind,
196 source: Option<Box<dyn Error + Send + Sync + 'static>>,
197}
198
199#[derive(Debug)]
200enum ErrorKind {
201 Request,
202 Config,
203 Field,
204 Signature,
205 Extension,
206}
207
208impl AttestationBuilderError {
209 fn new<E>(kind: ErrorKind, error: E) -> Self
210 where
211 E: Into<Box<dyn Error + Send + Sync + 'static>>,
212 {
213 Self {
214 kind,
215 source: Some(error.into()),
216 }
217 }
218
219 pub fn is_request(&self) -> bool {
221 matches!(self.kind, ErrorKind::Request)
222 }
223}
224
225impl std::fmt::Display for AttestationBuilderError {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 match self.kind {
228 ErrorKind::Request => f.write_str("request error")?,
229 ErrorKind::Config => f.write_str("config error")?,
230 ErrorKind::Field => f.write_str("field error")?,
231 ErrorKind::Signature => f.write_str("signature error")?,
232 ErrorKind::Extension => f.write_str("extension error")?,
233 }
234
235 if let Some(source) = &self.source {
236 write!(f, " caused by: {source}")?;
237 }
238
239 Ok(())
240 }
241}
242
243#[cfg(test)]
244mod test {
245 use rstest::{fixture, rstest};
246 use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
247
248 use crate::{
249 connection::{HandshakeData, HandshakeDataV1_2},
250 fixtures::{encoding_provider, request_fixture, ConnectionFixture, RequestFixture},
251 hash::Blake3,
252 transcript::Transcript,
253 };
254
255 use super::*;
256
257 #[fixture]
258 #[once]
259 fn attestation_config() -> AttestationConfig {
260 AttestationConfig::builder()
261 .supported_signature_algs([SignatureAlgId::SECP256K1])
262 .build()
263 .unwrap()
264 }
265
266 #[fixture]
267 #[once]
268 fn crypto_provider() -> CryptoProvider {
269 let mut provider = CryptoProvider::default();
270 provider.signer.set_secp256k1(&[42u8; 32]).unwrap();
271 provider
272 }
273
274 #[rstest]
275 fn test_attestation_builder_accept_unsupported_signer() {
276 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
277 let connection = ConnectionFixture::tlsnotary(transcript.length());
278
279 let RequestFixture { request, .. } = request_fixture(
280 transcript,
281 encoding_provider(GET_WITH_HEADER, OK_JSON),
282 connection,
283 Blake3::default(),
284 Vec::new(),
285 );
286
287 let attestation_config = AttestationConfig::builder()
288 .supported_signature_algs([SignatureAlgId::SECP256R1])
289 .build()
290 .unwrap();
291
292 let err = Attestation::builder(&attestation_config)
293 .accept_request(request)
294 .err()
295 .unwrap();
296 assert!(err.is_request());
297 }
298
299 #[rstest]
300 fn test_attestation_builder_accept_unsupported_hasher() {
301 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
302 let connection = ConnectionFixture::tlsnotary(transcript.length());
303
304 let RequestFixture { request, .. } = request_fixture(
305 transcript,
306 encoding_provider(GET_WITH_HEADER, OK_JSON),
307 connection,
308 Blake3::default(),
309 Vec::new(),
310 );
311
312 let attestation_config = AttestationConfig::builder()
313 .supported_signature_algs([SignatureAlgId::SECP256K1])
314 .supported_hash_algs([HashAlgId::KECCAK256])
315 .build()
316 .unwrap();
317
318 let err = Attestation::builder(&attestation_config)
319 .accept_request(request)
320 .err()
321 .unwrap();
322 assert!(err.is_request());
323 }
324
325 #[rstest]
326 fn test_attestation_builder_sign_missing_signer(attestation_config: &AttestationConfig) {
327 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
328 let connection = ConnectionFixture::tlsnotary(transcript.length());
329
330 let RequestFixture { request, .. } = request_fixture(
331 transcript,
332 encoding_provider(GET_WITH_HEADER, OK_JSON),
333 connection,
334 Blake3::default(),
335 Vec::new(),
336 );
337
338 let attestation_builder = Attestation::builder(attestation_config)
339 .accept_request(request)
340 .unwrap();
341
342 let mut provider = CryptoProvider::default();
343 provider.signer.set_secp256r1(&[42u8; 32]).unwrap();
344
345 let err = attestation_builder.build(&provider).unwrap_err();
346 assert!(matches!(err.kind, ErrorKind::Config));
347 }
348
349 #[rstest]
350 fn test_attestation_builder_sign_missing_server_ephemeral_key(
351 attestation_config: &AttestationConfig,
352 crypto_provider: &CryptoProvider,
353 ) {
354 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
355 let connection = ConnectionFixture::tlsnotary(transcript.length());
356
357 let RequestFixture { request, .. } = request_fixture(
358 transcript,
359 encoding_provider(GET_WITH_HEADER, OK_JSON),
360 connection.clone(),
361 Blake3::default(),
362 Vec::new(),
363 );
364
365 let mut attestation_builder = Attestation::builder(attestation_config)
366 .accept_request(request)
367 .unwrap();
368
369 let ConnectionFixture {
370 connection_info, ..
371 } = connection;
372
373 attestation_builder.connection_info(connection_info);
374
375 let err = attestation_builder.build(crypto_provider).unwrap_err();
376 assert!(matches!(err.kind, ErrorKind::Field));
377 }
378
379 #[rstest]
380 fn test_attestation_builder_sign_missing_connection_info(
381 attestation_config: &AttestationConfig,
382 crypto_provider: &CryptoProvider,
383 ) {
384 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
385 let connection = ConnectionFixture::tlsnotary(transcript.length());
386
387 let RequestFixture { request, .. } = request_fixture(
388 transcript,
389 encoding_provider(GET_WITH_HEADER, OK_JSON),
390 connection.clone(),
391 Blake3::default(),
392 Vec::new(),
393 );
394
395 let mut attestation_builder = Attestation::builder(attestation_config)
396 .accept_request(request)
397 .unwrap();
398
399 let ConnectionFixture {
400 server_cert_data, ..
401 } = connection;
402
403 let HandshakeData::V1_2(HandshakeDataV1_2 {
404 server_ephemeral_key,
405 ..
406 }) = server_cert_data.handshake;
407
408 attestation_builder.server_ephemeral_key(server_ephemeral_key);
409
410 let err = attestation_builder.build(crypto_provider).unwrap_err();
411 assert!(matches!(err.kind, ErrorKind::Field));
412 }
413
414 #[rstest]
415 fn test_attestation_builder_reject_extensions_by_default(
416 attestation_config: &AttestationConfig,
417 ) {
418 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
419 let connection = ConnectionFixture::tlsnotary(transcript.length());
420
421 let RequestFixture { request, .. } = request_fixture(
422 transcript,
423 encoding_provider(GET_WITH_HEADER, OK_JSON),
424 connection.clone(),
425 Blake3::default(),
426 vec![Extension {
427 id: b"foo".to_vec(),
428 value: b"bar".to_vec(),
429 }],
430 );
431
432 let err = Attestation::builder(attestation_config)
433 .accept_request(request)
434 .unwrap_err();
435
436 assert!(matches!(err.kind, ErrorKind::Extension));
437 }
438
439 #[rstest]
440 fn test_attestation_builder_accept_extension(crypto_provider: &CryptoProvider) {
441 let attestation_config = AttestationConfig::builder()
442 .supported_signature_algs([SignatureAlgId::SECP256K1])
443 .extension_validator(|_| Ok(()))
444 .build()
445 .unwrap();
446
447 let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
448 let connection = ConnectionFixture::tlsnotary(transcript.length());
449
450 let RequestFixture { request, .. } = request_fixture(
451 transcript,
452 encoding_provider(GET_WITH_HEADER, OK_JSON),
453 connection.clone(),
454 Blake3::default(),
455 vec![Extension {
456 id: b"foo".to_vec(),
457 value: b"bar".to_vec(),
458 }],
459 );
460
461 let mut attestation_builder = Attestation::builder(&attestation_config)
462 .accept_request(request)
463 .unwrap();
464
465 let ConnectionFixture {
466 server_cert_data,
467 connection_info,
468 ..
469 } = connection;
470
471 let HandshakeData::V1_2(HandshakeDataV1_2 {
472 server_ephemeral_key,
473 ..
474 }) = server_cert_data.handshake;
475
476 attestation_builder
477 .connection_info(connection_info)
478 .server_ephemeral_key(server_ephemeral_key);
479
480 let attestation = attestation_builder.build(crypto_provider).unwrap();
481
482 assert_eq!(attestation.body.extensions().count(), 1);
483 }
484}