tlsn_core/attestation/
builder.rs

1use std::error::Error;
2
3use rand::{rng, Rng};
4
5use crate::{
6    attestation::{
7        Attestation, AttestationConfig, Body, Extension, FieldId, Header, ServerCertCommitment,
8        VERSION,
9    },
10    connection::{ConnectionInfo, ServerEphemKey},
11    hash::HashAlgId,
12    request::Request,
13    serialize::CanonicalSerialize,
14    signing::SignatureAlgId,
15    transcript::TranscriptCommitment,
16    CryptoProvider,
17};
18
19/// Attestation builder state for accepting a request.
20#[derive(Debug)]
21pub struct Accept {}
22
23#[derive(Debug)]
24pub struct Sign {
25    signature_alg: SignatureAlgId,
26    hash_alg: HashAlgId,
27    connection_info: Option<ConnectionInfo>,
28    server_ephemeral_key: Option<ServerEphemKey>,
29    cert_commitment: ServerCertCommitment,
30    extensions: Vec<Extension>,
31    transcript_commitments: Vec<TranscriptCommitment>,
32}
33
34/// An attestation builder.
35#[derive(Debug)]
36pub struct AttestationBuilder<'a, T = Accept> {
37    config: &'a AttestationConfig,
38    state: T,
39}
40
41impl<'a> AttestationBuilder<'a, Accept> {
42    /// Creates a new attestation builder.
43    pub fn new(config: &'a AttestationConfig) -> Self {
44        Self {
45            config,
46            state: Accept {},
47        }
48    }
49
50    /// Accepts the attestation request.
51    pub fn accept_request(
52        self,
53        request: Request,
54    ) -> Result<AttestationBuilder<'a, Sign>, AttestationBuilderError> {
55        let config = self.config;
56
57        let Request {
58            signature_alg,
59            hash_alg,
60            server_cert_commitment: cert_commitment,
61            extensions,
62        } = request;
63
64        if !config.supported_signature_algs().contains(&signature_alg) {
65            return Err(AttestationBuilderError::new(
66                ErrorKind::Request,
67                format!("unsupported signature algorithm: {signature_alg}"),
68            ));
69        }
70
71        if !config.supported_hash_algs().contains(&hash_alg) {
72            return Err(AttestationBuilderError::new(
73                ErrorKind::Request,
74                format!("unsupported hash algorithm: {hash_alg}"),
75            ));
76        }
77
78        if let Some(validator) = config.extension_validator() {
79            validator(&extensions)
80                .map_err(|err| AttestationBuilderError::new(ErrorKind::Extension, err))?;
81        }
82
83        Ok(AttestationBuilder {
84            config: self.config,
85            state: Sign {
86                signature_alg,
87                hash_alg,
88                connection_info: None,
89                server_ephemeral_key: None,
90                cert_commitment,
91                transcript_commitments: Vec::new(),
92                extensions,
93            },
94        })
95    }
96}
97
98impl AttestationBuilder<'_, Sign> {
99    /// Sets the connection information.
100    pub fn connection_info(&mut self, connection_info: ConnectionInfo) -> &mut Self {
101        self.state.connection_info = Some(connection_info);
102        self
103    }
104
105    /// Sets the server ephemeral key.
106    pub fn server_ephemeral_key(&mut self, key: ServerEphemKey) -> &mut Self {
107        self.state.server_ephemeral_key = Some(key);
108        self
109    }
110
111    /// Adds an extension to the attestation.
112    pub fn extension(&mut self, extension: Extension) -> &mut Self {
113        self.state.extensions.push(extension);
114        self
115    }
116
117    /// Sets the transcript commitments.
118    pub fn transcript_commitments(
119        &mut self,
120        transcript_commitments: Vec<TranscriptCommitment>,
121    ) -> &mut Self {
122        self.state.transcript_commitments = transcript_commitments;
123        self
124    }
125
126    /// Builds the attestation.
127    pub fn build(self, provider: &CryptoProvider) -> Result<Attestation, AttestationBuilderError> {
128        let Sign {
129            signature_alg,
130            hash_alg,
131            connection_info,
132            server_ephemeral_key,
133            cert_commitment,
134            extensions,
135            transcript_commitments,
136        } = self.state;
137
138        let hasher = provider.hash.get(&hash_alg).map_err(|_| {
139            AttestationBuilderError::new(
140                ErrorKind::Config,
141                format!("accepted hash algorithm {hash_alg} but it's missing in the provider"),
142            )
143        })?;
144        let signer = provider.signer.get(&signature_alg).map_err(|_| {
145            AttestationBuilderError::new(
146                ErrorKind::Config,
147                format!(
148                    "accepted signature algorithm {signature_alg} but it's missing in the provider"
149                ),
150            )
151        })?;
152
153        let mut field_id = FieldId::default();
154
155        let body = Body {
156            verifying_key: field_id.next(signer.verifying_key()),
157            connection_info: field_id.next(connection_info.ok_or_else(|| {
158                AttestationBuilderError::new(ErrorKind::Field, "connection info was not set")
159            })?),
160            server_ephemeral_key: field_id.next(server_ephemeral_key.ok_or_else(|| {
161                AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set")
162            })?),
163            cert_commitment: field_id.next(cert_commitment),
164            extensions: extensions
165                .into_iter()
166                .map(|extension| field_id.next(extension))
167                .collect(),
168            transcript_commitments: transcript_commitments
169                .into_iter()
170                .map(|commitment| field_id.next(commitment))
171                .collect(),
172        };
173
174        let header = Header {
175            id: rng().random(),
176            version: VERSION,
177            root: body.root(hasher),
178        };
179
180        let signature = signer
181            .sign(&CanonicalSerialize::serialize(&header))
182            .map_err(|err| AttestationBuilderError::new(ErrorKind::Signature, err))?;
183
184        Ok(Attestation {
185            signature,
186            header,
187            body,
188        })
189    }
190}
191
192/// Error for [`AttestationBuilder`].
193#[derive(Debug, thiserror::Error)]
194pub struct AttestationBuilderError {
195    kind: ErrorKind,
196    source: Option<Box<dyn Error + Send + Sync + 'static>>,
197}
198
199#[derive(Debug)]
200enum ErrorKind {
201    Request,
202    Config,
203    Field,
204    Signature,
205    Extension,
206}
207
208impl AttestationBuilderError {
209    fn new<E>(kind: ErrorKind, error: E) -> Self
210    where
211        E: Into<Box<dyn Error + Send + Sync + 'static>>,
212    {
213        Self {
214            kind,
215            source: Some(error.into()),
216        }
217    }
218
219    /// Returns whether the error originates from a bad request.
220    pub fn is_request(&self) -> bool {
221        matches!(self.kind, ErrorKind::Request)
222    }
223}
224
225impl std::fmt::Display for AttestationBuilderError {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        match self.kind {
228            ErrorKind::Request => f.write_str("request error")?,
229            ErrorKind::Config => f.write_str("config error")?,
230            ErrorKind::Field => f.write_str("field error")?,
231            ErrorKind::Signature => f.write_str("signature error")?,
232            ErrorKind::Extension => f.write_str("extension error")?,
233        }
234
235        if let Some(source) = &self.source {
236            write!(f, " caused by: {source}")?;
237        }
238
239        Ok(())
240    }
241}
242
243#[cfg(test)]
244mod test {
245    use rstest::{fixture, rstest};
246    use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
247
248    use crate::{
249        connection::{HandshakeData, HandshakeDataV1_2},
250        fixtures::{encoding_provider, request_fixture, ConnectionFixture, RequestFixture},
251        hash::Blake3,
252        transcript::Transcript,
253    };
254
255    use super::*;
256
257    #[fixture]
258    #[once]
259    fn attestation_config() -> AttestationConfig {
260        AttestationConfig::builder()
261            .supported_signature_algs([SignatureAlgId::SECP256K1])
262            .build()
263            .unwrap()
264    }
265
266    #[fixture]
267    #[once]
268    fn crypto_provider() -> CryptoProvider {
269        let mut provider = CryptoProvider::default();
270        provider.signer.set_secp256k1(&[42u8; 32]).unwrap();
271        provider
272    }
273
274    #[rstest]
275    fn test_attestation_builder_accept_unsupported_signer() {
276        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
277        let connection = ConnectionFixture::tlsnotary(transcript.length());
278
279        let RequestFixture { request, .. } = request_fixture(
280            transcript,
281            encoding_provider(GET_WITH_HEADER, OK_JSON),
282            connection,
283            Blake3::default(),
284            Vec::new(),
285        );
286
287        let attestation_config = AttestationConfig::builder()
288            .supported_signature_algs([SignatureAlgId::SECP256R1])
289            .build()
290            .unwrap();
291
292        let err = Attestation::builder(&attestation_config)
293            .accept_request(request)
294            .err()
295            .unwrap();
296        assert!(err.is_request());
297    }
298
299    #[rstest]
300    fn test_attestation_builder_accept_unsupported_hasher() {
301        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
302        let connection = ConnectionFixture::tlsnotary(transcript.length());
303
304        let RequestFixture { request, .. } = request_fixture(
305            transcript,
306            encoding_provider(GET_WITH_HEADER, OK_JSON),
307            connection,
308            Blake3::default(),
309            Vec::new(),
310        );
311
312        let attestation_config = AttestationConfig::builder()
313            .supported_signature_algs([SignatureAlgId::SECP256K1])
314            .supported_hash_algs([HashAlgId::KECCAK256])
315            .build()
316            .unwrap();
317
318        let err = Attestation::builder(&attestation_config)
319            .accept_request(request)
320            .err()
321            .unwrap();
322        assert!(err.is_request());
323    }
324
325    #[rstest]
326    fn test_attestation_builder_sign_missing_signer(attestation_config: &AttestationConfig) {
327        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
328        let connection = ConnectionFixture::tlsnotary(transcript.length());
329
330        let RequestFixture { request, .. } = request_fixture(
331            transcript,
332            encoding_provider(GET_WITH_HEADER, OK_JSON),
333            connection,
334            Blake3::default(),
335            Vec::new(),
336        );
337
338        let attestation_builder = Attestation::builder(attestation_config)
339            .accept_request(request)
340            .unwrap();
341
342        let mut provider = CryptoProvider::default();
343        provider.signer.set_secp256r1(&[42u8; 32]).unwrap();
344
345        let err = attestation_builder.build(&provider).unwrap_err();
346        assert!(matches!(err.kind, ErrorKind::Config));
347    }
348
349    #[rstest]
350    fn test_attestation_builder_sign_missing_server_ephemeral_key(
351        attestation_config: &AttestationConfig,
352        crypto_provider: &CryptoProvider,
353    ) {
354        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
355        let connection = ConnectionFixture::tlsnotary(transcript.length());
356
357        let RequestFixture { request, .. } = request_fixture(
358            transcript,
359            encoding_provider(GET_WITH_HEADER, OK_JSON),
360            connection.clone(),
361            Blake3::default(),
362            Vec::new(),
363        );
364
365        let mut attestation_builder = Attestation::builder(attestation_config)
366            .accept_request(request)
367            .unwrap();
368
369        let ConnectionFixture {
370            connection_info, ..
371        } = connection;
372
373        attestation_builder.connection_info(connection_info);
374
375        let err = attestation_builder.build(crypto_provider).unwrap_err();
376        assert!(matches!(err.kind, ErrorKind::Field));
377    }
378
379    #[rstest]
380    fn test_attestation_builder_sign_missing_connection_info(
381        attestation_config: &AttestationConfig,
382        crypto_provider: &CryptoProvider,
383    ) {
384        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
385        let connection = ConnectionFixture::tlsnotary(transcript.length());
386
387        let RequestFixture { request, .. } = request_fixture(
388            transcript,
389            encoding_provider(GET_WITH_HEADER, OK_JSON),
390            connection.clone(),
391            Blake3::default(),
392            Vec::new(),
393        );
394
395        let mut attestation_builder = Attestation::builder(attestation_config)
396            .accept_request(request)
397            .unwrap();
398
399        let ConnectionFixture {
400            server_cert_data, ..
401        } = connection;
402
403        let HandshakeData::V1_2(HandshakeDataV1_2 {
404            server_ephemeral_key,
405            ..
406        }) = server_cert_data.handshake;
407
408        attestation_builder.server_ephemeral_key(server_ephemeral_key);
409
410        let err = attestation_builder.build(crypto_provider).unwrap_err();
411        assert!(matches!(err.kind, ErrorKind::Field));
412    }
413
414    #[rstest]
415    fn test_attestation_builder_reject_extensions_by_default(
416        attestation_config: &AttestationConfig,
417    ) {
418        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
419        let connection = ConnectionFixture::tlsnotary(transcript.length());
420
421        let RequestFixture { request, .. } = request_fixture(
422            transcript,
423            encoding_provider(GET_WITH_HEADER, OK_JSON),
424            connection.clone(),
425            Blake3::default(),
426            vec![Extension {
427                id: b"foo".to_vec(),
428                value: b"bar".to_vec(),
429            }],
430        );
431
432        let err = Attestation::builder(attestation_config)
433            .accept_request(request)
434            .unwrap_err();
435
436        assert!(matches!(err.kind, ErrorKind::Extension));
437    }
438
439    #[rstest]
440    fn test_attestation_builder_accept_extension(crypto_provider: &CryptoProvider) {
441        let attestation_config = AttestationConfig::builder()
442            .supported_signature_algs([SignatureAlgId::SECP256K1])
443            .extension_validator(|_| Ok(()))
444            .build()
445            .unwrap();
446
447        let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
448        let connection = ConnectionFixture::tlsnotary(transcript.length());
449
450        let RequestFixture { request, .. } = request_fixture(
451            transcript,
452            encoding_provider(GET_WITH_HEADER, OK_JSON),
453            connection.clone(),
454            Blake3::default(),
455            vec![Extension {
456                id: b"foo".to_vec(),
457                value: b"bar".to_vec(),
458            }],
459        );
460
461        let mut attestation_builder = Attestation::builder(&attestation_config)
462            .accept_request(request)
463            .unwrap();
464
465        let ConnectionFixture {
466            server_cert_data,
467            connection_info,
468            ..
469        } = connection;
470
471        let HandshakeData::V1_2(HandshakeDataV1_2 {
472            server_ephemeral_key,
473            ..
474        }) = server_cert_data.handshake;
475
476        attestation_builder
477            .connection_info(connection_info)
478            .server_ephemeral_key(server_ephemeral_key);
479
480        let attestation = attestation_builder.build(crypto_provider).unwrap();
481
482        assert_eq!(attestation.body.extensions().count(), 1);
483    }
484}