From 71ba574fa98d827a7416467b8d6a6c9f192a68a6 Mon Sep 17 00:00:00 2001 From: KOBAYASHI Kazuhiro Date: Tue, 23 Apr 2024 08:07:11 +0900 Subject: [PATCH 1/3] kble: wait for child process --- kble/src/app.rs | 98 ++++++++++++++++++++++++++++++++++++------------ kble/src/plug.rs | 39 ++++++++++++++----- 2 files changed, 103 insertions(+), 34 deletions(-) diff --git a/kble/src/app.rs b/kble/src/app.rs index e4651ec..01037c0 100644 --- a/kble/src/app.rs +++ b/kble/src/app.rs @@ -8,10 +8,16 @@ use std::collections::HashMap; use tokio::sync::broadcast; use tracing::{debug, warn}; +struct Connection { + backend: plug::Backend, + stream: Option, + sink: Option, +} + struct Connections<'a> { // Some: connections not used yet // None: connections is used in a link - map: HashMap<&'a str, (Option, Option)>, + map: HashMap<&'a str, Connection>, } struct Link<'a> { @@ -37,8 +43,10 @@ pub async fn run(config: &Config) -> Result<()> { let links = future::join_all(link_futs).await; let links = links.into_iter().chain(std::iter::once(terminated_link)); - let link_close_futs = future::try_join_all(links.map(|link| link.close())); - future::try_join(conns.close_all(), link_close_futs).await?; + for link in links { + conns.return_link(link); + } + conns.close_and_wait().await?; Ok(()) } @@ -50,30 +58,77 @@ impl<'a> Connections<'a> { } } - fn insert(&mut self, name: &'a str, stream: plug::PlugStream, sink: plug::PlugSink) { - self.map.insert(name, (Some(stream), Some(sink))); + fn insert( + &mut self, + name: &'a str, + backend: plug::Backend, + stream: plug::PlugStream, + sink: plug::PlugSink, + ) { + self.map.insert( + name, + Connection { + backend, + stream: Some(stream), + sink: Some(sink), + }, + ); + } + + fn return_link(&mut self, link: Link<'a>) { + let conn = self.map.get_mut(link.source_name).unwrap_or_else(|| { + panic!( + "tried to return a invalid link with source name {}", + link.source_name, + ) + }); + conn.stream = Some(link.source); + + let conn = self.map.get_mut(link.dest_name).unwrap_or_else(|| { + panic!( + "tried to return a invalid link with dest name {}", + link.dest_name, + ) + }); + conn.sink = Some(link.dest); } - // close all connections whose sink is not used in a link - async fn close_all(self) -> Result<()> { - let futs = self.map.into_iter().map(|(name, (_, sink))| async move { - if let Some(mut s) = sink { - debug!("Closing {name}"); - s.close().await?; - debug!("Closed {name}"); + // close all connections + // assume all links are returned + async fn close_and_wait(self) -> Result<()> { + let futs = self.map.into_iter().map(|(name, conn)| async move { + let fut = async { + if let Some(mut s) = conn.sink { + debug!("Closing {name}"); + s.close().await?; + debug!("Closed {name}"); + } + debug!("Waiting for plug {name} to exit"); + conn.backend.wait().await?; + debug!("Plug {name} exited"); + anyhow::Ok(()) + }; + let close_result = tokio::time::timeout(std::time::Duration::from_secs(10), fut).await; + + match close_result { + Ok(result) => result, + Err(_) => { + // abandon the connection + warn!("Plug {name} didn't exit in time"); + Ok(()) + } } - anyhow::Ok(()) }); future::try_join_all(futs).await?; Ok(()) } fn take_stream(&mut self, name: &str) -> Option { - self.map.get_mut(name)?.0.take() + self.map.get_mut(name)?.stream.take() } fn take_sink(&mut self, name: &str) -> Option { - self.map.get_mut(name)?.1.take() + self.map.get_mut(name)?.sink.take() } } @@ -87,16 +142,16 @@ async fn connect_to_plugs(config: &Config) -> Result { } }); - let (sink, stream) = match connect_result { + let (backend, sink, stream) = match connect_result { Ok(p) => p, Err(e) => { warn!("Error connecting to {name}: {e}"); - conns.close_all().await?; + conns.close_and_wait().await?; return Err(e); } }; debug!("Connected to {name}"); - conns.insert(name.as_str(), stream, sink); + conns.insert(name.as_str(), backend, stream, sink); } Ok(conns) } @@ -149,11 +204,4 @@ impl<'a> Link<'a> { } self } - - async fn close(mut self) -> Result<()> { - debug!("Closing {}", self.dest_name); - self.dest.close().await?; - debug!("Closed {}", self.dest_name); - Ok(()) - } } diff --git a/kble/src/plug.rs b/kble/src/plug.rs index 41ba096..d43c779 100644 --- a/kble/src/plug.rs +++ b/kble/src/plug.rs @@ -5,7 +5,7 @@ use futures::{future, stream, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use pin_project::pin_project; use tokio::{ io::{AsyncRead, AsyncWrite}, - process::{ChildStdin, ChildStdout}, + process::{Child, ChildStdin, ChildStdout}, }; use tokio_tungstenite::{ tungstenite::{protocol::Role, Message}, @@ -16,7 +16,26 @@ use url::Url; pub type PlugSink = Pin, Error = anyhow::Error> + Send + 'static>>; pub type PlugStream = Pin>> + Send + 'static>>; -pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> { +pub enum Backend { + WebSocketClient, + StdioProcess(Child), +} + +impl Backend { + pub async fn wait(self) -> Result<()> { + match self { + Backend::WebSocketClient => Ok(()), + Backend::StdioProcess(mut proc) => { + proc.wait() + .await + .with_context(|| format!("Failed to wait for {:?}", proc))?; + Ok(()) + } + } + } +} + +pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { match url.scheme() { "exec" => connect_exec(url).await, "ws" | "wss" => connect_ws(url).await, @@ -24,7 +43,7 @@ pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> { } } -async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> { +async fn connect_exec(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { assert_eq!(url.scheme(), "exec"); ensure!(url.username().is_empty()); ensure!(url.password().is_none()); @@ -32,18 +51,19 @@ async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> { ensure!(url.port().is_none()); ensure!(url.query().is_none()); ensure!(url.fragment().is_none()); - let proc = tokio::process::Command::new("sh") + let mut proc = tokio::process::Command::new("sh") .args(["-c", url.path()]) .stderr(Stdio::inherit()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn() .with_context(|| format!("Failed to spawn {}", url))?; - let stdin = proc.stdin.unwrap(); - let stdout = proc.stdout.unwrap(); + let stdin = proc.stdin.take().unwrap(); + let stdout = proc.stdout.take().unwrap(); let stdio = ChildStdio { stdin, stdout }; let wss = WebSocketStream::from_raw_socket(stdio, Role::Client, None).await; - Ok(wss_to_pair(wss)) + let (stream, sink) = wss_to_pair(wss); + Ok((Backend::StdioProcess(proc), stream, sink)) } #[pin_project] @@ -89,11 +109,12 @@ impl AsyncRead for ChildStdio { } } -async fn connect_ws(url: &Url) -> Result<(PlugSink, PlugStream)> { +async fn connect_ws(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { let (wss, _resp) = tokio_tungstenite::connect_async(url) .await .with_context(|| format!("Failed to connect to {}", url))?; - Ok(wss_to_pair(wss)) + let (stream, sink) = wss_to_pair(wss); + Ok((Backend::WebSocketClient, stream, sink)) } fn wss_to_pair(wss: WebSocketStream) -> (PlugSink, PlugStream) From 00a923ce29d1de3c17231b76ec245b4fae70c132 Mon Sep 17 00:00:00 2001 From: KOBAYASHI Kazuhiro Date: Thu, 25 Apr 2024 12:59:36 +0900 Subject: [PATCH 2/3] parametrize maximum wait time for child process --- kble/src/app.rs | 18 ++++++++++++------ kble/src/main.rs | 6 +++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/kble/src/app.rs b/kble/src/app.rs index 01037c0..7c165c7 100644 --- a/kble/src/app.rs +++ b/kble/src/app.rs @@ -18,6 +18,7 @@ struct Connections<'a> { // Some: connections not used yet // None: connections is used in a link map: HashMap<&'a str, Connection>, + max_child_wait_secs: u64, } struct Link<'a> { @@ -27,8 +28,8 @@ struct Link<'a> { dest: plug::PlugSink, } -pub async fn run(config: &Config) -> Result<()> { - let mut conns = connect_to_plugs(config).await?; +pub async fn run(config: &Config, max_child_wait_secs: u64) -> Result<()> { + let mut conns = connect_to_plugs(config, max_child_wait_secs).await?; let links = connect_links(&mut conns, config); let (quit_tx, _) = broadcast::channel(1); @@ -52,9 +53,10 @@ pub async fn run(config: &Config) -> Result<()> { } impl<'a> Connections<'a> { - fn new() -> Self { + fn new(max_child_wait_secs: u64) -> Self { Self { map: HashMap::new(), + max_child_wait_secs, } } @@ -108,7 +110,11 @@ impl<'a> Connections<'a> { debug!("Plug {name} exited"); anyhow::Ok(()) }; - let close_result = tokio::time::timeout(std::time::Duration::from_secs(10), fut).await; + let close_result = tokio::time::timeout( + std::time::Duration::from_secs(self.max_child_wait_secs), + fut, + ) + .await; match close_result { Ok(result) => result, @@ -132,8 +138,8 @@ impl<'a> Connections<'a> { } } -async fn connect_to_plugs(config: &Config) -> Result { - let mut conns = Connections::new(); +async fn connect_to_plugs(config: &Config, max_child_wait_secs: u64) -> Result { + let mut conns = Connections::new(max_child_wait_secs); for (name, url) in config.plugs().iter() { debug!("Connecting to {name}"); let connect_result = plug::connect(url).await.with_context(move || { diff --git a/kble/src/main.rs b/kble/src/main.rs index 09862eb..ae82365 100644 --- a/kble/src/main.rs +++ b/kble/src/main.rs @@ -16,6 +16,10 @@ use spaghetti::{Config, Raw}; struct Args { #[clap(long, short)] spaghetti: PathBuf, + + /// Maximum time to wait for a child process to exit after a closing handshake + #[clap(long, default_value_t = 10)] + max_child_wait_secs: u64, } impl Args { @@ -45,6 +49,6 @@ async fn main() -> Result<()> { let args = Args::parse_with_license_notice(include_notice!()); let config = args.load_spaghetti_config()?; - app::run(&config).await?; + app::run(&config, args.max_child_wait_secs).await?; Ok(()) } From f884ac5909fc52d1ae5d3f6a119a903c77095137 Mon Sep 17 00:00:00 2001 From: KOBAYASHI Kazuhiro Date: Tue, 30 Apr 2024 10:50:54 +0900 Subject: [PATCH 3/3] kble: kill after grace period --- kble/src/app.rs | 22 +++++++++++++--------- kble/src/main.rs | 7 ++++--- kble/src/plug.rs | 11 +++++++++-- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/kble/src/app.rs b/kble/src/app.rs index 7c165c7..b888b81 100644 --- a/kble/src/app.rs +++ b/kble/src/app.rs @@ -18,7 +18,7 @@ struct Connections<'a> { // Some: connections not used yet // None: connections is used in a link map: HashMap<&'a str, Connection>, - max_child_wait_secs: u64, + termination_grace_period_secs: u64, } struct Link<'a> { @@ -28,8 +28,8 @@ struct Link<'a> { dest: plug::PlugSink, } -pub async fn run(config: &Config, max_child_wait_secs: u64) -> Result<()> { - let mut conns = connect_to_plugs(config, max_child_wait_secs).await?; +pub async fn run(config: &Config, termination_grace_period_secs: u64) -> Result<()> { + let mut conns = connect_to_plugs(config, termination_grace_period_secs).await?; let links = connect_links(&mut conns, config); let (quit_tx, _) = broadcast::channel(1); @@ -53,10 +53,10 @@ pub async fn run(config: &Config, max_child_wait_secs: u64) -> Result<()> { } impl<'a> Connections<'a> { - fn new(max_child_wait_secs: u64) -> Self { + fn new(termination_grace_period_secs: u64) -> Self { Self { map: HashMap::new(), - max_child_wait_secs, + termination_grace_period_secs, } } @@ -98,7 +98,7 @@ impl<'a> Connections<'a> { // close all connections // assume all links are returned async fn close_and_wait(self) -> Result<()> { - let futs = self.map.into_iter().map(|(name, conn)| async move { + let futs = self.map.into_iter().map(|(name, mut conn)| async move { let fut = async { if let Some(mut s) = conn.sink { debug!("Closing {name}"); @@ -111,7 +111,7 @@ impl<'a> Connections<'a> { anyhow::Ok(()) }; let close_result = tokio::time::timeout( - std::time::Duration::from_secs(self.max_child_wait_secs), + std::time::Duration::from_secs(self.termination_grace_period_secs), fut, ) .await; @@ -121,6 +121,7 @@ impl<'a> Connections<'a> { Err(_) => { // abandon the connection warn!("Plug {name} didn't exit in time"); + conn.backend.kill().await?; Ok(()) } } @@ -138,8 +139,11 @@ impl<'a> Connections<'a> { } } -async fn connect_to_plugs(config: &Config, max_child_wait_secs: u64) -> Result { - let mut conns = Connections::new(max_child_wait_secs); +async fn connect_to_plugs( + config: &Config, + termination_grace_period_secs: u64, +) -> Result { + let mut conns = Connections::new(termination_grace_period_secs); for (name, url) in config.plugs().iter() { debug!("Connecting to {name}"); let connect_result = plug::connect(url).await.with_context(move || { diff --git a/kble/src/main.rs b/kble/src/main.rs index ae82365..4d74f82 100644 --- a/kble/src/main.rs +++ b/kble/src/main.rs @@ -17,9 +17,10 @@ struct Args { #[clap(long, short)] spaghetti: PathBuf, - /// Maximum time to wait for a child process to exit after a closing handshake + /// Period to wait for each child process to exit after a closing handshake + /// before killing it #[clap(long, default_value_t = 10)] - max_child_wait_secs: u64, + termination_grace_period_secs: u64, } impl Args { @@ -49,6 +50,6 @@ async fn main() -> Result<()> { let args = Args::parse_with_license_notice(include_notice!()); let config = args.load_spaghetti_config()?; - app::run(&config, args.max_child_wait_secs).await?; + app::run(&config, args.termination_grace_period_secs).await?; Ok(()) } diff --git a/kble/src/plug.rs b/kble/src/plug.rs index d43c779..c3a547f 100644 --- a/kble/src/plug.rs +++ b/kble/src/plug.rs @@ -22,10 +22,10 @@ pub enum Backend { } impl Backend { - pub async fn wait(self) -> Result<()> { + pub async fn wait(&mut self) -> Result<()> { match self { Backend::WebSocketClient => Ok(()), - Backend::StdioProcess(mut proc) => { + Backend::StdioProcess(proc) => { proc.wait() .await .with_context(|| format!("Failed to wait for {:?}", proc))?; @@ -33,6 +33,13 @@ impl Backend { } } } + + pub async fn kill(self) -> Result<()> { + match self { + Backend::WebSocketClient => Ok(()), + Backend::StdioProcess(mut proc) => proc.kill().await.map_err(Into::into), + } + } } pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {