diff --git a/Cargo.lock b/Cargo.lock index 50885c3e4d922..a73baef1b0246 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2790,9 +2790,6 @@ dependencies = [ "fastcrypto", "futures", "http 1.1.0", - "hyper 1.4.1", - "hyper-rustls 0.27.2", - "hyper-util", "itertools 0.13.0", "mockall", "mysten-common", @@ -2809,6 +2806,7 @@ dependencies = [ "serde", "shared-crypto", "strum_macros 0.24.3", + "sui-http", "sui-macros", "sui-protocol-config", "sui-tls", @@ -2817,7 +2815,6 @@ dependencies = [ "tempfile", "thiserror 1.0.64", "tokio", - "tokio-rustls 0.26.0", "tokio-stream", "tokio-util 0.7.10", "tonic 0.12.3", @@ -8450,6 +8447,7 @@ dependencies = [ "prometheus", "serde", "snap", + "sui-http", "tokio", "tokio-rustls 0.26.0", "tokio-stream", @@ -9745,9 +9743,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.13" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -13860,6 +13858,27 @@ dependencies = [ "axum 0.7.5", ] +[[package]] +name = "sui-http" +version = "0.0.0" +dependencies = [ + "axum 0.7.5", + "bytes", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "pin-project-lite", + "reqwest 0.12.5", + "socket2 0.5.6", + "tokio", + "tokio-rustls 0.26.0", + "tokio-util 0.7.10", + "tower 0.4.13", + "tracing", +] + [[package]] name = "sui-indexer" version = "1.40.0" @@ -14659,6 +14678,7 @@ dependencies = [ "sui-archival", "sui-config", "sui-core", + "sui-http", "sui-json-rpc", "sui-json-rpc-api", "sui-macros", diff --git a/Cargo.toml b/Cargo.toml index b7ffaf614990d..e6a6739e73c22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,6 +115,7 @@ members = [ "crates/sui-graphql-rpc", "crates/sui-graphql-rpc-client", "crates/sui-graphql-rpc-headers", + "crates/sui-http", "crates/sui-indexer", "crates/sui-indexer-alt", "crates/sui-indexer-alt-framework", @@ -641,6 +642,7 @@ sui-graphql-rpc = { path = "crates/sui-graphql-rpc" } sui-graphql-rpc-client = { path = "crates/sui-graphql-rpc-client" } sui-graphql-rpc-headers = { path = "crates/sui-graphql-rpc-headers" } sui-genesis-builder = { path = "crates/sui-genesis-builder" } +sui-http = { path = "crates/sui-http" } sui-indexer = { path = "crates/sui-indexer" } sui-indexer-alt-framework = { path = "crates/sui-indexer-alt-framework" } sui-indexer-alt-jsonrpc = { path = "crates/sui-indexer-alt-jsonrpc" } diff --git a/consensus/core/Cargo.toml b/consensus/core/Cargo.toml index 6d4756ac4075f..202e907bf6f8c 100644 --- a/consensus/core/Cargo.toml +++ b/consensus/core/Cargo.toml @@ -25,9 +25,6 @@ enum_dispatch.workspace = true fastcrypto.workspace = true futures.workspace = true http.workspace = true -hyper.workspace = true -hyper-util.workspace = true -hyper-rustls.workspace = true itertools.workspace = true quinn-proto.workspace = true mockall.workspace = true @@ -49,7 +46,6 @@ sui-tls.workspace = true tap.workspace = true thiserror.workspace = true tokio.workspace = true -tokio-rustls.workspace = true tokio-stream.workspace = true tokio-util.workspace = true tonic.workspace = true @@ -58,6 +54,7 @@ tower-http.workspace = true tracing.workspace = true typed-store.workspace = true tonic-rustls.workspace = true +sui-http.workspace = true [dev-dependencies] rstest.workspace = true diff --git a/consensus/core/src/error.rs b/consensus/core/src/error.rs index d78915b0c3097..4ec000d60152f 100644 --- a/consensus/core/src/error.rs +++ b/consensus/core/src/error.rs @@ -176,9 +176,6 @@ pub(crate) enum ConsensusError { #[error("Failed to connect as client: {0:?}")] NetworkClientConnection(String), - #[error("Failed to connect as server: {0:?}")] - NetworkServerConnection(String), - #[error("Failed to send request: {0:?}")] NetworkRequest(String), diff --git a/consensus/core/src/network/tonic_network.rs b/consensus/core/src/network/tonic_network.rs index 374c81302fe9b..367af543bd255 100644 --- a/consensus/core/src/network/tonic_network.rs +++ b/consensus/core/src/network/tonic_network.rs @@ -11,32 +11,19 @@ use std::{ use async_trait::async_trait; use bytes::Bytes; -use cfg_if::cfg_if; use consensus_config::{AuthorityIndex, NetworkKeyPair, NetworkPublicKey}; use futures::{stream, Stream, StreamExt as _}; -use hyper_util::rt::{tokio::TokioIo, TokioTimer}; -use hyper_util::service::TowerToHyperService; -use mysten_common::sync::notify_once::NotifyOnce; -use mysten_metrics::monitored_future; use mysten_network::{ callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler}, multiaddr::Protocol, Multiaddr, }; use parking_lot::RwLock; +use sui_http::ServerHandle; use sui_tls::AllowPublicKeys; -use tokio::{ - pin, - task::JoinSet, - time::{timeout, Instant}, -}; -use tokio_rustls::TlsAcceptor; use tokio_stream::{iter, Iter}; use tonic::{Request, Response, Streaming}; -use tower_http::{ - trace::{DefaultMakeSpan, DefaultOnFailure, TraceLayer}, - ServiceBuilderExt, -}; +use tower_http::trace::{DefaultMakeSpan, DefaultOnFailure, TraceLayer}; use tracing::{debug, error, info, trace, warn}; use super::{ @@ -66,14 +53,6 @@ const MAX_FETCH_RESPONSE_BYTES: usize = 4 * 1024 * 1024; // Maximum total bytes fetched in a single fetch_blocks() call, after combining the responses. const MAX_TOTAL_FETCHED_BYTES: usize = 128 * 1024 * 1024; -// Maximum number of connections in backlog. -#[cfg(not(msim))] -const MAX_CONNECTIONS_BACKLOG: u32 = 1024; - -// The time we are willing to wait for a connection to get gracefully shutdown before we attempt to -// forcefully shutdown its task. -const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1); - // Implements Tonic RPC client for Consensus. pub(crate) struct TonicClient { context: Arc, @@ -667,8 +646,7 @@ pub(crate) struct TonicManager { context: Arc, network_keypair: NetworkKeyPair, client: Arc, - server: JoinSet<()>, - shutdown_notif: Arc, + server: Option, } impl TonicManager { @@ -677,8 +655,7 @@ impl TonicManager { context: context.clone(), network_keypair: network_keypair.clone(), client: Arc::new(TonicClient::new(context, network_keypair)), - server: JoinSet::new(), - shutdown_notif: Arc::new(NotifyOnce::new()), + server: None, } } } @@ -716,32 +693,41 @@ impl NetworkManager for TonicManager { let service = TonicServiceProxy::new(self.context.clone(), service); let config = &self.context.parameters.tonic; + let connections_info = Arc::new(ConnectionsInfo::new(self.context.clone())); + let layers = tower::ServiceBuilder::new() + // Add a layer to extract a peer's PeerInfo from their TLS certs + .map_request(move |mut request: http::Request<_>| { + if let Some(peer_certificates) = + request.extensions().get::() + { + if let Some(peer_info) = + peer_info_from_certs(&connections_info, peer_certificates) + { + request.extensions_mut().insert(peer_info); + } + } + request + }) + .layer(CallbackLayer::new(MetricsCallbackMaker::new( + self.context.metrics.network_metrics.inbound.clone(), + self.context.parameters.tonic.excessive_message_size, + ))) + .layer( + TraceLayer::new_for_grpc() + .make_span_with(DefaultMakeSpan::new().level(tracing::Level::TRACE)) + .on_failure(DefaultOnFailure::new().level(tracing::Level::DEBUG)), + ) + .layer_fn(|service| mysten_network::grpc_timeout::GrpcTimeout::new(service, None)); + let consensus_service = tonic::service::Routes::new( ConsensusServiceServer::new(service) .max_encoding_message_size(config.message_size_limit) .max_decoding_message_size(config.message_size_limit), ) - .into_axum_router(); - - let inbound_metrics = self.context.metrics.network_metrics.inbound.clone(); - let excessive_message_size = self.context.parameters.tonic.excessive_message_size; - - let http = { - let mut builder = - hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()) - .http2_only(); - builder - .http2() - .timer(TokioTimer::new()) - .initial_connection_window_size(64 << 20) - .initial_stream_window_size(32 << 20) - .keep_alive_interval(Some(config.keepalive_interval)) - .keep_alive_timeout(config.keepalive_interval); - - Arc::new(builder) - }; + .into_axum_router() + .route_layer(layers); - let tls_server_config = sui_tls::create_rustls_server_config( + let tls_server_config = sui_tls::create_rustls_server_config_with_client_verifier( self.network_keypair.clone().private_key().into_inner(), certificate_server_name(&self.context), AllowPublicKeys::new( @@ -752,222 +738,64 @@ impl NetworkManager for TonicManager { .collect(), ), ); - let tls_acceptor = TlsAcceptor::from(Arc::new(tls_server_config)); - - // Create listener to incoming connections. - let deadline = Instant::now() + Duration::from_secs(20); - let listener = loop { - if Instant::now() > deadline { - panic!("Failed to start server: timeout"); - } - cfg_if!( - if #[cfg(msim)] { - // msim does not have a working stub for TcpSocket. So create TcpListener directly. - match tokio::net::TcpListener::bind(own_address).await { - Ok(listener) => break listener, - Err(e) => { - warn!("Error binding to {own_address}: {e:?}"); - tokio::time::sleep(Duration::from_secs(1)).await; - } - } - } else { - let tcp_connection_metrics = &self.context.metrics.network_metrics.tcp_connection_metrics; - // Try creating an ephemeral port to test the highest allowed send and recv buffer sizes. - // Buffer sizes are not set explicitly on the socket used for real traffic, - // to allow the OS to set appropriate values. - { - let ephemeral_addr = SocketAddr::new(own_address.ip(), 0); - let ephemeral_socket = create_socket(&ephemeral_addr); - if let Err(e) = ephemeral_socket.set_send_buffer_size(32 << 20) { - info!("Failed to set send buffer size: {e:?}"); - } - if let Err(e) = ephemeral_socket.set_recv_buffer_size(32 << 20) { - info!("Failed to set recv buffer size: {e:?}"); - } - if ephemeral_socket.bind(ephemeral_addr).is_ok() { - tcp_connection_metrics.socket_send_buffer_max_size.set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64); - tcp_connection_metrics.socket_recv_buffer_max_size.set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64); - }; - } - - info!("Binding tonic server to address {:?}", own_address); - - // Create TcpListener via TCP socket. - let socket = create_socket(&own_address); - match socket.bind(own_address) { - Ok(_) => { - info!("Successfully bound tonic server to address {:?}", own_address) - } - Err(e) => { - warn!("Error binding to {own_address}: {e:?}"); - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - } - }; - - tcp_connection_metrics.socket_send_buffer_size.set(socket.send_buffer_size().unwrap_or(0) as i64); - tcp_connection_metrics.socket_recv_buffer_size.set(socket.recv_buffer_size().unwrap_or(0) as i64); - - match socket.listen(MAX_CONNECTIONS_BACKLOG) { - Ok(listener) => break listener, - Err(e) => { - warn!("Error listening at {own_address}: {e:?}"); - tokio::time::sleep(Duration::from_secs(1)).await; - } - } + // Calculate some metrics around send/recv buffer sizes for the current machine/OS + #[cfg(not(msim))] + { + let tcp_connection_metrics = + &self.context.metrics.network_metrics.tcp_connection_metrics; + + // Try creating an ephemeral port to test the highest allowed send and recv buffer sizes. + // Buffer sizes are not set explicitly on the socket used for real traffic, + // to allow the OS to set appropriate values. + { + let ephemeral_addr = SocketAddr::new(own_address.ip(), 0); + let ephemeral_socket = create_socket(&ephemeral_addr); + tcp_connection_metrics + .socket_send_buffer_size + .set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64); + tcp_connection_metrics + .socket_recv_buffer_size + .set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64); + + if let Err(e) = ephemeral_socket.set_send_buffer_size(32 << 20) { + info!("Failed to set send buffer size: {e:?}"); } - ); - }; - - let connections_info = Arc::new(ConnectionsInfo::new(self.context.clone())); - - let shutdown_notif = self.shutdown_notif.clone(); - - self.server.spawn(monitored_future!(async move { - let mut connection_handlers = JoinSet::new(); - - loop { - let (tcp_stream, peer_addr) = tokio::select! { - result = listener.accept() => { - match result { - // This is the only branch that has addition processing. - // Other branches continue or break from the loop. - Ok(incoming) => incoming, - Err(e) => { - warn!("Error accepting connection: {}", e); - continue; - } - } - }, - Some(result) = connection_handlers.join_next() => { - match result { - Ok(Ok(())) => {}, - Ok(Err(e)) => { - debug!("Error serving connection: {e:?}"); - }, - Err(e) => { - debug!("Connection task error, likely shutting down: {e:?}"); - } - } - continue; - }, - _ = shutdown_notif.wait() => { - info!("Received shutdown. Stopping consensus service."); - if timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, async { - while connection_handlers.join_next().await.is_some() {} - }).await.is_err() { - warn!("Failed to stop all connection handlers in {CONNECTION_SHUTDOWN_GRACE_PERIOD:?}. Forcing shutdown."); - connection_handlers.shutdown().await; - } - return; - }, + if let Err(e) = ephemeral_socket.set_recv_buffer_size(32 << 20) { + info!("Failed to set recv buffer size: {e:?}"); + } + if ephemeral_socket.bind(ephemeral_addr).is_ok() { + tcp_connection_metrics + .socket_send_buffer_max_size + .set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64); + tcp_connection_metrics + .socket_recv_buffer_max_size + .set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64); }; - trace!("Received TCP connection attempt from {peer_addr}"); - - let tls_acceptor = tls_acceptor.clone(); - let consensus_service = consensus_service.clone(); - let inbound_metrics = inbound_metrics.clone(); - let http = http.clone(); - let connections_info = connections_info.clone(); - let shutdown_notif = shutdown_notif.clone(); - - connection_handlers.spawn(async move { - let tls_stream = tls_acceptor.accept(tcp_stream).await.map_err(|e| { - let msg = format!("Error accepting TLS connection: {e:?}"); - trace!(msg); - ConsensusError::NetworkServerConnection(msg) - })?; - trace!("Accepted TLS connection"); - - let certificate_public_key = - if let Some(certs) = tls_stream.get_ref().1.peer_certificates() { - if certs.len() != 1 { - let msg = format!( - "Unexpected number of certificates from TLS stream: {}", - certs.len() - ); - trace!(msg); - return Err(ConsensusError::NetworkServerConnection(msg)); - } - trace!("Received {} certificates", certs.len()); - sui_tls::public_key_from_certificate(&certs[0]).map_err(|e| { - trace!("Failed to extract public key from certificate: {e:?}"); - ConsensusError::NetworkServerConnection(format!( - "Failed to extract public key from certificate: {e:?}" - )) - })? - } else { - return Err(ConsensusError::NetworkServerConnection( - "No certificate found in TLS stream".to_string(), - )); - }; - let client_public_key = NetworkPublicKey::new(certificate_public_key); - // TODO: improvement connection management. limit connection per peer to 1. - let Some(authority_index) = - connections_info.authority_index(&client_public_key) - else { - let msg = format!( - "Failed to find the authority with public key {client_public_key:?}" - ); - error!("{}", msg); - return Err(ConsensusError::NetworkServerConnection(msg)); - }; - let svc = tower::ServiceBuilder::new() - // NOTE: the PeerInfo extension is copied to every request served. - // If PeerInfo starts to contain complex values, it should be wrapped in an Arc<>. - .add_extension(PeerInfo { authority_index }) - .layer(CallbackLayer::new(MetricsCallbackMaker::new( - inbound_metrics, - excessive_message_size, - ))) - .layer( - TraceLayer::new_for_grpc() - .make_span_with(DefaultMakeSpan::new().level(tracing::Level::TRACE)) - .on_failure(DefaultOnFailure::new().level(tracing::Level::DEBUG)), - ) - .service(consensus_service); - - pin! { - let connection = http.serve_connection(TokioIo::new(tls_stream), TowerToHyperService::new(svc)); - } - trace!("Connection ready. Starting to serve requests for {peer_addr:?}"); - - let mut has_shutdown = false; - loop { - tokio::select! { - result = connection.as_mut() => { - match result { - Ok(()) => { - trace!("Connection closed for {peer_addr:?}"); - break; - }, - Err(e) => { - let msg = format!("Connection error serving {peer_addr:?}: {e:?}"); - trace!(msg); - return Err(ConsensusError::NetworkServerConnection(msg)); - }, - } - }, - _ = shutdown_notif.wait(), if !has_shutdown => { - trace!("Received shutdown. Stopping connection for {peer_addr:?}"); - connection.as_mut().graceful_shutdown(); - has_shutdown = true; - }, - } - } - - Ok(()) - }); } - })); + } + + let server = sui_http::Builder::new() + .config( + sui_http::Config::default() + .initial_connection_window_size(64 << 20) + .initial_stream_window_size(32 << 20) + .http2_keepalive_interval(Some(config.keepalive_interval)) + .http2_keepalive_timeout(Some(config.keepalive_interval)) + .accept_http1(false), + ) + .tls_config(tls_server_config) + .serve(own_address, consensus_service) + .unwrap(); info!("Server started at: {own_address}"); + self.server = Some(server); } async fn stop(&mut self) { - let _ = self.shutdown_notif.notify(); - self.server.join_next().await; + if let Some(server) = self.server.take() { + server.shutdown().await; + } self.context .metrics @@ -978,6 +806,46 @@ impl NetworkManager for TonicManager { } } +// Ensure that if there is an active network running that it is shutdown when the TonicManager is +// dropped. +impl Drop for TonicManager { + fn drop(&mut self) { + if let Some(server) = self.server.as_ref() { + server.trigger_shutdown(); + } + } +} + +// TODO: improve sui-http to allow for providing a MakeService so that this can be done once per +// connection +fn peer_info_from_certs( + connections_info: &ConnectionsInfo, + peer_certificates: &sui_http::PeerCertificates, +) -> Option { + let certs = peer_certificates.peer_certs(); + + if certs.len() != 1 { + trace!( + "Unexpected number of certificates from TLS stream: {}", + certs.len() + ); + return None; + } + trace!("Received {} certificates", certs.len()); + let public_key = sui_tls::public_key_from_certificate(&certs[0]) + .map_err(|e| { + trace!("Failed to extract public key from certificate: {e:?}"); + e + }) + .ok()?; + let client_public_key = NetworkPublicKey::new(public_key); + let Some(authority_index) = connections_info.authority_index(&client_public_key) else { + error!("Failed to find the authority with public key {client_public_key:?}"); + return None; + }; + Some(PeerInfo { authority_index }) +} + /// Attempts to convert a multiaddr of the form `/[ip4,ip6,dns]/{}/udp/{port}` into /// a host:port string. fn to_host_port_str(addr: &Multiaddr) -> Result { diff --git a/crates/mysten-network/Cargo.toml b/crates/mysten-network/Cargo.toml index 0762b9210445f..5ee2bf60d84a6 100644 --- a/crates/mysten-network/Cargo.toml +++ b/crates/mysten-network/Cargo.toml @@ -33,3 +33,4 @@ tower.workspace = true tower-http.workspace = true pin-project-lite = "0.2.13" tracing.workspace = true +sui-http.workspace = true diff --git a/crates/mysten-network/src/config.rs b/crates/mysten-network/src/config.rs index eab88a024ec41..dbf339c47c376 100644 --- a/crates/mysten-network/src/config.rs +++ b/crates/mysten-network/src/config.rs @@ -12,7 +12,7 @@ use std::time::Duration; use tokio_rustls::rustls::ClientConfig; use tonic::transport::Channel; -#[derive(Debug, Default, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct Config { /// Set the concurrency limit applied to on requests inbound per connection. pub concurrency_limit_per_connection: Option, @@ -106,4 +106,15 @@ impl Config { ) -> Result { connect_lazy_with_config(addr, tls_config, self) } + + pub(crate) fn http_config(&self) -> sui_http::Config { + sui_http::Config::default() + .initial_stream_window_size(self.http2_initial_stream_window_size) + .initial_connection_window_size(self.http2_initial_connection_window_size) + .max_concurrent_streams(self.http2_max_concurrent_streams) + .http2_keepalive_timeout(self.http2_keepalive_timeout) + .http2_keepalive_interval(self.http2_keepalive_interval) + .tcp_keepalive(self.tcp_keepalive) + .tcp_nodelay(self.tcp_nodelay.unwrap_or_default()) + } } diff --git a/crates/mysten-network/src/grpc_timeout.rs b/crates/mysten-network/src/grpc_timeout.rs new file mode 100644 index 0000000000000..69757d8bdf88c --- /dev/null +++ b/crates/mysten-network/src/grpc_timeout.rs @@ -0,0 +1,284 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Ported from `tonic` crate +// SPDX-License-Identifier: MIT + +use http::{HeaderMap, HeaderValue, Request, Response}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tonic::Status; +use tower::Service; + +const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; + +#[derive(Debug, Clone)] +pub struct GrpcTimeout { + inner: S, + server_timeout: Option, +} + +impl GrpcTimeout { + pub fn new(inner: S, server_timeout: Option) -> Self { + Self { + inner, + server_timeout, + } + } +} + +impl Service> for GrpcTimeout +where + S: Service, Response = Response>, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { + tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); + None + }); + + // Use the shorter of the two durations, if either are set + let timeout_duration = match (client_timeout, self.server_timeout) { + (None, None) => None, + (Some(dur), None) => Some(dur), + (None, Some(dur)) => Some(dur), + (Some(header), Some(server)) => { + let shorter_duration = std::cmp::min(header, server); + Some(shorter_duration) + } + }; + + ResponseFuture { + inner: self.inner.call(req), + sleep: timeout_duration.map(tokio::time::sleep), + } + } +} + +pin_project! { + pub struct ResponseFuture { + #[pin] + inner: F, + #[pin] + sleep: Option, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(result) = this.inner.poll(cx) { + return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full))); + } + + if let Some(sleep) = this.sleep.as_pin_mut() { + ready!(sleep.poll(cx)); + let response = Status::deadline_exceeded("Timeout expired") + .into_http() + .map(|_| MaybeEmptyBody::empty()); + return Poll::Ready(Ok(response)); + } + + Poll::Pending + } +} + +pin_project! { + pub struct MaybeEmptyBody { + #[pin] + inner: Option, + } +} + +impl MaybeEmptyBody { + fn full(inner: B) -> Self { + Self { inner: Some(inner) } + } + + fn empty() -> Self { + Self { inner: None } + } +} + +impl http_body::Body for MaybeEmptyBody +where + B: http_body::Body + Send, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.project().inner.as_pin_mut() { + Some(b) => b.poll_frame(cx), + None => Poll::Ready(None), + } + } + + fn is_end_stream(&self) -> bool { + match &self.inner { + Some(b) => b.is_end_stream(), + None => true, + } + } + + fn size_hint(&self) -> http_body::SizeHint { + match &self.inner { + Some(body) => body.size_hint(), + None => http_body::SizeHint::with_exact(0), + } + } +} + +const SECONDS_IN_HOUR: u64 = 60 * 60; +const SECONDS_IN_MINUTE: u64 = 60; + +/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns +/// the value we attempted to parse. +/// +/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). +fn try_parse_grpc_timeout( + headers: &HeaderMap, +) -> Result, &HeaderValue> { + let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else { + return Ok(None); + }; + + let (timeout_value, timeout_unit) = val + .to_str() + .map_err(|_| val) + .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })? + // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this + // `split_at` will never panic from trying to split in the middle of a character. + // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str + // + // `len - 1` also wont panic since we just checked `s.is_empty`. + .split_at(val.len() - 1); + + // gRPC spec specifies `TimeoutValue` will be at most 8 digits + // Caping this at 8 digits also prevents integer overflow from ever occurring + if timeout_value.len() > 8 { + return Err(val); + } + + let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?; + + let duration = match timeout_unit { + // Hours + "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR), + // Minutes + "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE), + // Seconds + "S" => Duration::from_secs(timeout_value), + // Milliseconds + "m" => Duration::from_millis(timeout_value), + // Microseconds + "u" => Duration::from_micros(timeout_value), + // Nanoseconds + "n" => Duration::from_nanos(timeout_value), + _ => return Err(val), + }; + + Ok(Some(duration)) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function to reduce the boiler plate of our test cases + fn setup_map_try_parse(val: Option<&str>) -> Result, HeaderValue> { + let mut hm = HeaderMap::new(); + if let Some(v) = val { + let hv = HeaderValue::from_str(v).unwrap(); + hm.insert(GRPC_TIMEOUT_HEADER, hv); + }; + + try_parse_grpc_timeout(&hm).map_err(|e| e.clone()) + } + + #[test] + fn test_hours() { + let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration); + } + + #[test] + fn test_minutes() { + let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(60), parsed_duration); + } + + #[test] + fn test_seconds() { + let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(42), parsed_duration); + } + + #[test] + fn test_milliseconds() { + let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap(); + assert_eq!(Duration::from_millis(13), parsed_duration); + } + + #[test] + fn test_microseconds() { + let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap(); + assert_eq!(Duration::from_micros(2), parsed_duration); + } + + #[test] + fn test_nanoseconds() { + let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap(); + assert_eq!(Duration::from_nanos(82), parsed_duration); + } + + #[test] + fn test_header_not_present() { + let parsed_duration = setup_map_try_parse(None).unwrap(); + assert!(parsed_duration.is_none()); + } + + #[test] + #[should_panic(expected = "82f")] + fn test_invalid_unit() { + // "f" is not a valid TimeoutUnit + setup_map_try_parse(Some("82f")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "123456789H")] + fn test_too_many_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("123456789H")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "oneH")] + fn test_invalid_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("oneH")).unwrap().unwrap(); + } +} diff --git a/crates/mysten-network/src/lib.rs b/crates/mysten-network/src/lib.rs index 8c226fabdf77f..faa8485c036fa 100644 --- a/crates/mysten-network/src/lib.rs +++ b/crates/mysten-network/src/lib.rs @@ -1,10 +1,12 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 + pub mod anemo_ext; pub mod callback; pub mod client; pub mod codec; pub mod config; +pub mod grpc_timeout; pub mod metrics; pub mod multiaddr; pub mod server; diff --git a/crates/mysten-network/src/multiaddr.rs b/crates/mysten-network/src/multiaddr.rs index 0fbb1f7fd68c5..429b8ce183959 100644 --- a/crates/mysten-network/src/multiaddr.rs +++ b/crates/mysten-network/src/multiaddr.rs @@ -268,6 +268,36 @@ impl<'de> serde::Deserialize<'de> for Multiaddr { } } +impl std::net::ToSocketAddrs for Multiaddr { + type Iter = Box>; + + fn to_socket_addrs(&self) -> std::io::Result { + let mut iter = self.iter(); + + match (iter.next(), iter.next()) { + (Some(Protocol::Ip4(ip4)), Some(Protocol::Tcp(port) | Protocol::Udp(port))) => { + (ip4, port) + .to_socket_addrs() + .map(|iter| Box::new(iter) as _) + } + (Some(Protocol::Ip6(ip6)), Some(Protocol::Tcp(port) | Protocol::Udp(port))) => { + (ip6, port) + .to_socket_addrs() + .map(|iter| Box::new(iter) as _) + } + (Some(Protocol::Dns(hostname)), Some(Protocol::Tcp(port) | Protocol::Udp(port))) => { + (hostname.as_ref(), port) + .to_socket_addrs() + .map(|iter| Box::new(iter) as _) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "unable to convert Multiaddr to SocketAddr", + )), + } + } +} + pub(crate) fn parse_tcp<'a, T: Iterator>>(protocols: &mut T) -> Result { if let Protocol::Tcp(port) = protocols .next() diff --git a/crates/mysten-network/src/server.rs b/crates/mysten-network/src/server.rs index d60a3b4637ebf..0f7ab47aeff69 100644 --- a/crates/mysten-network/src/server.rs +++ b/crates/mysten-network/src/server.rs @@ -1,143 +1,44 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 + use crate::metrics::{ DefaultMetricsCallbackProvider, MetricsCallbackProvider, MetricsHandler, GRPC_ENDPOINT_PATH_HEADER, }; use crate::{ config::Config, - multiaddr::{parse_dns, parse_ip4, parse_ip6, Multiaddr, Protocol}, + multiaddr::{Multiaddr, Protocol}, }; use eyre::{eyre, Result}; -use futures::stream::FuturesUnordered; -use futures::{FutureExt, Stream, StreamExt}; -use std::pin::Pin; -use std::sync::Arc; +use std::convert::Infallible; use std::task::{Context, Poll}; -use std::{convert::Infallible, net::SocketAddr}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_rustls::rustls::ServerConfig; -use tokio_rustls::{server::TlsStream, TlsAcceptor}; use tonic::codegen::http::HeaderValue; use tonic::{ body::BoxBody, - codegen::{ - http::{Request, Response}, - BoxFuture, - }, + codegen::http::{Request, Response}, server::NamedService, - transport::server::Router, -}; -use tower::{ - layer::util::{Identity, Stack}, - limit::GlobalConcurrencyLimitLayer, - load_shed::LoadShedLayer, - util::Either, - Layer, Service, ServiceBuilder, }; -use tower_http::classify::{GrpcErrorsAsFailures, SharedClassifier}; +use tower::{Layer, Service, ServiceBuilder}; use tower_http::propagate_header::PropagateHeaderLayer; use tower_http::set_header::SetRequestHeaderLayer; -use tower_http::trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer}; -use tracing::debug; +use tower_http::trace::TraceLayer; pub struct ServerBuilder { - router: Router>, + config: Config, + metrics_provider: M, + router: tonic::service::Routes, health_reporter: tonic_health::server::HealthReporter, } -type AddPathToHeaderFunction = fn(&Request) -> Option; - -type WrapperService = Stack< - Stack< - PropagateHeaderLayer, - Stack< - TraceLayer< - SharedClassifier, - DefaultMakeSpan, - MetricsHandler, - MetricsHandler, - DefaultOnBodyChunk, - DefaultOnEos, - MetricsHandler, - >, - Stack< - SetRequestHeaderLayer, - Stack< - RequestLifetimeLayer, - Stack< - Either, - Stack, Identity>, - >, - >, - >, - >, - >, - Identity, ->; - impl ServerBuilder { pub fn from_config(config: &Config, metrics_provider: M) -> Self { - let mut builder = tonic::transport::server::Server::builder(); - - if let Some(limit) = config.concurrency_limit_per_connection { - builder = builder.concurrency_limit_per_connection(limit); - } - - if let Some(timeout) = config.request_timeout { - builder = builder.timeout(timeout); - } - - if let Some(tcp_nodelay) = config.tcp_nodelay { - builder = builder.tcp_nodelay(tcp_nodelay); - } - - let load_shed = config - .load_shed - .unwrap_or_default() - .then_some(tower::load_shed::LoadShedLayer::new()); - - let metrics = MetricsHandler::new(metrics_provider.clone()); - - let request_metrics = TraceLayer::new_for_grpc() - .on_request(metrics.clone()) - .on_response(metrics.clone()) - .on_failure(metrics); - - let global_concurrency_limit = config - .global_concurrency_limit - .map(tower::limit::GlobalConcurrencyLimitLayer::new); - - fn add_path_to_request_header(request: &Request) -> Option { - let path = request.uri().path(); - Some(HeaderValue::from_str(path).unwrap()) - } - - let layer = ServiceBuilder::new() - .option_layer(global_concurrency_limit) - .option_layer(load_shed) - .layer(RequestLifetimeLayer { metrics_provider }) - .layer(SetRequestHeaderLayer::overriding( - GRPC_ENDPOINT_PATH_HEADER.clone(), - add_path_to_request_header as AddPathToHeaderFunction, - )) - .layer(request_metrics) - .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone())) - .into_inner(); - let (health_reporter, health_service) = tonic_health::server::health_reporter(); - let router = builder - .initial_stream_window_size(config.http2_initial_stream_window_size) - .initial_connection_window_size(config.http2_initial_connection_window_size) - .http2_keepalive_interval(config.http2_keepalive_interval) - .http2_keepalive_timeout(config.http2_keepalive_timeout) - .max_concurrent_streams(config.http2_max_concurrent_streams) - .tcp_keepalive(config.tcp_keepalive) - .layer(layer) - .add_service(health_service); + let router = tonic::service::Routes::new(health_service); Self { + config: config.to_owned(), + metrics_provider, router, health_reporter, } @@ -162,201 +63,79 @@ impl ServerBuilder { } pub async fn bind(self, addr: &Multiaddr, tls_config: Option) -> Result { - let mut iter = addr.iter(); - - let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel(); - let rx_cancellation = rx_cancellation.map(|_| ()); - let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) = match iter - .next() - .ok_or_else(|| eyre!("malformed addr"))? - { - Protocol::Dns(_) => { - let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?; - let (local_addr, incoming) = - listen_and_update_multiaddr(addr, (dns_name.to_string(), tcp_port), tls_config) - .await?; - let server = Box::pin( - self.router - .serve_with_incoming_shutdown(incoming, rx_cancellation), - ); - (local_addr, server) - } - Protocol::Ip4(_) => { - let (socket_addr, _http_or_https) = parse_ip4(addr)?; - let (local_addr, incoming) = - listen_and_update_multiaddr(addr, socket_addr, tls_config).await?; - let server = Box::pin( - self.router - .serve_with_incoming_shutdown(incoming, rx_cancellation), - ); - (local_addr, server) - } - Protocol::Ip6(_) => { - let (socket_addr, _http_or_https) = parse_ip6(addr)?; - let (local_addr, incoming) = - listen_and_update_multiaddr(addr, socket_addr, tls_config).await?; - let server = Box::pin( - self.router - .serve_with_incoming_shutdown(incoming, rx_cancellation), - ); - (local_addr, server) - } - unsupported => return Err(eyre!("unsupported protocol {unsupported}")), - }; - - Ok(Server { - server, - cancel_handle: Some(tx_cancellation), - local_addr, - health_reporter: self.health_reporter, - }) - } -} - -async fn listen_and_update_multiaddr( - address: &Multiaddr, - socket_addr: T, - tls_config: Option, -) -> Result<( - Multiaddr, - impl Stream>, -)> { - let listener = TcpListener::bind(socket_addr).await?; - let local_addr = listener.local_addr()?; - let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port()); - - let tls_acceptor = tls_config.map(|tls_config| TlsAcceptor::from(Arc::new(tls_config))); - let incoming = TcpOrTlsListener::new(listener, tls_acceptor); - let stream = async_stream::stream! { - let mut new_connections = FuturesUnordered::new(); - loop { - tokio::select! { - result = incoming.accept_raw() => { - match result { - Ok((stream, addr)) => { - new_connections.push(incoming.maybe_upgrade(stream, addr)); - } - Err(e) => yield Err(e), - } - } - Some(result) = new_connections.next() => { - yield result; - } - } - } - }; - - Ok((local_addr, stream)) -} - -pub struct TcpOrTlsListener { - listener: TcpListener, - tls_acceptor: Option, -} - -impl TcpOrTlsListener { - fn new(listener: TcpListener, tls_acceptor: Option) -> Self { - Self { - listener, - tls_acceptor, - } - } - - async fn accept_raw(&self) -> std::io::Result<(TcpStream, SocketAddr)> { - self.listener.accept().await - } - - async fn maybe_upgrade( - &self, - stream: TcpStream, - addr: SocketAddr, - ) -> std::io::Result { - if self.tls_acceptor.is_none() { - return Ok(TcpOrTlsStream::Tcp(stream, addr)); - } + let http_config = self + .config + .http_config() + // Temporarily continue allowing clients to connection without TLS even when the server + // is configured with a tls_config + .allow_insecure(true); + + let request_timeout = self.config.request_timeout; + let metrics_provider = self.metrics_provider; + let metrics = MetricsHandler::new(metrics_provider.clone()); + let request_metrics = TraceLayer::new_for_grpc() + .on_request(metrics.clone()) + .on_response(metrics.clone()) + .on_failure(metrics); - // Determine whether new connection is TLS. - let mut buf = [0; 1]; - // `peek` blocks until at least some data is available, so if there is no error then - // it must return the one byte we are requesting. - stream.peek(&mut buf).await?; - if buf[0] == 0x16 { - // First byte of a TLS handshake is 0x16. - debug!("accepting TLS connection from {addr:?}"); - let stream = self.tls_acceptor.as_ref().unwrap().accept(stream).await?; - Ok(TcpOrTlsStream::Tls(stream, addr)) - } else { - debug!("accepting TCP connection from {addr:?}"); - Ok(TcpOrTlsStream::Tcp(stream, addr)) + fn add_path_to_request_header(request: &Request) -> Option { + let path = request.uri().path(); + Some(HeaderValue::from_str(path).unwrap()) } - } -} - -pub enum TcpOrTlsStream { - Tcp(TcpStream, SocketAddr), - Tls(TlsStream, SocketAddr), -} -impl AsyncRead for TcpOrTlsStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf, - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_read(cx, buf), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_read(cx, buf), - } - } -} + let limiting_layers = ServiceBuilder::new() + .option_layer( + self.config + .load_shed + .unwrap_or_default() + .then_some(tower::load_shed::LoadShedLayer::new()), + ) + .option_layer( + self.config + .global_concurrency_limit + .map(tower::limit::GlobalConcurrencyLimitLayer::new), + ); + let route_layers = ServiceBuilder::new() + .map_request(|mut request: http::Request<_>| { + if let Some(connect_info) = request.extensions().get::() { + let tonic_connect_info = tonic::transport::server::TcpConnectInfo { + local_addr: Some(connect_info.local_addr), + remote_addr: Some(connect_info.remote_addr), + }; + request.extensions_mut().insert(tonic_connect_info); + } + request + }) + .layer(RequestLifetimeLayer { metrics_provider }) + .layer(SetRequestHeaderLayer::overriding( + GRPC_ENDPOINT_PATH_HEADER.clone(), + add_path_to_request_header, + )) + .layer(request_metrics) + .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone())) + .layer_fn(move |service| { + crate::grpc_timeout::GrpcTimeout::new(service, request_timeout) + }); -impl AsyncWrite for TcpOrTlsStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_write(cx, buf), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_write(cx, buf), - } - } + let mut builder = sui_http::Builder::new().config(http_config); - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_flush(cx), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_flush(cx), + if let Some(tls_config) = tls_config { + builder = builder.tls_config(tls_config); } - } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_shutdown(cx), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_shutdown(cx), - } - } -} + let server_handle = builder + .serve( + addr, + limiting_layers.service(self.router.into_axum_router().layer(route_layers)), + ) + .map_err(|e| eyre!(e))?; -impl tonic::transport::server::Connected for TcpOrTlsStream { - type ConnectInfo = tonic::transport::server::TcpConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - match self { - TcpOrTlsStream::Tcp(stream, addr) => Self::ConnectInfo { - local_addr: stream.local_addr().ok(), - remote_addr: Some(*addr), - }, - TcpOrTlsStream::Tls(stream, addr) => Self::ConnectInfo { - local_addr: stream.get_ref().0.local_addr().ok(), - remote_addr: Some(*addr), - }, - } + let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port()); + Ok(Server { + server: server_handle, + local_addr, + health_reporter: self.health_reporter, + }) } } @@ -364,15 +143,15 @@ impl tonic::transport::server::Connected for TcpOrTlsStream { pub const SUI_TLS_SERVER_NAME: &str = "sui"; pub struct Server { - server: BoxFuture<(), tonic::transport::Error>, - cancel_handle: Option>, + server: sui_http::ServerHandle, local_addr: Multiaddr, health_reporter: tonic_health::server::HealthReporter, } impl Server { pub async fn serve(self) -> Result<(), tonic::transport::Error> { - self.server.await + self.server.wait_for_shutdown().await; + Ok(()) } pub fn local_addr(&self) -> &Multiaddr { @@ -383,8 +162,8 @@ impl Server { self.health_reporter.clone() } - pub fn take_cancel_handle(&mut self) -> Option> { - self.cancel_handle.take() + pub fn handle(&self) -> &sui_http::ServerHandle { + &self.server } } @@ -458,15 +237,13 @@ mod test { let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap(); let config = Config::new(); - let mut server = config + let server = config .server_builder_with_metrics(metrics.clone()) .bind(&address, None) .await .unwrap(); let address = server.local_addr().to_owned(); - let cancel_handle = server.take_cancel_handle().unwrap(); - let server_handle = tokio::spawn(server.serve()); let channel = config.connect(&address, None).await.unwrap(); let mut client = HealthClient::new(channel); @@ -477,8 +254,7 @@ mod test { .await .unwrap(); - cancel_handle.send(()).unwrap(); - server_handle.await.unwrap().unwrap(); + server.server.shutdown().await; assert!(metrics.metrics_called.lock().unwrap().deref()); } @@ -521,15 +297,13 @@ mod test { let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap(); let config = Config::new(); - let mut server = config + let server = config .server_builder_with_metrics(metrics.clone()) .bind(&address, None) .await .unwrap(); let address = server.local_addr().to_owned(); - let cancel_handle = server.take_cancel_handle().unwrap(); - let server_handle = tokio::spawn(server.serve()); let channel = config.connect(&address, None).await.unwrap(); let mut client = HealthClient::new(channel); @@ -542,18 +316,15 @@ mod test { }) .await; - cancel_handle.send(()).unwrap(); - server_handle.await.unwrap().unwrap(); + server.server.shutdown().await; assert!(metrics.metrics_called.lock().unwrap().deref()); } async fn test_multiaddr(address: Multiaddr) { let config = Config::new(); - let mut server = config.server_builder().bind(&address, None).await.unwrap(); - let address = server.local_addr().to_owned(); - let cancel_handle = server.take_cancel_handle().unwrap(); - let server_handle = tokio::spawn(server.serve()); + let server_handle = config.server_builder().bind(&address, None).await.unwrap(); + let address = server_handle.local_addr().to_owned(); let channel = config.connect(&address, None).await.unwrap(); let mut client = HealthClient::new(channel); @@ -564,8 +335,7 @@ mod test { .await .unwrap(); - cancel_handle.send(()).unwrap(); - server_handle.await.unwrap().unwrap(); + server_handle.server.shutdown().await; } #[tokio::test] diff --git a/crates/sui-core/src/authority_server.rs b/crates/sui-core/src/authority_server.rs index 0f32457740d50..8ee1290749965 100644 --- a/crates/sui-core/src/authority_server.rs +++ b/crates/sui-core/src/authority_server.rs @@ -44,7 +44,6 @@ use sui_types::{ }, }; use tap::TapFallible; -use tokio::task::JoinHandle; use tonic::metadata::{Ascii, MetadataValue}; use tracing::{error, error_span, info, Instrument}; @@ -72,32 +71,22 @@ use tonic::transport::server::TcpConnectInfo; mod server_tests; pub struct AuthorityServerHandle { - tx_cancellation: tokio::sync::oneshot::Sender<()>, - local_addr: Multiaddr, - handle: JoinHandle>, + server_handle: mysten_network::server::Server, } impl AuthorityServerHandle { pub async fn join(self) -> Result<(), io::Error> { - // Note that dropping `self.complete` would terminate the server. - self.handle - .await? - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + self.server_handle.handle().wait_for_shutdown().await; Ok(()) } pub async fn kill(self) -> Result<(), io::Error> { - self.tx_cancellation.send(()).map_err(|_e| { - io::Error::new(io::ErrorKind::Other, "could not send cancellation signal!") - })?; - self.handle - .await? - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + self.server_handle.handle().shutdown().await; Ok(()) } pub fn address(&self) -> &Multiaddr { - &self.local_addr + self.server_handle.local_addr() } } @@ -151,9 +140,8 @@ impl AuthorityServer { let tls_config = sui_tls::create_rustls_server_config( self.state.config.network_key_pair().copy().private(), SUI_TLS_SERVER_NAME.to_string(), - sui_tls::AllowAll, ); - let mut server = mysten_network::config::Config::new() + let server = mysten_network::config::Config::new() .server_builder() .add_service(ValidatorServer::new(ValidatorService::new_for_tests( self.state, @@ -166,9 +154,7 @@ impl AuthorityServer { let local_addr = server.local_addr().to_owned(); info!("Listening to traffic on {local_addr}"); let handle = AuthorityServerHandle { - tx_cancellation: server.take_cancel_handle().unwrap(), - local_addr, - handle: spawn_monitored_task!(server.serve()), + server_handle: server, }; Ok(handle) } diff --git a/crates/sui-http/Cargo.toml b/crates/sui-http/Cargo.toml new file mode 100644 index 0000000000000..3a0cdfa348cd3 --- /dev/null +++ b/crates/sui-http/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "sui-http" +version = "0.0.0" +authors = ["Brandon Williams "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +bytes = "1" +http = "1" +http-body = "1" +http-body-util = "0.1" +hyper = { version = "1", features = ["http1", "http2"] } +hyper-util = { version = "0.1.4", features = ["tokio", "server-auto", "service"] } +pin-project-lite = "0.2.15" +socket2 = { version = "0.5", features = ["all"] } +tokio = { version = "1.36.0", default-features = false, features = ["macros"] } +tokio-util = { version = "0.7.10" } +tower = { version = "0.4", default-features = false, features = ["util"] } +tracing = { version = "0.1" } + +# TLS support +tokio-rustls = { version = "0.26", default-features = false } + +[dev-dependencies] +axum.workspace = true +reqwest.workspace = true diff --git a/crates/sui-http/src/body.rs b/crates/sui-http/src/body.rs new file mode 100644 index 0000000000000..9a270d3423255 --- /dev/null +++ b/crates/sui-http/src/body.rs @@ -0,0 +1,29 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use crate::BoxError; +use bytes::Bytes; +use http_body_util::BodyExt; + +pub type BoxBody = http_body_util::combinators::UnsyncBoxBody; + +pub fn boxed(body: B) -> BoxBody +where + B: http_body::Body + Send + 'static, + B::Error: Into, +{ + try_downcast(body).unwrap_or_else(|body| body.map_err(Into::into).boxed_unsync()) +} + +pub(crate) fn try_downcast(k: K) -> Result +where + T: 'static, + K: Send + 'static, +{ + let mut k = Some(k); + if let Some(k) = ::downcast_mut::>(&mut k) { + Ok(k.take().unwrap()) + } else { + Err(k.unwrap()) + } +} diff --git a/crates/sui-http/src/config.rs b/crates/sui-http/src/config.rs new file mode 100644 index 0000000000000..fa11d7bd67aa5 --- /dev/null +++ b/crates/sui-http/src/config.rs @@ -0,0 +1,250 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::time::Duration; + +const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; + +pub struct Config { + init_stream_window_size: Option, + init_connection_window_size: Option, + max_concurrent_streams: Option, + pub(crate) tcp_keepalive: Option, + pub(crate) tcp_nodelay: bool, + http2_keepalive_interval: Option, + http2_keepalive_timeout: Option, + http2_adaptive_window: Option, + http2_max_pending_accept_reset_streams: Option, + http2_max_header_list_size: Option, + max_frame_size: Option, + pub(crate) accept_http1: bool, + enable_connect_protocol: bool, + pub(crate) max_connection_age: Option, + pub(crate) allow_insecure: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + init_stream_window_size: None, + init_connection_window_size: None, + max_concurrent_streams: None, + tcp_keepalive: None, + tcp_nodelay: false, + http2_keepalive_interval: None, + http2_keepalive_timeout: None, + http2_adaptive_window: None, + http2_max_pending_accept_reset_streams: None, + http2_max_header_list_size: None, + max_frame_size: None, + accept_http1: true, + enable_connect_protocol: true, + max_connection_age: None, + allow_insecure: false, + } + } +} + +impl Config { + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Default is 65,535 + /// + /// [spec]: https://httpwg.org/specs/rfc9113.html#InitialWindowSize + pub fn initial_stream_window_size(self, sz: impl Into>) -> Self { + Self { + init_stream_window_size: sz.into(), + ..self + } + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Default is 65,535 + pub fn initial_connection_window_size(self, sz: impl Into>) -> Self { + Self { + init_connection_window_size: sz.into(), + ..self + } + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is no limit (`None`). + /// + /// [spec]: https://httpwg.org/specs/rfc9113.html#n-stream-concurrency + pub fn max_concurrent_streams(self, max: impl Into>) -> Self { + Self { + max_concurrent_streams: max.into(), + ..self + } + } + + /// Sets the maximum time option in milliseconds that a connection may exist + /// + /// Default is no limit (`None`). + pub fn max_connection_age(self, max_connection_age: Duration) -> Self { + Self { + max_connection_age: Some(max_connection_age), + ..self + } + } + + /// Set whether HTTP2 Ping frames are enabled on accepted connections. + /// + /// If `None` is specified, HTTP2 keepalive is disabled, otherwise the duration + /// specified will be the time interval between HTTP2 Ping frames. + /// The timeout for receiving an acknowledgement of the keepalive ping + /// can be set with [`Config::http2_keepalive_timeout`]. + /// + /// Default is no HTTP2 keepalive (`None`) + pub fn http2_keepalive_interval(self, http2_keepalive_interval: Option) -> Self { + Self { + http2_keepalive_interval, + ..self + } + } + + /// Sets a timeout for receiving an acknowledgement of the keepalive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will be closed. + /// Does nothing if http2_keep_alive_interval is disabled. + /// + /// Default is 20 seconds. + pub fn http2_keepalive_timeout(self, http2_keepalive_timeout: Option) -> Self { + Self { + http2_keepalive_timeout, + ..self + } + } + + /// Sets whether to use an adaptive flow control. Defaults to false. + /// Enabling this will override the limits set in http2_initial_stream_window_size and + /// http2_initial_connection_window_size. + pub fn http2_adaptive_window(self, enabled: Option) -> Self { + Self { + http2_adaptive_window: enabled, + ..self + } + } + + /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent. + /// + /// This will default to whatever the default in h2 is. As of v0.3.17, it is 20. + /// + /// See for more information. + pub fn http2_max_pending_accept_reset_streams(self, max: Option) -> Self { + Self { + http2_max_pending_accept_reset_streams: max, + ..self + } + } + + /// Set whether TCP keepalive messages are enabled on accepted connections. + /// + /// If `None` is specified, keepalive is disabled, otherwise the duration + /// specified will be the time to remain idle before sending TCP keepalive + /// probes. + /// + /// Default is no keepalive (`None`) + pub fn tcp_keepalive(self, tcp_keepalive: Option) -> Self { + Self { + tcp_keepalive, + ..self + } + } + + /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default. + pub fn tcp_nodelay(self, enabled: bool) -> Self { + Self { + tcp_nodelay: enabled, + ..self + } + } + + /// Sets the max size of received header frames. + /// + /// This will default to whatever the default in hyper is. As of v1.4.1, it is 16 KiB. + pub fn http2_max_header_list_size(self, max: impl Into>) -> Self { + Self { + http2_max_header_list_size: max.into(), + ..self + } + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, will default from underlying transport. + pub fn max_frame_size(self, frame_size: impl Into>) -> Self { + Self { + max_frame_size: frame_size.into(), + ..self + } + } + + /// Allow this accepting http1 requests. + /// + /// Default is `true`. + pub fn accept_http1(self, accept_http1: bool) -> Self { + Config { + accept_http1, + ..self + } + } + + /// Allow accepting insecure connections when a tls_config is provided. + /// + /// This will allow clients to connect both using TLS as well as without TLS on the same + /// network interface. + /// + /// Default is `false`. + /// + /// NOTE: This presently will only work for `tokio::net::TcpStream` IO connections + pub fn allow_insecure(self, allow_insecure: bool) -> Self { + Config { + allow_insecure, + ..self + } + } + + pub(crate) fn connection_builder( + &self, + ) -> hyper_util::server::conn::auto::Builder { + let mut builder = + hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + + if !self.accept_http1 { + builder = builder.http2_only(); + } + + if self.enable_connect_protocol { + builder.http2().enable_connect_protocol(); + } + + let http2_keepalive_timeout = self + .http2_keepalive_timeout + .unwrap_or_else(|| Duration::new(DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS, 0)); + + builder + .http2() + .timer(hyper_util::rt::TokioTimer::new()) + .initial_connection_window_size(self.init_connection_window_size) + .initial_stream_window_size(self.init_stream_window_size) + .max_concurrent_streams(self.max_concurrent_streams) + .keep_alive_interval(self.http2_keepalive_interval) + .keep_alive_timeout(http2_keepalive_timeout) + .adaptive_window(self.http2_adaptive_window.unwrap_or_default()) + .max_pending_accept_reset_streams(self.http2_max_pending_accept_reset_streams) + .max_frame_size(self.max_frame_size); + + if let Some(max_header_list_size) = self.http2_max_header_list_size { + builder.http2().max_header_list_size(max_header_list_size); + } + + builder + } +} diff --git a/crates/sui-http/src/connection_handler.rs b/crates/sui-http/src/connection_handler.rs new file mode 100644 index 0000000000000..1474e88dc9d1e --- /dev/null +++ b/crates/sui-http/src/connection_handler.rs @@ -0,0 +1,83 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::{pin::pin, time::Duration}; + +use http::{Request, Response}; +use tracing::{debug, trace}; + +use crate::{fuse::Fuse, ActiveConnections, BoxError, ConnectionId}; + +// This is moved to its own function as a way to get around +// https://github.com/rust-lang/rust/issues/102211 +pub async fn serve_connection( + hyper_io: IO, + hyper_svc: S, + builder: hyper_util::server::conn::auto::Builder, + graceful_shutdown_token: tokio_util::sync::CancellationToken, + max_connection_age: Option, + on_connection_close: C, +) where + B: http_body::Body + Send + 'static, + B::Data: Send, + B::Error: Into, + IO: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + S: hyper::service::Service, Response = Response> + 'static, + S::Future: Send + 'static, + S::Error: Into>, +{ + let mut sig = pin!(Fuse::new(graceful_shutdown_token.cancelled_owned())); + + let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc)); + + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + loop { + tokio::select! { + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sleep => { + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + } + } + + trace!("connection closed"); + drop(on_connection_close); +} + +async fn sleep_or_pending(wait_for: Option) { + match wait_for { + Some(wait) => tokio::time::sleep(wait).await, + None => std::future::pending().await, + }; +} + +pub(crate) struct OnConnectionClose { + id: ConnectionId, + active_connections: ActiveConnections, +} + +impl OnConnectionClose { + pub(crate) fn new(id: ConnectionId, active_connections: ActiveConnections) -> Self { + Self { + id, + active_connections, + } + } +} + +impl Drop for OnConnectionClose { + fn drop(&mut self) { + self.active_connections.write().unwrap().remove(&self.id); + } +} diff --git a/crates/sui-http/src/connection_info.rs b/crates/sui-http/src/connection_info.rs new file mode 100644 index 0000000000000..ab104439eb7cd --- /dev/null +++ b/crates/sui-http/src/connection_info.rs @@ -0,0 +1,93 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use tokio_rustls::rustls::pki_types::CertificateDer; + +pub(crate) type ActiveConnections = + Arc>>>; + +pub type ConnectionId = usize; + +#[derive(Debug)] +pub struct ConnectionInfo(Arc>); + +#[derive(Clone, Debug)] +pub struct PeerCertificates(Arc>>); + +impl PeerCertificates { + pub fn peer_certs(&self) -> &[tokio_rustls::rustls::pki_types::CertificateDer<'static>] { + self.0.as_ref() + } +} + +impl ConnectionInfo { + pub(crate) fn new( + address: A, + peer_certificates: Option>>>, + graceful_shutdown_token: tokio_util::sync::CancellationToken, + ) -> Self { + Self(Arc::new(Inner { + address, + time_established: std::time::Instant::now(), + peer_certificates: peer_certificates.map(PeerCertificates), + graceful_shutdown_token, + })) + } + + /// The peer's remote address + pub fn remote_address(&self) -> &A { + &self.0.address + } + + /// Time the Connection was established + pub fn time_established(&self) -> std::time::Instant { + self.0.time_established + } + + pub fn peer_certificates(&self) -> Option<&PeerCertificates> { + self.0.peer_certificates.as_ref() + } + + /// A stable identifier for this connection + pub fn id(&self) -> ConnectionId { + &*self.0 as *const _ as usize + } + + /// Trigger a graceful shutdown of this connection + pub fn close(&self) { + self.0.graceful_shutdown_token.cancel() + } +} + +#[derive(Debug)] +struct Inner { + address: A, + + // Time that the connection was established + time_established: std::time::Instant, + + peer_certificates: Option, + graceful_shutdown_token: tokio_util::sync::CancellationToken, +} + +#[derive(Debug, Clone)] +pub struct ConnectInfo { + /// Returns the local address of this connection. + pub local_addr: A, + /// Returns the remote (peer) address of this connection. + pub remote_addr: A, +} + +impl ConnectInfo { + /// Return the local address the IO resource is connected. + pub fn local_addr(&self) -> &A { + &self.local_addr + } + + /// Return the remote address the IO resource is connected too. + pub fn remote_addr(&self) -> &A { + &self.remote_addr + } +} diff --git a/crates/sui-http/src/fuse.rs b/crates/sui-http/src/fuse.rs new file mode 100644 index 0000000000000..362f0bbead2a3 --- /dev/null +++ b/crates/sui-http/src/fuse.rs @@ -0,0 +1,43 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +// From `futures-util` crate +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +pin_project_lite::pin_project! { + pub struct Fuse { + #[pin] + inner: Option, + } +} + +impl Fuse { + pub fn new(future: F) -> Self { + Self { + inner: Some(future), + } + } +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/crates/sui-http/src/io.rs b/crates/sui-http/src/io.rs new file mode 100644 index 0000000000000..6172b3b908ef6 --- /dev/null +++ b/crates/sui-http/src/io.rs @@ -0,0 +1,103 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::io::IoSlice; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::server::TlsStream; + +pub(crate) enum ServerIo { + Io(IO), + TlsIo(Box>), +} + +impl ServerIo { + pub(crate) fn new_io(io: IO) -> Self { + Self::Io(io) + } + + pub(crate) fn new_tls_io(io: TlsStream) -> Self { + Self::TlsIo(Box::new(io)) + } + + pub(crate) fn peer_certs( + &self, + ) -> Option>>> { + match self { + Self::Io(_) => None, + Self::TlsIo(io) => { + let (_inner, session) = io.get_ref(); + + session + .peer_certificates() + .map(|certs| certs.to_owned().into()) + } + } + } +} + +impl AsyncRead for ServerIo +where + IO: AsyncWrite + AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_read(cx, buf), + Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ServerIo +where + IO: AsyncWrite + AsyncRead + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write(cx, buf), + Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_flush(cx), + Self::TlsIo(io) => Pin::new(io).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_shutdown(cx), + Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs), + Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Io(io) => io.is_write_vectored(), + Self::TlsIo(io) => io.is_write_vectored(), + } + } +} diff --git a/crates/sui-http/src/lib.rs b/crates/sui-http/src/lib.rs new file mode 100644 index 0000000000000..3d40203520bca --- /dev/null +++ b/crates/sui-http/src/lib.rs @@ -0,0 +1,424 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use connection_handler::OnConnectionClose; +use http::{Request, Response}; +use hyper_util::service::TowerToHyperService; +use io::ServerIo; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::task::JoinSet; +use tokio_rustls::TlsAcceptor; +use tower::{Service, ServiceBuilder, ServiceExt}; +use tracing::trace; + +use self::body::BoxBody; +use self::connection_info::ActiveConnections; + +pub use http; + +pub mod body; +mod config; +mod connection_handler; +mod connection_info; +mod fuse; +mod io; +mod listener; + +pub use config::Config; +pub use listener::Listener; +pub use listener::ListenerExt; + +pub use connection_info::ConnectInfo; +pub use connection_info::ConnectionId; +pub use connection_info::ConnectionInfo; +pub use connection_info::PeerCertificates; + +pub(crate) type BoxError = Box; +/// h2 alpn in plain format for rustls. +const ALPN_H2: &[u8] = b"h2"; +/// h1 alpn in plain format for rustls. +const ALPN_H1: &[u8] = b"http/1.1"; + +#[derive(Default)] +pub struct Builder { + config: Config, + tls_config: Option, +} + +impl Builder { + pub fn new() -> Self { + Self::default() + } + + pub fn config(mut self, config: Config) -> Self { + self.config = config; + self + } + + pub fn tls_config(mut self, tls_config: tokio_rustls::rustls::ServerConfig) -> Self { + self.tls_config = Some(tls_config); + self + } + + pub fn serve( + self, + addr: A, + service: S, + ) -> Result, BoxError> + where + A: std::net::ToSocketAddrs, + S: Service< + Request, + Response = Response, + Error: Into, + Future: Send, + > + Clone + + Send + + 'static, + ResponseBody: http_body::Body> + Send + 'static, + { + let listener = listener::TcpListenerWithOptions::new( + addr, + self.config.tcp_nodelay, + self.config.tcp_keepalive, + )?; + + Self::serve_with_listener(self, listener, service) + } + + fn serve_with_listener( + self, + listener: L, + service: S, + ) -> Result, BoxError> + where + L: Listener, + S: Service< + Request, + Response = Response, + Error: Into, + Future: Send, + > + Clone + + Send + + 'static, + ResponseBody: http_body::Body> + Send + 'static, + { + let local_addr = listener.local_addr()?; + let graceful_shutdown_token = tokio_util::sync::CancellationToken::new(); + let connections = ActiveConnections::default(); + + let tls_config = self.tls_config.map(|mut tls| { + tls.alpn_protocols.push(ALPN_H2.into()); + if self.config.accept_http1 { + tls.alpn_protocols.push(ALPN_H1.into()); + } + Arc::new(tls) + }); + + let (watch_sender, watch_reciever) = tokio::sync::watch::channel(()); + let server = Server { + config: self.config, + tls_config, + listener, + local_addr: local_addr.clone(), + service: ServiceBuilder::new() + .layer(tower::util::BoxCloneService::layer()) + .map_response(|response: Response| response.map(body::boxed)) + .map_err(Into::into) + .service(service), + pending_connections: JoinSet::new(), + connection_handlers: JoinSet::new(), + connections: connections.clone(), + graceful_shutdown_token: graceful_shutdown_token.clone(), + _watch_reciever: watch_reciever, + }; + + let handle = ServerHandle(Arc::new(HandleInner { + local_addr, + connections, + graceful_shutdown_token, + watch_sender, + })); + + tokio::spawn(server.serve()); + + Ok(handle) + } +} + +#[derive(Debug)] +pub struct ServerHandle(Arc>); + +#[derive(Debug)] +struct HandleInner { + /// The local address of the server. + local_addr: A, + connections: ActiveConnections, + graceful_shutdown_token: tokio_util::sync::CancellationToken, + watch_sender: tokio::sync::watch::Sender<()>, +} + +impl ServerHandle { + /// Returns the local address of the server + pub fn local_addr(&self) -> &A { + &self.0.local_addr + } + + /// Trigger a graceful shutdown of the server, but don't wait till the server has completed + /// shutting down + pub fn trigger_shutdown(&self) { + self.0.graceful_shutdown_token.cancel(); + } + + /// Completes once the network has been shutdown. + /// + /// This explicitly *does not* trigger the network to shutdown, see `trigger_shutdown` or + /// `shutdown` if you want to trigger shutting down the server. + pub async fn wait_for_shutdown(&self) { + self.0.watch_sender.closed().await + } + + /// Triggers a shutdown of the server and waits for it to complete shutting down. + pub async fn shutdown(&self) { + self.trigger_shutdown(); + self.wait_for_shutdown().await; + } + + /// Checks if the Server has been shutdown. + pub fn is_shutdown(&self) -> bool { + self.0.watch_sender.is_closed() + } + + pub fn connections( + &self, + ) -> std::sync::RwLockReadGuard<'_, HashMap>> { + self.0.connections.read().unwrap() + } + + /// Returns the number of active connections the server is handling + pub fn number_of_connections(&self) -> usize { + self.connections().len() + } +} + +type ConnectingOutput = Result<(ServerIo, Addr), crate::BoxError>; + +struct Server { + config: Config, + tls_config: Option>, + + listener: L, + local_addr: L::Addr, + service: tower::util::BoxCloneService, Response, crate::BoxError>, + + pending_connections: JoinSet>, + connection_handlers: JoinSet<()>, + connections: ActiveConnections, + graceful_shutdown_token: tokio_util::sync::CancellationToken, + // Used to signal to a ServerHandle when the server has completed shutting down + _watch_reciever: tokio::sync::watch::Receiver<()>, +} + +impl Server +where + L: Listener, +{ + async fn serve(mut self) -> Result<(), BoxError> { + loop { + tokio::select! { + _ = self.graceful_shutdown_token.cancelled() => { + trace!("signal received, shutting down"); + break; + }, + (io, remote_addr) = self.listener.accept() => { + self.handle_incomming(io, remote_addr); + }, + Some(maybe_connection) = self.pending_connections.join_next() => { + // If a task panics, just propagate it + let (io, remote_addr) = match maybe_connection.unwrap() { + Ok((io, remote_addr)) => { + (io, remote_addr) + } + Err(e) => { + tracing::debug!(error = %e, "error accepting connection"); + continue; + } + }; + + trace!("connection accepted"); + self.handle_connection(io, remote_addr); + }, + Some(connection_handler_output) = self.connection_handlers.join_next() => { + // If a task panics, just propagate it + let _: () = connection_handler_output.unwrap(); + }, + } + } + + // Shutting down, wait for all connection handlers to finish + self.shutdown().await; + + Ok(()) + } + + fn handle_incomming(&mut self, io: L::Io, remote_addr: L::Addr) { + if let Some(tls) = self.tls_config.clone() { + let tls_acceptor = TlsAcceptor::from(tls); + let allow_insecure = self.config.allow_insecure; + self.pending_connections.spawn(async move { + if allow_insecure { + // XXX: If we want to allow for supporting insecure traffic from other types of + // io, we'll need to implement a generic peekable IO type + if let Some(tcp) = + ::downcast_ref::(&io) + { + // Determine whether new connection is TLS. + let mut buf = [0; 1]; + // `peek` blocks until at least some data is available, so if there is no error then + // it must return the one byte we are requesting. + tcp.peek(&mut buf).await?; + // First byte of a TLS handshake is 0x16, so if it isn't 0x16 then its + // insecure + if buf != [0x16] { + tracing::trace!("accepting insecure connection"); + return Ok((ServerIo::new_io(io), remote_addr)); + } + } else { + tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'"); + } + } + + tracing::trace!("accepting TLS connection"); + let io = tls_acceptor.accept(io).await?; + Ok((ServerIo::new_tls_io(io), remote_addr)) + }); + } else { + self.handle_connection(ServerIo::new_io(io), remote_addr); + } + } + + fn handle_connection(&mut self, io: ServerIo, remote_addr: L::Addr) { + let connection_shutdown_token = self.graceful_shutdown_token.child_token(); + let connection_info = ConnectionInfo::new( + remote_addr, + io.peer_certs(), + connection_shutdown_token.clone(), + ); + let connection_id = connection_info.id(); + let connect_info = connection_info::ConnectInfo { + local_addr: self.local_addr.clone(), + remote_addr: connection_info.remote_address().clone(), + }; + let peer_certificates = connection_info.peer_certificates().cloned(); + let hyper_io = hyper_util::rt::TokioIo::new(io); + + let hyper_svc = TowerToHyperService::new(self.service.clone().map_request( + move |mut request: Request| { + request.extensions_mut().insert(connect_info.clone()); + if let Some(peer_certificates) = peer_certificates.clone() { + request.extensions_mut().insert(peer_certificates); + } + + request.map(body::boxed) + }, + )); + + self.connections + .write() + .unwrap() + .insert(connection_id, connection_info); + let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone()); + + self.connection_handlers + .spawn(connection_handler::serve_connection( + hyper_io, + hyper_svc, + self.config.connection_builder(), + connection_shutdown_token, + self.config.max_connection_age, + on_connection_close, + )); + } + + async fn shutdown(mut self) { + // The time we are willing to wait for a connection to get gracefully shutdown before we + // attempt to forcefully shutdown all active connections + const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1); + + // Just to be careful make sure the token is canceled + self.graceful_shutdown_token.cancel(); + + // Terminate any in-progress pending connections + self.pending_connections.shutdown().await; + + // Wait for all connection handlers to terminate + trace!( + "waiting for {} connections to close", + self.connection_handlers.len() + ); + + let graceful_shutdown = + async { while self.connection_handlers.join_next().await.is_some() {} }; + + if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown) + .await + .is_err() + { + tracing::warn!( + "Failed to stop all connection handlers in {:?}. Forcing shutdown.", + CONNECTION_SHUTDOWN_GRACE_PERIOD + ); + self.connection_handlers.shutdown().await; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::Router; + + #[tokio::test] + async fn simple() { + const MESSAGE: &str = "Hello, World!"; + + let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE })); + + let handle = Builder::new().serve(("localhost", 0), app).unwrap(); + + let url = format!("http://{}", handle.local_addr()); + + let response = reqwest::get(url).await.unwrap().bytes().await.unwrap(); + + assert_eq!(response, MESSAGE.as_bytes()); + } + + #[tokio::test] + async fn shutdown() { + const MESSAGE: &str = "Hello, World!"; + + let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE })); + + let handle = Builder::new().serve(("localhost", 0), app).unwrap(); + + let url = format!("http://{}", handle.local_addr()); + + let response = reqwest::get(url).await.unwrap().bytes().await.unwrap(); + + // a request was just made so we should have 1 active connection + assert_eq!(handle.connections().len(), 1); + + assert_eq!(response, MESSAGE.as_bytes()); + + assert!(!handle.is_shutdown()); + + handle.shutdown().await; + + assert!(handle.is_shutdown()); + + // Now that the network has been shutdown there should be zero connections + assert_eq!(handle.connections().len(), 0); + } +} diff --git a/crates/sui-http/src/listener.rs b/crates/sui-http/src/listener.rs new file mode 100644 index 0000000000000..b933a05a6f358 --- /dev/null +++ b/crates/sui-http/src/listener.rs @@ -0,0 +1,238 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::time::Duration; + +/// Types that can listen for connections. +pub trait Listener: Send + 'static { + /// The listener's IO type. + type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static; + + /// The listener's address type. + // all these bounds are necessary to add this information in a request extension + type Addr: Clone + Send + Sync + 'static; + + /// Accept a new incoming connection to this listener. + /// + /// If the underlying accept call can return an error, this function must + /// take care of logging and retrying. + fn accept(&mut self) -> impl std::future::Future + Send; + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> std::io::Result; +} + +/// Extensions to [`Listener`]. +pub trait ListenerExt: Listener + Sized { + /// Run a mutable closure on every accepted `Io`. + /// + /// # Example + /// + /// ``` + /// use tracing::trace; + /// use sui_http::ListenerExt; + /// + /// # async { + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + /// .await + /// .unwrap() + /// .tap_io(|tcp_stream| { + /// if let Err(err) = tcp_stream.set_nodelay(true) { + /// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + /// } + /// }); + /// # }; + /// ``` + fn tap_io(self, tap_fn: F) -> TapIo + where + F: FnMut(&mut Self::Io) + Send + 'static, + { + TapIo { + listener: self, + tap_fn, + } + } +} + +impl ListenerExt for L {} + +impl Listener for tokio::net::TcpListener { + type Io = tokio::net::TcpStream; + type Addr = std::net::SocketAddr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } + } + + #[inline] + fn local_addr(&self) -> std::io::Result { + Self::local_addr(self) + } +} + +#[derive(Debug)] +pub struct TcpListenerWithOptions { + inner: tokio::net::TcpListener, + nodelay: bool, + keepalive: Option, +} + +impl TcpListenerWithOptions { + pub fn new( + addr: A, + nodelay: bool, + keepalive: Option, + ) -> Result { + let std_listener = std::net::TcpListener::bind(addr)?; + std_listener.set_nonblocking(true)?; + let listener = tokio::net::TcpListener::from_std(std_listener)?; + + Ok(Self::from_listener(listener, nodelay, keepalive)) + } + + /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. + pub fn from_listener( + listener: tokio::net::TcpListener, + nodelay: bool, + keepalive: Option, + ) -> Self { + Self { + inner: listener, + nodelay, + keepalive, + } + } + + // Consistent with hyper-0.14, this function does not return an error. + fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) { + if self.nodelay { + if let Err(e) = stream.set_nodelay(true) { + tracing::warn!("error trying to set TCP nodelay: {}", e); + } + } + + if let Some(timeout) = self.keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); + + if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { + tracing::warn!("error trying to set TCP keepalive: {}", e); + } + } + } +} + +impl Listener for TcpListenerWithOptions { + type Io = tokio::net::TcpStream; + type Addr = std::net::SocketAddr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + let (io, addr) = Listener::accept(&mut self.inner).await; + self.set_accepted_socket_options(&io); + (io, addr) + } + + #[inline] + fn local_addr(&self) -> std::io::Result { + Listener::local_addr(&self.inner) + } +} + +// Uncomment once we update tokio to >=1.41.0 +// #[cfg(unix)] +// impl Listener for tokio::net::UnixListener { +// type Io = tokio::net::UnixStream; +// type Addr = std::os::unix::net::SocketAddr; + +// async fn accept(&mut self) -> (Self::Io, Self::Addr) { +// loop { +// match Self::accept(self).await { +// Ok((io, addr)) => return (io, addr.into()), +// Err(e) => handle_accept_error(e).await, +// } +// } +// } + +// #[inline] +// fn local_addr(&self) -> std::io::Result { +// Self::local_addr(self).map(Into::into) +// } +// } + +/// Return type of [`ListenerExt::tap_io`]. +/// +/// See that method for details. +pub struct TapIo { + listener: L, + tap_fn: F, +} + +impl std::fmt::Debug for TapIo +where + L: Listener + std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TapIo") + .field("listener", &self.listener) + .finish_non_exhaustive() + } +} + +impl Listener for TapIo +where + L: Listener, + F: FnMut(&mut L::Io) + Send + 'static, +{ + type Io = L::Io; + type Addr = L::Addr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + let (mut io, addr) = self.listener.accept().await; + (self.tap_fn)(&mut io); + (io, addr) + } + + fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() + } +} + +async fn handle_accept_error(e: std::io::Error) { + if is_connection_error(&e) { + return; + } + + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + // + // > A possible scenario is that the process has hit the max open files + // > allowed, and so trying to accept a new connection will fail with + // > `EMFILE`. In some cases, it's preferable to just wait for some time, if + // > the application will likely close some files (or connections), and try + // > to accept the connection again. If this option is `true`, the error + // > will be logged at the `error` level, since it is still a big deal, + // > and then the listener will sleep for 1 second. + // + // hyper allowed customizing this but axum does not. + tracing::error!("accept error: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; +} + +fn is_connection_error(e: &std::io::Error) -> bool { + use std::io::ErrorKind; + + matches!( + e.kind(), + ErrorKind::ConnectionRefused + | ErrorKind::ConnectionAborted + | ErrorKind::ConnectionReset + | ErrorKind::BrokenPipe + | ErrorKind::Interrupted + | ErrorKind::WouldBlock + | ErrorKind::TimedOut + ) +} diff --git a/crates/sui-node/Cargo.toml b/crates/sui-node/Cargo.toml index 430d73f963abe..db60f9770d470 100644 --- a/crates/sui-node/Cargo.toml +++ b/crates/sui-node/Cargo.toml @@ -55,6 +55,7 @@ telemetry-subscribers.workspace = true fastcrypto.workspace = true fastcrypto-zkp.workspace = true move-vm-profiler.workspace = true +sui-http.workspace = true [target.'cfg(msim)'.dependencies] sui-simulator.workspace = true diff --git a/crates/sui-node/src/lib.rs b/crates/sui-node/src/lib.rs index 8e84d01b2e29b..094162ad6e394 100644 --- a/crates/sui-node/src/lib.rs +++ b/crates/sui-node/src/lib.rs @@ -17,7 +17,6 @@ use mysten_network::server::SUI_TLS_SERVER_NAME; use prometheus::Registry; use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt; -use std::net::SocketAddr; use std::path::PathBuf; use std::str::FromStr; #[cfg(msim)] @@ -229,7 +228,7 @@ pub struct SuiNode { config: NodeConfig, validator_components: Mutex>, /// The http server responsible for serving JSON-RPC as well as the experimental rest service - _http_server: Option>, + _http_server: Option, state: Arc, transaction_orchestrator: Option>>, registry_service: RegistryService, @@ -1492,7 +1491,6 @@ impl SuiNode { let tls_config = sui_tls::create_rustls_server_config( config.network_key_pair().copy().private(), SUI_TLS_SERVER_NAME.to_string(), - sui_tls::AllowAll, ); let server = server_builder .bind(config.network_address(), Some(tls_config)) @@ -2020,7 +2018,7 @@ pub async fn build_http_server( prometheus_registry: &Registry, _custom_runtime: Option, software_version: &'static str, -) -> Result>> { +) -> Result> { // Validators do not expose these APIs if config.consensus_config().is_some() { return Ok(None); @@ -2125,25 +2123,23 @@ pub async fn build_http_server( rpc_service.into_router().await }; - router = router.merge(rpc_router); - - let listener = tokio::net::TcpListener::bind(&config.json_rpc_address) - .await - .unwrap(); - let addr = listener.local_addr().unwrap(); + let layers = ServiceBuilder::new() + .map_request(|mut request: axum::http::Request<_>| { + if let Some(connect_info) = request.extensions().get::() { + let axum_connect_info = axum::extract::ConnectInfo(connect_info.remote_addr); + request.extensions_mut().insert(axum_connect_info); + } + request + }) + .layer(axum::middleware::from_fn(server_timing_middleware)); - router = router.layer(axum::middleware::from_fn(server_timing_middleware)); + router = router.merge(rpc_router).layer(layers); - let handle = tokio::spawn(async move { - axum::serve( - listener, - router.into_make_service_with_connect_info::(), - ) - .await - .unwrap() - }); + let handle = sui_http::Builder::new() + .serve(&config.json_rpc_address, router) + .map_err(|e| anyhow::anyhow!("{e}"))?; - info!(local_addr =? addr, "Sui JSON-RPC server listening on {addr}"); + info!(local_addr =? handle.local_addr(), "Sui JSON-RPC server listening on {}", handle.local_addr()); Ok(Some(handle)) } diff --git a/crates/sui-tls/src/lib.rs b/crates/sui-tls/src/lib.rs index 7f40317d43303..a94169391b7c0 100644 --- a/crates/sui-tls/src/lib.rs +++ b/crates/sui-tls/src/lib.rs @@ -5,6 +5,8 @@ mod acceptor; mod certgen; mod verifier; +use std::sync::Arc; + pub use acceptor::{TlsAcceptor, TlsConnectionInfo}; pub use certgen::SelfSignedCertificate; use rustls::ClientConfig; @@ -20,7 +22,29 @@ use tokio_rustls::rustls::ServerConfig; pub const SUI_VALIDATOR_SERVER_NAME: &str = "sui"; -pub fn create_rustls_server_config( +pub fn create_rustls_server_config( + private_key: Ed25519PrivateKey, + server_name: String, +) -> ServerConfig { + // TODO: refactor to use key bytes + let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str()); + let tls_cert = self_signed_cert.rustls_certificate(); + let tls_private_key = self_signed_cert.rustls_private_key(); + let mut tls_config = rustls::ServerConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap_or_else(|e| panic!("Failed to create TLS server config: {:?}", e)) + .with_no_client_auth() + .with_single_cert(vec![tls_cert], tls_private_key) + .unwrap_or_else(|e| panic!("Failed to create TLS server config: {:?}", e)); + tls_config.alpn_protocols = vec![b"h2".to_vec()]; + tls_config +} + +/// Create a TLS server config which requires mTLS, eg the client to also provide a cert and be +/// verified by the server based on the provided policy +pub fn create_rustls_server_config_with_client_verifier( private_key: Ed25519PrivateKey, server_name: String, allower: A, diff --git a/crates/sui-tls/src/verifier.rs b/crates/sui-tls/src/verifier.rs index b1e87fbf88823..6c69d2010fff4 100644 --- a/crates/sui-tls/src/verifier.rs +++ b/crates/sui-tls/src/verifier.rs @@ -88,7 +88,7 @@ impl ClientCertVerifier { let mut config = rustls::ServerConfig::builder_with_provider(Arc::new( rustls::crypto::ring::default_provider(), )) - .with_safe_default_protocol_versions()? + .with_protocol_versions(&[&rustls::version::TLS13])? .with_client_cert_verifier(std::sync::Arc::new(self)) .with_single_cert(certificates, private_key)?; config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; @@ -186,7 +186,7 @@ impl ServerCertVerifier { rustls::ClientConfig::builder_with_provider(Arc::new( rustls::crypto::ring::default_provider(), )) - .with_safe_default_protocol_versions()? + .with_protocol_versions(&[&rustls::version::TLS13])? .dangerous() .with_custom_certificate_verifier(std::sync::Arc::new(self)) .with_client_auth_cert(certificates, private_key) @@ -198,7 +198,7 @@ impl ServerCertVerifier { Ok(rustls::ClientConfig::builder_with_provider(Arc::new( rustls::crypto::ring::default_provider(), )) - .with_safe_default_protocol_versions()? + .with_protocol_versions(&[&rustls::version::TLS13])? .dangerous() .with_custom_certificate_verifier(std::sync::Arc::new(self)) .with_no_client_auth())