Skip to content

Commit

Permalink
Merge pull request #82 from arkedge/wait_child_process
Browse files Browse the repository at this point in the history
kble: wait for child process
  • Loading branch information
KOBA789 authored May 9, 2024
2 parents 86eb08a + f884ac5 commit 60c349b
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 40 deletions.
118 changes: 88 additions & 30 deletions kble/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ use std::collections::HashMap;
use tokio::sync::broadcast;
use tracing::{debug, warn};

struct Connection {
backend: plug::Backend,
stream: Option<plug::PlugStream>,
sink: Option<plug::PlugSink>,
}

struct Connections<'a> {
// Some: connections not used yet
// None: connections is used in a link
map: HashMap<&'a str, (Option<plug::PlugStream>, Option<plug::PlugSink>)>,
map: HashMap<&'a str, Connection>,
termination_grace_period_secs: u64,
}

struct Link<'a> {
Expand All @@ -21,8 +28,8 @@ struct Link<'a> {
dest: plug::PlugSink,
}

pub async fn run(config: &Config) -> Result<()> {
let mut conns = connect_to_plugs(config).await?;
pub async fn run(config: &Config, termination_grace_period_secs: u64) -> Result<()> {
let mut conns = connect_to_plugs(config, termination_grace_period_secs).await?;
let links = connect_links(&mut conns, config);

let (quit_tx, _) = broadcast::channel(1);
Expand All @@ -37,48 +44,106 @@ pub async fn run(config: &Config) -> Result<()> {
let links = future::join_all(link_futs).await;
let links = links.into_iter().chain(std::iter::once(terminated_link));

let link_close_futs = future::try_join_all(links.map(|link| link.close()));
future::try_join(conns.close_all(), link_close_futs).await?;
for link in links {
conns.return_link(link);
}
conns.close_and_wait().await?;

Ok(())
}

impl<'a> Connections<'a> {
fn new() -> Self {
fn new(termination_grace_period_secs: u64) -> Self {
Self {
map: HashMap::new(),
termination_grace_period_secs,
}
}

fn insert(&mut self, name: &'a str, stream: plug::PlugStream, sink: plug::PlugSink) {
self.map.insert(name, (Some(stream), Some(sink)));
fn insert(
&mut self,
name: &'a str,
backend: plug::Backend,
stream: plug::PlugStream,
sink: plug::PlugSink,
) {
self.map.insert(
name,
Connection {
backend,
stream: Some(stream),
sink: Some(sink),
},
);
}

fn return_link(&mut self, link: Link<'a>) {
let conn = self.map.get_mut(link.source_name).unwrap_or_else(|| {
panic!(
"tried to return a invalid link with source name {}",
link.source_name,
)
});
conn.stream = Some(link.source);

let conn = self.map.get_mut(link.dest_name).unwrap_or_else(|| {
panic!(
"tried to return a invalid link with dest name {}",
link.dest_name,
)
});
conn.sink = Some(link.dest);
}

// close all connections whose sink is not used in a link
async fn close_all(self) -> Result<()> {
let futs = self.map.into_iter().map(|(name, (_, sink))| async move {
if let Some(mut s) = sink {
debug!("Closing {name}");
s.close().await?;
debug!("Closed {name}");
// close all connections
// assume all links are returned
async fn close_and_wait(self) -> Result<()> {
let futs = self.map.into_iter().map(|(name, mut conn)| async move {
let fut = async {
if let Some(mut s) = conn.sink {
debug!("Closing {name}");
s.close().await?;
debug!("Closed {name}");
}
debug!("Waiting for plug {name} to exit");
conn.backend.wait().await?;
debug!("Plug {name} exited");
anyhow::Ok(())
};
let close_result = tokio::time::timeout(
std::time::Duration::from_secs(self.termination_grace_period_secs),
fut,
)
.await;

match close_result {
Ok(result) => result,
Err(_) => {
// abandon the connection
warn!("Plug {name} didn't exit in time");
conn.backend.kill().await?;
Ok(())
}
}
anyhow::Ok(())
});
future::try_join_all(futs).await?;
Ok(())
}

fn take_stream(&mut self, name: &str) -> Option<plug::PlugStream> {
self.map.get_mut(name)?.0.take()
self.map.get_mut(name)?.stream.take()
}

fn take_sink(&mut self, name: &str) -> Option<plug::PlugSink> {
self.map.get_mut(name)?.1.take()
self.map.get_mut(name)?.sink.take()
}
}

async fn connect_to_plugs(config: &Config) -> Result<Connections> {
let mut conns = Connections::new();
async fn connect_to_plugs(
config: &Config,
termination_grace_period_secs: u64,
) -> Result<Connections> {
let mut conns = Connections::new(termination_grace_period_secs);
for (name, url) in config.plugs().iter() {
debug!("Connecting to {name}");
let connect_result = plug::connect(url).await.with_context(move || {
Expand All @@ -87,16 +152,16 @@ async fn connect_to_plugs(config: &Config) -> Result<Connections> {
}
});

let (sink, stream) = match connect_result {
let (backend, sink, stream) = match connect_result {
Ok(p) => p,
Err(e) => {
warn!("Error connecting to {name}: {e}");
conns.close_all().await?;
conns.close_and_wait().await?;
return Err(e);
}
};
debug!("Connected to {name}");
conns.insert(name.as_str(), stream, sink);
conns.insert(name.as_str(), backend, stream, sink);
}
Ok(conns)
}
Expand Down Expand Up @@ -149,11 +214,4 @@ impl<'a> Link<'a> {
}
self
}

async fn close(mut self) -> Result<()> {
debug!("Closing {}", self.dest_name);
self.dest.close().await?;
debug!("Closed {}", self.dest_name);
Ok(())
}
}
7 changes: 6 additions & 1 deletion kble/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ use spaghetti::{Config, Raw};
struct Args {
#[clap(long, short)]
spaghetti: PathBuf,

/// Period to wait for each child process to exit after a closing handshake
/// before killing it
#[clap(long, default_value_t = 10)]
termination_grace_period_secs: u64,
}

impl Args {
Expand Down Expand Up @@ -45,6 +50,6 @@ async fn main() -> Result<()> {

let args = Args::parse_with_license_notice(include_notice!());
let config = args.load_spaghetti_config()?;
app::run(&config).await?;
app::run(&config, args.termination_grace_period_secs).await?;
Ok(())
}
46 changes: 37 additions & 9 deletions kble/src/plug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures::{future, stream, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
process::{ChildStdin, ChildStdout},
process::{Child, ChildStdin, ChildStdout},
};
use tokio_tungstenite::{
tungstenite::{protocol::Role, Message},
Expand All @@ -16,34 +16,61 @@ use url::Url;
pub type PlugSink = Pin<Box<dyn Sink<Vec<u8>, Error = anyhow::Error> + Send + 'static>>;
pub type PlugStream = Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + Send + 'static>>;

pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> {
pub enum Backend {
WebSocketClient,
StdioProcess(Child),
}

impl Backend {
pub async fn wait(&mut self) -> Result<()> {
match self {
Backend::WebSocketClient => Ok(()),
Backend::StdioProcess(proc) => {
proc.wait()
.await
.with_context(|| format!("Failed to wait for {:?}", proc))?;
Ok(())
}
}
}

pub async fn kill(self) -> Result<()> {
match self {
Backend::WebSocketClient => Ok(()),
Backend::StdioProcess(mut proc) => proc.kill().await.map_err(Into::into),
}
}
}

pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {
match url.scheme() {
"exec" => connect_exec(url).await,
"ws" | "wss" => connect_ws(url).await,
_ => Err(anyhow!("Unsupported scheme: {}", url.scheme())),
}
}

async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> {
async fn connect_exec(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {
assert_eq!(url.scheme(), "exec");
ensure!(url.username().is_empty());
ensure!(url.password().is_none());
ensure!(url.host().is_none());
ensure!(url.port().is_none());
ensure!(url.query().is_none());
ensure!(url.fragment().is_none());
let proc = tokio::process::Command::new("sh")
let mut proc = tokio::process::Command::new("sh")
.args(["-c", url.path()])
.stderr(Stdio::inherit())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.with_context(|| format!("Failed to spawn {}", url))?;
let stdin = proc.stdin.unwrap();
let stdout = proc.stdout.unwrap();
let stdin = proc.stdin.take().unwrap();
let stdout = proc.stdout.take().unwrap();
let stdio = ChildStdio { stdin, stdout };
let wss = WebSocketStream::from_raw_socket(stdio, Role::Client, None).await;
Ok(wss_to_pair(wss))
let (stream, sink) = wss_to_pair(wss);
Ok((Backend::StdioProcess(proc), stream, sink))
}

#[pin_project]
Expand Down Expand Up @@ -89,11 +116,12 @@ impl AsyncRead for ChildStdio {
}
}

async fn connect_ws(url: &Url) -> Result<(PlugSink, PlugStream)> {
async fn connect_ws(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {
let (wss, _resp) = tokio_tungstenite::connect_async(url)
.await
.with_context(|| format!("Failed to connect to {}", url))?;
Ok(wss_to_pair(wss))
let (stream, sink) = wss_to_pair(wss);
Ok((Backend::WebSocketClient, stream, sink))
}

fn wss_to_pair<S>(wss: WebSocketStream<S>) -> (PlugSink, PlugStream)
Expand Down

0 comments on commit 60c349b

Please sign in to comment.