tlsn_prover/
lib.rs

1//! TLSNotary prover library.
2
3#![deny(missing_docs, unreachable_pub, unused_must_use)]
4#![deny(clippy::all)]
5#![forbid(unsafe_code)]
6
7mod config;
8mod error;
9mod future;
10pub mod state;
11
12pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError};
13pub use error::ProverError;
14pub use future::ProverFuture;
15pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
16
17use mpz_common::Context;
18use mpz_core::Block;
19use mpz_garble_core::Delta;
20use mpz_vm_core::prelude::*;
21
22use futures::{AsyncRead, AsyncWrite, TryFutureExt};
23use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
24use rand::Rng;
25use serio::{stream::IoStreamExt, SinkExt};
26use std::sync::Arc;
27use tls_client::{ClientConnection, ServerName as TlsServerName};
28use tls_client_async::{bind_client, TlsConnection};
29use tls_core::msgs::enums::ContentType;
30use tlsn_common::{
31    commit::{commit_records, hash::prove_hash},
32    context::build_mt_context,
33    encoding,
34    mux::attach_mux,
35    transcript::{decode_transcript, Record, TlsTranscript},
36    zk_aes::ZkAesCtr,
37    Role,
38};
39use tlsn_core::{
40    attestation::Attestation,
41    connection::{
42        ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature,
43        TranscriptLength,
44    },
45    request::{Request, RequestConfig},
46    transcript::{Direction, Transcript, TranscriptCommitment, TranscriptSecret},
47    ProvePayload, Secrets,
48};
49use tlsn_deap::Deap;
50use tokio::sync::Mutex;
51
52use tracing::{debug, info_span, instrument, Instrument, Span};
53
54pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
55    mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
56    mpz_core::Block,
57>;
58pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
59    mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
60    bool,
61    mpz_core::Block,
62>;
63pub(crate) type Mpc =
64    mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
65pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
66
67/// A prover instance.
68#[derive(Debug)]
69pub struct Prover<T: state::ProverState = state::Initialized> {
70    config: ProverConfig,
71    span: Span,
72    state: T,
73}
74
75impl Prover<state::Initialized> {
76    /// Creates a new prover.
77    ///
78    /// # Arguments
79    ///
80    /// * `config` - The configuration for the prover.
81    pub fn new(config: ProverConfig) -> Self {
82        let span = info_span!("prover");
83        Self {
84            config,
85            span,
86            state: state::Initialized,
87        }
88    }
89
90    /// Sets up the prover.
91    ///
92    /// This performs all MPC setup prior to establishing the connection to the
93    /// application server.
94    ///
95    /// # Arguments
96    ///
97    /// * `socket` - The socket to the TLS verifier.
98    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
99    pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
100        self,
101        socket: S,
102    ) -> Result<Prover<state::Setup>, ProverError> {
103        let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
104        let mut mt = build_mt_context(mux_ctrl.clone());
105        let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
106
107        // Sends protocol configuration to verifier for compatibility check.
108        mux_fut
109            .poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
110            .await?;
111
112        let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
113
114        // Allocate resources for MPC-TLS in VM.
115        let mut keys = mpc_tls.alloc()?;
116        translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
117
118        // Allocate for committing to plaintext.
119        let mut zk_aes = ZkAesCtr::new(Role::Prover);
120        zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
121        zk_aes.alloc(
122            &mut (*vm.try_lock().expect("VM is not locked").zk()),
123            self.config.protocol_config().max_recv_data(),
124        )?;
125
126        debug!("setting up mpc-tls");
127
128        mux_fut.poll_with(mpc_tls.preprocess()).await?;
129
130        debug!("mpc-tls setup complete");
131
132        Ok(Prover {
133            config: self.config,
134            span: self.span,
135            state: state::Setup {
136                mux_ctrl,
137                mux_fut,
138                mpc_tls,
139                zk_aes,
140                keys,
141                vm,
142            },
143        })
144    }
145}
146
147impl Prover<state::Setup> {
148    /// Connects to the server using the provided socket.
149    ///
150    /// Returns a handle to the TLS connection, a future which returns the
151    /// prover once the connection is closed.
152    ///
153    /// # Arguments
154    ///
155    /// * `socket` - The socket to the server.
156    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
157    pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
158        self,
159        socket: S,
160    ) -> Result<(TlsConnection, ProverFuture), ProverError> {
161        let state::Setup {
162            mux_ctrl,
163            mut mux_fut,
164            mpc_tls,
165            mut zk_aes,
166            keys,
167            vm,
168            ..
169        } = self.state;
170
171        let (mpc_ctrl, mpc_fut) = mpc_tls.run();
172
173        let server_name =
174            TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
175                ProverError::config(format!(
176                    "invalid server name: {}",
177                    self.config.server_name()
178                ))
179            })?;
180
181        let config = tls_client::ClientConfig::builder()
182            .with_safe_defaults()
183            .with_root_certificates(self.config.crypto_provider().cert.root_store().clone())
184            .with_no_client_auth();
185        let client =
186            ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
187                .map_err(ProverError::config)?;
188
189        let (conn, conn_fut) = bind_client(socket, client);
190
191        let start_time = web_time::UNIX_EPOCH
192            .elapsed()
193            .expect("system time is available")
194            .as_secs();
195
196        let fut = Box::pin({
197            let span = self.span.clone();
198            let mpc_ctrl = mpc_ctrl.clone();
199            async move {
200                let conn_fut = async {
201                    mux_fut
202                        .poll_with(conn_fut.map_err(ProverError::from))
203                        .await?;
204
205                    mpc_ctrl.stop().await?;
206
207                    Ok::<_, ProverError>(())
208                };
209
210                let (_, (mut ctx, mut data)) = futures::try_join!(
211                    conn_fut,
212                    mpc_fut.in_current_span().map_err(ProverError::from)
213                )?;
214
215                {
216                    let mut vm = vm.try_lock().expect("VM should not be locked");
217
218                    translate_transcript(&mut data.transcript, &vm)?;
219
220                    // Prove received plaintext. Prover drops the proof output, as they trust
221                    // themselves.
222                    _ = commit_records(
223                        &mut (*vm.zk()),
224                        &mut zk_aes,
225                        data.transcript
226                            .recv
227                            .iter_mut()
228                            .filter(|record| record.typ == ContentType::ApplicationData),
229                    )
230                    .map_err(ProverError::zk)?;
231
232                    debug!("finalizing mpc");
233
234                    // Finalize DEAP and execute the plaintext proofs.
235                    mux_fut
236                        .poll_with(vm.finalize(&mut ctx))
237                        .await
238                        .map_err(ProverError::mpc)?;
239
240                    debug!("mpc finalized");
241                }
242
243                let transcript = data
244                    .transcript
245                    .to_transcript()
246                    .expect("transcript is complete");
247                let transcript_refs = data
248                    .transcript
249                    .to_transcript_refs()
250                    .expect("transcript is complete");
251
252                let connection_info = ConnectionInfo {
253                    time: start_time,
254                    version: data
255                        .protocol_version
256                        .try_into()
257                        .expect("only supported version should have been accepted"),
258                    transcript_length: TranscriptLength {
259                        sent: transcript.sent().len() as u32,
260                        received: transcript.received().len() as u32,
261                    },
262                };
263
264                let server_cert_data =
265                    ServerCertData {
266                        certs: data
267                            .server_cert_details
268                            .cert_chain()
269                            .iter()
270                            .cloned()
271                            .map(|c| c.into())
272                            .collect(),
273                        sig: ServerSignature {
274                            scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
275                                "only supported signature scheme should have been accepted",
276                            ),
277                            sig: data.server_kx_details.kx_sig().sig.0.clone(),
278                        },
279                        handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
280                            client_random: data.client_random.0,
281                            server_random: data.server_random.0,
282                            server_ephemeral_key: data
283                                .server_key
284                                .try_into()
285                                .expect("only supported key scheme should have been accepted"),
286                        }),
287                    };
288
289                // Pull out ZK VM.
290                let (_, vm) = Arc::into_inner(vm)
291                    .expect("vm should have only 1 reference")
292                    .into_inner()
293                    .into_inner();
294
295                Ok(Prover {
296                    config: self.config,
297                    span: self.span,
298                    state: state::Committed {
299                        mux_ctrl,
300                        mux_fut,
301                        ctx,
302                        _keys: keys,
303                        vm,
304                        connection_info,
305                        server_cert_data,
306                        transcript,
307                        transcript_refs,
308                    },
309                })
310            }
311            .instrument(span)
312        });
313
314        Ok((
315            conn,
316            ProverFuture {
317                fut,
318                ctrl: ProverControl { mpc_ctrl },
319            },
320        ))
321    }
322}
323
324impl Prover<state::Committed> {
325    /// Returns the connection information.
326    pub fn connection_info(&self) -> &ConnectionInfo {
327        &self.state.connection_info
328    }
329
330    /// Returns the transcript.
331    pub fn transcript(&self) -> &Transcript {
332        &self.state.transcript
333    }
334
335    /// Proves information to the verifier.
336    ///
337    /// # Arguments
338    ///
339    /// * `config` - The disclosure configuration.
340    #[instrument(parent = &self.span, level = "info", skip_all, err)]
341    pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
342        let state::Committed {
343            mux_fut,
344            ctx,
345            vm,
346            server_cert_data,
347            transcript_refs,
348            ..
349        } = &mut self.state;
350
351        let mut output = ProverOutput {
352            transcript_commitments: Vec::new(),
353            transcript_secrets: Vec::new(),
354        };
355
356        let payload = ProvePayload {
357            server_identity: config
358                .server_identity()
359                .then(|| (self.config.server_name().clone(), server_cert_data.clone())),
360            transcript: config.transcript().cloned(),
361            transcript_commit: config.transcript_commit().map(|config| config.to_request()),
362        };
363
364        // Send payload.
365        mux_fut
366            .poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
367            .await?;
368
369        if let Some(partial_transcript) = config.transcript() {
370            decode_transcript(
371                vm,
372                partial_transcript.sent_authed(),
373                partial_transcript.received_authed(),
374                transcript_refs,
375            )
376            .map_err(ProverError::zk)?;
377        }
378
379        let mut hash_commitments = None;
380        if let Some(commit_config) = config.transcript_commit() {
381            if commit_config.has_encoding() {
382                let hasher = self
383                    .config
384                    .crypto_provider()
385                    .hash
386                    .get(commit_config.encoding_hash_alg())
387                    .map_err(ProverError::config)?;
388
389                let (commitment, tree) = mux_fut
390                    .poll_with(
391                        encoding::receive(
392                            ctx,
393                            hasher,
394                            transcript_refs,
395                            |plaintext| vm.get_macs(plaintext).expect("reference is valid"),
396                            commit_config.iter_encoding(),
397                        )
398                        .map_err(ProverError::commit),
399                    )
400                    .await?;
401
402                output
403                    .transcript_commitments
404                    .push(TranscriptCommitment::Encoding(commitment));
405                output
406                    .transcript_secrets
407                    .push(TranscriptSecret::Encoding(tree));
408            }
409
410            if commit_config.has_hash() {
411                hash_commitments = Some(
412                    prove_hash(
413                        vm,
414                        transcript_refs,
415                        commit_config
416                            .iter_hash()
417                            .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
418                    )
419                    .map_err(ProverError::commit)?,
420                );
421            }
422        }
423
424        mux_fut
425            .poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
426            .await?;
427
428        if let Some((hash_fut, hash_secrets)) = hash_commitments {
429            let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
430            for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
431                output
432                    .transcript_commitments
433                    .push(TranscriptCommitment::Hash(commitment));
434                output
435                    .transcript_secrets
436                    .push(TranscriptSecret::Hash(secret));
437            }
438        }
439
440        Ok(output)
441    }
442
443    /// Requests an attestation from the verifier.
444    ///
445    /// # Arguments
446    ///
447    /// * `config` - The attestation request configuration.
448    #[instrument(parent = &self.span, level = "info", skip_all, err)]
449    #[deprecated(
450        note = "attestation functionality will be removed from this API in future releases."
451    )]
452    pub async fn notarize(
453        &mut self,
454        config: &RequestConfig,
455    ) -> Result<(Attestation, Secrets), ProverError> {
456        let mut builder = ProveConfig::builder(self.transcript());
457
458        if let Some(config) = config.transcript_commit() {
459            // Temporarily, we reject attestation requests which contain hash commitments to
460            // subsets of the transcript. We do this because we want to preserve the
461            // obliviousness of the reference notary, and hash commitments currently leak
462            // the ranges which are being committed.
463            for ((direction, idx), _) in config.iter_hash() {
464                let len = match direction {
465                    Direction::Sent => self.transcript().sent().len(),
466                    Direction::Received => self.transcript().received().len(),
467                };
468
469                if idx.start() > 0 || idx.end() < len || idx.count() != 1 {
470                    return Err(ProverError::attestation(
471                        "hash commitments to subsets of the transcript are currently not supported in attestation requests",
472                    ));
473                }
474            }
475
476            builder.transcript_commit(config.clone());
477        }
478
479        let disclosure_config = builder.build().map_err(ProverError::attestation)?;
480
481        let ProverOutput {
482            transcript_commitments,
483            transcript_secrets,
484            ..
485        } = self.prove(&disclosure_config).await?;
486
487        let state::Committed {
488            mux_fut,
489            ctx,
490            server_cert_data,
491            transcript,
492            ..
493        } = &mut self.state;
494
495        let mut builder = Request::builder(config);
496
497        builder
498            .server_name(self.config.server_name().clone())
499            .server_cert_data(server_cert_data.clone())
500            .transcript(transcript.clone())
501            .transcript_commitments(transcript_secrets, transcript_commitments);
502
503        let (request, secrets) = builder
504            .build(self.config.crypto_provider())
505            .map_err(ProverError::attestation)?;
506
507        let attestation = mux_fut
508            .poll_with(async {
509                debug!("sending attestation request");
510
511                ctx.io_mut().send(request.clone()).await?;
512
513                let attestation: Attestation = ctx.io_mut().expect_next().await?;
514
515                Ok::<_, ProverError>(attestation)
516            })
517            .await?;
518
519        // Check the attestation is consistent with the Prover's view.
520        request
521            .validate(&attestation)
522            .map_err(ProverError::attestation)?;
523
524        Ok((attestation, secrets))
525    }
526
527    /// Closes the connection with the verifier.
528    #[instrument(parent = &self.span, level = "info", skip_all, err)]
529    pub async fn close(self) -> Result<(), ProverError> {
530        let state::Committed {
531            mux_ctrl, mux_fut, ..
532        } = self.state;
533
534        // Wait for the verifier to correctly close the connection.
535        if !mux_fut.is_complete() {
536            mux_ctrl.close();
537            mux_fut.await?;
538        }
539
540        Ok(())
541    }
542}
543
544fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
545    let mut rng = rand::rng();
546    let delta = Delta::new(Block::random(&mut rng));
547
548    let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
549    let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
550    let rcot_send = mpz_ot::kos::Sender::new(
551        mpz_ot::kos::SenderConfig::default(),
552        delta.into_inner(),
553        base_ot_recv,
554    );
555    let rcot_recv =
556        mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
557    let rcot_recv = mpz_ot::ferret::Receiver::new(
558        mpz_ot::ferret::FerretConfig::builder()
559            .lpn_type(mpz_ot::ferret::LpnType::Regular)
560            .build()
561            .expect("ferret config is valid"),
562        Block::random(&mut rng),
563        rcot_recv,
564    );
565
566    let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
567    let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
568
569    let mpc = Mpc::new(
570        mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
571        rng.random(),
572        delta,
573    );
574
575    let zk = Zk::new(rcot_recv.clone());
576
577    let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
578
579    (
580        vm.clone(),
581        MpcTlsLeader::new(
582            config.build_mpc_tls_config(),
583            ctx,
584            vm,
585            (rcot_send.clone(), rcot_send.clone(), rcot_send),
586            rcot_recv,
587        ),
588    )
589}
590
591/// A controller for the prover.
592#[derive(Clone)]
593pub struct ProverControl {
594    mpc_ctrl: LeaderCtrl,
595}
596
597impl ProverControl {
598    /// Defers decryption of data from the server until the server has closed
599    /// the connection.
600    ///
601    /// This is a performance optimization which will significantly reduce the
602    /// amount of upload bandwidth used by the prover.
603    ///
604    /// # Notes
605    ///
606    /// * The prover may need to close the connection to the server in order for
607    ///   it to close the connection on its end. If neither the prover or server
608    ///   close the connection this will cause a deadlock.
609    pub async fn defer_decryption(&self) -> Result<(), ProverError> {
610        self.mpc_ctrl
611            .defer_decryption()
612            .await
613            .map_err(ProverError::from)
614    }
615}
616
617/// Translates VM references to the ZK address space.
618fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
619    keys.client_write_key = vm
620        .translate(keys.client_write_key)
621        .map_err(ProverError::mpc)?;
622    keys.client_write_iv = vm
623        .translate(keys.client_write_iv)
624        .map_err(ProverError::mpc)?;
625    keys.server_write_key = vm
626        .translate(keys.server_write_key)
627        .map_err(ProverError::mpc)?;
628    keys.server_write_iv = vm
629        .translate(keys.server_write_iv)
630        .map_err(ProverError::mpc)?;
631
632    Ok(())
633}
634
635/// Translates VM references to the ZK address space.
636fn translate_transcript<Mpc, Zk>(
637    transcript: &mut TlsTranscript,
638    vm: &Deap<Mpc, Zk>,
639) -> Result<(), ProverError> {
640    for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
641    {
642        if let Some(plaintext_ref) = plaintext_ref.as_mut() {
643            *plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?;
644        }
645    }
646
647    Ok(())
648}