Skip to main content

tlsn/
session.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        Arc, Mutex,
6        atomic::{AtomicBool, Ordering},
7    },
8    task::{Context, Poll, Waker},
9};
10
11use futures::{AsyncRead, AsyncWrite};
12use mpz_common::{Executor, io::Io, mux::Mux};
13use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
14use tlsn_mux::{Connection, Handle};
15
16use crate::{
17    Error, Result,
18    prover::{Prover, state as prover_state},
19    verifier::{Verifier, state as verifier_state},
20};
21
22/// A TLSNotary session over a communication channel.
23///
24/// Wraps an async IO stream and provides multiplexing for the protocol. Use
25/// [`new_prover`](Self::new_prover) or [`new_verifier`](Self::new_verifier) to
26/// create protocol participants.
27///
28/// The session must be polled continuously (either directly or via
29/// [`split`](Self::split)) to drive the underlying connection. After the
30/// session closes, the underlying IO can be reclaimed with
31/// [`try_take`](Self::try_take).
32///
33/// **Important**: The order in which provers and verifiers are created must
34/// match on both sides. For example, if the prover side calls `new_prover`
35/// then `new_verifier`, the verifier side must call `new_verifier` then
36/// `new_prover`.
37#[must_use = "session must be polled continuously to make progress, including during closing."]
38pub struct Session<Io> {
39    conn: Option<Connection<Io>>,
40    executor: Executor,
41    handle: Handle,
42}
43
44impl<Io> Session<Io>
45where
46    Io: AsyncRead + AsyncWrite + Unpin,
47{
48    /// Creates a new session.
49    pub fn new(io: Io) -> Self {
50        let mut mux_config = tlsn_mux::Config::default();
51
52        mux_config.set_keep_alive(true);
53        mux_config.set_close_sync(true);
54
55        let conn = tlsn_mux::Connection::new(io, mux_config);
56        let handle = conn.handle().expect("handle should be available");
57        let executor = build_executor(MuxHandle {
58            handle: handle.clone(),
59        });
60
61        Self {
62            conn: Some(conn),
63            executor,
64            handle,
65        }
66    }
67
68    /// Creates a new prover.
69    pub fn new_prover(
70        &mut self,
71        config: ProverConfig,
72    ) -> Result<Prover<prover_state::Initialized>> {
73        let ctx = self.executor.new_context().map_err(|e| {
74            Error::internal()
75                .with_msg("failed to create new prover")
76                .with_source(e)
77        })?;
78
79        Ok(Prover::new(ctx, self.handle.clone(), config))
80    }
81
82    /// Creates a new verifier.
83    pub fn new_verifier(
84        &mut self,
85        config: VerifierConfig,
86    ) -> Result<Verifier<verifier_state::Initialized>> {
87        let ctx = self.executor.new_context().map_err(|e| {
88            Error::internal()
89                .with_msg("failed to create new verifier")
90                .with_source(e)
91        })?;
92
93        Ok(Verifier::new(ctx, self.handle.clone(), config))
94    }
95
96    /// Returns `true` if the session is closed.
97    pub fn is_closed(&self) -> bool {
98        self.conn
99            .as_ref()
100            .map(|mux| mux.is_complete())
101            .unwrap_or_default()
102    }
103
104    /// Closes the session.
105    ///
106    /// This will cause the session to begin closing. Session must continue to
107    /// be polled until completion.
108    pub fn close(&mut self) {
109        if let Some(conn) = self.conn.as_mut() {
110            conn.close()
111        }
112    }
113
114    /// Attempts to take the IO, returning an error if it is not available.
115    pub fn try_take(&mut self) -> Result<Io> {
116        let conn = self.conn.take().ok_or_else(|| {
117            Error::io().with_msg("failed to take the session io, it was already taken")
118        })?;
119
120        match conn.try_into_io() {
121            Err(conn) => {
122                self.conn = Some(conn);
123                Err(Error::io()
124                    .with_msg("failed to take the session io, session was not completed yet"))
125            }
126            Ok(conn) => Ok(conn),
127        }
128    }
129
130    /// Polls the session.
131    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
132        self.conn
133            .as_mut()
134            .ok_or_else(|| {
135                Error::io()
136                    .with_msg("failed to poll the session connection because it has been taken")
137            })?
138            .poll(cx)
139            .map_err(|e| {
140                Error::io()
141                    .with_msg("error occurred while polling the session connection")
142                    .with_source(e)
143            })
144    }
145
146    /// Splits the session into a driver and handle.
147    ///
148    /// The driver must be polled to make progress. The handle is used
149    /// for creating provers/verifiers and closing the session.
150    pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
151        let should_close = Arc::new(AtomicBool::new(false));
152        let waker = Arc::new(Mutex::new(None::<Waker>));
153
154        (
155            SessionDriver {
156                conn: self.conn,
157                should_close: should_close.clone(),
158                waker: waker.clone(),
159            },
160            SessionHandle {
161                executor: self.executor,
162                should_close,
163                waker,
164                handle: self.handle,
165            },
166        )
167    }
168}
169
170impl<Io> Future for Session<Io>
171where
172    Io: AsyncRead + AsyncWrite + Unpin,
173{
174    type Output = Result<()>;
175
176    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177        Session::poll(&mut (*self), cx)
178    }
179}
180
181/// The polling half of a split session.
182///
183/// Must be polled continuously to drive the session. Returns the underlying
184/// IO when the session closes.
185#[must_use = "driver must be polled to make progress"]
186pub struct SessionDriver<Io> {
187    conn: Option<Connection<Io>>,
188    should_close: Arc<AtomicBool>,
189    waker: Arc<Mutex<Option<Waker>>>,
190}
191
192impl<Io> SessionDriver<Io>
193where
194    Io: AsyncRead + AsyncWrite + Unpin,
195{
196    /// Polls the driver.
197    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
198        // Store the waker so the handle can wake us when close() is called.
199        {
200            let mut waker_guard = self.waker.lock().unwrap();
201            *waker_guard = Some(cx.waker().clone());
202        }
203
204        let conn = self
205            .conn
206            .as_mut()
207            .ok_or_else(|| Error::io().with_msg("session driver already completed"))?;
208
209        if self.should_close.load(Ordering::Acquire) {
210            conn.close();
211        }
212
213        match conn.poll(cx) {
214            Poll::Ready(Ok(())) => {}
215            Poll::Ready(Err(e)) => {
216                return Poll::Ready(Err(Error::io()
217                    .with_msg("error polling session connection")
218                    .with_source(e)));
219            }
220            Poll::Pending => return Poll::Pending,
221        }
222
223        let conn = self.conn.take().unwrap();
224        Poll::Ready(
225            conn.try_into_io()
226                .map_err(|_| Error::io().with_msg("failed to take session io")),
227        )
228    }
229}
230
231impl<Io> Future for SessionDriver<Io>
232where
233    Io: AsyncRead + AsyncWrite + Unpin,
234{
235    type Output = Result<Io>;
236
237    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238        SessionDriver::poll(&mut *self, cx)
239    }
240}
241
242/// The control half of a split session.
243///
244/// Used to create provers/verifiers and control the session lifecycle.
245pub struct SessionHandle {
246    executor: Executor,
247    should_close: Arc<AtomicBool>,
248    waker: Arc<Mutex<Option<Waker>>>,
249    handle: Handle,
250}
251
252impl SessionHandle {
253    /// Creates a new prover.
254    pub fn new_prover(
255        &mut self,
256        config: ProverConfig,
257    ) -> Result<Prover<prover_state::Initialized>> {
258        let ctx = self.executor.new_context().map_err(|e| {
259            Error::internal()
260                .with_msg("failed to create new prover")
261                .with_source(e)
262        })?;
263
264        Ok(Prover::new(ctx, self.handle.clone(), config))
265    }
266
267    /// Creates a new verifier.
268    pub fn new_verifier(
269        &mut self,
270        config: VerifierConfig,
271    ) -> Result<Verifier<verifier_state::Initialized>> {
272        let ctx = self.executor.new_context().map_err(|e| {
273            Error::internal()
274                .with_msg("failed to create new verifier")
275                .with_source(e)
276        })?;
277
278        Ok(Verifier::new(ctx, self.handle.clone(), config))
279    }
280
281    /// Signals the session to close.
282    ///
283    /// The driver must continue to be polled until it completes.
284    pub fn close(&self) {
285        self.should_close.store(true, Ordering::Release);
286        if let Some(waker) = self.waker.lock().unwrap().take() {
287            waker.wake();
288        }
289    }
290}
291
292/// Multiplexer controller providing streams.
293struct MuxHandle {
294    handle: Handle,
295}
296
297impl std::fmt::Debug for MuxHandle {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        f.debug_struct("MuxHandle").finish_non_exhaustive()
300    }
301}
302
303impl Mux for MuxHandle {
304    fn open(&self, id: &[u8]) -> Result<Io, std::io::Error> {
305        let stream = self.handle.new_stream(id).map_err(std::io::Error::other)?;
306        let io = Io::from_io(stream);
307
308        Ok(io)
309    }
310}
311
312/// Builds a work-stealing executor with the given muxer.
313fn build_executor(mux: MuxHandle) -> Executor {
314    #[cfg(all(feature = "web", target_arch = "wasm32"))]
315    let cores = web_spawn::available_parallelism().map(|n| n.get());
316
317    #[cfg(not(all(feature = "web", target_arch = "wasm32")))]
318    let cores = std::thread::available_parallelism().map(|n| n.get());
319
320    let builder = Executor::builder().num_threads(cores.unwrap_or(8));
321
322    #[cfg(all(feature = "web", target_arch = "wasm32"))]
323    let builder = builder.spawn(|f| {
324        let _ = web_spawn::spawn(f);
325        Ok(())
326    });
327
328    builder.build(mux)
329}