Skip to main content

tlsn/
session.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        Arc, Mutex,
6        atomic::{AtomicBool, Ordering},
7    },
8    task::{Context, Poll, Waker},
9};
10
11use futures::{AsyncRead, AsyncWrite};
12use mpz_common::{Executor, io::Io, mux::Mux};
13use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
14use tlsn_mux::{Connection, Handle};
15
16use crate::{
17    Error, Result,
18    prover::{Prover, state as prover_state},
19    verifier::{Verifier, state as verifier_state},
20};
21
22/// A TLSNotary session over a communication channel.
23///
24/// Wraps an async IO stream and provides multiplexing for the protocol. Use
25/// [`new_prover`](Self::new_prover) or [`new_verifier`](Self::new_verifier) to
26/// create protocol participants.
27///
28/// The session must be polled continuously (either directly or via
29/// [`split`](Self::split)) to drive the underlying connection. After the
30/// session closes, the underlying IO can be reclaimed with
31/// [`try_take`](Self::try_take).
32///
33/// **Important**: The order in which provers and verifiers are created must
34/// match on both sides. For example, if the prover side calls `new_prover`
35/// then `new_verifier`, the verifier side must call `new_verifier` then
36/// `new_prover`.
37#[must_use = "session must be polled continuously to make progress, including during closing."]
38pub struct Session<Io> {
39    conn: Option<Connection<Io>>,
40    executor: Executor,
41    handle: Handle,
42}
43
44impl<Io> Session<Io>
45where
46    Io: AsyncRead + AsyncWrite + Unpin,
47{
48    /// Creates a new session over `io`, the prover ↔ verifier channel.
49    ///
50    /// On TCP transports, disable Nagle's algorithm; see
51    /// [Performance](crate#performance).
52    pub fn new(io: Io) -> Self {
53        let mut mux_config = tlsn_mux::Config::default();
54
55        mux_config.set_keep_alive(true);
56        mux_config.set_close_sync(true);
57
58        let conn = tlsn_mux::Connection::new(io, mux_config);
59        let handle = conn.handle().expect("handle should be available");
60        let executor = build_executor(MuxHandle {
61            handle: handle.clone(),
62        });
63
64        Self {
65            conn: Some(conn),
66            executor,
67            handle,
68        }
69    }
70
71    /// Creates a new prover.
72    pub fn new_prover(
73        &mut self,
74        config: ProverConfig,
75    ) -> Result<Prover<prover_state::Initialized>> {
76        let ctx = self.executor.new_context().map_err(|e| {
77            Error::internal()
78                .with_msg("failed to create new prover")
79                .with_source(e)
80        })?;
81
82        Ok(Prover::new(ctx, self.handle.clone(), config))
83    }
84
85    /// Creates a new verifier.
86    pub fn new_verifier(
87        &mut self,
88        config: VerifierConfig,
89    ) -> Result<Verifier<verifier_state::Initialized>> {
90        let ctx = self.executor.new_context().map_err(|e| {
91            Error::internal()
92                .with_msg("failed to create new verifier")
93                .with_source(e)
94        })?;
95
96        Ok(Verifier::new(ctx, self.handle.clone(), config))
97    }
98
99    /// Returns `true` if the session is closed.
100    pub fn is_closed(&self) -> bool {
101        self.conn
102            .as_ref()
103            .map(|mux| mux.is_complete())
104            .unwrap_or_default()
105    }
106
107    /// Closes the session.
108    ///
109    /// This will cause the session to begin closing. Session must continue to
110    /// be polled until completion.
111    pub fn close(&mut self) {
112        if let Some(conn) = self.conn.as_mut() {
113            conn.close()
114        }
115    }
116
117    /// Attempts to take the IO, returning an error if it is not available.
118    pub fn try_take(&mut self) -> Result<Io> {
119        let conn = self.conn.take().ok_or_else(|| {
120            Error::io().with_msg("failed to take the session io, it was already taken")
121        })?;
122
123        match conn.try_into_io() {
124            Err(conn) => {
125                self.conn = Some(conn);
126                Err(Error::io()
127                    .with_msg("failed to take the session io, session was not completed yet"))
128            }
129            Ok(conn) => Ok(conn),
130        }
131    }
132
133    /// Polls the session.
134    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
135        self.conn
136            .as_mut()
137            .ok_or_else(|| {
138                Error::io()
139                    .with_msg("failed to poll the session connection because it has been taken")
140            })?
141            .poll(cx)
142            .map_err(|e| {
143                Error::io()
144                    .with_msg("error occurred while polling the session connection")
145                    .with_source(e)
146            })
147    }
148
149    /// Splits the session into a driver and handle.
150    ///
151    /// The driver must be polled to make progress. The handle is used
152    /// for creating provers/verifiers and closing the session.
153    pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
154        let should_close = Arc::new(AtomicBool::new(false));
155        let waker = Arc::new(Mutex::new(None::<Waker>));
156
157        (
158            SessionDriver {
159                conn: self.conn,
160                should_close: should_close.clone(),
161                waker: waker.clone(),
162            },
163            SessionHandle {
164                executor: self.executor,
165                should_close,
166                waker,
167                handle: self.handle,
168            },
169        )
170    }
171}
172
173impl<Io> Future for Session<Io>
174where
175    Io: AsyncRead + AsyncWrite + Unpin,
176{
177    type Output = Result<()>;
178
179    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180        Session::poll(&mut (*self), cx)
181    }
182}
183
184/// The polling half of a split session.
185///
186/// Must be polled continuously to drive the session. Returns the underlying
187/// IO when the session closes.
188#[must_use = "driver must be polled to make progress"]
189pub struct SessionDriver<Io> {
190    conn: Option<Connection<Io>>,
191    should_close: Arc<AtomicBool>,
192    waker: Arc<Mutex<Option<Waker>>>,
193}
194
195impl<Io> SessionDriver<Io>
196where
197    Io: AsyncRead + AsyncWrite + Unpin,
198{
199    /// Polls the driver.
200    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
201        // Store the waker so the handle can wake us when close() is called.
202        {
203            let mut waker_guard = self.waker.lock().unwrap();
204            *waker_guard = Some(cx.waker().clone());
205        }
206
207        let conn = self
208            .conn
209            .as_mut()
210            .ok_or_else(|| Error::io().with_msg("session driver already completed"))?;
211
212        if self.should_close.load(Ordering::Acquire) {
213            conn.close();
214        }
215
216        match conn.poll(cx) {
217            Poll::Ready(Ok(())) => {}
218            Poll::Ready(Err(e)) => {
219                return Poll::Ready(Err(Error::io()
220                    .with_msg("error polling session connection")
221                    .with_source(e)));
222            }
223            Poll::Pending => return Poll::Pending,
224        }
225
226        let conn = self.conn.take().unwrap();
227        Poll::Ready(
228            conn.try_into_io()
229                .map_err(|_| Error::io().with_msg("failed to take session io")),
230        )
231    }
232}
233
234impl<Io> Future for SessionDriver<Io>
235where
236    Io: AsyncRead + AsyncWrite + Unpin,
237{
238    type Output = Result<Io>;
239
240    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241        SessionDriver::poll(&mut *self, cx)
242    }
243}
244
245/// The control half of a split session.
246///
247/// Used to create provers/verifiers and control the session lifecycle.
248pub struct SessionHandle {
249    executor: Executor,
250    should_close: Arc<AtomicBool>,
251    waker: Arc<Mutex<Option<Waker>>>,
252    handle: Handle,
253}
254
255impl SessionHandle {
256    /// Creates a new prover.
257    pub fn new_prover(
258        &mut self,
259        config: ProverConfig,
260    ) -> Result<Prover<prover_state::Initialized>> {
261        let ctx = self.executor.new_context().map_err(|e| {
262            Error::internal()
263                .with_msg("failed to create new prover")
264                .with_source(e)
265        })?;
266
267        Ok(Prover::new(ctx, self.handle.clone(), config))
268    }
269
270    /// Creates a new verifier.
271    pub fn new_verifier(
272        &mut self,
273        config: VerifierConfig,
274    ) -> Result<Verifier<verifier_state::Initialized>> {
275        let ctx = self.executor.new_context().map_err(|e| {
276            Error::internal()
277                .with_msg("failed to create new verifier")
278                .with_source(e)
279        })?;
280
281        Ok(Verifier::new(ctx, self.handle.clone(), config))
282    }
283
284    /// Signals the session to close.
285    ///
286    /// The driver must continue to be polled until it completes.
287    pub fn close(&self) {
288        self.should_close.store(true, Ordering::Release);
289        if let Some(waker) = self.waker.lock().unwrap().take() {
290            waker.wake();
291        }
292    }
293}
294
295/// Multiplexer controller providing streams.
296struct MuxHandle {
297    handle: Handle,
298}
299
300impl std::fmt::Debug for MuxHandle {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("MuxHandle").finish_non_exhaustive()
303    }
304}
305
306impl Mux for MuxHandle {
307    fn open(&self, id: &[u8]) -> Result<Io, std::io::Error> {
308        let stream = self.handle.new_stream(id).map_err(std::io::Error::other)?;
309        let io = Io::from_io(stream);
310
311        Ok(io)
312    }
313}
314
315/// Builds a work-stealing executor with the given muxer.
316fn build_executor(mux: MuxHandle) -> Executor {
317    #[cfg(all(feature = "web", target_arch = "wasm32"))]
318    let cores = web_spawn::available_parallelism().map(|n| n.get());
319
320    #[cfg(not(all(feature = "web", target_arch = "wasm32")))]
321    let cores = std::thread::available_parallelism().map(|n| n.get());
322
323    let builder = Executor::builder().num_threads(cores.unwrap_or(8));
324
325    #[cfg(all(feature = "web", target_arch = "wasm32"))]
326    let builder = builder.spawn(|f| {
327        let _ = web_spawn::spawn(f);
328        Ok(())
329    });
330
331    builder.build(mux)
332}