tlsn_prover/
lib.rs

1//! TLSNotary prover library.
2
3#![deny(missing_docs, unreachable_pub, unused_must_use)]
4#![deny(clippy::all)]
5#![forbid(unsafe_code)]
6
7mod config;
8mod error;
9mod future;
10mod notarize;
11mod prove;
12pub mod state;
13
14pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError};
15pub use error::ProverError;
16pub use future::ProverFuture;
17use mpz_common::Context;
18use mpz_core::Block;
19use mpz_garble_core::Delta;
20use rand06_compat::Rand0_6CompatExt;
21use state::{Notarize, Prove};
22
23use futures::{AsyncRead, AsyncWrite, TryFutureExt};
24use mpc_tls::{LeaderCtrl, MpcTlsLeader};
25use rand::Rng;
26use serio::SinkExt;
27use std::sync::Arc;
28use tls_client::{ClientConnection, ServerName as TlsServerName};
29use tls_client_async::{bind_client, TlsConnection};
30use tls_core::msgs::enums::ContentType;
31use tlsn_common::{
32    commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role,
33};
34use tlsn_core::{
35    connection::{
36        ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature,
37        TranscriptLength,
38    },
39    transcript::Transcript,
40};
41use tlsn_deap::Deap;
42use tokio::sync::Mutex;
43
44use tracing::{debug, info_span, instrument, Instrument, Span};
45
46pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
47    mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
48    mpz_core::Block,
49>;
50pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
51    mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
52    bool,
53    mpz_core::Block,
54>;
55pub(crate) type Mpc =
56    mpz_garble::protocol::semihonest::Generator<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
57pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
58
59/// A prover instance.
60#[derive(Debug)]
61pub struct Prover<T: state::ProverState> {
62    config: ProverConfig,
63    span: Span,
64    state: T,
65}
66
67impl Prover<state::Initialized> {
68    /// Creates a new prover.
69    ///
70    /// # Arguments
71    ///
72    /// * `config` - The configuration for the prover.
73    pub fn new(config: ProverConfig) -> Self {
74        let span = info_span!("prover");
75        Self {
76            config,
77            span,
78            state: state::Initialized,
79        }
80    }
81
82    /// Sets up the prover.
83    ///
84    /// This performs all MPC setup prior to establishing the connection to the
85    /// application server.
86    ///
87    /// # Arguments
88    ///
89    /// * `socket` - The socket to the TLS verifier.
90    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
91    pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
92        self,
93        socket: S,
94    ) -> Result<Prover<state::Setup>, ProverError> {
95        let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
96        let mut mt = build_mt_context(mux_ctrl.clone());
97        let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
98
99        // Sends protocol configuration to verifier for compatibility check.
100        mux_fut
101            .poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
102            .await?;
103
104        let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
105
106        // Allocate resources for MPC-TLS in VM.
107        let keys = mpc_tls.alloc()?;
108        // Allocate for committing to plaintext.
109        let mut zk_aes = ZkAesCtr::new(Role::Prover);
110        zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
111        zk_aes.alloc(
112            &mut (*vm.try_lock().expect("VM is not locked").zk()),
113            self.config.protocol_config().max_recv_data(),
114        )?;
115
116        debug!("setting up mpc-tls");
117
118        mux_fut.poll_with(mpc_tls.preprocess()).await?;
119
120        debug!("mpc-tls setup complete");
121
122        Ok(Prover {
123            config: self.config,
124            span: self.span,
125            state: state::Setup {
126                mux_ctrl,
127                mux_fut,
128                mt,
129                mpc_tls,
130                zk_aes,
131                keys,
132                vm,
133            },
134        })
135    }
136}
137
138impl Prover<state::Setup> {
139    /// Connects to the server using the provided socket.
140    ///
141    /// Returns a handle to the TLS connection, a future which returns the
142    /// prover once the connection is closed.
143    ///
144    /// # Arguments
145    ///
146    /// * `socket` - The socket to the server.
147    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
148    pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
149        self,
150        socket: S,
151    ) -> Result<(TlsConnection, ProverFuture), ProverError> {
152        let state::Setup {
153            mux_ctrl,
154            mut mux_fut,
155            mt,
156            mpc_tls,
157            mut zk_aes,
158            keys,
159            vm,
160        } = self.state;
161
162        let (mpc_ctrl, mpc_fut) = mpc_tls.run();
163
164        let server_name =
165            TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
166                ProverError::config(format!(
167                    "invalid server name: {}",
168                    self.config.server_name()
169                ))
170            })?;
171
172        let config = tls_client::ClientConfig::builder()
173            .with_safe_defaults()
174            .with_root_certificates(self.config.crypto_provider().cert.root_store().clone())
175            .with_no_client_auth();
176        let client =
177            ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
178                .map_err(ProverError::config)?;
179
180        let (conn, conn_fut) = bind_client(socket, client);
181
182        let start_time = web_time::UNIX_EPOCH
183            .elapsed()
184            .expect("system time is available")
185            .as_secs();
186
187        let fut = Box::pin({
188            let span = self.span.clone();
189            let mpc_ctrl = mpc_ctrl.clone();
190            async move {
191                let conn_fut = async {
192                    mux_fut
193                        .poll_with(conn_fut.map_err(ProverError::from))
194                        .await?;
195
196                    mpc_ctrl.stop().await?;
197
198                    Ok::<_, ProverError>(())
199                };
200
201                let (_, (mut ctx, mut data)) = futures::try_join!(
202                    conn_fut,
203                    mpc_fut.in_current_span().map_err(ProverError::from)
204                )?;
205
206                {
207                    let mut vm = vm.try_lock().expect("VM should not be locked");
208
209                    // Prove received plaintext. Prover drops the proof output, as they trust
210                    // themselves.
211                    _ = commit_records(
212                        &mut (*vm.zk()),
213                        &mut zk_aes,
214                        data.transcript
215                            .recv
216                            .iter_mut()
217                            .filter(|record| record.typ == ContentType::ApplicationData),
218                    )
219                    .map_err(ProverError::zk)?;
220
221                    debug!("finalizing mpc");
222
223                    // Finalize DEAP and execute the plaintext proofs.
224                    mux_fut
225                        .poll_with(vm.finalize(&mut ctx))
226                        .await
227                        .map_err(ProverError::mpc)?;
228
229                    debug!("mpc finalized");
230                }
231
232                let transcript = data
233                    .transcript
234                    .to_transcript()
235                    .expect("transcript is complete");
236                let transcript_refs = data
237                    .transcript
238                    .to_transcript_refs()
239                    .expect("transcript is complete");
240
241                let connection_info = ConnectionInfo {
242                    time: start_time,
243                    version: data
244                        .protocol_version
245                        .try_into()
246                        .expect("only supported version should have been accepted"),
247                    transcript_length: TranscriptLength {
248                        sent: transcript.sent().len() as u32,
249                        received: transcript.received().len() as u32,
250                    },
251                };
252
253                let server_cert_data =
254                    ServerCertData {
255                        certs: data
256                            .server_cert_details
257                            .cert_chain()
258                            .iter()
259                            .cloned()
260                            .map(|c| c.into())
261                            .collect(),
262                        sig: ServerSignature {
263                            scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
264                                "only supported signature scheme should have been accepted",
265                            ),
266                            sig: data.server_kx_details.kx_sig().sig.0.clone(),
267                        },
268                        handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
269                            client_random: data.client_random.0,
270                            server_random: data.server_random.0,
271                            server_ephemeral_key: data
272                                .server_key
273                                .try_into()
274                                .expect("only supported key scheme should have been accepted"),
275                        }),
276                    };
277
278                // Pull out ZK VM.
279                let (_, vm) = Arc::into_inner(vm)
280                    .expect("vm should have only 1 reference")
281                    .into_inner()
282                    .into_inner();
283
284                Ok(Prover {
285                    config: self.config,
286                    span: self.span,
287                    state: state::Closed {
288                        mux_ctrl,
289                        mux_fut,
290                        mt,
291                        ctx,
292                        _keys: keys,
293                        vm,
294                        connection_info,
295                        server_cert_data,
296                        transcript,
297                        transcript_refs,
298                    },
299                })
300            }
301            .instrument(span)
302        });
303
304        Ok((
305            conn,
306            ProverFuture {
307                fut,
308                ctrl: ProverControl { mpc_ctrl },
309            },
310        ))
311    }
312}
313
314impl Prover<state::Closed> {
315    /// Returns the transcript.
316    pub fn transcript(&self) -> &Transcript {
317        &self.state.transcript
318    }
319
320    /// Starts notarization of the TLS session.
321    ///
322    /// Used when the TLS verifier is a Notary to transition the prover to the
323    /// next state where it can generate commitments to the transcript prior
324    /// to finalization.
325    pub fn start_notarize(self) -> Prover<Notarize> {
326        Prover {
327            config: self.config,
328            span: self.span,
329            state: self.state.into(),
330        }
331    }
332
333    /// Starts proving the TLS session.
334    ///
335    /// This function transitions the prover into a state where it can prove
336    /// content of the transcript.
337    pub fn start_prove(self) -> Prover<Prove> {
338        Prover {
339            config: self.config,
340            span: self.span,
341            state: self.state.into(),
342        }
343    }
344}
345
346fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
347    let mut rng = rand::rng();
348    let delta = Delta::new(Block::random(&mut rng.compat_by_ref()));
349
350    let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
351    let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
352    let rcot_send = mpz_ot::kos::Sender::new(
353        mpz_ot::kos::SenderConfig::default(),
354        delta.into_inner(),
355        base_ot_recv,
356    );
357    let rcot_recv =
358        mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
359    let rcot_recv = mpz_ot::ferret::Receiver::new(
360        mpz_ot::ferret::FerretConfig::builder()
361            .lpn_type(mpz_ot::ferret::LpnType::Regular)
362            .build()
363            .expect("ferret config is valid"),
364        Block::random(&mut rng.compat_by_ref()),
365        rcot_recv,
366    );
367
368    let mut rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(4, rcot_send);
369    let mut rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(2, rcot_recv);
370
371    let mpc = Mpc::new(
372        mpz_ot::cot::DerandCOTSender::new(rcot_send.next().expect("enough senders are available")),
373        rng.random(),
374        delta,
375    );
376
377    let zk = Zk::new(rcot_recv.next().expect("enough receivers are available"));
378
379    let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
380
381    (
382        vm.clone(),
383        MpcTlsLeader::new(
384            config.build_mpc_tls_config(),
385            ctx,
386            vm,
387            (
388                rcot_send.next().expect("enough senders are available"),
389                rcot_send.next().expect("enough senders are available"),
390                rcot_send.next().expect("enough senders are available"),
391            ),
392            rcot_recv.next().expect("enough receivers are available"),
393        ),
394    )
395}
396
397/// A controller for the prover.
398#[derive(Clone)]
399pub struct ProverControl {
400    mpc_ctrl: LeaderCtrl,
401}
402
403impl ProverControl {
404    /// Defers decryption of data from the server until the server has closed
405    /// the connection.
406    ///
407    /// This is a performance optimization which will significantly reduce the
408    /// amount of upload bandwidth used by the prover.
409    ///
410    /// # Notes
411    ///
412    /// * The prover may need to close the connection to the server in order for
413    ///   it to close the connection on its end. If neither the prover or server
414    ///   close the connection this will cause a deadlock.
415    pub async fn defer_decryption(&self) -> Result<(), ProverError> {
416        self.mpc_ctrl
417            .defer_decryption()
418            .await
419            .map_err(ProverError::from)
420    }
421}