tlsn_prover/
config.rs

1use std::sync::Arc;
2
3use derive_builder::UninitializedFieldError;
4use mpc_tls::Config;
5use rustls_pki_types::{pem::PemObject, CertificateDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
6use tls_core::key;
7use tlsn_common::config::{NetworkSetting, ProtocolConfig};
8use tlsn_core::{connection::ServerName, CryptoProvider};
9
10/// Configuration for the prover.
11#[derive(Debug, Clone, derive_builder::Builder)]
12pub struct ProverConfig {
13    /// The server DNS name.
14    #[builder(setter(into))]
15    server_name: ServerName,
16    /// Protocol configuration to be checked with the verifier.
17    protocol_config: ProtocolConfig,
18    /// Cryptography provider.
19    #[builder(default, setter(into))]
20    crypto_provider: Arc<CryptoProvider>,
21    /// TLS configuration.
22    #[builder(default)]
23    tls_config: TlsConfig,
24}
25
26impl ProverConfig {
27    /// Creates a new builder for `ProverConfig`.
28    pub fn builder() -> ProverConfigBuilder {
29        ProverConfigBuilder::default()
30    }
31
32    /// Returns the server DNS name.
33    pub fn server_name(&self) -> &ServerName {
34        &self.server_name
35    }
36
37    /// Returns the crypto provider.
38    pub fn crypto_provider(&self) -> &CryptoProvider {
39        &self.crypto_provider
40    }
41
42    /// Returns the protocol configuration.
43    pub fn protocol_config(&self) -> &ProtocolConfig {
44        &self.protocol_config
45    }
46
47    /// Returns the TLS configuration.
48    pub fn tls_config(&self) -> &TlsConfig {
49        &self.tls_config
50    }
51
52    pub(crate) fn build_mpc_tls_config(&self) -> Config {
53        let mut builder = Config::builder();
54
55        builder
56            .defer_decryption(self.protocol_config.defer_decryption_from_start())
57            .max_sent(self.protocol_config.max_sent_data())
58            .max_recv_online(self.protocol_config.max_recv_data_online())
59            .max_recv(self.protocol_config.max_recv_data());
60
61        if let Some(max_sent_records) = self.protocol_config.max_sent_records() {
62            builder.max_sent_records(max_sent_records);
63        }
64
65        if let Some(max_recv_records_online) = self.protocol_config.max_recv_records_online() {
66            builder.max_recv_records_online(max_recv_records_online);
67        }
68
69        if let NetworkSetting::Latency = self.protocol_config.network() {
70            builder.low_bandwidth();
71        }
72
73        builder.build().unwrap()
74    }
75}
76
77/// Configuration for the prover's TLS connection.
78#[derive(Debug, Clone, Default, derive_builder::Builder)]
79#[builder(build_fn(error = "TlsConfigError"))]
80pub struct TlsConfig {
81    /// Certificate chain and a matching private key for client
82    /// authentication.
83    #[builder(default, setter(custom, strip_option))]
84    client_auth: Option<(Vec<key::Certificate>, key::PrivateKey)>,
85}
86
87impl TlsConfig {
88    /// Creates a new builder for `TlsConfig`.
89    pub fn builder() -> TlsConfigBuilder {
90        TlsConfigBuilder::default()
91    }
92
93    /// Returns a certificate chain and a matching private key for client
94    /// authentication.
95    pub fn client_auth(&self) -> &Option<(Vec<key::Certificate>, key::PrivateKey)> {
96        &self.client_auth
97    }
98}
99
100impl TlsConfigBuilder {
101    /// Sets a DER-encoded certificate chain and a matching private key for
102    /// client authentication.
103    ///
104    /// Often the chain will consist of a single end-entity certificate.
105    ///
106    /// # Arguments
107    ///
108    /// * `cert_key` - A tuple containing the certificate chain and the private
109    ///   key.
110    ///
111    ///   - Each certificate in the chain must be in the X.509 format.
112    ///   - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
113    pub fn client_auth(&mut self, cert_key: (Vec<Vec<u8>>, Vec<u8>)) -> &mut Self {
114        let certs = cert_key
115            .0
116            .into_iter()
117            .map(key::Certificate)
118            .collect::<Vec<_>>();
119
120        self.client_auth = Some(Some((certs, key::PrivateKey(cert_key.1))));
121        self
122    }
123
124    /// Sets a PEM-encoded certificate chain and a matching private key for
125    /// client authentication.
126    ///
127    /// Often the chain will consist of a single end-entity certificate.
128    ///
129    /// # Arguments
130    ///
131    /// * `cert_key` - A tuple containing the certificate chain and the private
132    ///   key.
133    ///
134    ///   - Each certificate in the chain must be in the X.509 format.
135    ///   - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
136    pub fn client_auth_pem(
137        &mut self,
138        cert_key: (Vec<Vec<u8>>, Vec<u8>),
139    ) -> Result<&mut Self, TlsConfigError> {
140        let key = match PrivatePkcs8KeyDer::from_pem_slice(&cert_key.1) {
141            // Try to parse as PEM PKCS#8.
142            Ok(key) => (*key.secret_pkcs8_der()).to_vec(),
143            // Otherwise, try to parse as PEM PKCS#1.
144            Err(_) => match PrivatePkcs1KeyDer::from_pem_slice(&cert_key.1) {
145                Ok(key) => (*key.secret_pkcs1_der()).to_vec(),
146                Err(_) => return Err(ErrorRepr::InvalidKey.into()),
147            },
148        };
149
150        let certs = cert_key
151            .0
152            .iter()
153            .map(|c| {
154                let c =
155                    CertificateDer::from_pem_slice(c).map_err(|_| ErrorRepr::InvalidCertificate)?;
156                Ok::<key::Certificate, TlsConfigError>(key::Certificate(c.as_ref().to_vec()))
157            })
158            .collect::<Result<Vec<_>, _>>()?;
159
160        self.client_auth = Some(Some((certs, key::PrivateKey(key))));
161        Ok(self)
162    }
163}
164
165/// TLS configuration error.
166#[derive(Debug, thiserror::Error)]
167#[error(transparent)]
168pub struct TlsConfigError(#[from] ErrorRepr);
169
170#[derive(Debug, thiserror::Error)]
171#[error("tls config error: {0}")]
172enum ErrorRepr {
173    #[error("missing field: {0:?}")]
174    MissingField(String),
175    #[error("the certificate for client authentication is invalid")]
176    InvalidCertificate,
177    #[error("the private key for client authentication is invalid")]
178    InvalidKey,
179}
180
181impl From<derive_builder::UninitializedFieldError> for TlsConfigError {
182    fn from(e: UninitializedFieldError) -> Self {
183        ErrorRepr::MissingField(e.field_name().to_string()).into()
184    }
185}