diff --git a/tlsn/examples/Cargo.toml b/tlsn/examples/Cargo.toml index ff8f3a8e9c..2c7e572b97 100644 --- a/tlsn/examples/Cargo.toml +++ b/tlsn/examples/Cargo.toml @@ -39,6 +39,7 @@ rustls = { version = "0.21" } rustls-pemfile = { version = "1.0.2" } tokio-rustls = { version = "0.24.1" } dotenv = "0.15.0" +httparse = "1" [[example]] name = "twitter_dm" diff --git a/tlsn/examples/twitter_dm.rs b/tlsn/examples/twitter_dm.rs index 297fd2b4ad..2e30b14fab 100644 --- a/tlsn/examples/twitter_dm.rs +++ b/tlsn/examples/twitter_dm.rs @@ -1,6 +1,7 @@ /// This prover implementation talks to the notary server implemented in https://github.com/tlsnotary/notary-server, instead of the simple_notary.rs in this example directory use eyre::Result; use futures::AsyncWriteExt; +use httparse::EMPTY_HEADER; use hyper::{body::to_bytes, client::conn::Parts, Body, Request, StatusCode}; use rustls::{Certificate, ClientConfig, RootCertStore}; use serde::{Deserialize, Serialize}; @@ -12,6 +13,7 @@ use std::{ ops::Range, sync::Arc, }; +use tlsn_core::span::{http::HttpSpanner, invert_ranges, SpanCommit, SpanError}; use tokio::{fs::File, io::AsyncWriteExt as _}; use tokio_rustls::TlsConnector; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; @@ -117,7 +119,7 @@ async fn main() { let request = Request::builder() .uri(format!("https://{NOTARY_DOMAIN}:{NOTARY_PORT}/session")) .method("POST") - .header("Host", NOTARY_DOMAIN.clone()) + .header("Host", NOTARY_DOMAIN) // Need to specify application/json for axum to parse it as json .header("Content-Type", "application/json") .body(Body::from(payload)) @@ -248,29 +250,9 @@ async fn main() { client_socket.close().await.unwrap(); // The Prover task should be done now, so we can grab it. - let mut prover = prover_task.await.unwrap().unwrap(); - - // Identify the ranges in the transcript that contain secrets - let (public_ranges, private_ranges) = find_ranges( - prover.sent_transcript().data(), - &[ - access_token.as_bytes(), - auth_token.as_bytes(), - csrf_token.as_bytes(), - ], - ); - - // Commit to the outbound transcript, isolating the data that contain secrets - for range in public_ranges.iter().chain(private_ranges.iter()) { - prover.add_commitment_sent(range.clone()).unwrap(); - } - - // Commit to the full received transcript in one shot, as we don't need to redact anything - let recv_len = prover.recv_transcript().data().len(); - prover.add_commitment_recv(0..recv_len as u32).unwrap(); + let prover = prover_task.await.unwrap().unwrap(); - // Finalize, returning the notarized session - let notarized_session = prover.finalize().await.unwrap(); + let notarized_session = prover.finalize(Box::new(TwitterSpanner)).await.unwrap(); debug!("Notarization complete!"); @@ -285,40 +267,38 @@ async fn main() { .unwrap(); } -/// Find the ranges of the public and private parts of a sequence. -/// -/// Returns a tuple of `(public, private)` ranges. -fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec>, Vec>) { - let mut private_ranges = Vec::new(); - for s in sub_seq { - for (idx, w) in seq.windows(s.len()).enumerate() { - if w == *s { - private_ranges.push(idx as u32..(idx + w.len()) as u32); - } - } - } +/// Read a PEM-formatted file and return its buffer reader +async fn read_pem_file(file_path: &str) -> Result> { + let key_file = File::open(file_path).await?.into_std().await; + Ok(BufReader::new(key_file)) +} - let mut sorted_ranges = private_ranges.clone(); - sorted_ranges.sort_by_key(|r| r.start); +struct TwitterSpanner; - let mut public_ranges = Vec::new(); - let mut last_end = 0; - for r in sorted_ranges { - if r.start > last_end { - public_ranges.push(last_end..r.start); - } - last_end = r.end; - } +impl SpanCommit for TwitterSpanner { + fn span_request(&mut self, request: &[u8]) -> Result>, SpanError> { + let mut headers = vec![EMPTY_HEADER; 12]; + let mut http_spanner = HttpSpanner::new(); - if last_end < seq.len() as u32 { - public_ranges.push(last_end..seq.len() as u32); - } + http_spanner.parse_request(&mut headers, request).unwrap(); - (public_ranges, private_ranges) -} + let cookie = http_spanner + .header_value_span_request("Cookie", request) + .unwrap(); + let authorization = http_spanner + .header_value_span_request("Authorization", request) + .unwrap(); + let csrf = http_spanner + .header_value_span_request("X-Csrf-Token", request) + .unwrap(); -/// Read a PEM-formatted file and return its buffer reader -async fn read_pem_file(file_path: &str) -> Result> { - let key_file = File::open(file_path).await?.into_std().await; - Ok(BufReader::new(key_file)) + invert_ranges(vec![cookie, authorization, csrf], request.len()) + } + + fn span_response(&mut self, response: &[u8]) -> Result>, SpanError> { + Ok(vec![Range { + start: 0, + end: response.len(), + }]) + } } diff --git a/tlsn/tests-integration/Cargo.toml b/tlsn/tests-integration/Cargo.toml index 2c6f63d81c..2444fd9813 100644 --- a/tlsn/tests-integration/Cargo.toml +++ b/tlsn/tests-integration/Cargo.toml @@ -13,6 +13,7 @@ publish = false tlsn-tls-core.workspace = true tlsn-prover.workspace = true tlsn-notary.workspace = true +tlsn-core.workspace = true tls-server-fixture.workspace = true p256 = { workspace = true, features = ["ecdsa"] } diff --git a/tlsn/tests-integration/tests/test.rs b/tlsn/tests-integration/tests/test.rs index 9a34927066..b79365a20e 100644 --- a/tlsn/tests-integration/tests/test.rs +++ b/tlsn/tests-integration/tests/test.rs @@ -1,6 +1,7 @@ use futures::AsyncWriteExt; use hyper::{body::to_bytes, Body, Request, StatusCode}; use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_core::span::TotalSpanner; use tlsn_notary::{bind_notary, NotaryConfig}; use tlsn_prover::{bind_prover, ProverConfig}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -76,15 +77,9 @@ async fn prover(notary_socke client_socket.close().await.unwrap(); - let mut prover = prover_task.await.unwrap().unwrap(); + let prover = prover_task.await.unwrap().unwrap(); - let sent_len = prover.sent_transcript().data().len(); - let recv_len = prover.recv_transcript().data().len(); - - prover.add_commitment_sent(0..sent_len as u32).unwrap(); - prover.add_commitment_recv(0..recv_len as u32).unwrap(); - - _ = prover.finalize().await.unwrap(); + _ = prover.finalize(Box::new(TotalSpanner)).await.unwrap(); } #[instrument(skip(socket))] diff --git a/tlsn/tlsn-core/Cargo.toml b/tlsn/tlsn-core/Cargo.toml index 64aacec495..b46b239981 100644 --- a/tlsn/tlsn-core/Cargo.toml +++ b/tlsn/tlsn-core/Cargo.toml @@ -30,6 +30,9 @@ rs_merkle.workspace = true rstest = { workspace = true, optional = true} hex = { workspace = true, optional = true} tracing = { workspace = true, optional = true } +httparse = "1" +pest = "2" +pest_derive = "2" [dev-dependencies] rstest.workspace = true diff --git a/tlsn/tlsn-core/src/lib.rs b/tlsn/tlsn-core/src/lib.rs index 30e6766d28..e541e6621d 100644 --- a/tlsn/tlsn-core/src/lib.rs +++ b/tlsn/tlsn-core/src/lib.rs @@ -2,7 +2,6 @@ #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] -#![forbid(unsafe_code)] pub mod commitment; mod error; @@ -15,6 +14,7 @@ pub mod merkle; pub mod msg; mod session; pub mod signature; +pub mod span; pub mod substrings; pub mod transcript; mod utils; diff --git a/tlsn/tlsn-core/src/span.rs b/tlsn/tlsn-core/src/span.rs new file mode 100644 index 0000000000..ee52353a44 --- /dev/null +++ b/tlsn/tlsn-core/src/span.rs @@ -0,0 +1,156 @@ +//! This module provides tooling to create spanning information for the [transcripts](crate::transcript::Transcript). +//! +//! When creating a [NotarizedSession](crate::NotarizedSession), the +//! [SessionData](crate::SessionData) inside contains the plaintext of the request and response. +//! The prover can decide to only commit to a subset of these bytes in order to withhold content +//! from the verifier. Consumers of this crate can implement the [SpanCommit] trait to come up with +//! their own approach for identifying the byte ranges which shall be committed to. + +use std::ops::Range; + +/// A trait for identifying byte ranges in the request and response for which commitments will be +/// created +pub trait SpanCommit { + /// Identify byte ranges in the request to commit to + fn span_request(&mut self, request: &[u8]) -> Result>, SpanError>; + /// Identify byte ranges in the response to commit to + fn span_response(&mut self, response: &[u8]) -> Result>, SpanError>; +} + +/// A Spanner that commits to the entire request and response +pub struct TotalSpanner; + +impl SpanCommit for TotalSpanner { + fn span_request(&mut self, request: &[u8]) -> Result>, SpanError> { + Ok(vec![Range { + start: 0, + end: request.len(), + }]) + } + + fn span_response(&mut self, response: &[u8]) -> Result>, SpanError> { + Ok(vec![Range { + start: 0, + end: response.len(), + }]) + } +} + +/// Inverts a set of ranges, i.e. returns the complement of the ranges +pub fn invert_ranges( + ranges: Vec>, + len: usize, +) -> Result>, SpanError> { + for (k, range) in ranges.iter().enumerate() { + // Check that there is no invalid or empty range + if range.start >= range.end { + return Err(SpanError::InvalidRange); + } + + // Check that ranges are not out of bounds + if range.start >= len || range.end > len { + return Err(SpanError::InvalidRange); + } + + // Check that ranges are not overlapping + if ranges + .iter() + .enumerate() + .any(|(l, r)| k != l && r.start < range.end && r.end > range.start) + { + return Err(SpanError::InvalidRange); + } + } + + // Now invert ranges + let mut inverted = vec![Range { start: 0, end: len }]; + + for range in ranges.iter() { + let inv = inverted + .iter_mut() + .find(|inv| range.start >= inv.start && range.end <= inv.end) + .expect("Should have found range to invert"); + + let original_end = inv.end; + inv.end = range.start; + + inverted.push(Range { + start: range.end, + end: original_end, + }); + } + + // Remove empty ranges + inverted.retain(|r| r.start != r.end); + + Ok(inverted) +} + +/// An error that can occur during span creation +#[allow(missing_docs)] +#[derive(Debug, thiserror::Error)] +pub enum SpanError { + #[error("Error during parsing")] + ParseError, + #[error("Found invalid ranges")] + InvalidRange, + #[error("Custom error: {0}")] + Custom(String), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_invert_ranges_errors() { + let empty_range = Range { start: 0, end: 0 }; + let invalid_range = Range { start: 2, end: 1 }; + let out_of_bounds = Range { start: 4, end: 11 }; + + let ranges = vec![empty_range, invalid_range, out_of_bounds]; + + for range in ranges { + assert!(invert_ranges(vec![range], 10).is_err()); + } + } + + #[test] + fn test_invert_ranges_overlapping() { + let overlapping1 = vec![Range { start: 2, end: 5 }, Range { start: 4, end: 7 }]; + let overlapping2 = vec![Range { start: 2, end: 5 }, Range { start: 1, end: 4 }]; + let overlapping3 = vec![Range { start: 2, end: 5 }, Range { start: 3, end: 4 }]; + let overlapping4 = vec![Range { start: 2, end: 5 }, Range { start: 2, end: 5 }]; + + // this should not be an error + let ok1 = vec![Range { start: 2, end: 5 }, Range { start: 5, end: 8 }]; + let ok2 = vec![Range { start: 2, end: 5 }, Range { start: 7, end: 10 }]; + + let overlap = vec![overlapping1, overlapping2, overlapping3, overlapping4]; + let ok = vec![ok1, ok2]; + + for range in overlap { + assert!(invert_ranges(range, 10).is_err()); + } + + for range in ok { + assert!(invert_ranges(range, 10).is_ok()); + } + } + + #[test] + fn test_invert_ranges() { + let len = 20; + + let ranges = vec![ + Range { start: 0, end: 5 }, + Range { start: 5, end: 10 }, + Range { start: 12, end: 16 }, + Range { start: 18, end: 20 }, + ]; + + let expected = vec![Range { start: 10, end: 12 }, Range { start: 16, end: 18 }]; + + assert_eq!(invert_ranges(ranges, len).unwrap(), expected); + } +} diff --git a/tlsn/tlsn-prover/src/error.rs b/tlsn/tlsn-prover/src/error.rs index bf963172cb..8b5ea25231 100644 --- a/tlsn/tlsn-prover/src/error.rs +++ b/tlsn/tlsn-prover/src/error.rs @@ -22,6 +22,8 @@ pub enum ProverError { ServerNoCloseNotify, #[error(transparent)] CommitmentError(#[from] CommitmentError), + #[error(transparent)] + SpanError(#[from] tlsn_core::span::SpanError), } impl From for ProverError { diff --git a/tlsn/tlsn-prover/src/lib.rs b/tlsn/tlsn-prover/src/lib.rs index adab92bd68..2dea30ba03 100644 --- a/tlsn/tlsn-prover/src/lib.rs +++ b/tlsn/tlsn-prover/src/lib.rs @@ -36,6 +36,7 @@ use tlsn_core::{ commitment::Blake3, merkle::MerkleTree, msg::{SignedSessionHeader, TlsnMessage}, + span::SpanCommit, transcript::Transcript, Direction, NotarizedSession, SessionData, SubstringsCommitment, SubstringsCommitmentSet, }; @@ -260,16 +261,6 @@ where &self.state.transcript_rx } - /// Add a commitment to the sent requests - pub fn add_commitment_sent(&mut self, range: Range) -> Result<(), ProverError> { - self.add_commitment(range, Direction::Sent) - } - - /// Add a commitment to the received responses - pub fn add_commitment_recv(&mut self, range: Range) -> Result<(), ProverError> { - self.add_commitment(range, Direction::Received) - } - #[cfg_attr( feature = "tracing", instrument(level = "debug", skip(self, range), err) @@ -313,7 +304,18 @@ where /// Finalize the notarization returning a [`NotarizedSession`] #[cfg_attr(feature = "tracing", instrument(level = "info", skip(self), err))] - pub async fn finalize(self) -> Result { + pub async fn finalize( + mut self, + mut spanner: Box, + ) -> Result { + // Add commitments identified by the spanner + for range in spanner.span_request(self.state.transcript_tx.data())? { + self.add_commitment(range.start as u32..range.end as u32, Direction::Sent)?; + } + for range in spanner.span_response(self.state.transcript_rx.data())? { + self.add_commitment(range.start as u32..range.end as u32, Direction::Received)?; + } + let Notarize { notary_mux: mut mux, mut vm,