1#![deny(missing_docs, unreachable_pub, unused_must_use)]
4#![deny(clippy::all)]
5#![forbid(unsafe_code)]
6
7mod config;
8mod error;
9mod future;
10pub mod state;
11
12pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError};
13pub use error::ProverError;
14pub use future::ProverFuture;
15pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
16
17use mpz_common::Context;
18use mpz_core::Block;
19use mpz_garble_core::Delta;
20use mpz_vm_core::prelude::*;
21
22use futures::{AsyncRead, AsyncWrite, TryFutureExt};
23use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
24use rand::Rng;
25use serio::{stream::IoStreamExt, SinkExt};
26use std::sync::Arc;
27use tls_client::{ClientConnection, ServerName as TlsServerName};
28use tls_client_async::{bind_client, TlsConnection};
29use tls_core::msgs::enums::ContentType;
30use tlsn_common::{
31 commit::{commit_records, hash::prove_hash},
32 context::build_mt_context,
33 encoding,
34 mux::attach_mux,
35 transcript::{decode_transcript, Record, TlsTranscript},
36 zk_aes::ZkAesCtr,
37 Role,
38};
39use tlsn_core::{
40 attestation::Attestation,
41 connection::{
42 ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature,
43 TranscriptLength,
44 },
45 request::{Request, RequestConfig},
46 transcript::{Direction, Transcript, TranscriptCommitment, TranscriptSecret},
47 ProvePayload, Secrets,
48};
49use tlsn_deap::Deap;
50use tokio::sync::Mutex;
51
52use tracing::{debug, info_span, instrument, Instrument, Span};
53
54pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
55 mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
56 mpz_core::Block,
57>;
58pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
59 mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
60 bool,
61 mpz_core::Block,
62>;
63pub(crate) type Mpc =
64 mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
65pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
66
67#[derive(Debug)]
69pub struct Prover<T: state::ProverState = state::Initialized> {
70 config: ProverConfig,
71 span: Span,
72 state: T,
73}
74
75impl Prover<state::Initialized> {
76 pub fn new(config: ProverConfig) -> Self {
82 let span = info_span!("prover");
83 Self {
84 config,
85 span,
86 state: state::Initialized,
87 }
88 }
89
90 #[instrument(parent = &self.span, level = "debug", skip_all, err)]
99 pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
100 self,
101 socket: S,
102 ) -> Result<Prover<state::Setup>, ProverError> {
103 let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
104 let mut mt = build_mt_context(mux_ctrl.clone());
105 let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
106
107 mux_fut
109 .poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
110 .await?;
111
112 let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
113
114 let mut keys = mpc_tls.alloc()?;
116 translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
117
118 let mut zk_aes = ZkAesCtr::new(Role::Prover);
120 zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
121 zk_aes.alloc(
122 &mut (*vm.try_lock().expect("VM is not locked").zk()),
123 self.config.protocol_config().max_recv_data(),
124 )?;
125
126 debug!("setting up mpc-tls");
127
128 mux_fut.poll_with(mpc_tls.preprocess()).await?;
129
130 debug!("mpc-tls setup complete");
131
132 Ok(Prover {
133 config: self.config,
134 span: self.span,
135 state: state::Setup {
136 mux_ctrl,
137 mux_fut,
138 mpc_tls,
139 zk_aes,
140 keys,
141 vm,
142 },
143 })
144 }
145}
146
147impl Prover<state::Setup> {
148 #[instrument(parent = &self.span, level = "debug", skip_all, err)]
157 pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
158 self,
159 socket: S,
160 ) -> Result<(TlsConnection, ProverFuture), ProverError> {
161 let state::Setup {
162 mux_ctrl,
163 mut mux_fut,
164 mpc_tls,
165 mut zk_aes,
166 keys,
167 vm,
168 ..
169 } = self.state;
170
171 let (mpc_ctrl, mpc_fut) = mpc_tls.run();
172
173 let server_name =
174 TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
175 ProverError::config(format!(
176 "invalid server name: {}",
177 self.config.server_name()
178 ))
179 })?;
180
181 let config = tls_client::ClientConfig::builder()
182 .with_safe_defaults()
183 .with_root_certificates(self.config.crypto_provider().cert.root_store().clone())
184 .with_no_client_auth();
185 let client =
186 ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
187 .map_err(ProverError::config)?;
188
189 let (conn, conn_fut) = bind_client(socket, client);
190
191 let start_time = web_time::UNIX_EPOCH
192 .elapsed()
193 .expect("system time is available")
194 .as_secs();
195
196 let fut = Box::pin({
197 let span = self.span.clone();
198 let mpc_ctrl = mpc_ctrl.clone();
199 async move {
200 let conn_fut = async {
201 mux_fut
202 .poll_with(conn_fut.map_err(ProverError::from))
203 .await?;
204
205 mpc_ctrl.stop().await?;
206
207 Ok::<_, ProverError>(())
208 };
209
210 let (_, (mut ctx, mut data)) = futures::try_join!(
211 conn_fut,
212 mpc_fut.in_current_span().map_err(ProverError::from)
213 )?;
214
215 {
216 let mut vm = vm.try_lock().expect("VM should not be locked");
217
218 translate_transcript(&mut data.transcript, &vm)?;
219
220 _ = commit_records(
223 &mut (*vm.zk()),
224 &mut zk_aes,
225 data.transcript
226 .recv
227 .iter_mut()
228 .filter(|record| record.typ == ContentType::ApplicationData),
229 )
230 .map_err(ProverError::zk)?;
231
232 debug!("finalizing mpc");
233
234 mux_fut
236 .poll_with(vm.finalize(&mut ctx))
237 .await
238 .map_err(ProverError::mpc)?;
239
240 debug!("mpc finalized");
241 }
242
243 let transcript = data
244 .transcript
245 .to_transcript()
246 .expect("transcript is complete");
247 let transcript_refs = data
248 .transcript
249 .to_transcript_refs()
250 .expect("transcript is complete");
251
252 let connection_info = ConnectionInfo {
253 time: start_time,
254 version: data
255 .protocol_version
256 .try_into()
257 .expect("only supported version should have been accepted"),
258 transcript_length: TranscriptLength {
259 sent: transcript.sent().len() as u32,
260 received: transcript.received().len() as u32,
261 },
262 };
263
264 let server_cert_data =
265 ServerCertData {
266 certs: data
267 .server_cert_details
268 .cert_chain()
269 .iter()
270 .cloned()
271 .map(|c| c.into())
272 .collect(),
273 sig: ServerSignature {
274 scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
275 "only supported signature scheme should have been accepted",
276 ),
277 sig: data.server_kx_details.kx_sig().sig.0.clone(),
278 },
279 handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
280 client_random: data.client_random.0,
281 server_random: data.server_random.0,
282 server_ephemeral_key: data
283 .server_key
284 .try_into()
285 .expect("only supported key scheme should have been accepted"),
286 }),
287 };
288
289 let (_, vm) = Arc::into_inner(vm)
291 .expect("vm should have only 1 reference")
292 .into_inner()
293 .into_inner();
294
295 Ok(Prover {
296 config: self.config,
297 span: self.span,
298 state: state::Committed {
299 mux_ctrl,
300 mux_fut,
301 ctx,
302 _keys: keys,
303 vm,
304 connection_info,
305 server_cert_data,
306 transcript,
307 transcript_refs,
308 },
309 })
310 }
311 .instrument(span)
312 });
313
314 Ok((
315 conn,
316 ProverFuture {
317 fut,
318 ctrl: ProverControl { mpc_ctrl },
319 },
320 ))
321 }
322}
323
324impl Prover<state::Committed> {
325 pub fn connection_info(&self) -> &ConnectionInfo {
327 &self.state.connection_info
328 }
329
330 pub fn transcript(&self) -> &Transcript {
332 &self.state.transcript
333 }
334
335 #[instrument(parent = &self.span, level = "info", skip_all, err)]
341 pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
342 let state::Committed {
343 mux_fut,
344 ctx,
345 vm,
346 server_cert_data,
347 transcript_refs,
348 ..
349 } = &mut self.state;
350
351 let mut output = ProverOutput {
352 transcript_commitments: Vec::new(),
353 transcript_secrets: Vec::new(),
354 };
355
356 let payload = ProvePayload {
357 server_identity: config
358 .server_identity()
359 .then(|| (self.config.server_name().clone(), server_cert_data.clone())),
360 transcript: config.transcript().cloned(),
361 transcript_commit: config.transcript_commit().map(|config| config.to_request()),
362 };
363
364 mux_fut
366 .poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
367 .await?;
368
369 if let Some(partial_transcript) = config.transcript() {
370 decode_transcript(
371 vm,
372 partial_transcript.sent_authed(),
373 partial_transcript.received_authed(),
374 transcript_refs,
375 )
376 .map_err(ProverError::zk)?;
377 }
378
379 let mut hash_commitments = None;
380 if let Some(commit_config) = config.transcript_commit() {
381 if commit_config.has_encoding() {
382 let hasher = self
383 .config
384 .crypto_provider()
385 .hash
386 .get(commit_config.encoding_hash_alg())
387 .map_err(ProverError::config)?;
388
389 let (commitment, tree) = mux_fut
390 .poll_with(
391 encoding::receive(
392 ctx,
393 hasher,
394 transcript_refs,
395 |plaintext| vm.get_macs(plaintext).expect("reference is valid"),
396 commit_config.iter_encoding(),
397 )
398 .map_err(ProverError::commit),
399 )
400 .await?;
401
402 output
403 .transcript_commitments
404 .push(TranscriptCommitment::Encoding(commitment));
405 output
406 .transcript_secrets
407 .push(TranscriptSecret::Encoding(tree));
408 }
409
410 if commit_config.has_hash() {
411 hash_commitments = Some(
412 prove_hash(
413 vm,
414 transcript_refs,
415 commit_config
416 .iter_hash()
417 .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
418 )
419 .map_err(ProverError::commit)?,
420 );
421 }
422 }
423
424 mux_fut
425 .poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
426 .await?;
427
428 if let Some((hash_fut, hash_secrets)) = hash_commitments {
429 let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
430 for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
431 output
432 .transcript_commitments
433 .push(TranscriptCommitment::Hash(commitment));
434 output
435 .transcript_secrets
436 .push(TranscriptSecret::Hash(secret));
437 }
438 }
439
440 Ok(output)
441 }
442
443 #[instrument(parent = &self.span, level = "info", skip_all, err)]
449 #[deprecated(
450 note = "attestation functionality will be removed from this API in future releases."
451 )]
452 pub async fn notarize(
453 &mut self,
454 config: &RequestConfig,
455 ) -> Result<(Attestation, Secrets), ProverError> {
456 let mut builder = ProveConfig::builder(self.transcript());
457
458 if let Some(config) = config.transcript_commit() {
459 for ((direction, idx), _) in config.iter_hash() {
464 let len = match direction {
465 Direction::Sent => self.transcript().sent().len(),
466 Direction::Received => self.transcript().received().len(),
467 };
468
469 if idx.start() > 0 || idx.end() < len || idx.count() != 1 {
470 return Err(ProverError::attestation(
471 "hash commitments to subsets of the transcript are currently not supported in attestation requests",
472 ));
473 }
474 }
475
476 builder.transcript_commit(config.clone());
477 }
478
479 let disclosure_config = builder.build().map_err(ProverError::attestation)?;
480
481 let ProverOutput {
482 transcript_commitments,
483 transcript_secrets,
484 ..
485 } = self.prove(&disclosure_config).await?;
486
487 let state::Committed {
488 mux_fut,
489 ctx,
490 server_cert_data,
491 transcript,
492 ..
493 } = &mut self.state;
494
495 let mut builder = Request::builder(config);
496
497 builder
498 .server_name(self.config.server_name().clone())
499 .server_cert_data(server_cert_data.clone())
500 .transcript(transcript.clone())
501 .transcript_commitments(transcript_secrets, transcript_commitments);
502
503 let (request, secrets) = builder
504 .build(self.config.crypto_provider())
505 .map_err(ProverError::attestation)?;
506
507 let attestation = mux_fut
508 .poll_with(async {
509 debug!("sending attestation request");
510
511 ctx.io_mut().send(request.clone()).await?;
512
513 let attestation: Attestation = ctx.io_mut().expect_next().await?;
514
515 Ok::<_, ProverError>(attestation)
516 })
517 .await?;
518
519 request
521 .validate(&attestation)
522 .map_err(ProverError::attestation)?;
523
524 Ok((attestation, secrets))
525 }
526
527 #[instrument(parent = &self.span, level = "info", skip_all, err)]
529 pub async fn close(self) -> Result<(), ProverError> {
530 let state::Committed {
531 mux_ctrl, mux_fut, ..
532 } = self.state;
533
534 if !mux_fut.is_complete() {
536 mux_ctrl.close();
537 mux_fut.await?;
538 }
539
540 Ok(())
541 }
542}
543
544fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
545 let mut rng = rand::rng();
546 let delta = Delta::new(Block::random(&mut rng));
547
548 let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
549 let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
550 let rcot_send = mpz_ot::kos::Sender::new(
551 mpz_ot::kos::SenderConfig::default(),
552 delta.into_inner(),
553 base_ot_recv,
554 );
555 let rcot_recv =
556 mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
557 let rcot_recv = mpz_ot::ferret::Receiver::new(
558 mpz_ot::ferret::FerretConfig::builder()
559 .lpn_type(mpz_ot::ferret::LpnType::Regular)
560 .build()
561 .expect("ferret config is valid"),
562 Block::random(&mut rng),
563 rcot_recv,
564 );
565
566 let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
567 let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
568
569 let mpc = Mpc::new(
570 mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
571 rng.random(),
572 delta,
573 );
574
575 let zk = Zk::new(rcot_recv.clone());
576
577 let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
578
579 (
580 vm.clone(),
581 MpcTlsLeader::new(
582 config.build_mpc_tls_config(),
583 ctx,
584 vm,
585 (rcot_send.clone(), rcot_send.clone(), rcot_send),
586 rcot_recv,
587 ),
588 )
589}
590
591#[derive(Clone)]
593pub struct ProverControl {
594 mpc_ctrl: LeaderCtrl,
595}
596
597impl ProverControl {
598 pub async fn defer_decryption(&self) -> Result<(), ProverError> {
610 self.mpc_ctrl
611 .defer_decryption()
612 .await
613 .map_err(ProverError::from)
614 }
615}
616
617fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
619 keys.client_write_key = vm
620 .translate(keys.client_write_key)
621 .map_err(ProverError::mpc)?;
622 keys.client_write_iv = vm
623 .translate(keys.client_write_iv)
624 .map_err(ProverError::mpc)?;
625 keys.server_write_key = vm
626 .translate(keys.server_write_key)
627 .map_err(ProverError::mpc)?;
628 keys.server_write_iv = vm
629 .translate(keys.server_write_iv)
630 .map_err(ProverError::mpc)?;
631
632 Ok(())
633}
634
635fn translate_transcript<Mpc, Zk>(
637 transcript: &mut TlsTranscript,
638 vm: &Deap<Mpc, Zk>,
639) -> Result<(), ProverError> {
640 for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
641 {
642 if let Some(plaintext_ref) = plaintext_ref.as_mut() {
643 *plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?;
644 }
645 }
646
647 Ok(())
648}