1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
//! This module handles the proving phase of the prover.
//!
//! Here the prover deals with a verifier directly, so there is no notary involved. Instead
//! the verifier directly verifies parts of the transcript.

use super::{state::Prove as ProveState, Prover, ProverError};
use mpz_garble::{Memory, Prove};
use mpz_ot::VerifiableOTReceiver;
use serio::SinkExt as _;
use tlsn_core::{proof::SessionInfo, transcript::get_value_ids, Direction, ServerName, Transcript};
use utils::range::{RangeSet, RangeUnion};

use tracing::{info, instrument};

impl Prover<ProveState> {
    /// Returns the transcript of the sent requests
    pub fn sent_transcript(&self) -> &Transcript {
        &self.state.transcript_tx
    }

    /// Returns the transcript of the received responses
    pub fn recv_transcript(&self) -> &Transcript {
        &self.state.transcript_rx
    }

    /// Reveal certain parts of the transcripts to the verifier
    ///
    /// This function allows to collect certain transcript ranges. When [Prover::prove] is called, these
    /// ranges will be opened to the verifier.
    ///
    /// # Arguments
    /// * `ranges` - The ranges of the transcript to reveal
    /// * `direction` - The direction of the transcript to reveal
    pub fn reveal(
        &mut self,
        ranges: impl Into<RangeSet<usize>>,
        direction: Direction,
    ) -> Result<(), ProverError> {
        let sent_ids = &mut self.state.proving_info.sent_ids;
        let recv_ids = &mut self.state.proving_info.recv_ids;

        let range_set = ranges.into();

        // Check ranges
        let transcript = match direction {
            Direction::Sent => &self.state.transcript_tx,
            Direction::Received => &self.state.transcript_rx,
        };

        if range_set.max().unwrap_or_default() > transcript.data().len() {
            return Err(ProverError::InvalidRange);
        }

        match direction {
            Direction::Sent => *sent_ids = sent_ids.union(&range_set),
            Direction::Received => *recv_ids = recv_ids.union(&range_set),
        }

        Ok(())
    }

    /// Prove transcript values
    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
    pub async fn prove(&mut self) -> Result<(), ProverError> {
        let mut proving_info = std::mem::take(&mut self.state.proving_info);

        self.state
            .mux_fut
            .poll_with(async {
                // Now prove the transcript parts which have been marked for reveal
                let sent_value_ids = proving_info
                    .sent_ids
                    .iter_ranges()
                    .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::<Vec<String>>());
                let recv_value_ids = proving_info.recv_ids.iter_ranges().map(|r| {
                    get_value_ids(&r.into(), Direction::Received).collect::<Vec<String>>()
                });

                let value_refs = sent_value_ids
                    .chain(recv_value_ids)
                    .map(|ids| {
                        let inner_refs = ids
                            .iter()
                            .map(|id| {
                                self.state
                                    .vm
                                    .get_value(id.as_str())
                                    .expect("Byte should be in VM memory")
                            })
                            .collect::<Vec<_>>();

                        self.state
                            .vm
                            .array_from_values(inner_refs.as_slice())
                            .expect("Byte should be in VM Memory")
                    })
                    .collect::<Vec<_>>();

                // Extract cleartext we want to reveal from transcripts
                let mut cleartext =
                    Vec::with_capacity(proving_info.sent_ids.len() + proving_info.recv_ids.len());
                proving_info
                    .sent_ids
                    .iter_ranges()
                    .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_tx.data()[r]));
                proving_info
                    .recv_ids
                    .iter_ranges()
                    .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_rx.data()[r]));
                proving_info.cleartext = cleartext;

                // Send the proving info to the verifier
                self.state.io.send(proving_info).await?;

                info!("Sent proving info to verifier");

                // Prove the revealed transcript parts
                self.state.vm.prove(value_refs.as_slice()).await?;

                info!("Successfully proved cleartext");

                Ok::<_, ProverError>(())
            })
            .await?;

        Ok(())
    }

    /// Finalize the proving
    #[instrument(parent = &self.span, level = "debug", skip_all, err)]
    pub async fn finalize(self) -> Result<(), ProverError> {
        let ProveState {
            mut io,
            mux_ctrl,
            mut mux_fut,
            mut vm,
            mut ot_recv,
            mut ctx,
            handshake_decommitment,
            ..
        } = self.state;

        // Create session info.
        let session_info = SessionInfo {
            server_name: ServerName::Dns(self.config.server_dns().to_string()),
            handshake_decommitment,
        };

        mux_fut
            .poll_with(async move {
                ot_recv.accept_reveal(&mut ctx).await?;

                _ = vm
                    .finalize()
                    .await
                    .map_err(|e| ProverError::MpcError(Box::new(e)))?
                    .expect("encoder seed returned");

                // Send session_info to the verifier
                io.send(session_info).await?;

                Ok::<_, ProverError>(())
            })
            .await?;

        // Wait for the verifier to correctly close the connection.
        if !mux_fut.is_complete() {
            mux_ctrl.mux().close();
            mux_fut.await?;
        }

        Ok(())
    }
}