1#![deny(missing_docs, unreachable_pub, unused_must_use)]
4#![deny(clippy::all)]
5#![forbid(unsafe_code)]
6
7mod config;
8mod error;
9mod future;
10pub mod state;
11
12pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
13pub use error::ProverError;
14pub use future::ProverFuture;
15pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
16
17use mpz_common::Context;
18use mpz_core::Block;
19use mpz_garble_core::Delta;
20use mpz_vm_core::prelude::*;
21
22use futures::{AsyncRead, AsyncWrite, TryFutureExt};
23use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
24use rand::Rng;
25use serio::{stream::IoStreamExt, SinkExt};
26use std::sync::Arc;
27use tls_client::{ClientConnection, ServerName as TlsServerName};
28use tls_client_async::{bind_client, TlsConnection};
29use tls_core::msgs::enums::ContentType;
30use tlsn_common::{
31 commit::{commit_records, hash::prove_hash},
32 context::build_mt_context,
33 encoding,
34 mux::attach_mux,
35 tag::verify_tags,
36 transcript::{decode_transcript, Record, TlsTranscript},
37 zk_aes_ctr::ZkAesCtr,
38 Role,
39};
40use tlsn_core::{
41 attestation::Attestation,
42 connection::{
43 ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature,
44 TranscriptLength,
45 },
46 request::{Request, RequestConfig},
47 transcript::{Direction, Transcript, TranscriptCommitment, TranscriptSecret},
48 ProvePayload, Secrets,
49};
50use tlsn_deap::Deap;
51use tokio::sync::Mutex;
52
53use tracing::{debug, info, info_span, instrument, Instrument, Span};
54
55pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
56 mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
57 mpz_core::Block,
58>;
59pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
60 mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
61 bool,
62 mpz_core::Block,
63>;
64pub(crate) type Mpc =
65 mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
66pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
67
68#[derive(Debug)]
70pub struct Prover<T: state::ProverState = state::Initialized> {
71 config: ProverConfig,
72 span: Span,
73 state: T,
74}
75
76impl Prover<state::Initialized> {
77 pub fn new(config: ProverConfig) -> Self {
83 let span = info_span!("prover");
84 Self {
85 config,
86 span,
87 state: state::Initialized,
88 }
89 }
90
91 #[instrument(parent = &self.span, level = "debug", skip_all, err)]
100 pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
101 self,
102 socket: S,
103 ) -> Result<Prover<state::Setup>, ProverError> {
104 let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
105 let mut mt = build_mt_context(mux_ctrl.clone());
106 let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
107
108 mux_fut
110 .poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
111 .await?;
112
113 let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
114
115 let mut keys = mpc_tls.alloc()?;
117 translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
118
119 let mut zk_aes_ctr = ZkAesCtr::new(Role::Prover);
121 zk_aes_ctr.set_key(keys.server_write_key, keys.server_write_iv);
122 zk_aes_ctr.alloc(
123 &mut (*vm.try_lock().expect("VM is not locked").zk()),
124 self.config.protocol_config().max_recv_data(),
125 )?;
126
127 debug!("setting up mpc-tls");
128
129 mux_fut.poll_with(mpc_tls.preprocess()).await?;
130
131 debug!("mpc-tls setup complete");
132
133 Ok(Prover {
134 config: self.config,
135 span: self.span,
136 state: state::Setup {
137 mux_ctrl,
138 mux_fut,
139 mpc_tls,
140 zk_aes_ctr,
141 keys,
142 vm,
143 },
144 })
145 }
146}
147
148impl Prover<state::Setup> {
149 #[instrument(parent = &self.span, level = "debug", skip_all, err)]
158 pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
159 self,
160 socket: S,
161 ) -> Result<(TlsConnection, ProverFuture), ProverError> {
162 let state::Setup {
163 mux_ctrl,
164 mut mux_fut,
165 mpc_tls,
166 mut zk_aes_ctr,
167 keys,
168 vm,
169 ..
170 } = self.state;
171
172 let (mpc_ctrl, mpc_fut) = mpc_tls.run();
173
174 let server_name =
175 TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
176 ProverError::config(format!(
177 "invalid server name: {}",
178 self.config.server_name()
179 ))
180 })?;
181
182 let config = tls_client::ClientConfig::builder()
183 .with_safe_defaults()
184 .with_root_certificates(self.config.crypto_provider().cert.root_store().clone());
185
186 let config = if let Some((cert, key)) = self.config.tls_config().client_auth() {
187 config
188 .with_single_cert(cert.clone(), key.clone())
189 .map_err(ProverError::config)?
190 } else {
191 config.with_no_client_auth()
192 };
193
194 let client =
195 ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
196 .map_err(ProverError::config)?;
197
198 let (conn, conn_fut) = bind_client(socket, client);
199
200 let start_time = web_time::UNIX_EPOCH
201 .elapsed()
202 .expect("system time is available")
203 .as_secs();
204
205 let fut = Box::pin({
206 let span = self.span.clone();
207 let mpc_ctrl = mpc_ctrl.clone();
208 async move {
209 let conn_fut = async {
210 mux_fut
211 .poll_with(conn_fut.map_err(ProverError::from))
212 .await?;
213
214 mpc_ctrl.stop().await?;
215
216 Ok::<_, ProverError>(())
217 };
218
219 info!("starting MPC-TLS");
220
221 let (_, (mut ctx, mut data, ..)) = futures::try_join!(
222 conn_fut,
223 mpc_fut.in_current_span().map_err(ProverError::from)
224 )?;
225
226 info!("finished MPC-TLS");
227
228 {
229 let mut vm = vm.try_lock().expect("VM should not be locked");
230
231 translate_transcript(&mut data.transcript, &vm)?;
232
233 debug!("finalizing mpc");
234
235 mux_fut
237 .poll_with(vm.finalize(&mut ctx))
238 .await
239 .map_err(ProverError::mpc)?;
240
241 debug!("mpc finalized");
242 }
243
244 let (_, mut vm) = Arc::into_inner(vm)
246 .expect("vm should have only 1 reference")
247 .into_inner()
248 .into_inner();
249
250 let _ = verify_tags(
253 &mut vm,
254 (data.keys.server_write_key, data.keys.server_write_iv),
255 data.keys.server_write_mac_key,
256 data.transcript.recv.clone(),
257 )
258 .map_err(ProverError::zk)?;
259
260 _ = commit_records(
263 &mut vm,
264 &mut zk_aes_ctr,
265 data.transcript
266 .recv
267 .iter_mut()
268 .filter(|record| record.typ == ContentType::ApplicationData),
269 )
270 .map_err(ProverError::zk)?;
271
272 mux_fut
273 .poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
274 .await?;
275
276 let transcript = data
277 .transcript
278 .to_transcript()
279 .expect("transcript is complete");
280 let transcript_refs = data
281 .transcript
282 .to_transcript_refs()
283 .expect("transcript is complete");
284
285 let connection_info = ConnectionInfo {
286 time: start_time,
287 version: data
288 .protocol_version
289 .try_into()
290 .expect("only supported version should have been accepted"),
291 transcript_length: TranscriptLength {
292 sent: transcript.sent().len() as u32,
293 received: transcript.received().len() as u32,
294 },
295 };
296
297 let server_cert_data =
298 ServerCertData {
299 certs: data
300 .server_cert_details
301 .cert_chain()
302 .iter()
303 .cloned()
304 .map(|c| c.into())
305 .collect(),
306 sig: ServerSignature {
307 scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
308 "only supported signature scheme should have been accepted",
309 ),
310 sig: data.server_kx_details.kx_sig().sig.0.clone(),
311 },
312 handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
313 client_random: data.client_random.0,
314 server_random: data.server_random.0,
315 server_ephemeral_key: data
316 .server_key
317 .try_into()
318 .expect("only supported key scheme should have been accepted"),
319 }),
320 };
321
322 Ok(Prover {
323 config: self.config,
324 span: self.span,
325 state: state::Committed {
326 mux_ctrl,
327 mux_fut,
328 ctx,
329 _keys: keys,
330 vm,
331 connection_info,
332 server_cert_data,
333 transcript,
334 transcript_refs,
335 },
336 })
337 }
338 .instrument(span)
339 });
340
341 Ok((
342 conn,
343 ProverFuture {
344 fut,
345 ctrl: ProverControl { mpc_ctrl },
346 },
347 ))
348 }
349}
350
351impl Prover<state::Committed> {
352 pub fn connection_info(&self) -> &ConnectionInfo {
354 &self.state.connection_info
355 }
356
357 pub fn transcript(&self) -> &Transcript {
359 &self.state.transcript
360 }
361
362 #[instrument(parent = &self.span, level = "info", skip_all, err)]
368 pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
369 let state::Committed {
370 mux_fut,
371 ctx,
372 vm,
373 server_cert_data,
374 transcript_refs,
375 ..
376 } = &mut self.state;
377
378 let mut output = ProverOutput {
379 transcript_commitments: Vec::new(),
380 transcript_secrets: Vec::new(),
381 };
382
383 let payload = ProvePayload {
384 server_identity: config
385 .server_identity()
386 .then(|| (self.config.server_name().clone(), server_cert_data.clone())),
387 transcript: config.transcript().cloned(),
388 transcript_commit: config.transcript_commit().map(|config| config.to_request()),
389 };
390
391 mux_fut
393 .poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
394 .await?;
395
396 if let Some(partial_transcript) = config.transcript() {
397 decode_transcript(
398 vm,
399 partial_transcript.sent_authed(),
400 partial_transcript.received_authed(),
401 transcript_refs,
402 )
403 .map_err(ProverError::zk)?;
404 }
405
406 let mut hash_commitments = None;
407 if let Some(commit_config) = config.transcript_commit() {
408 if commit_config.has_encoding() {
409 let hasher = self
410 .config
411 .crypto_provider()
412 .hash
413 .get(commit_config.encoding_hash_alg())
414 .map_err(ProverError::config)?;
415
416 let (commitment, tree) = mux_fut
417 .poll_with(
418 encoding::receive(
419 ctx,
420 hasher,
421 transcript_refs,
422 |plaintext| vm.get_macs(plaintext).expect("reference is valid"),
423 commit_config.iter_encoding(),
424 )
425 .map_err(ProverError::commit),
426 )
427 .await?;
428
429 output
430 .transcript_commitments
431 .push(TranscriptCommitment::Encoding(commitment));
432 output
433 .transcript_secrets
434 .push(TranscriptSecret::Encoding(tree));
435 }
436
437 if commit_config.has_hash() {
438 hash_commitments = Some(
439 prove_hash(
440 vm,
441 transcript_refs,
442 commit_config
443 .iter_hash()
444 .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
445 )
446 .map_err(ProverError::commit)?,
447 );
448 }
449 }
450
451 mux_fut
452 .poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
453 .await?;
454
455 if let Some((hash_fut, hash_secrets)) = hash_commitments {
456 let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
457 for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
458 output
459 .transcript_commitments
460 .push(TranscriptCommitment::Hash(commitment));
461 output
462 .transcript_secrets
463 .push(TranscriptSecret::Hash(secret));
464 }
465 }
466
467 Ok(output)
468 }
469
470 #[instrument(parent = &self.span, level = "info", skip_all, err)]
476 #[deprecated(
477 note = "attestation functionality will be removed from this API in future releases."
478 )]
479 pub async fn notarize(
480 &mut self,
481 config: &RequestConfig,
482 ) -> Result<(Attestation, Secrets), ProverError> {
483 let mut builder = ProveConfig::builder(self.transcript());
484
485 if let Some(config) = config.transcript_commit() {
486 for ((direction, idx), _) in config.iter_hash() {
491 let len = match direction {
492 Direction::Sent => self.transcript().sent().len(),
493 Direction::Received => self.transcript().received().len(),
494 };
495
496 if idx.start() > 0 || idx.end() < len || idx.count() != 1 {
497 return Err(ProverError::attestation(
498 "hash commitments to subsets of the transcript are currently not supported in attestation requests",
499 ));
500 }
501 }
502
503 builder.transcript_commit(config.clone());
504 }
505
506 let disclosure_config = builder.build().map_err(ProverError::attestation)?;
507
508 let ProverOutput {
509 transcript_commitments,
510 transcript_secrets,
511 ..
512 } = self.prove(&disclosure_config).await?;
513
514 let state::Committed {
515 mux_fut,
516 ctx,
517 server_cert_data,
518 transcript,
519 ..
520 } = &mut self.state;
521
522 let mut builder = Request::builder(config);
523
524 builder
525 .server_name(self.config.server_name().clone())
526 .server_cert_data(server_cert_data.clone())
527 .transcript(transcript.clone())
528 .transcript_commitments(transcript_secrets, transcript_commitments);
529
530 let (request, secrets) = builder
531 .build(self.config.crypto_provider())
532 .map_err(ProverError::attestation)?;
533
534 let attestation = mux_fut
535 .poll_with(async {
536 debug!("sending attestation request");
537
538 ctx.io_mut().send(request.clone()).await?;
539
540 let attestation: Attestation = ctx.io_mut().expect_next().await?;
541
542 Ok::<_, ProverError>(attestation)
543 })
544 .await?;
545
546 request
548 .validate(&attestation)
549 .map_err(ProverError::attestation)?;
550
551 Ok((attestation, secrets))
552 }
553
554 #[instrument(parent = &self.span, level = "info", skip_all, err)]
556 pub async fn close(self) -> Result<(), ProverError> {
557 let state::Committed {
558 mux_ctrl, mux_fut, ..
559 } = self.state;
560
561 if !mux_fut.is_complete() {
563 mux_ctrl.close();
564 mux_fut.await?;
565 }
566
567 Ok(())
568 }
569}
570
571fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
572 let mut rng = rand::rng();
573 let delta = Delta::new(Block::random(&mut rng));
574
575 let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
576 let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
577 let rcot_send = mpz_ot::kos::Sender::new(
578 mpz_ot::kos::SenderConfig::default(),
579 delta.into_inner(),
580 base_ot_recv,
581 );
582 let rcot_recv =
583 mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
584 let rcot_recv = mpz_ot::ferret::Receiver::new(
585 mpz_ot::ferret::FerretConfig::builder()
586 .lpn_type(mpz_ot::ferret::LpnType::Regular)
587 .build()
588 .expect("ferret config is valid"),
589 Block::random(&mut rng),
590 rcot_recv,
591 );
592
593 let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
594 let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
595
596 let mpc = Mpc::new(
597 mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
598 rng.random(),
599 delta,
600 );
601
602 let zk = Zk::new(rcot_recv.clone());
603
604 let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
605
606 (
607 vm.clone(),
608 MpcTlsLeader::new(
609 config.build_mpc_tls_config(),
610 ctx,
611 vm,
612 (rcot_send.clone(), rcot_send.clone(), rcot_send),
613 rcot_recv,
614 ),
615 )
616}
617
618#[derive(Clone)]
620pub struct ProverControl {
621 mpc_ctrl: LeaderCtrl,
622}
623
624impl ProverControl {
625 pub async fn defer_decryption(&self) -> Result<(), ProverError> {
637 self.mpc_ctrl
638 .defer_decryption()
639 .await
640 .map_err(ProverError::from)
641 }
642}
643
644fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
646 keys.client_write_key = vm
647 .translate(keys.client_write_key)
648 .map_err(ProverError::mpc)?;
649 keys.client_write_iv = vm
650 .translate(keys.client_write_iv)
651 .map_err(ProverError::mpc)?;
652 keys.server_write_key = vm
653 .translate(keys.server_write_key)
654 .map_err(ProverError::mpc)?;
655 keys.server_write_iv = vm
656 .translate(keys.server_write_iv)
657 .map_err(ProverError::mpc)?;
658 keys.server_write_mac_key = vm
659 .translate(keys.server_write_mac_key)
660 .map_err(ProverError::mpc)?;
661
662 Ok(())
663}
664
665fn translate_transcript<Mpc, Zk>(
667 transcript: &mut TlsTranscript,
668 vm: &Deap<Mpc, Zk>,
669) -> Result<(), ProverError> {
670 for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
671 {
672 if let Some(plaintext_ref) = plaintext_ref.as_mut() {
673 *plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?;
674 }
675 }
676
677 Ok(())
678}