tlsn_prover/
lib.rs

1//! TLSNotary prover library.
2
3#![deny(missing_docs, unreachable_pub, unused_must_use)]
4#![deny(clippy::all)]
5#![forbid(unsafe_code)]
6
7mod config;
8mod error;
9mod future;
10pub mod state;
11
12pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
13pub use error::ProverError;
14pub use future::ProverFuture;
15pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
16
17use mpz_common::Context;
18use mpz_core::Block;
19use mpz_garble_core::Delta;
20use mpz_vm_core::prelude::*;
21
22use futures::{AsyncRead, AsyncWrite, TryFutureExt};
23use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
24use rand::Rng;
25use serio::{stream::IoStreamExt, SinkExt};
26use std::sync::Arc;
27use tls_client::{ClientConnection, ServerName as TlsServerName};
28use tls_client_async::{bind_client, TlsConnection};
29use tls_core::msgs::enums::ContentType;
30use tlsn_common::{
31    commit::{commit_records, hash::prove_hash},
32    context::build_mt_context,
33    encoding,
34    mux::attach_mux,
35    tag::verify_tags,
36    transcript::{decode_transcript, Record, TlsTranscript},
37    zk_aes_ctr::ZkAesCtr,
38    Role,
39};
40use tlsn_core::{
41    attestation::Attestation,
42    connection::{
43        ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature,
44        TranscriptLength,
45    },
46    request::{Request, RequestConfig},
47    transcript::{Direction, Transcript, TranscriptCommitment, TranscriptSecret},
48    ProvePayload, Secrets,
49};
50use tlsn_deap::Deap;
51use tokio::sync::Mutex;
52
53use tracing::{debug, info, info_span, instrument, Instrument, Span};
54
55pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
56    mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
57    mpz_core::Block,
58>;
59pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
60    mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
61    bool,
62    mpz_core::Block,
63>;
64pub(crate) type Mpc =
65    mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
66pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
67
68/// A prover instance.
69#[derive(Debug)]
70pub struct Prover<T: state::ProverState = state::Initialized> {
71    config: ProverConfig,
72    span: Span,
73    state: T,
74}
75
76impl Prover<state::Initialized> {
77    /// Creates a new prover.
78    ///
79    /// # Arguments
80    ///
81    /// * `config` - The configuration for the prover.
82    pub fn new(config: ProverConfig) -> Self {
83        let span = info_span!("prover");
84        Self {
85            config,
86            span,
87            state: state::Initialized,
88        }
89    }
90
91    /// Sets up the prover.
92    ///
93    /// This performs all MPC setup prior to establishing the connection to the
94    /// application server.
95    ///
96    /// # Arguments
97    ///
98    /// * `socket` - The socket to the TLS verifier.
99    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
100    pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
101        self,
102        socket: S,
103    ) -> Result<Prover<state::Setup>, ProverError> {
104        let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
105        let mut mt = build_mt_context(mux_ctrl.clone());
106        let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
107
108        // Sends protocol configuration to verifier for compatibility check.
109        mux_fut
110            .poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
111            .await?;
112
113        let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
114
115        // Allocate resources for MPC-TLS in the VM.
116        let mut keys = mpc_tls.alloc()?;
117        translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
118
119        // Allocate for committing to plaintext.
120        let mut zk_aes_ctr = ZkAesCtr::new(Role::Prover);
121        zk_aes_ctr.set_key(keys.server_write_key, keys.server_write_iv);
122        zk_aes_ctr.alloc(
123            &mut (*vm.try_lock().expect("VM is not locked").zk()),
124            self.config.protocol_config().max_recv_data(),
125        )?;
126
127        debug!("setting up mpc-tls");
128
129        mux_fut.poll_with(mpc_tls.preprocess()).await?;
130
131        debug!("mpc-tls setup complete");
132
133        Ok(Prover {
134            config: self.config,
135            span: self.span,
136            state: state::Setup {
137                mux_ctrl,
138                mux_fut,
139                mpc_tls,
140                zk_aes_ctr,
141                keys,
142                vm,
143            },
144        })
145    }
146}
147
148impl Prover<state::Setup> {
149    /// Connects to the server using the provided socket.
150    ///
151    /// Returns a handle to the TLS connection, a future which returns the
152    /// prover once the connection is closed.
153    ///
154    /// # Arguments
155    ///
156    /// * `socket` - The socket to the server.
157    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
158    pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
159        self,
160        socket: S,
161    ) -> Result<(TlsConnection, ProverFuture), ProverError> {
162        let state::Setup {
163            mux_ctrl,
164            mut mux_fut,
165            mpc_tls,
166            mut zk_aes_ctr,
167            keys,
168            vm,
169            ..
170        } = self.state;
171
172        let (mpc_ctrl, mpc_fut) = mpc_tls.run();
173
174        let server_name =
175            TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
176                ProverError::config(format!(
177                    "invalid server name: {}",
178                    self.config.server_name()
179                ))
180            })?;
181
182        let config = tls_client::ClientConfig::builder()
183            .with_safe_defaults()
184            .with_root_certificates(self.config.crypto_provider().cert.root_store().clone());
185
186        let config = if let Some((cert, key)) = self.config.tls_config().client_auth() {
187            config
188                .with_single_cert(cert.clone(), key.clone())
189                .map_err(ProverError::config)?
190        } else {
191            config.with_no_client_auth()
192        };
193
194        let client =
195            ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
196                .map_err(ProverError::config)?;
197
198        let (conn, conn_fut) = bind_client(socket, client);
199
200        let start_time = web_time::UNIX_EPOCH
201            .elapsed()
202            .expect("system time is available")
203            .as_secs();
204
205        let fut = Box::pin({
206            let span = self.span.clone();
207            let mpc_ctrl = mpc_ctrl.clone();
208            async move {
209                let conn_fut = async {
210                    mux_fut
211                        .poll_with(conn_fut.map_err(ProverError::from))
212                        .await?;
213
214                    mpc_ctrl.stop().await?;
215
216                    Ok::<_, ProverError>(())
217                };
218
219                info!("starting MPC-TLS");
220
221                let (_, (mut ctx, mut data, ..)) = futures::try_join!(
222                    conn_fut,
223                    mpc_fut.in_current_span().map_err(ProverError::from)
224                )?;
225
226                info!("finished MPC-TLS");
227
228                {
229                    let mut vm = vm.try_lock().expect("VM should not be locked");
230
231                    translate_transcript(&mut data.transcript, &vm)?;
232
233                    debug!("finalizing mpc");
234
235                    // Finalize DEAP.
236                    mux_fut
237                        .poll_with(vm.finalize(&mut ctx))
238                        .await
239                        .map_err(ProverError::mpc)?;
240
241                    debug!("mpc finalized");
242                }
243
244                // Pull out ZK VM.
245                let (_, mut vm) = Arc::into_inner(vm)
246                    .expect("vm should have only 1 reference")
247                    .into_inner()
248                    .into_inner();
249
250                // Prove tag verification of received records.
251                // The prover drops the proof output.
252                let _ = verify_tags(
253                    &mut vm,
254                    (data.keys.server_write_key, data.keys.server_write_iv),
255                    data.keys.server_write_mac_key,
256                    data.transcript.recv.clone(),
257                )
258                .map_err(ProverError::zk)?;
259
260                // Prove received plaintext. Prover drops the proof output, as
261                // they trust themselves.
262                _ = commit_records(
263                    &mut vm,
264                    &mut zk_aes_ctr,
265                    data.transcript
266                        .recv
267                        .iter_mut()
268                        .filter(|record| record.typ == ContentType::ApplicationData),
269                )
270                .map_err(ProverError::zk)?;
271
272                mux_fut
273                    .poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
274                    .await?;
275
276                let transcript = data
277                    .transcript
278                    .to_transcript()
279                    .expect("transcript is complete");
280                let transcript_refs = data
281                    .transcript
282                    .to_transcript_refs()
283                    .expect("transcript is complete");
284
285                let connection_info = ConnectionInfo {
286                    time: start_time,
287                    version: data
288                        .protocol_version
289                        .try_into()
290                        .expect("only supported version should have been accepted"),
291                    transcript_length: TranscriptLength {
292                        sent: transcript.sent().len() as u32,
293                        received: transcript.received().len() as u32,
294                    },
295                };
296
297                let server_cert_data =
298                    ServerCertData {
299                        certs: data
300                            .server_cert_details
301                            .cert_chain()
302                            .iter()
303                            .cloned()
304                            .map(|c| c.into())
305                            .collect(),
306                        sig: ServerSignature {
307                            scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
308                                "only supported signature scheme should have been accepted",
309                            ),
310                            sig: data.server_kx_details.kx_sig().sig.0.clone(),
311                        },
312                        handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
313                            client_random: data.client_random.0,
314                            server_random: data.server_random.0,
315                            server_ephemeral_key: data
316                                .server_key
317                                .try_into()
318                                .expect("only supported key scheme should have been accepted"),
319                        }),
320                    };
321
322                Ok(Prover {
323                    config: self.config,
324                    span: self.span,
325                    state: state::Committed {
326                        mux_ctrl,
327                        mux_fut,
328                        ctx,
329                        _keys: keys,
330                        vm,
331                        connection_info,
332                        server_cert_data,
333                        transcript,
334                        transcript_refs,
335                    },
336                })
337            }
338            .instrument(span)
339        });
340
341        Ok((
342            conn,
343            ProverFuture {
344                fut,
345                ctrl: ProverControl { mpc_ctrl },
346            },
347        ))
348    }
349}
350
351impl Prover<state::Committed> {
352    /// Returns the connection information.
353    pub fn connection_info(&self) -> &ConnectionInfo {
354        &self.state.connection_info
355    }
356
357    /// Returns the transcript.
358    pub fn transcript(&self) -> &Transcript {
359        &self.state.transcript
360    }
361
362    /// Proves information to the verifier.
363    ///
364    /// # Arguments
365    ///
366    /// * `config` - The disclosure configuration.
367    #[instrument(parent = &self.span, level = "info", skip_all, err)]
368    pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
369        let state::Committed {
370            mux_fut,
371            ctx,
372            vm,
373            server_cert_data,
374            transcript_refs,
375            ..
376        } = &mut self.state;
377
378        let mut output = ProverOutput {
379            transcript_commitments: Vec::new(),
380            transcript_secrets: Vec::new(),
381        };
382
383        let payload = ProvePayload {
384            server_identity: config
385                .server_identity()
386                .then(|| (self.config.server_name().clone(), server_cert_data.clone())),
387            transcript: config.transcript().cloned(),
388            transcript_commit: config.transcript_commit().map(|config| config.to_request()),
389        };
390
391        // Send payload.
392        mux_fut
393            .poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
394            .await?;
395
396        if let Some(partial_transcript) = config.transcript() {
397            decode_transcript(
398                vm,
399                partial_transcript.sent_authed(),
400                partial_transcript.received_authed(),
401                transcript_refs,
402            )
403            .map_err(ProverError::zk)?;
404        }
405
406        let mut hash_commitments = None;
407        if let Some(commit_config) = config.transcript_commit() {
408            if commit_config.has_encoding() {
409                let hasher = self
410                    .config
411                    .crypto_provider()
412                    .hash
413                    .get(commit_config.encoding_hash_alg())
414                    .map_err(ProverError::config)?;
415
416                let (commitment, tree) = mux_fut
417                    .poll_with(
418                        encoding::receive(
419                            ctx,
420                            hasher,
421                            transcript_refs,
422                            |plaintext| vm.get_macs(plaintext).expect("reference is valid"),
423                            commit_config.iter_encoding(),
424                        )
425                        .map_err(ProverError::commit),
426                    )
427                    .await?;
428
429                output
430                    .transcript_commitments
431                    .push(TranscriptCommitment::Encoding(commitment));
432                output
433                    .transcript_secrets
434                    .push(TranscriptSecret::Encoding(tree));
435            }
436
437            if commit_config.has_hash() {
438                hash_commitments = Some(
439                    prove_hash(
440                        vm,
441                        transcript_refs,
442                        commit_config
443                            .iter_hash()
444                            .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
445                    )
446                    .map_err(ProverError::commit)?,
447                );
448            }
449        }
450
451        mux_fut
452            .poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
453            .await?;
454
455        if let Some((hash_fut, hash_secrets)) = hash_commitments {
456            let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
457            for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
458                output
459                    .transcript_commitments
460                    .push(TranscriptCommitment::Hash(commitment));
461                output
462                    .transcript_secrets
463                    .push(TranscriptSecret::Hash(secret));
464            }
465        }
466
467        Ok(output)
468    }
469
470    /// Requests an attestation from the verifier.
471    ///
472    /// # Arguments
473    ///
474    /// * `config` - The attestation request configuration.
475    #[instrument(parent = &self.span, level = "info", skip_all, err)]
476    #[deprecated(
477        note = "attestation functionality will be removed from this API in future releases."
478    )]
479    pub async fn notarize(
480        &mut self,
481        config: &RequestConfig,
482    ) -> Result<(Attestation, Secrets), ProverError> {
483        let mut builder = ProveConfig::builder(self.transcript());
484
485        if let Some(config) = config.transcript_commit() {
486            // Temporarily, we reject attestation requests which contain hash commitments to
487            // subsets of the transcript. We do this because we want to preserve the
488            // obliviousness of the reference notary, and hash commitments currently leak
489            // the ranges which are being committed.
490            for ((direction, idx), _) in config.iter_hash() {
491                let len = match direction {
492                    Direction::Sent => self.transcript().sent().len(),
493                    Direction::Received => self.transcript().received().len(),
494                };
495
496                if idx.start() > 0 || idx.end() < len || idx.count() != 1 {
497                    return Err(ProverError::attestation(
498                        "hash commitments to subsets of the transcript are currently not supported in attestation requests",
499                    ));
500                }
501            }
502
503            builder.transcript_commit(config.clone());
504        }
505
506        let disclosure_config = builder.build().map_err(ProverError::attestation)?;
507
508        let ProverOutput {
509            transcript_commitments,
510            transcript_secrets,
511            ..
512        } = self.prove(&disclosure_config).await?;
513
514        let state::Committed {
515            mux_fut,
516            ctx,
517            server_cert_data,
518            transcript,
519            ..
520        } = &mut self.state;
521
522        let mut builder = Request::builder(config);
523
524        builder
525            .server_name(self.config.server_name().clone())
526            .server_cert_data(server_cert_data.clone())
527            .transcript(transcript.clone())
528            .transcript_commitments(transcript_secrets, transcript_commitments);
529
530        let (request, secrets) = builder
531            .build(self.config.crypto_provider())
532            .map_err(ProverError::attestation)?;
533
534        let attestation = mux_fut
535            .poll_with(async {
536                debug!("sending attestation request");
537
538                ctx.io_mut().send(request.clone()).await?;
539
540                let attestation: Attestation = ctx.io_mut().expect_next().await?;
541
542                Ok::<_, ProverError>(attestation)
543            })
544            .await?;
545
546        // Check the attestation is consistent with the Prover's view.
547        request
548            .validate(&attestation)
549            .map_err(ProverError::attestation)?;
550
551        Ok((attestation, secrets))
552    }
553
554    /// Closes the connection with the verifier.
555    #[instrument(parent = &self.span, level = "info", skip_all, err)]
556    pub async fn close(self) -> Result<(), ProverError> {
557        let state::Committed {
558            mux_ctrl, mux_fut, ..
559        } = self.state;
560
561        // Wait for the verifier to correctly close the connection.
562        if !mux_fut.is_complete() {
563            mux_ctrl.close();
564            mux_fut.await?;
565        }
566
567        Ok(())
568    }
569}
570
571fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
572    let mut rng = rand::rng();
573    let delta = Delta::new(Block::random(&mut rng));
574
575    let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
576    let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
577    let rcot_send = mpz_ot::kos::Sender::new(
578        mpz_ot::kos::SenderConfig::default(),
579        delta.into_inner(),
580        base_ot_recv,
581    );
582    let rcot_recv =
583        mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
584    let rcot_recv = mpz_ot::ferret::Receiver::new(
585        mpz_ot::ferret::FerretConfig::builder()
586            .lpn_type(mpz_ot::ferret::LpnType::Regular)
587            .build()
588            .expect("ferret config is valid"),
589        Block::random(&mut rng),
590        rcot_recv,
591    );
592
593    let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
594    let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
595
596    let mpc = Mpc::new(
597        mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
598        rng.random(),
599        delta,
600    );
601
602    let zk = Zk::new(rcot_recv.clone());
603
604    let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
605
606    (
607        vm.clone(),
608        MpcTlsLeader::new(
609            config.build_mpc_tls_config(),
610            ctx,
611            vm,
612            (rcot_send.clone(), rcot_send.clone(), rcot_send),
613            rcot_recv,
614        ),
615    )
616}
617
618/// A controller for the prover.
619#[derive(Clone)]
620pub struct ProverControl {
621    mpc_ctrl: LeaderCtrl,
622}
623
624impl ProverControl {
625    /// Defers decryption of data from the server until the server has closed
626    /// the connection.
627    ///
628    /// This is a performance optimization which will significantly reduce the
629    /// amount of upload bandwidth used by the prover.
630    ///
631    /// # Notes
632    ///
633    /// * The prover may need to close the connection to the server in order for
634    ///   it to close the connection on its end. If neither the prover or server
635    ///   close the connection this will cause a deadlock.
636    pub async fn defer_decryption(&self) -> Result<(), ProverError> {
637        self.mpc_ctrl
638            .defer_decryption()
639            .await
640            .map_err(ProverError::from)
641    }
642}
643
644/// Translates VM references to the ZK address space.
645fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
646    keys.client_write_key = vm
647        .translate(keys.client_write_key)
648        .map_err(ProverError::mpc)?;
649    keys.client_write_iv = vm
650        .translate(keys.client_write_iv)
651        .map_err(ProverError::mpc)?;
652    keys.server_write_key = vm
653        .translate(keys.server_write_key)
654        .map_err(ProverError::mpc)?;
655    keys.server_write_iv = vm
656        .translate(keys.server_write_iv)
657        .map_err(ProverError::mpc)?;
658    keys.server_write_mac_key = vm
659        .translate(keys.server_write_mac_key)
660        .map_err(ProverError::mpc)?;
661
662    Ok(())
663}
664
665/// Translates VM references to the ZK address space.
666fn translate_transcript<Mpc, Zk>(
667    transcript: &mut TlsTranscript,
668    vm: &Deap<Mpc, Zk>,
669) -> Result<(), ProverError> {
670    for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
671    {
672        if let Some(plaintext_ref) = plaintext_ref.as_mut() {
673            *plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?;
674        }
675    }
676
677    Ok(())
678}