diff --git a/kble/src/app.rs b/kble/src/app.rs index 7c165c7..b888b81 100644 --- a/kble/src/app.rs +++ b/kble/src/app.rs @@ -18,7 +18,7 @@ struct Connections<'a> { // Some: connections not used yet // None: connections is used in a link map: HashMap<&'a str, Connection>, - max_child_wait_secs: u64, + termination_grace_period_secs: u64, } struct Link<'a> { @@ -28,8 +28,8 @@ struct Link<'a> { dest: plug::PlugSink, } -pub async fn run(config: &Config, max_child_wait_secs: u64) -> Result<()> { - let mut conns = connect_to_plugs(config, max_child_wait_secs).await?; +pub async fn run(config: &Config, termination_grace_period_secs: u64) -> Result<()> { + let mut conns = connect_to_plugs(config, termination_grace_period_secs).await?; let links = connect_links(&mut conns, config); let (quit_tx, _) = broadcast::channel(1); @@ -53,10 +53,10 @@ pub async fn run(config: &Config, max_child_wait_secs: u64) -> Result<()> { } impl<'a> Connections<'a> { - fn new(max_child_wait_secs: u64) -> Self { + fn new(termination_grace_period_secs: u64) -> Self { Self { map: HashMap::new(), - max_child_wait_secs, + termination_grace_period_secs, } } @@ -98,7 +98,7 @@ impl<'a> Connections<'a> { // close all connections // assume all links are returned async fn close_and_wait(self) -> Result<()> { - let futs = self.map.into_iter().map(|(name, conn)| async move { + let futs = self.map.into_iter().map(|(name, mut conn)| async move { let fut = async { if let Some(mut s) = conn.sink { debug!("Closing {name}"); @@ -111,7 +111,7 @@ impl<'a> Connections<'a> { anyhow::Ok(()) }; let close_result = tokio::time::timeout( - std::time::Duration::from_secs(self.max_child_wait_secs), + std::time::Duration::from_secs(self.termination_grace_period_secs), fut, ) .await; @@ -121,6 +121,7 @@ impl<'a> Connections<'a> { Err(_) => { // abandon the connection warn!("Plug {name} didn't exit in time"); + conn.backend.kill().await?; Ok(()) } } @@ -138,8 +139,11 @@ impl<'a> Connections<'a> { } } -async fn connect_to_plugs(config: &Config, max_child_wait_secs: u64) -> Result { - let mut conns = Connections::new(max_child_wait_secs); +async fn connect_to_plugs( + config: &Config, + termination_grace_period_secs: u64, +) -> Result { + let mut conns = Connections::new(termination_grace_period_secs); for (name, url) in config.plugs().iter() { debug!("Connecting to {name}"); let connect_result = plug::connect(url).await.with_context(move || { diff --git a/kble/src/main.rs b/kble/src/main.rs index ae82365..4d74f82 100644 --- a/kble/src/main.rs +++ b/kble/src/main.rs @@ -17,9 +17,10 @@ struct Args { #[clap(long, short)] spaghetti: PathBuf, - /// Maximum time to wait for a child process to exit after a closing handshake + /// Period to wait for each child process to exit after a closing handshake + /// before killing it #[clap(long, default_value_t = 10)] - max_child_wait_secs: u64, + termination_grace_period_secs: u64, } impl Args { @@ -49,6 +50,6 @@ async fn main() -> Result<()> { let args = Args::parse_with_license_notice(include_notice!()); let config = args.load_spaghetti_config()?; - app::run(&config, args.max_child_wait_secs).await?; + app::run(&config, args.termination_grace_period_secs).await?; Ok(()) } diff --git a/kble/src/plug.rs b/kble/src/plug.rs index d43c779..c3a547f 100644 --- a/kble/src/plug.rs +++ b/kble/src/plug.rs @@ -22,10 +22,10 @@ pub enum Backend { } impl Backend { - pub async fn wait(self) -> Result<()> { + pub async fn wait(&mut self) -> Result<()> { match self { Backend::WebSocketClient => Ok(()), - Backend::StdioProcess(mut proc) => { + Backend::StdioProcess(proc) => { proc.wait() .await .with_context(|| format!("Failed to wait for {:?}", proc))?; @@ -33,6 +33,13 @@ impl Backend { } } } + + pub async fn kill(self) -> Result<()> { + match self { + Backend::WebSocketClient => Ok(()), + Backend::StdioProcess(mut proc) => proc.kill().await.map_err(Into::into), + } + } } pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {