1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 Arc, Mutex,
6 atomic::{AtomicBool, Ordering},
7 },
8 task::{Context, Poll, Waker},
9};
10
11use futures::{AsyncRead, AsyncWrite};
12use mpz_common::{Executor, io::Io, mux::Mux};
13use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
14use tlsn_mux::{Connection, Handle};
15
16use crate::{
17 Error, Result,
18 prover::{Prover, state as prover_state},
19 verifier::{Verifier, state as verifier_state},
20};
21
22#[must_use = "session must be polled continuously to make progress, including during closing."]
38pub struct Session<Io> {
39 conn: Option<Connection<Io>>,
40 executor: Executor,
41 handle: Handle,
42}
43
44impl<Io> Session<Io>
45where
46 Io: AsyncRead + AsyncWrite + Unpin,
47{
48 pub fn new(io: Io) -> Self {
50 let mut mux_config = tlsn_mux::Config::default();
51
52 mux_config.set_keep_alive(true);
53 mux_config.set_close_sync(true);
54
55 let conn = tlsn_mux::Connection::new(io, mux_config);
56 let handle = conn.handle().expect("handle should be available");
57 let executor = build_executor(MuxHandle {
58 handle: handle.clone(),
59 });
60
61 Self {
62 conn: Some(conn),
63 executor,
64 handle,
65 }
66 }
67
68 pub fn new_prover(
70 &mut self,
71 config: ProverConfig,
72 ) -> Result<Prover<prover_state::Initialized>> {
73 let ctx = self.executor.new_context().map_err(|e| {
74 Error::internal()
75 .with_msg("failed to create new prover")
76 .with_source(e)
77 })?;
78
79 Ok(Prover::new(ctx, self.handle.clone(), config))
80 }
81
82 pub fn new_verifier(
84 &mut self,
85 config: VerifierConfig,
86 ) -> Result<Verifier<verifier_state::Initialized>> {
87 let ctx = self.executor.new_context().map_err(|e| {
88 Error::internal()
89 .with_msg("failed to create new verifier")
90 .with_source(e)
91 })?;
92
93 Ok(Verifier::new(ctx, self.handle.clone(), config))
94 }
95
96 pub fn is_closed(&self) -> bool {
98 self.conn
99 .as_ref()
100 .map(|mux| mux.is_complete())
101 .unwrap_or_default()
102 }
103
104 pub fn close(&mut self) {
109 if let Some(conn) = self.conn.as_mut() {
110 conn.close()
111 }
112 }
113
114 pub fn try_take(&mut self) -> Result<Io> {
116 let conn = self.conn.take().ok_or_else(|| {
117 Error::io().with_msg("failed to take the session io, it was already taken")
118 })?;
119
120 match conn.try_into_io() {
121 Err(conn) => {
122 self.conn = Some(conn);
123 Err(Error::io()
124 .with_msg("failed to take the session io, session was not completed yet"))
125 }
126 Ok(conn) => Ok(conn),
127 }
128 }
129
130 pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
132 self.conn
133 .as_mut()
134 .ok_or_else(|| {
135 Error::io()
136 .with_msg("failed to poll the session connection because it has been taken")
137 })?
138 .poll(cx)
139 .map_err(|e| {
140 Error::io()
141 .with_msg("error occurred while polling the session connection")
142 .with_source(e)
143 })
144 }
145
146 pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
151 let should_close = Arc::new(AtomicBool::new(false));
152 let waker = Arc::new(Mutex::new(None::<Waker>));
153
154 (
155 SessionDriver {
156 conn: self.conn,
157 should_close: should_close.clone(),
158 waker: waker.clone(),
159 },
160 SessionHandle {
161 executor: self.executor,
162 should_close,
163 waker,
164 handle: self.handle,
165 },
166 )
167 }
168}
169
170impl<Io> Future for Session<Io>
171where
172 Io: AsyncRead + AsyncWrite + Unpin,
173{
174 type Output = Result<()>;
175
176 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177 Session::poll(&mut (*self), cx)
178 }
179}
180
181#[must_use = "driver must be polled to make progress"]
186pub struct SessionDriver<Io> {
187 conn: Option<Connection<Io>>,
188 should_close: Arc<AtomicBool>,
189 waker: Arc<Mutex<Option<Waker>>>,
190}
191
192impl<Io> SessionDriver<Io>
193where
194 Io: AsyncRead + AsyncWrite + Unpin,
195{
196 pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
198 {
200 let mut waker_guard = self.waker.lock().unwrap();
201 *waker_guard = Some(cx.waker().clone());
202 }
203
204 let conn = self
205 .conn
206 .as_mut()
207 .ok_or_else(|| Error::io().with_msg("session driver already completed"))?;
208
209 if self.should_close.load(Ordering::Acquire) {
210 conn.close();
211 }
212
213 match conn.poll(cx) {
214 Poll::Ready(Ok(())) => {}
215 Poll::Ready(Err(e)) => {
216 return Poll::Ready(Err(Error::io()
217 .with_msg("error polling session connection")
218 .with_source(e)));
219 }
220 Poll::Pending => return Poll::Pending,
221 }
222
223 let conn = self.conn.take().unwrap();
224 Poll::Ready(
225 conn.try_into_io()
226 .map_err(|_| Error::io().with_msg("failed to take session io")),
227 )
228 }
229}
230
231impl<Io> Future for SessionDriver<Io>
232where
233 Io: AsyncRead + AsyncWrite + Unpin,
234{
235 type Output = Result<Io>;
236
237 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238 SessionDriver::poll(&mut *self, cx)
239 }
240}
241
242pub struct SessionHandle {
246 executor: Executor,
247 should_close: Arc<AtomicBool>,
248 waker: Arc<Mutex<Option<Waker>>>,
249 handle: Handle,
250}
251
252impl SessionHandle {
253 pub fn new_prover(
255 &mut self,
256 config: ProverConfig,
257 ) -> Result<Prover<prover_state::Initialized>> {
258 let ctx = self.executor.new_context().map_err(|e| {
259 Error::internal()
260 .with_msg("failed to create new prover")
261 .with_source(e)
262 })?;
263
264 Ok(Prover::new(ctx, self.handle.clone(), config))
265 }
266
267 pub fn new_verifier(
269 &mut self,
270 config: VerifierConfig,
271 ) -> Result<Verifier<verifier_state::Initialized>> {
272 let ctx = self.executor.new_context().map_err(|e| {
273 Error::internal()
274 .with_msg("failed to create new verifier")
275 .with_source(e)
276 })?;
277
278 Ok(Verifier::new(ctx, self.handle.clone(), config))
279 }
280
281 pub fn close(&self) {
285 self.should_close.store(true, Ordering::Release);
286 if let Some(waker) = self.waker.lock().unwrap().take() {
287 waker.wake();
288 }
289 }
290}
291
292struct MuxHandle {
294 handle: Handle,
295}
296
297impl std::fmt::Debug for MuxHandle {
298 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299 f.debug_struct("MuxHandle").finish_non_exhaustive()
300 }
301}
302
303impl Mux for MuxHandle {
304 fn open(&self, id: &[u8]) -> Result<Io, std::io::Error> {
305 let stream = self.handle.new_stream(id).map_err(std::io::Error::other)?;
306 let io = Io::from_io(stream);
307
308 Ok(io)
309 }
310}
311
312fn build_executor(mux: MuxHandle) -> Executor {
314 #[cfg(all(feature = "web", target_arch = "wasm32"))]
315 let cores = web_spawn::available_parallelism().map(|n| n.get());
316
317 #[cfg(not(all(feature = "web", target_arch = "wasm32")))]
318 let cores = std::thread::available_parallelism().map(|n| n.get());
319
320 let builder = Executor::builder().num_threads(cores.unwrap_or(8));
321
322 #[cfg(all(feature = "web", target_arch = "wasm32"))]
323 let builder = builder.spawn(|f| {
324 let _ = web_spawn::spawn(f);
325 Ok(())
326 });
327
328 builder.build(mux)
329}