Skip to main content

tlsn_wasm/
io.rs

1//! IO adapters for WASM.
2//!
3//! This module provides adapters to bridge JavaScript IO streams to Rust's
4//! async IO traits.
5
6use std::{
7    collections::VecDeque,
8    pin::Pin,
9    sync::{Arc, Mutex, MutexGuard, PoisonError},
10    task::{Context, Poll, Waker},
11};
12
13use futures::{AsyncRead, AsyncWrite, Future};
14use js_sys::{Promise, Uint8Array};
15use wasm_bindgen::prelude::*;
16use wasm_bindgen_futures::JsFuture;
17
18/// JavaScript interface for IO channels.
19///
20/// This is the interface that JavaScript objects must implement to be used
21/// as IO streams with the SDK.
22#[wasm_bindgen]
23extern "C" {
24    /// An IO channel from JavaScript.
25    #[wasm_bindgen(typescript_type = "IoChannel")]
26    pub type JsIo;
27
28    /// Reads bytes from the stream.
29    ///
30    /// Returns a Promise that resolves to a Uint8Array, or null if EOF.
31    #[wasm_bindgen(method, catch)]
32    pub fn read(this: &JsIo) -> Result<Promise, JsValue>;
33
34    /// Writes bytes to the stream.
35    ///
36    /// Returns a Promise that resolves when the write is complete.
37    #[wasm_bindgen(method, catch)]
38    pub fn write(this: &JsIo, data: &Uint8Array) -> Result<Promise, JsValue>;
39
40    /// Closes the stream.
41    ///
42    /// Returns a Promise that resolves when the stream is closed.
43    #[wasm_bindgen(method, catch)]
44    pub fn close(this: &JsIo) -> Result<Promise, JsValue>;
45}
46
47/// Internal state for the adapter.
48struct AdapterState {
49    /// Buffered data from reads.
50    read_buffer: VecDeque<u8>,
51    /// Whether we've seen EOF.
52    eof: bool,
53    /// Pending read future.
54    pending_read: Option<JsFuture>,
55    /// Waker for when data becomes available.
56    read_waker: Option<Waker>,
57    /// Whether the stream is closed.
58    closed: bool,
59    /// Any error that occurred.
60    error: Option<String>,
61}
62
63/// Adapter that wraps a JavaScript IoChannel object.
64///
65/// This adapter implements `AsyncRead` and `AsyncWrite` by calling the
66/// JavaScript methods on the underlying object.
67pub(crate) struct JsIoAdapter {
68    inner: JsIo,
69    state: Arc<Mutex<AdapterState>>,
70}
71
72impl JsIoAdapter {
73    fn lock_state(&self) -> std::io::Result<MutexGuard<'_, AdapterState>> {
74        self.state.lock().map_err(|e: PoisonError<_>| {
75            std::io::Error::new(std::io::ErrorKind::Other, e.to_string())
76        })
77    }
78
79    /// Creates a new adapter wrapping the given JavaScript IO object.
80    pub(crate) fn new(js_io: JsIo) -> Self {
81        Self {
82            inner: js_io,
83            state: Arc::new(Mutex::new(AdapterState {
84                read_buffer: VecDeque::new(),
85                eof: false,
86                pending_read: None,
87                read_waker: None,
88                closed: false,
89                error: None,
90            })),
91        }
92    }
93}
94
95impl AsyncRead for JsIoAdapter {
96    fn poll_read(
97        self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99        buf: &mut [u8],
100    ) -> Poll<std::io::Result<usize>> {
101        let this = self.get_mut();
102        let mut state = match this.lock_state() {
103            Ok(guard) => guard,
104            Err(e) => return Poll::Ready(Err(e)),
105        };
106
107        // Check for errors.
108        if let Some(ref err) = state.error {
109            return Poll::Ready(Err(std::io::Error::new(
110                std::io::ErrorKind::Other,
111                err.clone(),
112            )));
113        }
114
115        // If we have buffered data, return it.
116        if !state.read_buffer.is_empty() {
117            let to_read = std::cmp::min(buf.len(), state.read_buffer.len());
118            for (i, byte) in state.read_buffer.drain(..to_read).enumerate() {
119                buf[i] = byte;
120            }
121            return Poll::Ready(Ok(to_read));
122        }
123
124        // If we've seen EOF, return 0.
125        if state.eof {
126            return Poll::Ready(Ok(0));
127        }
128
129        // Store waker for later.
130        state.read_waker = Some(cx.waker().clone());
131
132        // If there's no pending read, start one.
133        if state.pending_read.is_none() {
134            match this.inner.read() {
135                Ok(promise) => {
136                    state.pending_read = Some(JsFuture::from(promise));
137                }
138                Err(e) => {
139                    let err_msg = format!("read error: {:?}", e);
140                    state.error = Some(err_msg.clone());
141                    return Poll::Ready(Err(std::io::Error::new(
142                        std::io::ErrorKind::Other,
143                        err_msg,
144                    )));
145                }
146            }
147        }
148
149        // Poll the pending read.
150        if let Some(ref mut future) = state.pending_read {
151            // SAFETY: We're inside a WASM context where this is safe.
152            let future = unsafe { Pin::new_unchecked(future) };
153            match future.poll(cx) {
154                Poll::Ready(Ok(value)) => {
155                    state.pending_read = None;
156
157                    // Check if it's null (EOF).
158                    if value.is_null() || value.is_undefined() {
159                        tracing::warn!("JsIo read returned null/undefined (EOF)");
160                        state.eof = true;
161                        return Poll::Ready(Ok(0));
162                    }
163
164                    // Convert to bytes.
165                    let array = Uint8Array::new(&value);
166                    let bytes = array.to_vec();
167
168                    if bytes.is_empty() {
169                        tracing::warn!("JsIo read returned empty array (EOF)");
170                        state.eof = true;
171                        return Poll::Ready(Ok(0));
172                    }
173
174                    // Copy to buffer and return.
175                    let to_read = std::cmp::min(buf.len(), bytes.len());
176                    buf[..to_read].copy_from_slice(&bytes[..to_read]);
177
178                    // Buffer any remaining bytes.
179                    if bytes.len() > to_read {
180                        state.read_buffer.extend(&bytes[to_read..]);
181                    }
182
183                    Poll::Ready(Ok(to_read))
184                }
185                Poll::Ready(Err(e)) => {
186                    state.pending_read = None;
187                    let err_msg = format!("read error: {:?}", e);
188                    state.error = Some(err_msg.clone());
189                    Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err_msg)))
190                }
191                Poll::Pending => Poll::Pending,
192            }
193        } else {
194            Poll::Pending
195        }
196    }
197}
198
199impl AsyncWrite for JsIoAdapter {
200    fn poll_write(
201        self: Pin<&mut Self>,
202        _cx: &mut Context<'_>,
203        buf: &[u8],
204    ) -> Poll<std::io::Result<usize>> {
205        let this = self.get_mut();
206        let state = match this.lock_state() {
207            Ok(guard) => guard,
208            Err(e) => return Poll::Ready(Err(e)),
209        };
210
211        // Check for errors.
212        if let Some(ref err) = state.error {
213            return Poll::Ready(Err(std::io::Error::new(
214                std::io::ErrorKind::Other,
215                err.clone(),
216            )));
217        }
218
219        if state.closed {
220            return Poll::Ready(Err(std::io::Error::new(
221                std::io::ErrorKind::BrokenPipe,
222                "stream closed",
223            )));
224        }
225
226        // Create Uint8Array from buffer.
227        let array = Uint8Array::from(buf);
228
229        // Fire-and-forget write: common pattern for WASM IO.
230        // We don't wait for the Promise to resolve to avoid backpressure.
231        match this.inner.write(&array) {
232            Ok(_promise) => {
233                // Return success immediately without waiting for Promise.
234                Poll::Ready(Ok(buf.len()))
235            }
236            Err(e) => {
237                let err_msg = format!("write error: {:?}", e);
238                tracing::error!("JsIo write failed: {}", err_msg);
239                Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err_msg)))
240            }
241        }
242    }
243
244    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
245        // JS streams typically auto-flush.
246        Poll::Ready(Ok(()))
247    }
248
249    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
250        let this = self.get_mut();
251        let mut state = match this.lock_state() {
252            Ok(guard) => guard,
253            Err(e) => return Poll::Ready(Err(e)),
254        };
255
256        if state.closed {
257            return Poll::Ready(Ok(()));
258        }
259
260        // Fire-and-forget close to avoid blocking.
261        match this.inner.close() {
262            Ok(_promise) => {
263                state.closed = true;
264                Poll::Ready(Ok(()))
265            }
266            Err(e) => {
267                let err_msg = format!("close error: {:?}", e);
268                Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err_msg)))
269            }
270        }
271    }
272}
273
274// SAFETY: `JsIo` (a JS handle via wasm_bindgen) is `!Send`. This is safe
275// because `JsIoAdapter` is only used from the main WASM async executor thread.
276// While the extension does use multi-threaded WASM (SharedArrayBuffer + rayon
277// via web-spawn), the rayon worker threads only perform parallel computation
278// (mpz/garble) on shared memory and never access JS handles or this adapter.
279unsafe impl Send for JsIoAdapter {}