diff --git a/Cargo.lock b/Cargo.lock index 18f05dc167fd1..6fca532f7f4f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8447,6 +8447,7 @@ dependencies = [ "prometheus", "serde", "snap", + "sui-http", "tokio", "tokio-rustls 0.26.0", "tokio-stream", diff --git a/crates/mysten-network/Cargo.toml b/crates/mysten-network/Cargo.toml index 0762b9210445f..5ee2bf60d84a6 100644 --- a/crates/mysten-network/Cargo.toml +++ b/crates/mysten-network/Cargo.toml @@ -33,3 +33,4 @@ tower.workspace = true tower-http.workspace = true pin-project-lite = "0.2.13" tracing.workspace = true +sui-http.workspace = true diff --git a/crates/mysten-network/src/config.rs b/crates/mysten-network/src/config.rs index eab88a024ec41..dbf339c47c376 100644 --- a/crates/mysten-network/src/config.rs +++ b/crates/mysten-network/src/config.rs @@ -12,7 +12,7 @@ use std::time::Duration; use tokio_rustls::rustls::ClientConfig; use tonic::transport::Channel; -#[derive(Debug, Default, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct Config { /// Set the concurrency limit applied to on requests inbound per connection. pub concurrency_limit_per_connection: Option, @@ -106,4 +106,15 @@ impl Config { ) -> Result { connect_lazy_with_config(addr, tls_config, self) } + + pub(crate) fn http_config(&self) -> sui_http::Config { + sui_http::Config::default() + .initial_stream_window_size(self.http2_initial_stream_window_size) + .initial_connection_window_size(self.http2_initial_connection_window_size) + .max_concurrent_streams(self.http2_max_concurrent_streams) + .http2_keepalive_timeout(self.http2_keepalive_timeout) + .http2_keepalive_interval(self.http2_keepalive_interval) + .tcp_keepalive(self.tcp_keepalive) + .tcp_nodelay(self.tcp_nodelay.unwrap_or_default()) + } } diff --git a/crates/mysten-network/src/server.rs b/crates/mysten-network/src/server.rs index d60a3b4637ebf..0f7ab47aeff69 100644 --- a/crates/mysten-network/src/server.rs +++ b/crates/mysten-network/src/server.rs @@ -1,143 +1,44 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 + use crate::metrics::{ DefaultMetricsCallbackProvider, MetricsCallbackProvider, MetricsHandler, GRPC_ENDPOINT_PATH_HEADER, }; use crate::{ config::Config, - multiaddr::{parse_dns, parse_ip4, parse_ip6, Multiaddr, Protocol}, + multiaddr::{Multiaddr, Protocol}, }; use eyre::{eyre, Result}; -use futures::stream::FuturesUnordered; -use futures::{FutureExt, Stream, StreamExt}; -use std::pin::Pin; -use std::sync::Arc; +use std::convert::Infallible; use std::task::{Context, Poll}; -use std::{convert::Infallible, net::SocketAddr}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_rustls::rustls::ServerConfig; -use tokio_rustls::{server::TlsStream, TlsAcceptor}; use tonic::codegen::http::HeaderValue; use tonic::{ body::BoxBody, - codegen::{ - http::{Request, Response}, - BoxFuture, - }, + codegen::http::{Request, Response}, server::NamedService, - transport::server::Router, -}; -use tower::{ - layer::util::{Identity, Stack}, - limit::GlobalConcurrencyLimitLayer, - load_shed::LoadShedLayer, - util::Either, - Layer, Service, ServiceBuilder, }; -use tower_http::classify::{GrpcErrorsAsFailures, SharedClassifier}; +use tower::{Layer, Service, ServiceBuilder}; use tower_http::propagate_header::PropagateHeaderLayer; use tower_http::set_header::SetRequestHeaderLayer; -use tower_http::trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer}; -use tracing::debug; +use tower_http::trace::TraceLayer; pub struct ServerBuilder { - router: Router>, + config: Config, + metrics_provider: M, + router: tonic::service::Routes, health_reporter: tonic_health::server::HealthReporter, } -type AddPathToHeaderFunction = fn(&Request) -> Option; - -type WrapperService = Stack< - Stack< - PropagateHeaderLayer, - Stack< - TraceLayer< - SharedClassifier, - DefaultMakeSpan, - MetricsHandler, - MetricsHandler, - DefaultOnBodyChunk, - DefaultOnEos, - MetricsHandler, - >, - Stack< - SetRequestHeaderLayer, - Stack< - RequestLifetimeLayer, - Stack< - Either, - Stack, Identity>, - >, - >, - >, - >, - >, - Identity, ->; - impl ServerBuilder { pub fn from_config(config: &Config, metrics_provider: M) -> Self { - let mut builder = tonic::transport::server::Server::builder(); - - if let Some(limit) = config.concurrency_limit_per_connection { - builder = builder.concurrency_limit_per_connection(limit); - } - - if let Some(timeout) = config.request_timeout { - builder = builder.timeout(timeout); - } - - if let Some(tcp_nodelay) = config.tcp_nodelay { - builder = builder.tcp_nodelay(tcp_nodelay); - } - - let load_shed = config - .load_shed - .unwrap_or_default() - .then_some(tower::load_shed::LoadShedLayer::new()); - - let metrics = MetricsHandler::new(metrics_provider.clone()); - - let request_metrics = TraceLayer::new_for_grpc() - .on_request(metrics.clone()) - .on_response(metrics.clone()) - .on_failure(metrics); - - let global_concurrency_limit = config - .global_concurrency_limit - .map(tower::limit::GlobalConcurrencyLimitLayer::new); - - fn add_path_to_request_header(request: &Request) -> Option { - let path = request.uri().path(); - Some(HeaderValue::from_str(path).unwrap()) - } - - let layer = ServiceBuilder::new() - .option_layer(global_concurrency_limit) - .option_layer(load_shed) - .layer(RequestLifetimeLayer { metrics_provider }) - .layer(SetRequestHeaderLayer::overriding( - GRPC_ENDPOINT_PATH_HEADER.clone(), - add_path_to_request_header as AddPathToHeaderFunction, - )) - .layer(request_metrics) - .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone())) - .into_inner(); - let (health_reporter, health_service) = tonic_health::server::health_reporter(); - let router = builder - .initial_stream_window_size(config.http2_initial_stream_window_size) - .initial_connection_window_size(config.http2_initial_connection_window_size) - .http2_keepalive_interval(config.http2_keepalive_interval) - .http2_keepalive_timeout(config.http2_keepalive_timeout) - .max_concurrent_streams(config.http2_max_concurrent_streams) - .tcp_keepalive(config.tcp_keepalive) - .layer(layer) - .add_service(health_service); + let router = tonic::service::Routes::new(health_service); Self { + config: config.to_owned(), + metrics_provider, router, health_reporter, } @@ -162,201 +63,79 @@ impl ServerBuilder { } pub async fn bind(self, addr: &Multiaddr, tls_config: Option) -> Result { - let mut iter = addr.iter(); - - let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel(); - let rx_cancellation = rx_cancellation.map(|_| ()); - let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) = match iter - .next() - .ok_or_else(|| eyre!("malformed addr"))? - { - Protocol::Dns(_) => { - let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?; - let (local_addr, incoming) = - listen_and_update_multiaddr(addr, (dns_name.to_string(), tcp_port), tls_config) - .await?; - let server = Box::pin( - self.router - .serve_with_incoming_shutdown(incoming, rx_cancellation), - ); - (local_addr, server) - } - Protocol::Ip4(_) => { - let (socket_addr, _http_or_https) = parse_ip4(addr)?; - let (local_addr, incoming) = - listen_and_update_multiaddr(addr, socket_addr, tls_config).await?; - let server = Box::pin( - self.router - .serve_with_incoming_shutdown(incoming, rx_cancellation), - ); - (local_addr, server) - } - Protocol::Ip6(_) => { - let (socket_addr, _http_or_https) = parse_ip6(addr)?; - let (local_addr, incoming) = - listen_and_update_multiaddr(addr, socket_addr, tls_config).await?; - let server = Box::pin( - self.router - .serve_with_incoming_shutdown(incoming, rx_cancellation), - ); - (local_addr, server) - } - unsupported => return Err(eyre!("unsupported protocol {unsupported}")), - }; - - Ok(Server { - server, - cancel_handle: Some(tx_cancellation), - local_addr, - health_reporter: self.health_reporter, - }) - } -} - -async fn listen_and_update_multiaddr( - address: &Multiaddr, - socket_addr: T, - tls_config: Option, -) -> Result<( - Multiaddr, - impl Stream>, -)> { - let listener = TcpListener::bind(socket_addr).await?; - let local_addr = listener.local_addr()?; - let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port()); - - let tls_acceptor = tls_config.map(|tls_config| TlsAcceptor::from(Arc::new(tls_config))); - let incoming = TcpOrTlsListener::new(listener, tls_acceptor); - let stream = async_stream::stream! { - let mut new_connections = FuturesUnordered::new(); - loop { - tokio::select! { - result = incoming.accept_raw() => { - match result { - Ok((stream, addr)) => { - new_connections.push(incoming.maybe_upgrade(stream, addr)); - } - Err(e) => yield Err(e), - } - } - Some(result) = new_connections.next() => { - yield result; - } - } - } - }; - - Ok((local_addr, stream)) -} - -pub struct TcpOrTlsListener { - listener: TcpListener, - tls_acceptor: Option, -} - -impl TcpOrTlsListener { - fn new(listener: TcpListener, tls_acceptor: Option) -> Self { - Self { - listener, - tls_acceptor, - } - } - - async fn accept_raw(&self) -> std::io::Result<(TcpStream, SocketAddr)> { - self.listener.accept().await - } - - async fn maybe_upgrade( - &self, - stream: TcpStream, - addr: SocketAddr, - ) -> std::io::Result { - if self.tls_acceptor.is_none() { - return Ok(TcpOrTlsStream::Tcp(stream, addr)); - } + let http_config = self + .config + .http_config() + // Temporarily continue allowing clients to connection without TLS even when the server + // is configured with a tls_config + .allow_insecure(true); + + let request_timeout = self.config.request_timeout; + let metrics_provider = self.metrics_provider; + let metrics = MetricsHandler::new(metrics_provider.clone()); + let request_metrics = TraceLayer::new_for_grpc() + .on_request(metrics.clone()) + .on_response(metrics.clone()) + .on_failure(metrics); - // Determine whether new connection is TLS. - let mut buf = [0; 1]; - // `peek` blocks until at least some data is available, so if there is no error then - // it must return the one byte we are requesting. - stream.peek(&mut buf).await?; - if buf[0] == 0x16 { - // First byte of a TLS handshake is 0x16. - debug!("accepting TLS connection from {addr:?}"); - let stream = self.tls_acceptor.as_ref().unwrap().accept(stream).await?; - Ok(TcpOrTlsStream::Tls(stream, addr)) - } else { - debug!("accepting TCP connection from {addr:?}"); - Ok(TcpOrTlsStream::Tcp(stream, addr)) + fn add_path_to_request_header(request: &Request) -> Option { + let path = request.uri().path(); + Some(HeaderValue::from_str(path).unwrap()) } - } -} - -pub enum TcpOrTlsStream { - Tcp(TcpStream, SocketAddr), - Tls(TlsStream, SocketAddr), -} -impl AsyncRead for TcpOrTlsStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf, - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_read(cx, buf), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_read(cx, buf), - } - } -} + let limiting_layers = ServiceBuilder::new() + .option_layer( + self.config + .load_shed + .unwrap_or_default() + .then_some(tower::load_shed::LoadShedLayer::new()), + ) + .option_layer( + self.config + .global_concurrency_limit + .map(tower::limit::GlobalConcurrencyLimitLayer::new), + ); + let route_layers = ServiceBuilder::new() + .map_request(|mut request: http::Request<_>| { + if let Some(connect_info) = request.extensions().get::() { + let tonic_connect_info = tonic::transport::server::TcpConnectInfo { + local_addr: Some(connect_info.local_addr), + remote_addr: Some(connect_info.remote_addr), + }; + request.extensions_mut().insert(tonic_connect_info); + } + request + }) + .layer(RequestLifetimeLayer { metrics_provider }) + .layer(SetRequestHeaderLayer::overriding( + GRPC_ENDPOINT_PATH_HEADER.clone(), + add_path_to_request_header, + )) + .layer(request_metrics) + .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone())) + .layer_fn(move |service| { + crate::grpc_timeout::GrpcTimeout::new(service, request_timeout) + }); -impl AsyncWrite for TcpOrTlsStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_write(cx, buf), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_write(cx, buf), - } - } + let mut builder = sui_http::Builder::new().config(http_config); - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_flush(cx), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_flush(cx), + if let Some(tls_config) = tls_config { + builder = builder.tls_config(tls_config); } - } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.get_mut() { - TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_shutdown(cx), - TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_shutdown(cx), - } - } -} + let server_handle = builder + .serve( + addr, + limiting_layers.service(self.router.into_axum_router().layer(route_layers)), + ) + .map_err(|e| eyre!(e))?; -impl tonic::transport::server::Connected for TcpOrTlsStream { - type ConnectInfo = tonic::transport::server::TcpConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - match self { - TcpOrTlsStream::Tcp(stream, addr) => Self::ConnectInfo { - local_addr: stream.local_addr().ok(), - remote_addr: Some(*addr), - }, - TcpOrTlsStream::Tls(stream, addr) => Self::ConnectInfo { - local_addr: stream.get_ref().0.local_addr().ok(), - remote_addr: Some(*addr), - }, - } + let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port()); + Ok(Server { + server: server_handle, + local_addr, + health_reporter: self.health_reporter, + }) } } @@ -364,15 +143,15 @@ impl tonic::transport::server::Connected for TcpOrTlsStream { pub const SUI_TLS_SERVER_NAME: &str = "sui"; pub struct Server { - server: BoxFuture<(), tonic::transport::Error>, - cancel_handle: Option>, + server: sui_http::ServerHandle, local_addr: Multiaddr, health_reporter: tonic_health::server::HealthReporter, } impl Server { pub async fn serve(self) -> Result<(), tonic::transport::Error> { - self.server.await + self.server.wait_for_shutdown().await; + Ok(()) } pub fn local_addr(&self) -> &Multiaddr { @@ -383,8 +162,8 @@ impl Server { self.health_reporter.clone() } - pub fn take_cancel_handle(&mut self) -> Option> { - self.cancel_handle.take() + pub fn handle(&self) -> &sui_http::ServerHandle { + &self.server } } @@ -458,15 +237,13 @@ mod test { let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap(); let config = Config::new(); - let mut server = config + let server = config .server_builder_with_metrics(metrics.clone()) .bind(&address, None) .await .unwrap(); let address = server.local_addr().to_owned(); - let cancel_handle = server.take_cancel_handle().unwrap(); - let server_handle = tokio::spawn(server.serve()); let channel = config.connect(&address, None).await.unwrap(); let mut client = HealthClient::new(channel); @@ -477,8 +254,7 @@ mod test { .await .unwrap(); - cancel_handle.send(()).unwrap(); - server_handle.await.unwrap().unwrap(); + server.server.shutdown().await; assert!(metrics.metrics_called.lock().unwrap().deref()); } @@ -521,15 +297,13 @@ mod test { let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap(); let config = Config::new(); - let mut server = config + let server = config .server_builder_with_metrics(metrics.clone()) .bind(&address, None) .await .unwrap(); let address = server.local_addr().to_owned(); - let cancel_handle = server.take_cancel_handle().unwrap(); - let server_handle = tokio::spawn(server.serve()); let channel = config.connect(&address, None).await.unwrap(); let mut client = HealthClient::new(channel); @@ -542,18 +316,15 @@ mod test { }) .await; - cancel_handle.send(()).unwrap(); - server_handle.await.unwrap().unwrap(); + server.server.shutdown().await; assert!(metrics.metrics_called.lock().unwrap().deref()); } async fn test_multiaddr(address: Multiaddr) { let config = Config::new(); - let mut server = config.server_builder().bind(&address, None).await.unwrap(); - let address = server.local_addr().to_owned(); - let cancel_handle = server.take_cancel_handle().unwrap(); - let server_handle = tokio::spawn(server.serve()); + let server_handle = config.server_builder().bind(&address, None).await.unwrap(); + let address = server_handle.local_addr().to_owned(); let channel = config.connect(&address, None).await.unwrap(); let mut client = HealthClient::new(channel); @@ -564,8 +335,7 @@ mod test { .await .unwrap(); - cancel_handle.send(()).unwrap(); - server_handle.await.unwrap().unwrap(); + server_handle.server.shutdown().await; } #[tokio::test] diff --git a/crates/sui-core/src/authority_server.rs b/crates/sui-core/src/authority_server.rs index 0f32457740d50..cf0d861691e82 100644 --- a/crates/sui-core/src/authority_server.rs +++ b/crates/sui-core/src/authority_server.rs @@ -44,7 +44,6 @@ use sui_types::{ }, }; use tap::TapFallible; -use tokio::task::JoinHandle; use tonic::metadata::{Ascii, MetadataValue}; use tracing::{error, error_span, info, Instrument}; @@ -72,32 +71,22 @@ use tonic::transport::server::TcpConnectInfo; mod server_tests; pub struct AuthorityServerHandle { - tx_cancellation: tokio::sync::oneshot::Sender<()>, - local_addr: Multiaddr, - handle: JoinHandle>, + server_handle: mysten_network::server::Server, } impl AuthorityServerHandle { pub async fn join(self) -> Result<(), io::Error> { - // Note that dropping `self.complete` would terminate the server. - self.handle - .await? - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + self.server_handle.handle().wait_for_shutdown().await; Ok(()) } pub async fn kill(self) -> Result<(), io::Error> { - self.tx_cancellation.send(()).map_err(|_e| { - io::Error::new(io::ErrorKind::Other, "could not send cancellation signal!") - })?; - self.handle - .await? - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + self.server_handle.handle().shutdown().await; Ok(()) } pub fn address(&self) -> &Multiaddr { - &self.local_addr + self.server_handle.local_addr() } } @@ -153,7 +142,7 @@ impl AuthorityServer { SUI_TLS_SERVER_NAME.to_string(), sui_tls::AllowAll, ); - let mut server = mysten_network::config::Config::new() + let server = mysten_network::config::Config::new() .server_builder() .add_service(ValidatorServer::new(ValidatorService::new_for_tests( self.state, @@ -166,9 +155,7 @@ impl AuthorityServer { let local_addr = server.local_addr().to_owned(); info!("Listening to traffic on {local_addr}"); let handle = AuthorityServerHandle { - tx_cancellation: server.take_cancel_handle().unwrap(), - local_addr, - handle: spawn_monitored_task!(server.serve()), + server_handle: server, }; Ok(handle) }