1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 Arc, Mutex,
6 atomic::{AtomicBool, Ordering},
7 },
8 task::{Context, Poll, Waker},
9};
10
11use futures::{AsyncRead, AsyncWrite};
12use mpz_common::{Executor, io::Io, mux::Mux};
13use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
14use tlsn_mux::{Connection, Handle};
15
16use crate::{
17 Error, Result,
18 prover::{Prover, state as prover_state},
19 verifier::{Verifier, state as verifier_state},
20};
21
22#[must_use = "session must be polled continuously to make progress, including during closing."]
38pub struct Session<Io> {
39 conn: Option<Connection<Io>>,
40 executor: Executor,
41 handle: Handle,
42}
43
44impl<Io> Session<Io>
45where
46 Io: AsyncRead + AsyncWrite + Unpin,
47{
48 pub fn new(io: Io) -> Self {
53 let mut mux_config = tlsn_mux::Config::default();
54
55 mux_config.set_keep_alive(true);
56 mux_config.set_close_sync(true);
57
58 let conn = tlsn_mux::Connection::new(io, mux_config);
59 let handle = conn.handle().expect("handle should be available");
60 let executor = build_executor(MuxHandle {
61 handle: handle.clone(),
62 });
63
64 Self {
65 conn: Some(conn),
66 executor,
67 handle,
68 }
69 }
70
71 pub fn new_prover(
73 &mut self,
74 config: ProverConfig,
75 ) -> Result<Prover<prover_state::Initialized>> {
76 let ctx = self.executor.new_context().map_err(|e| {
77 Error::internal()
78 .with_msg("failed to create new prover")
79 .with_source(e)
80 })?;
81
82 Ok(Prover::new(ctx, self.handle.clone(), config))
83 }
84
85 pub fn new_verifier(
87 &mut self,
88 config: VerifierConfig,
89 ) -> Result<Verifier<verifier_state::Initialized>> {
90 let ctx = self.executor.new_context().map_err(|e| {
91 Error::internal()
92 .with_msg("failed to create new verifier")
93 .with_source(e)
94 })?;
95
96 Ok(Verifier::new(ctx, self.handle.clone(), config))
97 }
98
99 pub fn is_closed(&self) -> bool {
101 self.conn
102 .as_ref()
103 .map(|mux| mux.is_complete())
104 .unwrap_or_default()
105 }
106
107 pub fn close(&mut self) {
112 if let Some(conn) = self.conn.as_mut() {
113 conn.close()
114 }
115 }
116
117 pub fn try_take(&mut self) -> Result<Io> {
119 let conn = self.conn.take().ok_or_else(|| {
120 Error::io().with_msg("failed to take the session io, it was already taken")
121 })?;
122
123 match conn.try_into_io() {
124 Err(conn) => {
125 self.conn = Some(conn);
126 Err(Error::io()
127 .with_msg("failed to take the session io, session was not completed yet"))
128 }
129 Ok(conn) => Ok(conn),
130 }
131 }
132
133 pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
135 self.conn
136 .as_mut()
137 .ok_or_else(|| {
138 Error::io()
139 .with_msg("failed to poll the session connection because it has been taken")
140 })?
141 .poll(cx)
142 .map_err(|e| {
143 Error::io()
144 .with_msg("error occurred while polling the session connection")
145 .with_source(e)
146 })
147 }
148
149 pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
154 let should_close = Arc::new(AtomicBool::new(false));
155 let waker = Arc::new(Mutex::new(None::<Waker>));
156
157 (
158 SessionDriver {
159 conn: self.conn,
160 should_close: should_close.clone(),
161 waker: waker.clone(),
162 },
163 SessionHandle {
164 executor: self.executor,
165 should_close,
166 waker,
167 handle: self.handle,
168 },
169 )
170 }
171}
172
173impl<Io> Future for Session<Io>
174where
175 Io: AsyncRead + AsyncWrite + Unpin,
176{
177 type Output = Result<()>;
178
179 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180 Session::poll(&mut (*self), cx)
181 }
182}
183
184#[must_use = "driver must be polled to make progress"]
189pub struct SessionDriver<Io> {
190 conn: Option<Connection<Io>>,
191 should_close: Arc<AtomicBool>,
192 waker: Arc<Mutex<Option<Waker>>>,
193}
194
195impl<Io> SessionDriver<Io>
196where
197 Io: AsyncRead + AsyncWrite + Unpin,
198{
199 pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
201 {
203 let mut waker_guard = self.waker.lock().unwrap();
204 *waker_guard = Some(cx.waker().clone());
205 }
206
207 let conn = self
208 .conn
209 .as_mut()
210 .ok_or_else(|| Error::io().with_msg("session driver already completed"))?;
211
212 if self.should_close.load(Ordering::Acquire) {
213 conn.close();
214 }
215
216 match conn.poll(cx) {
217 Poll::Ready(Ok(())) => {}
218 Poll::Ready(Err(e)) => {
219 return Poll::Ready(Err(Error::io()
220 .with_msg("error polling session connection")
221 .with_source(e)));
222 }
223 Poll::Pending => return Poll::Pending,
224 }
225
226 let conn = self.conn.take().unwrap();
227 Poll::Ready(
228 conn.try_into_io()
229 .map_err(|_| Error::io().with_msg("failed to take session io")),
230 )
231 }
232}
233
234impl<Io> Future for SessionDriver<Io>
235where
236 Io: AsyncRead + AsyncWrite + Unpin,
237{
238 type Output = Result<Io>;
239
240 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241 SessionDriver::poll(&mut *self, cx)
242 }
243}
244
245pub struct SessionHandle {
249 executor: Executor,
250 should_close: Arc<AtomicBool>,
251 waker: Arc<Mutex<Option<Waker>>>,
252 handle: Handle,
253}
254
255impl SessionHandle {
256 pub fn new_prover(
258 &mut self,
259 config: ProverConfig,
260 ) -> Result<Prover<prover_state::Initialized>> {
261 let ctx = self.executor.new_context().map_err(|e| {
262 Error::internal()
263 .with_msg("failed to create new prover")
264 .with_source(e)
265 })?;
266
267 Ok(Prover::new(ctx, self.handle.clone(), config))
268 }
269
270 pub fn new_verifier(
272 &mut self,
273 config: VerifierConfig,
274 ) -> Result<Verifier<verifier_state::Initialized>> {
275 let ctx = self.executor.new_context().map_err(|e| {
276 Error::internal()
277 .with_msg("failed to create new verifier")
278 .with_source(e)
279 })?;
280
281 Ok(Verifier::new(ctx, self.handle.clone(), config))
282 }
283
284 pub fn close(&self) {
288 self.should_close.store(true, Ordering::Release);
289 if let Some(waker) = self.waker.lock().unwrap().take() {
290 waker.wake();
291 }
292 }
293}
294
295struct MuxHandle {
297 handle: Handle,
298}
299
300impl std::fmt::Debug for MuxHandle {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("MuxHandle").finish_non_exhaustive()
303 }
304}
305
306impl Mux for MuxHandle {
307 fn open(&self, id: &[u8]) -> Result<Io, std::io::Error> {
308 let stream = self.handle.new_stream(id).map_err(std::io::Error::other)?;
309 let io = Io::from_io(stream);
310
311 Ok(io)
312 }
313}
314
315fn build_executor(mux: MuxHandle) -> Executor {
317 #[cfg(all(feature = "web", target_arch = "wasm32"))]
318 let cores = web_spawn::available_parallelism().map(|n| n.get());
319
320 #[cfg(not(all(feature = "web", target_arch = "wasm32")))]
321 let cores = std::thread::available_parallelism().map(|n| n.get());
322
323 let builder = Executor::builder().num_threads(cores.unwrap_or(8));
324
325 #[cfg(all(feature = "web", target_arch = "wasm32"))]
326 let builder = builder.spawn(|f| {
327 let _ = web_spawn::spawn(f);
328 Ok(())
329 });
330
331 builder.build(mux)
332}