1use std::sync::Arc;
2
3use derive_builder::UninitializedFieldError;
4use mpc_tls::Config;
5use rustls_pki_types::{pem::PemObject, CertificateDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
6use tls_core::key;
7use tlsn_common::config::{NetworkSetting, ProtocolConfig};
8use tlsn_core::{connection::ServerName, CryptoProvider};
9
10#[derive(Debug, Clone, derive_builder::Builder)]
12pub struct ProverConfig {
13 #[builder(setter(into))]
15 server_name: ServerName,
16 protocol_config: ProtocolConfig,
18 #[builder(default, setter(into))]
20 crypto_provider: Arc<CryptoProvider>,
21 #[builder(default)]
23 tls_config: TlsConfig,
24}
25
26impl ProverConfig {
27 pub fn builder() -> ProverConfigBuilder {
29 ProverConfigBuilder::default()
30 }
31
32 pub fn server_name(&self) -> &ServerName {
34 &self.server_name
35 }
36
37 pub fn crypto_provider(&self) -> &CryptoProvider {
39 &self.crypto_provider
40 }
41
42 pub fn protocol_config(&self) -> &ProtocolConfig {
44 &self.protocol_config
45 }
46
47 pub fn tls_config(&self) -> &TlsConfig {
49 &self.tls_config
50 }
51
52 pub(crate) fn build_mpc_tls_config(&self) -> Config {
53 let mut builder = Config::builder();
54
55 builder
56 .defer_decryption(self.protocol_config.defer_decryption_from_start())
57 .max_sent(self.protocol_config.max_sent_data())
58 .max_recv_online(self.protocol_config.max_recv_data_online())
59 .max_recv(self.protocol_config.max_recv_data());
60
61 if let Some(max_sent_records) = self.protocol_config.max_sent_records() {
62 builder.max_sent_records(max_sent_records);
63 }
64
65 if let Some(max_recv_records_online) = self.protocol_config.max_recv_records_online() {
66 builder.max_recv_records_online(max_recv_records_online);
67 }
68
69 if let NetworkSetting::Latency = self.protocol_config.network() {
70 builder.low_bandwidth();
71 }
72
73 builder.build().unwrap()
74 }
75}
76
77#[derive(Debug, Clone, Default, derive_builder::Builder)]
79#[builder(build_fn(error = "TlsConfigError"))]
80pub struct TlsConfig {
81 #[builder(default, setter(custom, strip_option))]
84 client_auth: Option<(Vec<key::Certificate>, key::PrivateKey)>,
85}
86
87impl TlsConfig {
88 pub fn builder() -> TlsConfigBuilder {
90 TlsConfigBuilder::default()
91 }
92
93 pub fn client_auth(&self) -> &Option<(Vec<key::Certificate>, key::PrivateKey)> {
96 &self.client_auth
97 }
98}
99
100impl TlsConfigBuilder {
101 pub fn client_auth(&mut self, cert_key: (Vec<Vec<u8>>, Vec<u8>)) -> &mut Self {
114 let certs = cert_key
115 .0
116 .into_iter()
117 .map(key::Certificate)
118 .collect::<Vec<_>>();
119
120 self.client_auth = Some(Some((certs, key::PrivateKey(cert_key.1))));
121 self
122 }
123
124 pub fn client_auth_pem(
137 &mut self,
138 cert_key: (Vec<Vec<u8>>, Vec<u8>),
139 ) -> Result<&mut Self, TlsConfigError> {
140 let key = match PrivatePkcs8KeyDer::from_pem_slice(&cert_key.1) {
141 Ok(key) => (*key.secret_pkcs8_der()).to_vec(),
143 Err(_) => match PrivatePkcs1KeyDer::from_pem_slice(&cert_key.1) {
145 Ok(key) => (*key.secret_pkcs1_der()).to_vec(),
146 Err(_) => return Err(ErrorRepr::InvalidKey.into()),
147 },
148 };
149
150 let certs = cert_key
151 .0
152 .iter()
153 .map(|c| {
154 let c =
155 CertificateDer::from_pem_slice(c).map_err(|_| ErrorRepr::InvalidCertificate)?;
156 Ok::<key::Certificate, TlsConfigError>(key::Certificate(c.as_ref().to_vec()))
157 })
158 .collect::<Result<Vec<_>, _>>()?;
159
160 self.client_auth = Some(Some((certs, key::PrivateKey(key))));
161 Ok(self)
162 }
163}
164
165#[derive(Debug, thiserror::Error)]
167#[error(transparent)]
168pub struct TlsConfigError(#[from] ErrorRepr);
169
170#[derive(Debug, thiserror::Error)]
171#[error("tls config error: {0}")]
172enum ErrorRepr {
173 #[error("missing field: {0:?}")]
174 MissingField(String),
175 #[error("the certificate for client authentication is invalid")]
176 InvalidCertificate,
177 #[error("the private key for client authentication is invalid")]
178 InvalidKey,
179}
180
181impl From<derive_builder::UninitializedFieldError> for TlsConfigError {
182 fn from(e: UninitializedFieldError) -> Self {
183 ErrorRepr::MissingField(e.field_name().to_string()).into()
184 }
185}