1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 Arc, Mutex,
6 atomic::{AtomicBool, Ordering},
7 },
8 task::{Context, Poll, Waker},
9};
10
11use futures::{AsyncRead, AsyncWrite};
12use mpz_common::{ThreadId, context::Multithread, io::Io, mux::Mux};
13use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
14use tlsn_mux::{Connection, Handle};
15
16use crate::{
17 Error, Result,
18 prover::{Prover, state as prover_state},
19 verifier::{Verifier, state as verifier_state},
20};
21
22const MAX_CONCURRENCY: usize = 8;
24
25#[must_use = "session must be polled continuously to make progress, including during closing."]
41pub struct Session<Io> {
42 conn: Option<Connection<Io>>,
43 mt: Multithread,
44}
45
46impl<Io> Session<Io>
47where
48 Io: AsyncRead + AsyncWrite + Unpin,
49{
50 pub fn new(io: Io) -> Self {
52 let mut mux_config = tlsn_mux::Config::default();
53
54 mux_config.set_max_num_streams(36);
55 mux_config.set_keep_alive(true);
56 mux_config.set_close_sync(true);
57
58 let conn = tlsn_mux::Connection::new(io, mux_config);
59 let handle = conn.handle().expect("handle should be available");
60 let mt = build_mt_context(MuxHandle { handle });
61
62 Self {
63 conn: Some(conn),
64 mt,
65 }
66 }
67
68 pub fn new_prover(
70 &mut self,
71 config: ProverConfig,
72 ) -> Result<Prover<prover_state::Initialized>> {
73 let ctx = self.mt.new_context().map_err(|e| {
74 Error::internal()
75 .with_msg("failed to create new prover")
76 .with_source(e)
77 })?;
78
79 Ok(Prover::new(ctx, config))
80 }
81
82 pub fn new_verifier(
84 &mut self,
85 config: VerifierConfig,
86 ) -> Result<Verifier<verifier_state::Initialized>> {
87 let ctx = self.mt.new_context().map_err(|e| {
88 Error::internal()
89 .with_msg("failed to create new verifier")
90 .with_source(e)
91 })?;
92
93 Ok(Verifier::new(ctx, config))
94 }
95
96 pub fn is_closed(&self) -> bool {
98 self.conn
99 .as_ref()
100 .map(|mux| mux.is_complete())
101 .unwrap_or_default()
102 }
103
104 pub fn close(&mut self) {
109 if let Some(conn) = self.conn.as_mut() {
110 conn.close()
111 }
112 }
113
114 pub fn try_take(&mut self) -> Result<Io> {
116 let conn = self.conn.take().ok_or_else(|| {
117 Error::io().with_msg("failed to take the session io, it was already taken")
118 })?;
119
120 match conn.try_into_io() {
121 Err(conn) => {
122 self.conn = Some(conn);
123 Err(Error::io()
124 .with_msg("failed to take the session io, session was not completed yet"))
125 }
126 Ok(conn) => Ok(conn),
127 }
128 }
129
130 pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
132 self.conn
133 .as_mut()
134 .ok_or_else(|| {
135 Error::io()
136 .with_msg("failed to poll the session connection because it has been taken")
137 })?
138 .poll(cx)
139 .map_err(|e| {
140 Error::io()
141 .with_msg("error occurred while polling the session connection")
142 .with_source(e)
143 })
144 }
145
146 pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
151 let should_close = Arc::new(AtomicBool::new(false));
152 let waker = Arc::new(Mutex::new(None::<Waker>));
153
154 (
155 SessionDriver {
156 conn: self.conn,
157 should_close: should_close.clone(),
158 waker: waker.clone(),
159 },
160 SessionHandle {
161 mt: self.mt,
162 should_close,
163 waker,
164 },
165 )
166 }
167}
168
169impl<Io> Future for Session<Io>
170where
171 Io: AsyncRead + AsyncWrite + Unpin,
172{
173 type Output = Result<()>;
174
175 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176 Session::poll(&mut (*self), cx)
177 }
178}
179
180#[must_use = "driver must be polled to make progress"]
185pub struct SessionDriver<Io> {
186 conn: Option<Connection<Io>>,
187 should_close: Arc<AtomicBool>,
188 waker: Arc<Mutex<Option<Waker>>>,
189}
190
191impl<Io> SessionDriver<Io>
192where
193 Io: AsyncRead + AsyncWrite + Unpin,
194{
195 pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
197 {
199 let mut waker_guard = self.waker.lock().unwrap();
200 *waker_guard = Some(cx.waker().clone());
201 }
202
203 let conn = self
204 .conn
205 .as_mut()
206 .ok_or_else(|| Error::io().with_msg("session driver already completed"))?;
207
208 if self.should_close.load(Ordering::Acquire) {
209 conn.close();
210 }
211
212 match conn.poll(cx) {
213 Poll::Ready(Ok(())) => {}
214 Poll::Ready(Err(e)) => {
215 return Poll::Ready(Err(Error::io()
216 .with_msg("error polling session connection")
217 .with_source(e)));
218 }
219 Poll::Pending => return Poll::Pending,
220 }
221
222 let conn = self.conn.take().unwrap();
223 Poll::Ready(
224 conn.try_into_io()
225 .map_err(|_| Error::io().with_msg("failed to take session io")),
226 )
227 }
228}
229
230impl<Io> Future for SessionDriver<Io>
231where
232 Io: AsyncRead + AsyncWrite + Unpin,
233{
234 type Output = Result<Io>;
235
236 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 SessionDriver::poll(&mut *self, cx)
238 }
239}
240
241pub struct SessionHandle {
245 mt: Multithread,
246 should_close: Arc<AtomicBool>,
247 waker: Arc<Mutex<Option<Waker>>>,
248}
249
250impl SessionHandle {
251 pub fn new_prover(
253 &mut self,
254 config: ProverConfig,
255 ) -> Result<Prover<prover_state::Initialized>> {
256 let ctx = self.mt.new_context().map_err(|e| {
257 Error::internal()
258 .with_msg("failed to create new prover")
259 .with_source(e)
260 })?;
261
262 Ok(Prover::new(ctx, config))
263 }
264
265 pub fn new_verifier(
267 &mut self,
268 config: VerifierConfig,
269 ) -> Result<Verifier<verifier_state::Initialized>> {
270 let ctx = self.mt.new_context().map_err(|e| {
271 Error::internal()
272 .with_msg("failed to create new verifier")
273 .with_source(e)
274 })?;
275
276 Ok(Verifier::new(ctx, config))
277 }
278
279 pub fn close(&self) {
283 self.should_close.store(true, Ordering::Release);
284 if let Some(waker) = self.waker.lock().unwrap().take() {
285 waker.wake();
286 }
287 }
288}
289
290struct MuxHandle {
292 handle: Handle,
293}
294
295impl std::fmt::Debug for MuxHandle {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 f.debug_struct("MuxHandle").finish_non_exhaustive()
298 }
299}
300
301impl Mux for MuxHandle {
302 fn open(&self, id: ThreadId) -> Result<Io, std::io::Error> {
303 let stream = self
304 .handle
305 .new_stream(id.as_ref())
306 .map_err(std::io::Error::other)?;
307 let io = Io::from_io(stream);
308
309 Ok(io)
310 }
311}
312
313fn build_mt_context(mux: MuxHandle) -> Multithread {
315 let builder = Multithread::builder()
316 .mux(Box::new(mux) as Box<_>)
317 .concurrency(MAX_CONCURRENCY);
318
319 #[cfg(all(feature = "web", target_arch = "wasm32"))]
320 let builder = builder.spawn_handler(|f| {
321 let _ = web_spawn::spawn(f);
322 Ok(())
323 });
324
325 builder.build().unwrap()
326}