Skip to main content

tlsn/
session.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        Arc, Mutex,
6        atomic::{AtomicBool, Ordering},
7    },
8    task::{Context, Poll, Waker},
9};
10
11use futures::{AsyncRead, AsyncWrite};
12use mpz_common::{ThreadId, context::Multithread, io::Io, mux::Mux};
13use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
14use tlsn_mux::{Connection, Handle};
15
16use crate::{
17    Error, Result,
18    prover::{Prover, state as prover_state},
19    verifier::{Verifier, state as verifier_state},
20};
21
22/// Maximum concurrency for multi-threaded context.
23const MAX_CONCURRENCY: usize = 8;
24
25/// A TLSNotary session over a communication channel.
26///
27/// Wraps an async IO stream and provides multiplexing for the protocol. Use
28/// [`new_prover`](Self::new_prover) or [`new_verifier`](Self::new_verifier) to
29/// create protocol participants.
30///
31/// The session must be polled continuously (either directly or via
32/// [`split`](Self::split)) to drive the underlying connection. After the
33/// session closes, the underlying IO can be reclaimed with
34/// [`try_take`](Self::try_take).
35///
36/// **Important**: The order in which provers and verifiers are created must
37/// match on both sides. For example, if the prover side calls `new_prover`
38/// then `new_verifier`, the verifier side must call `new_verifier` then
39/// `new_prover`.
40#[must_use = "session must be polled continuously to make progress, including during closing."]
41pub struct Session<Io> {
42    conn: Option<Connection<Io>>,
43    mt: Multithread,
44}
45
46impl<Io> Session<Io>
47where
48    Io: AsyncRead + AsyncWrite + Unpin,
49{
50    /// Creates a new session.
51    pub fn new(io: Io) -> Self {
52        let mut mux_config = tlsn_mux::Config::default();
53
54        mux_config.set_max_num_streams(36);
55        mux_config.set_keep_alive(true);
56        mux_config.set_close_sync(true);
57
58        let conn = tlsn_mux::Connection::new(io, mux_config);
59        let handle = conn.handle().expect("handle should be available");
60        let mt = build_mt_context(MuxHandle { handle });
61
62        Self {
63            conn: Some(conn),
64            mt,
65        }
66    }
67
68    /// Creates a new prover.
69    pub fn new_prover(
70        &mut self,
71        config: ProverConfig,
72    ) -> Result<Prover<prover_state::Initialized>> {
73        let ctx = self.mt.new_context().map_err(|e| {
74            Error::internal()
75                .with_msg("failed to create new prover")
76                .with_source(e)
77        })?;
78
79        Ok(Prover::new(ctx, config))
80    }
81
82    /// Creates a new verifier.
83    pub fn new_verifier(
84        &mut self,
85        config: VerifierConfig,
86    ) -> Result<Verifier<verifier_state::Initialized>> {
87        let ctx = self.mt.new_context().map_err(|e| {
88            Error::internal()
89                .with_msg("failed to create new verifier")
90                .with_source(e)
91        })?;
92
93        Ok(Verifier::new(ctx, config))
94    }
95
96    /// Returns `true` if the session is closed.
97    pub fn is_closed(&self) -> bool {
98        self.conn
99            .as_ref()
100            .map(|mux| mux.is_complete())
101            .unwrap_or_default()
102    }
103
104    /// Closes the session.
105    ///
106    /// This will cause the session to begin closing. Session must continue to
107    /// be polled until completion.
108    pub fn close(&mut self) {
109        if let Some(conn) = self.conn.as_mut() {
110            conn.close()
111        }
112    }
113
114    /// Attempts to take the IO, returning an error if it is not available.
115    pub fn try_take(&mut self) -> Result<Io> {
116        let conn = self.conn.take().ok_or_else(|| {
117            Error::io().with_msg("failed to take the session io, it was already taken")
118        })?;
119
120        match conn.try_into_io() {
121            Err(conn) => {
122                self.conn = Some(conn);
123                Err(Error::io()
124                    .with_msg("failed to take the session io, session was not completed yet"))
125            }
126            Ok(conn) => Ok(conn),
127        }
128    }
129
130    /// Polls the session.
131    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
132        self.conn
133            .as_mut()
134            .ok_or_else(|| {
135                Error::io()
136                    .with_msg("failed to poll the session connection because it has been taken")
137            })?
138            .poll(cx)
139            .map_err(|e| {
140                Error::io()
141                    .with_msg("error occurred while polling the session connection")
142                    .with_source(e)
143            })
144    }
145
146    /// Splits the session into a driver and handle.
147    ///
148    /// The driver must be polled to make progress. The handle is used
149    /// for creating provers/verifiers and closing the session.
150    pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
151        let should_close = Arc::new(AtomicBool::new(false));
152        let waker = Arc::new(Mutex::new(None::<Waker>));
153
154        (
155            SessionDriver {
156                conn: self.conn,
157                should_close: should_close.clone(),
158                waker: waker.clone(),
159            },
160            SessionHandle {
161                mt: self.mt,
162                should_close,
163                waker,
164            },
165        )
166    }
167}
168
169impl<Io> Future for Session<Io>
170where
171    Io: AsyncRead + AsyncWrite + Unpin,
172{
173    type Output = Result<()>;
174
175    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176        Session::poll(&mut (*self), cx)
177    }
178}
179
180/// The polling half of a split session.
181///
182/// Must be polled continuously to drive the session. Returns the underlying
183/// IO when the session closes.
184#[must_use = "driver must be polled to make progress"]
185pub struct SessionDriver<Io> {
186    conn: Option<Connection<Io>>,
187    should_close: Arc<AtomicBool>,
188    waker: Arc<Mutex<Option<Waker>>>,
189}
190
191impl<Io> SessionDriver<Io>
192where
193    Io: AsyncRead + AsyncWrite + Unpin,
194{
195    /// Polls the driver.
196    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
197        // Store the waker so the handle can wake us when close() is called.
198        {
199            let mut waker_guard = self.waker.lock().unwrap();
200            *waker_guard = Some(cx.waker().clone());
201        }
202
203        let conn = self
204            .conn
205            .as_mut()
206            .ok_or_else(|| Error::io().with_msg("session driver already completed"))?;
207
208        if self.should_close.load(Ordering::Acquire) {
209            conn.close();
210        }
211
212        match conn.poll(cx) {
213            Poll::Ready(Ok(())) => {}
214            Poll::Ready(Err(e)) => {
215                return Poll::Ready(Err(Error::io()
216                    .with_msg("error polling session connection")
217                    .with_source(e)));
218            }
219            Poll::Pending => return Poll::Pending,
220        }
221
222        let conn = self.conn.take().unwrap();
223        Poll::Ready(
224            conn.try_into_io()
225                .map_err(|_| Error::io().with_msg("failed to take session io")),
226        )
227    }
228}
229
230impl<Io> Future for SessionDriver<Io>
231where
232    Io: AsyncRead + AsyncWrite + Unpin,
233{
234    type Output = Result<Io>;
235
236    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        SessionDriver::poll(&mut *self, cx)
238    }
239}
240
241/// The control half of a split session.
242///
243/// Used to create provers/verifiers and control the session lifecycle.
244pub struct SessionHandle {
245    mt: Multithread,
246    should_close: Arc<AtomicBool>,
247    waker: Arc<Mutex<Option<Waker>>>,
248}
249
250impl SessionHandle {
251    /// Creates a new prover.
252    pub fn new_prover(
253        &mut self,
254        config: ProverConfig,
255    ) -> Result<Prover<prover_state::Initialized>> {
256        let ctx = self.mt.new_context().map_err(|e| {
257            Error::internal()
258                .with_msg("failed to create new prover")
259                .with_source(e)
260        })?;
261
262        Ok(Prover::new(ctx, config))
263    }
264
265    /// Creates a new verifier.
266    pub fn new_verifier(
267        &mut self,
268        config: VerifierConfig,
269    ) -> Result<Verifier<verifier_state::Initialized>> {
270        let ctx = self.mt.new_context().map_err(|e| {
271            Error::internal()
272                .with_msg("failed to create new verifier")
273                .with_source(e)
274        })?;
275
276        Ok(Verifier::new(ctx, config))
277    }
278
279    /// Signals the session to close.
280    ///
281    /// The driver must continue to be polled until it completes.
282    pub fn close(&self) {
283        self.should_close.store(true, Ordering::Release);
284        if let Some(waker) = self.waker.lock().unwrap().take() {
285            waker.wake();
286        }
287    }
288}
289
290/// Multiplexer controller providing streams.
291struct MuxHandle {
292    handle: Handle,
293}
294
295impl std::fmt::Debug for MuxHandle {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        f.debug_struct("MuxHandle").finish_non_exhaustive()
298    }
299}
300
301impl Mux for MuxHandle {
302    fn open(&self, id: ThreadId) -> Result<Io, std::io::Error> {
303        let stream = self
304            .handle
305            .new_stream(id.as_ref())
306            .map_err(std::io::Error::other)?;
307        let io = Io::from_io(stream);
308
309        Ok(io)
310    }
311}
312
313/// Builds a multi-threaded context with the given muxer.
314fn build_mt_context(mux: MuxHandle) -> Multithread {
315    let builder = Multithread::builder()
316        .mux(Box::new(mux) as Box<_>)
317        .concurrency(MAX_CONCURRENCY);
318
319    #[cfg(all(feature = "web", target_arch = "wasm32"))]
320    let builder = builder.spawn_handler(|f| {
321        let _ = web_spawn::spawn(f);
322        Ok(())
323    });
324
325    builder.build().unwrap()
326}