tlsn_prover/
config.rs

1use std::sync::Arc;
2
3use mpc_tls::Config;
4use tlsn_common::config::{NetworkSetting, ProtocolConfig};
5use tlsn_core::{connection::ServerName, CryptoProvider};
6
7/// Configuration for the prover
8#[derive(Debug, Clone, derive_builder::Builder)]
9pub struct ProverConfig {
10    /// The server DNS name.
11    #[builder(setter(into))]
12    server_name: ServerName,
13    /// Protocol configuration to be checked with the verifier.
14    protocol_config: ProtocolConfig,
15    /// Cryptography provider.
16    #[builder(default, setter(into))]
17    crypto_provider: Arc<CryptoProvider>,
18}
19
20impl ProverConfig {
21    /// Create a new builder for `ProverConfig`.
22    pub fn builder() -> ProverConfigBuilder {
23        ProverConfigBuilder::default()
24    }
25
26    /// Returns the server DNS name.
27    pub fn server_name(&self) -> &ServerName {
28        &self.server_name
29    }
30
31    /// Returns the crypto provider.
32    pub fn crypto_provider(&self) -> &CryptoProvider {
33        &self.crypto_provider
34    }
35
36    /// Returns the protocol configuration.
37    pub fn protocol_config(&self) -> &ProtocolConfig {
38        &self.protocol_config
39    }
40
41    pub(crate) fn build_mpc_tls_config(&self) -> Config {
42        let mut builder = Config::builder();
43
44        builder
45            .defer_decryption(self.protocol_config.defer_decryption_from_start())
46            .max_sent(self.protocol_config.max_sent_data())
47            .max_recv_online(self.protocol_config.max_recv_data_online())
48            .max_recv(self.protocol_config.max_recv_data());
49
50        if let Some(max_sent_records) = self.protocol_config.max_sent_records() {
51            builder.max_sent_records(max_sent_records);
52        }
53
54        if let Some(max_recv_records) = self.protocol_config.max_recv_records() {
55            builder.max_recv_records(max_recv_records);
56        }
57
58        if let NetworkSetting::Bandwidth = self.protocol_config.network() {
59            builder.high_bandwidth();
60        }
61
62        builder.build().unwrap()
63    }
64}