tlsn_core/attestation/
config.rs

1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4    attestation::{Extension, InvalidExtension},
5    hash::{HashAlgId, DEFAULT_SUPPORTED_HASH_ALGS},
6    signing::SignatureAlgId,
7};
8
9type ExtensionValidator = Arc<dyn Fn(&[Extension]) -> Result<(), InvalidExtension> + Send + Sync>;
10
11#[derive(Debug)]
12#[allow(dead_code)]
13enum ErrorKind {
14    Builder,
15}
16
17impl std::fmt::Display for ErrorKind {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            ErrorKind::Builder => write!(f, "builder"),
21        }
22    }
23}
24
25/// Error for [`AttestationConfig`].
26#[derive(Debug, thiserror::Error)]
27#[error("attestation config error: kind: {kind}, reason: {reason}")]
28pub struct AttestationConfigError {
29    kind: ErrorKind,
30    reason: String,
31}
32
33impl AttestationConfigError {
34    #[allow(dead_code)]
35    fn builder(reason: impl Into<String>) -> Self {
36        Self {
37            kind: ErrorKind::Builder,
38            reason: reason.into(),
39        }
40    }
41}
42
43/// Attestation configuration.
44#[derive(Clone)]
45pub struct AttestationConfig {
46    supported_signature_algs: Vec<SignatureAlgId>,
47    supported_hash_algs: Vec<HashAlgId>,
48    extension_validator: Option<ExtensionValidator>,
49}
50
51impl AttestationConfig {
52    /// Creates a new builder.
53    pub fn builder() -> AttestationConfigBuilder {
54        AttestationConfigBuilder::default()
55    }
56
57    pub(crate) fn supported_signature_algs(&self) -> &[SignatureAlgId] {
58        &self.supported_signature_algs
59    }
60
61    pub(crate) fn supported_hash_algs(&self) -> &[HashAlgId] {
62        &self.supported_hash_algs
63    }
64
65    pub(crate) fn extension_validator(&self) -> Option<&ExtensionValidator> {
66        self.extension_validator.as_ref()
67    }
68}
69
70impl Debug for AttestationConfig {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("AttestationConfig")
73            .field("supported_signature_algs", &self.supported_signature_algs)
74            .field("supported_hash_algs", &self.supported_hash_algs)
75            .finish_non_exhaustive()
76    }
77}
78
79/// Builder for [`AttestationConfig`].
80pub struct AttestationConfigBuilder {
81    supported_signature_algs: Vec<SignatureAlgId>,
82    supported_hash_algs: Vec<HashAlgId>,
83    extension_validator: Option<ExtensionValidator>,
84}
85
86impl Default for AttestationConfigBuilder {
87    fn default() -> Self {
88        Self {
89            supported_signature_algs: Vec::default(),
90            supported_hash_algs: DEFAULT_SUPPORTED_HASH_ALGS.to_vec(),
91            extension_validator: Some(Arc::new(|e| {
92                if !e.is_empty() {
93                    Err(InvalidExtension::new(
94                        "all extensions are disallowed by default",
95                    ))
96                } else {
97                    Ok(())
98                }
99            })),
100        }
101    }
102}
103
104impl AttestationConfigBuilder {
105    /// Sets the supported signature algorithms.
106    pub fn supported_signature_algs(
107        &mut self,
108        supported_signature_algs: impl Into<Vec<SignatureAlgId>>,
109    ) -> &mut Self {
110        self.supported_signature_algs = supported_signature_algs.into();
111        self
112    }
113
114    /// Sets the supported hash algorithms.
115    pub fn supported_hash_algs(
116        &mut self,
117        supported_hash_algs: impl Into<Vec<HashAlgId>>,
118    ) -> &mut Self {
119        self.supported_hash_algs = supported_hash_algs.into();
120        self
121    }
122
123    /// Sets the extension validator.
124    ///
125    /// # Example
126    /// ```
127    /// # use tlsn_core::attestation::{AttestationConfig, InvalidExtension};
128    /// # let mut builder = AttestationConfig::builder();
129    /// builder.extension_validator(|extensions| {
130    ///     for extension in extensions {
131    ///         if extension.id != b"example.type" {
132    ///             return Err(InvalidExtension::new("invalid extension type"));
133    ///         }
134    ///     }
135    ///     Ok(())
136    /// });
137    /// ```
138    pub fn extension_validator<F>(&mut self, f: F) -> &mut Self
139    where
140        F: Fn(&[Extension]) -> Result<(), InvalidExtension> + Send + Sync + 'static,
141    {
142        self.extension_validator = Some(Arc::new(f));
143        self
144    }
145
146    /// Builds the configuration.
147    pub fn build(&self) -> Result<AttestationConfig, AttestationConfigError> {
148        Ok(AttestationConfig {
149            supported_signature_algs: self.supported_signature_algs.clone(),
150            supported_hash_algs: self.supported_hash_algs.clone(),
151            extension_validator: self.extension_validator.clone(),
152        })
153    }
154}
155
156impl Debug for AttestationConfigBuilder {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("AttestationConfigBuilder")
159            .field("supported_signature_algs", &self.supported_signature_algs)
160            .field("supported_hash_algs", &self.supported_hash_algs)
161            .finish_non_exhaustive()
162    }
163}