1#![deny(missing_docs, unreachable_pub, unused_must_use)]
4#![deny(clippy::all)]
5#![forbid(unsafe_code)]
6
7mod config;
8mod error;
9mod future;
10mod notarize;
11mod prove;
12pub mod state;
13
14pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError};
15pub use error::ProverError;
16pub use future::ProverFuture;
17use mpz_common::Context;
18use mpz_core::Block;
19use mpz_garble_core::Delta;
20use rand06_compat::Rand0_6CompatExt;
21use state::{Notarize, Prove};
22
23use futures::{AsyncRead, AsyncWrite, TryFutureExt};
24use mpc_tls::{LeaderCtrl, MpcTlsLeader};
25use rand::Rng;
26use serio::SinkExt;
27use std::sync::Arc;
28use tls_client::{ClientConnection, ServerName as TlsServerName};
29use tls_client_async::{bind_client, TlsConnection};
30use tls_core::msgs::enums::ContentType;
31use tlsn_common::{
32 commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role,
33};
34use tlsn_core::{
35 connection::{
36 ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature,
37 TranscriptLength,
38 },
39 transcript::Transcript,
40};
41use tlsn_deap::Deap;
42use tokio::sync::Mutex;
43
44use tracing::{debug, info_span, instrument, Instrument, Span};
45
46pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
47 mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
48 mpz_core::Block,
49>;
50pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
51 mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
52 bool,
53 mpz_core::Block,
54>;
55pub(crate) type Mpc =
56 mpz_garble::protocol::semihonest::Generator<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
57pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
58
59#[derive(Debug)]
61pub struct Prover<T: state::ProverState> {
62 config: ProverConfig,
63 span: Span,
64 state: T,
65}
66
67impl Prover<state::Initialized> {
68 pub fn new(config: ProverConfig) -> Self {
74 let span = info_span!("prover");
75 Self {
76 config,
77 span,
78 state: state::Initialized,
79 }
80 }
81
82 #[instrument(parent = &self.span, level = "debug", skip_all, err)]
91 pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
92 self,
93 socket: S,
94 ) -> Result<Prover<state::Setup>, ProverError> {
95 let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
96 let mut mt = build_mt_context(mux_ctrl.clone());
97 let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
98
99 mux_fut
101 .poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
102 .await?;
103
104 let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
105
106 let keys = mpc_tls.alloc()?;
108 let mut zk_aes = ZkAesCtr::new(Role::Prover);
110 zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
111 zk_aes.alloc(
112 &mut (*vm.try_lock().expect("VM is not locked").zk()),
113 self.config.protocol_config().max_recv_data(),
114 )?;
115
116 debug!("setting up mpc-tls");
117
118 mux_fut.poll_with(mpc_tls.preprocess()).await?;
119
120 debug!("mpc-tls setup complete");
121
122 Ok(Prover {
123 config: self.config,
124 span: self.span,
125 state: state::Setup {
126 mux_ctrl,
127 mux_fut,
128 mt,
129 mpc_tls,
130 zk_aes,
131 keys,
132 vm,
133 },
134 })
135 }
136}
137
138impl Prover<state::Setup> {
139 #[instrument(parent = &self.span, level = "debug", skip_all, err)]
148 pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
149 self,
150 socket: S,
151 ) -> Result<(TlsConnection, ProverFuture), ProverError> {
152 let state::Setup {
153 mux_ctrl,
154 mut mux_fut,
155 mt,
156 mpc_tls,
157 mut zk_aes,
158 keys,
159 vm,
160 } = self.state;
161
162 let (mpc_ctrl, mpc_fut) = mpc_tls.run();
163
164 let server_name =
165 TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
166 ProverError::config(format!(
167 "invalid server name: {}",
168 self.config.server_name()
169 ))
170 })?;
171
172 let config = tls_client::ClientConfig::builder()
173 .with_safe_defaults()
174 .with_root_certificates(self.config.crypto_provider().cert.root_store().clone())
175 .with_no_client_auth();
176 let client =
177 ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
178 .map_err(ProverError::config)?;
179
180 let (conn, conn_fut) = bind_client(socket, client);
181
182 let start_time = web_time::UNIX_EPOCH
183 .elapsed()
184 .expect("system time is available")
185 .as_secs();
186
187 let fut = Box::pin({
188 let span = self.span.clone();
189 let mpc_ctrl = mpc_ctrl.clone();
190 async move {
191 let conn_fut = async {
192 mux_fut
193 .poll_with(conn_fut.map_err(ProverError::from))
194 .await?;
195
196 mpc_ctrl.stop().await?;
197
198 Ok::<_, ProverError>(())
199 };
200
201 let (_, (mut ctx, mut data)) = futures::try_join!(
202 conn_fut,
203 mpc_fut.in_current_span().map_err(ProverError::from)
204 )?;
205
206 {
207 let mut vm = vm.try_lock().expect("VM should not be locked");
208
209 _ = commit_records(
212 &mut (*vm.zk()),
213 &mut zk_aes,
214 data.transcript
215 .recv
216 .iter_mut()
217 .filter(|record| record.typ == ContentType::ApplicationData),
218 )
219 .map_err(ProverError::zk)?;
220
221 debug!("finalizing mpc");
222
223 mux_fut
225 .poll_with(vm.finalize(&mut ctx))
226 .await
227 .map_err(ProverError::mpc)?;
228
229 debug!("mpc finalized");
230 }
231
232 let transcript = data
233 .transcript
234 .to_transcript()
235 .expect("transcript is complete");
236 let transcript_refs = data
237 .transcript
238 .to_transcript_refs()
239 .expect("transcript is complete");
240
241 let connection_info = ConnectionInfo {
242 time: start_time,
243 version: data
244 .protocol_version
245 .try_into()
246 .expect("only supported version should have been accepted"),
247 transcript_length: TranscriptLength {
248 sent: transcript.sent().len() as u32,
249 received: transcript.received().len() as u32,
250 },
251 };
252
253 let server_cert_data =
254 ServerCertData {
255 certs: data
256 .server_cert_details
257 .cert_chain()
258 .iter()
259 .cloned()
260 .map(|c| c.into())
261 .collect(),
262 sig: ServerSignature {
263 scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
264 "only supported signature scheme should have been accepted",
265 ),
266 sig: data.server_kx_details.kx_sig().sig.0.clone(),
267 },
268 handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
269 client_random: data.client_random.0,
270 server_random: data.server_random.0,
271 server_ephemeral_key: data
272 .server_key
273 .try_into()
274 .expect("only supported key scheme should have been accepted"),
275 }),
276 };
277
278 let (_, vm) = Arc::into_inner(vm)
280 .expect("vm should have only 1 reference")
281 .into_inner()
282 .into_inner();
283
284 Ok(Prover {
285 config: self.config,
286 span: self.span,
287 state: state::Closed {
288 mux_ctrl,
289 mux_fut,
290 mt,
291 ctx,
292 _keys: keys,
293 vm,
294 connection_info,
295 server_cert_data,
296 transcript,
297 transcript_refs,
298 },
299 })
300 }
301 .instrument(span)
302 });
303
304 Ok((
305 conn,
306 ProverFuture {
307 fut,
308 ctrl: ProverControl { mpc_ctrl },
309 },
310 ))
311 }
312}
313
314impl Prover<state::Closed> {
315 pub fn transcript(&self) -> &Transcript {
317 &self.state.transcript
318 }
319
320 pub fn start_notarize(self) -> Prover<Notarize> {
326 Prover {
327 config: self.config,
328 span: self.span,
329 state: self.state.into(),
330 }
331 }
332
333 pub fn start_prove(self) -> Prover<Prove> {
338 Prover {
339 config: self.config,
340 span: self.span,
341 state: self.state.into(),
342 }
343 }
344}
345
346fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
347 let mut rng = rand::rng();
348 let delta = Delta::new(Block::random(&mut rng.compat_by_ref()));
349
350 let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
351 let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
352 let rcot_send = mpz_ot::kos::Sender::new(
353 mpz_ot::kos::SenderConfig::default(),
354 delta.into_inner(),
355 base_ot_recv,
356 );
357 let rcot_recv =
358 mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
359 let rcot_recv = mpz_ot::ferret::Receiver::new(
360 mpz_ot::ferret::FerretConfig::builder()
361 .lpn_type(mpz_ot::ferret::LpnType::Regular)
362 .build()
363 .expect("ferret config is valid"),
364 Block::random(&mut rng.compat_by_ref()),
365 rcot_recv,
366 );
367
368 let mut rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(4, rcot_send);
369 let mut rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(2, rcot_recv);
370
371 let mpc = Mpc::new(
372 mpz_ot::cot::DerandCOTSender::new(rcot_send.next().expect("enough senders are available")),
373 rng.random(),
374 delta,
375 );
376
377 let zk = Zk::new(rcot_recv.next().expect("enough receivers are available"));
378
379 let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
380
381 (
382 vm.clone(),
383 MpcTlsLeader::new(
384 config.build_mpc_tls_config(),
385 ctx,
386 vm,
387 (
388 rcot_send.next().expect("enough senders are available"),
389 rcot_send.next().expect("enough senders are available"),
390 rcot_send.next().expect("enough senders are available"),
391 ),
392 rcot_recv.next().expect("enough receivers are available"),
393 ),
394 )
395}
396
397#[derive(Clone)]
399pub struct ProverControl {
400 mpc_ctrl: LeaderCtrl,
401}
402
403impl ProverControl {
404 pub async fn defer_decryption(&self) -> Result<(), ProverError> {
416 self.mpc_ctrl
417 .defer_decryption()
418 .await
419 .map_err(ProverError::from)
420 }
421}