Skip to content

Commit

Permalink
Async tls client fix. New server fixture and tests. Don't rely on clo…
Browse files Browse the repository at this point in the history
…se_notify.
  • Loading branch information
themighty1 authored and themighty1 committed Aug 14, 2023
1 parent d1e5678 commit 05c551f
Show file tree
Hide file tree
Showing 9 changed files with 593 additions and 33 deletions.
3 changes: 3 additions & 0 deletions components/tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,8 @@ thiserror = "1"
log = "0.4"
env_logger = "0.10"

# testing
rstest = "0.12"

# misc
derive_builder = "0.12"
1 change: 1 addition & 0 deletions components/tls/tls-client-async/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ tokio = { workspace = true, features = [
webpki-roots.workspace = true
hyper = { workspace = true, features = ["client", "http1"] }
tls-server-fixture = { path = "../tls-server-fixture" }
rstest = { workspace = true }
33 changes: 19 additions & 14 deletions components/tls/tls-client-async/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
trace!("received {} tls bytes from server", received);

// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
while processed < received {
loop {
processed += client.read_tls(&mut &rx_tls_buf[processed..received])?;
match client.process_new_packets().await {
Ok(_) => {}
Expand All @@ -123,8 +124,16 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
}
}

debug_assert!(processed <= received);

if processed == received {
break;
}
}

// by convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
Expand All @@ -151,7 +160,6 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(

#[cfg(feature = "tracing")]
trace!("sending close_notify to server");

client.send_close_notify().await?;

// Flush all remaining plaintext
Expand All @@ -168,7 +176,7 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(

#[cfg(feature = "tracing")]
debug!("client closed connection");
}
},
}

while client.wants_write() && !client_closed {
Expand All @@ -189,18 +197,19 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if server_closed {
#[cfg(feature = "tracing")]
debug!("server closed, no more data to read");
debug!("server closed without close_notify, no more data to read");

// We didn't get Ok(0) to indicate a clean closure, yet the
// server has already closed. We do not treat this as an error.
break 'outer;
} else {
break;
}
}
// Some servers will not send a close_notify, in which case we need to
// error because we can't reveal the MAC key to the Notary.
// Some servers will not send a close_notify but we do not treat this as
// an error.
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
#[cfg(feature = "tracing")]
error!("server did not send close_notify");
return Err(e)?;
break 'outer;
}
Err(e) => return Err(e)?,
};
Expand All @@ -215,14 +224,10 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
.await;
} else {
#[cfg(feature = "tracing")]
debug!("server closed, no more data to read");
debug!("server closed cleanly, no more data to read");
break 'outer;
}
}

if client_closed && server_closed {
break;
}
}

#[cfg(feature = "tracing")]
Expand Down
Loading

0 comments on commit 05c551f

Please sign in to comment.