Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket接続のclosing handshakeを行う #78

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 152 additions & 21 deletions kble/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,159 @@
use crate::{plug, spaghetti::Config};
use anyhow::Result;
use futures::future;
use futures::StreamExt;
use crate::{
plug,
spaghetti::{Config, Validated},
};
use anyhow::{Context, Result};
use futures::{future, SinkExt, StreamExt};
use std::collections::HashMap;
use tokio::sync::broadcast;
use tracing::{debug, warn};

struct Connections<'a> {
// Some: connections not used yet
// None: connections is used in a link
map: HashMap<&'a str, (Option<plug::PlugStream>, Option<plug::PlugSink>)>,
}

struct Link<'a> {
source_name: &'a str,
dest_name: &'a str,
source: plug::PlugStream,
dest: plug::PlugSink,
}

pub async fn run(config: &Config) -> Result<()> {
let mut sinks = HashMap::new();
let mut streams = HashMap::new();
let mut conns = connect_to_plugs(config).await?;
let links = connect_links(&mut conns, config);

let (quit_tx, _) = broadcast::channel(1);
let link_futs = links.map(|link| {
let quit_rx = quit_tx.subscribe();
let fut = link.forward(quit_rx);
Box::pin(fut)
});

let (terminated_link, _, link_futs) = futures::future::select_all(link_futs).await;
quit_tx.send(())?;
let links = future::join_all(link_futs).await;
let links = links.into_iter().chain(std::iter::once(terminated_link));

let link_close_futs = future::try_join_all(links.map(|link| link.close()));
future::try_join(conns.close_all(), link_close_futs).await?;

Ok(())
}

impl<'a> Connections<'a> {
fn new() -> Self {
Self {
map: HashMap::new(),
}
}

fn insert(&mut self, name: &'a str, stream: plug::PlugStream, sink: plug::PlugSink) {
self.map.insert(name, (Some(stream), Some(sink)));
}

// close all connections whose sink is not used in a link
async fn close_all(self) -> Result<()> {
let futs = self.map.into_iter().map(|(name, (_, sink))| async move {
if let Some(mut s) = sink {
debug!("Closing {name}");
s.close().await?;
debug!("Closed {name}");
}
anyhow::Ok(())
});
future::try_join_all(futs).await?;
Ok(())
}

fn take_stream(&mut self, name: &str) -> Option<plug::PlugStream> {
self.map.get_mut(name)?.0.take()
}

fn take_sink(&mut self, name: &str) -> Option<plug::PlugSink> {
self.map.get_mut(name)?.1.take()
}
}

async fn connect_to_plugs(config: &Config) -> Result<Connections> {
let mut conns = Connections::new();
for (name, url) in config.plugs().iter() {
let (sink, stream) = plug::connect(url).await?;
sinks.insert(name.as_str(), sink);
streams.insert(name.as_str(), stream);
}
let mut edges = vec![];
for (stream_name, sink_name) in config.links().iter() {
let Some(stream) = streams.remove(stream_name.as_str()) else {
unreachable!("No such plug: {stream_name}");
};
let Some(sink) = sinks.remove(sink_name.as_str()) else {
unreachable!("No such plug or already used: {sink_name}");
debug!("Connecting to {name}");
let connect_result = plug::connect(url).await.with_context(move || {
format! {
"Failed to connect to plug `{name}`"
}
});

let (sink, stream) = match connect_result {
Ok(p) => p,
Err(e) => {
warn!("Error connecting to {name}: {e}");
conns.close_all().await?;
return Err(e);
}
};
let edge = stream.forward(sink);
edges.push(edge);
debug!("Connected to {name}");
conns.insert(name.as_str(), stream, sink);
}
Ok(conns)
}

fn connect_links<'a, 'conns>(
conns: &'conns mut Connections<'a>,
config: &'a Config<Validated>, // emphasize that the config is validated
) -> impl Iterator<Item = Link<'a>> + 'conns {
config.links().iter().map(|(source_name, dest_name)| {
// Those panics shouldn't happen if config is valid and conns is properly initialized
let source = conns.take_stream(source_name).unwrap_or_else(|| {
panic!("stream not found: {source_name}");
});
let dest = conns.take_sink(dest_name).unwrap_or_else(|| {
panic!("sink not found: {dest_name}");
});

Link {
source_name,
dest_name,
source,
dest,
}
})
}

impl<'a> Link<'a> {
async fn forward(mut self, mut quit_rx: broadcast::Receiver<()>) -> Self {
loop {
let recv_result = tokio::select! {
_ = quit_rx.recv() => break,
recv_result = self.source.next() => match recv_result {
Some(data) => data,
None => break,
},
};

let data = match recv_result {
Err(e) => {
warn!("Error reading from {}: {}", self.source_name, e);
break;
}
Ok(data) => data,
};

if let Err(e) = self.dest.send(data).await {
warn!("Error writing to {}: {}", self.dest_name, e);
break;
}
}
self
}

async fn close(mut self) -> Result<()> {
debug!("Closing {}", self.dest_name);
self.dest.close().await?;
debug!("Closed {}", self.dest_name);
Ok(())
}
future::try_join_all(edges).await?;
Ok(())
}