diff --git a/kble/src/app.rs b/kble/src/app.rs index e4651ec..b888b81 100644 --- a/kble/src/app.rs +++ b/kble/src/app.rs @@ -8,10 +8,17 @@ use std::collections::HashMap; use tokio::sync::broadcast; use tracing::{debug, warn}; +struct Connection { + backend: plug::Backend, + stream: Option, + sink: Option, +} + struct Connections<'a> { // Some: connections not used yet // None: connections is used in a link - map: HashMap<&'a str, (Option, Option)>, + map: HashMap<&'a str, Connection>, + termination_grace_period_secs: u64, } struct Link<'a> { @@ -21,8 +28,8 @@ struct Link<'a> { dest: plug::PlugSink, } -pub async fn run(config: &Config) -> Result<()> { - let mut conns = connect_to_plugs(config).await?; +pub async fn run(config: &Config, termination_grace_period_secs: u64) -> Result<()> { + let mut conns = connect_to_plugs(config, termination_grace_period_secs).await?; let links = connect_links(&mut conns, config); let (quit_tx, _) = broadcast::channel(1); @@ -37,48 +44,106 @@ pub async fn run(config: &Config) -> Result<()> { let links = future::join_all(link_futs).await; let links = links.into_iter().chain(std::iter::once(terminated_link)); - let link_close_futs = future::try_join_all(links.map(|link| link.close())); - future::try_join(conns.close_all(), link_close_futs).await?; + for link in links { + conns.return_link(link); + } + conns.close_and_wait().await?; Ok(()) } impl<'a> Connections<'a> { - fn new() -> Self { + fn new(termination_grace_period_secs: u64) -> Self { Self { map: HashMap::new(), + termination_grace_period_secs, } } - fn insert(&mut self, name: &'a str, stream: plug::PlugStream, sink: plug::PlugSink) { - self.map.insert(name, (Some(stream), Some(sink))); + fn insert( + &mut self, + name: &'a str, + backend: plug::Backend, + stream: plug::PlugStream, + sink: plug::PlugSink, + ) { + self.map.insert( + name, + Connection { + backend, + stream: Some(stream), + sink: Some(sink), + }, + ); + } + + fn return_link(&mut self, link: Link<'a>) { + let conn = self.map.get_mut(link.source_name).unwrap_or_else(|| { + panic!( + "tried to return a invalid link with source name {}", + link.source_name, + ) + }); + conn.stream = Some(link.source); + + let conn = self.map.get_mut(link.dest_name).unwrap_or_else(|| { + panic!( + "tried to return a invalid link with dest name {}", + link.dest_name, + ) + }); + conn.sink = Some(link.dest); } - // close all connections whose sink is not used in a link - async fn close_all(self) -> Result<()> { - let futs = self.map.into_iter().map(|(name, (_, sink))| async move { - if let Some(mut s) = sink { - debug!("Closing {name}"); - s.close().await?; - debug!("Closed {name}"); + // close all connections + // assume all links are returned + async fn close_and_wait(self) -> Result<()> { + let futs = self.map.into_iter().map(|(name, mut conn)| async move { + let fut = async { + if let Some(mut s) = conn.sink { + debug!("Closing {name}"); + s.close().await?; + debug!("Closed {name}"); + } + debug!("Waiting for plug {name} to exit"); + conn.backend.wait().await?; + debug!("Plug {name} exited"); + anyhow::Ok(()) + }; + let close_result = tokio::time::timeout( + std::time::Duration::from_secs(self.termination_grace_period_secs), + fut, + ) + .await; + + match close_result { + Ok(result) => result, + Err(_) => { + // abandon the connection + warn!("Plug {name} didn't exit in time"); + conn.backend.kill().await?; + Ok(()) + } } - anyhow::Ok(()) }); future::try_join_all(futs).await?; Ok(()) } fn take_stream(&mut self, name: &str) -> Option { - self.map.get_mut(name)?.0.take() + self.map.get_mut(name)?.stream.take() } fn take_sink(&mut self, name: &str) -> Option { - self.map.get_mut(name)?.1.take() + self.map.get_mut(name)?.sink.take() } } -async fn connect_to_plugs(config: &Config) -> Result { - let mut conns = Connections::new(); +async fn connect_to_plugs( + config: &Config, + termination_grace_period_secs: u64, +) -> Result { + let mut conns = Connections::new(termination_grace_period_secs); for (name, url) in config.plugs().iter() { debug!("Connecting to {name}"); let connect_result = plug::connect(url).await.with_context(move || { @@ -87,16 +152,16 @@ async fn connect_to_plugs(config: &Config) -> Result { } }); - let (sink, stream) = match connect_result { + let (backend, sink, stream) = match connect_result { Ok(p) => p, Err(e) => { warn!("Error connecting to {name}: {e}"); - conns.close_all().await?; + conns.close_and_wait().await?; return Err(e); } }; debug!("Connected to {name}"); - conns.insert(name.as_str(), stream, sink); + conns.insert(name.as_str(), backend, stream, sink); } Ok(conns) } @@ -149,11 +214,4 @@ impl<'a> Link<'a> { } self } - - async fn close(mut self) -> Result<()> { - debug!("Closing {}", self.dest_name); - self.dest.close().await?; - debug!("Closed {}", self.dest_name); - Ok(()) - } } diff --git a/kble/src/main.rs b/kble/src/main.rs index 09862eb..4d74f82 100644 --- a/kble/src/main.rs +++ b/kble/src/main.rs @@ -16,6 +16,11 @@ use spaghetti::{Config, Raw}; struct Args { #[clap(long, short)] spaghetti: PathBuf, + + /// Period to wait for each child process to exit after a closing handshake + /// before killing it + #[clap(long, default_value_t = 10)] + termination_grace_period_secs: u64, } impl Args { @@ -45,6 +50,6 @@ async fn main() -> Result<()> { let args = Args::parse_with_license_notice(include_notice!()); let config = args.load_spaghetti_config()?; - app::run(&config).await?; + app::run(&config, args.termination_grace_period_secs).await?; Ok(()) } diff --git a/kble/src/plug.rs b/kble/src/plug.rs index 41ba096..c3a547f 100644 --- a/kble/src/plug.rs +++ b/kble/src/plug.rs @@ -5,7 +5,7 @@ use futures::{future, stream, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use pin_project::pin_project; use tokio::{ io::{AsyncRead, AsyncWrite}, - process::{ChildStdin, ChildStdout}, + process::{Child, ChildStdin, ChildStdout}, }; use tokio_tungstenite::{ tungstenite::{protocol::Role, Message}, @@ -16,7 +16,33 @@ use url::Url; pub type PlugSink = Pin, Error = anyhow::Error> + Send + 'static>>; pub type PlugStream = Pin>> + Send + 'static>>; -pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> { +pub enum Backend { + WebSocketClient, + StdioProcess(Child), +} + +impl Backend { + pub async fn wait(&mut self) -> Result<()> { + match self { + Backend::WebSocketClient => Ok(()), + Backend::StdioProcess(proc) => { + proc.wait() + .await + .with_context(|| format!("Failed to wait for {:?}", proc))?; + Ok(()) + } + } + } + + pub async fn kill(self) -> Result<()> { + match self { + Backend::WebSocketClient => Ok(()), + Backend::StdioProcess(mut proc) => proc.kill().await.map_err(Into::into), + } + } +} + +pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { match url.scheme() { "exec" => connect_exec(url).await, "ws" | "wss" => connect_ws(url).await, @@ -24,7 +50,7 @@ pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> { } } -async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> { +async fn connect_exec(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { assert_eq!(url.scheme(), "exec"); ensure!(url.username().is_empty()); ensure!(url.password().is_none()); @@ -32,18 +58,19 @@ async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> { ensure!(url.port().is_none()); ensure!(url.query().is_none()); ensure!(url.fragment().is_none()); - let proc = tokio::process::Command::new("sh") + let mut proc = tokio::process::Command::new("sh") .args(["-c", url.path()]) .stderr(Stdio::inherit()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn() .with_context(|| format!("Failed to spawn {}", url))?; - let stdin = proc.stdin.unwrap(); - let stdout = proc.stdout.unwrap(); + let stdin = proc.stdin.take().unwrap(); + let stdout = proc.stdout.take().unwrap(); let stdio = ChildStdio { stdin, stdout }; let wss = WebSocketStream::from_raw_socket(stdio, Role::Client, None).await; - Ok(wss_to_pair(wss)) + let (stream, sink) = wss_to_pair(wss); + Ok((Backend::StdioProcess(proc), stream, sink)) } #[pin_project] @@ -89,11 +116,12 @@ impl AsyncRead for ChildStdio { } } -async fn connect_ws(url: &Url) -> Result<(PlugSink, PlugStream)> { +async fn connect_ws(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { let (wss, _resp) = tokio_tungstenite::connect_async(url) .await .with_context(|| format!("Failed to connect to {}", url))?; - Ok(wss_to_pair(wss)) + let (stream, sink) = wss_to_pair(wss); + Ok((Backend::WebSocketClient, stream, sink)) } fn wss_to_pair(wss: WebSocketStream) -> (PlugSink, PlugStream)