diff --git a/Cargo.lock b/Cargo.lock index 216d91c..d9024e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -770,6 +770,8 @@ dependencies = [ "serde_yaml", "tokio", "tokio-tungstenite", + "tracing", + "tracing-subscriber", "url", ] diff --git a/kble/Cargo.toml b/kble/Cargo.toml index 2912729..e1d2c86 100644 --- a/kble/Cargo.toml +++ b/kble/Cargo.toml @@ -24,5 +24,7 @@ clap.workspace = true serde.workspace = true serde_yaml = "0.9" serde_with = "3.7" +tracing-subscriber.workspace = true +tracing.workspace = true notalawyer.workspace = true notalawyer-clap.workspace = true diff --git a/kble/src/app.rs b/kble/src/app.rs new file mode 100644 index 0000000..0fa07e0 --- /dev/null +++ b/kble/src/app.rs @@ -0,0 +1,28 @@ +use crate::{plug, spaghetti::Config}; +use anyhow::Result; +use futures::future; +use futures::StreamExt; +use std::collections::HashMap; + +pub async fn run(config: &Config) -> Result<()> { + let mut sinks = HashMap::new(); + let mut streams = HashMap::new(); + for (name, url) in config.plugs().iter() { + let (sink, stream) = plug::connect(url).await?; + sinks.insert(name.as_str(), sink); + streams.insert(name.as_str(), stream); + } + let mut edges = vec![]; + for (stream_name, sink_name) in config.links().iter() { + let Some(stream) = streams.remove(stream_name.as_str()) else { + unreachable!("No such plug: {stream_name}"); + }; + let Some(sink) = sinks.remove(sink_name.as_str()) else { + unreachable!("No such plug or already used: {sink_name}"); + }; + let edge = stream.forward(sink); + edges.push(edge); + } + future::try_join_all(edges).await?; + Ok(()) +} diff --git a/kble/src/main.rs b/kble/src/main.rs index 1fda032..09862eb 100644 --- a/kble/src/main.rs +++ b/kble/src/main.rs @@ -3,10 +3,14 @@ use std::path::PathBuf; use anyhow::{Context, Result}; use clap::Parser; use notalawyer_clap::*; +use tracing_subscriber::{prelude::*, EnvFilter}; +mod app; mod plug; mod spaghetti; +use spaghetti::{Config, Raw}; + #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { @@ -21,15 +25,26 @@ impl Args { .open(&self.spaghetti) .with_context(|| format!("Failed to open {:?}", &self.spaghetti))?; let spagetthi_rdr = std::io::BufReader::new(spaghetti_file); - serde_yaml::from_reader(spagetthi_rdr) - .with_context(|| format!("Unable to parse {:?}", self.spaghetti)) + let raw: Config = serde_yaml::from_reader(spagetthi_rdr) + .with_context(|| format!("Unable to parse {:?}", self.spaghetti))?; + raw.validate() + .with_context(|| format!("Invalid configuration in {:?}", self.spaghetti)) } } #[tokio::main] async fn main() -> Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(std::io::stderr), + ) + .with(EnvFilter::from_default_env()) + .init(); + let args = Args::parse_with_license_notice(include_notice!()); let config = args.load_spaghetti_config()?; - config.run().await?; + app::run(&config).await?; Ok(()) } diff --git a/kble/src/spaghetti.rs b/kble/src/spaghetti.rs index bddabf5..15c8fa3 100644 --- a/kble/src/spaghetti.rs +++ b/kble/src/spaghetti.rs @@ -1,40 +1,74 @@ +use anyhow::{anyhow, Result}; use std::collections::HashMap; -use anyhow::{anyhow, Result}; -use futures::{future, StreamExt}; use serde::{Deserialize, Serialize}; use url::Url; -use crate::plug; - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct Config { +pub struct Inner { plugs: HashMap, links: HashMap, } -impl Config { - pub async fn run(&self) -> Result<()> { - let mut sinks = HashMap::new(); - let mut streams = HashMap::new(); - for (name, url) in self.plugs.iter() { - let (sink, stream) = plug::connect(url).await?; - sinks.insert(name.as_str(), sink); - streams.insert(name.as_str(), stream); +#[derive(PartialEq, Debug)] +pub enum Raw {} +pub enum Validated {} + +#[derive(Serialize, Debug, Clone, PartialEq, Eq)] +pub struct Config { + #[serde(flatten)] + inner: Inner, + state: std::marker::PhantomData, +} + +impl<'de> serde::Deserialize<'de> for Config { + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + let inner = Inner::deserialize(deserializer)?; + Ok(Config::new(inner)) + } +} + +impl Config { + fn new(inner: Inner) -> Self { + Config { + inner, + state: std::marker::PhantomData, } - let mut edges = vec![]; - for (stream_name, sink_name) in self.links.iter() { - let Some(stream) = streams.remove(stream_name.as_str()) else { + } +} + +impl Config { + pub fn validate(self) -> Result> { + use std::collections::HashSet; + let mut seen_sinks = HashSet::new(); + + for (stream_name, sink_name) in self.inner.links.iter() { + if !self.inner.plugs.contains_key(stream_name) { return Err(anyhow!("No such plug: {stream_name}")); - }; - let Some(sink) = sinks.remove(sink_name.as_str()) else { - return Err(anyhow!("No such plug or already used: {sink_name}")); - }; - let edge = stream.forward(sink); - edges.push(edge); + } + if !self.inner.plugs.contains_key(sink_name) { + return Err(anyhow!("No such plug: {sink_name}")); + } + + if seen_sinks.contains(sink_name) { + return Err(anyhow!("Sink {sink_name} used more than once")); + } + seen_sinks.insert(sink_name); } - future::try_join_all(edges).await?; - Ok(()) + Ok(Config::new(self.inner)) + } +} + +impl Config { + pub fn plugs(&self) -> &HashMap { + &self.inner.plugs + } + + pub fn links(&self) -> &HashMap { + &self.inner.links } } @@ -47,7 +81,7 @@ mod tests { #[test] fn test_de() { let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: seriald\n"; - let expected = Config { + let inner = Inner { plugs: HashMap::from_iter([ ("tfsync".to_string(), Url::parse("exec:tfsync foo").unwrap()), ( @@ -57,7 +91,33 @@ mod tests { ]), links: HashMap::from_iter([("tfsync".to_string(), "seriald".to_string())]), }; + let expected = Config { + inner, + state: std::marker::PhantomData, + }; let actual = serde_yaml::from_str(yaml).unwrap(); assert_eq!(expected, actual); + actual.validate().unwrap(); + } + + #[test] + fn test_de_invalid_dest() { + let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: serialdxxxx\n"; + let actual: Config = serde_yaml::from_str(yaml).unwrap(); + assert!(actual.validate().is_err()); + } + + #[test] + fn test_de_invalid_source() { + let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsyncxxxx: seriald\n"; + let actual: Config = serde_yaml::from_str(yaml).unwrap(); + assert!(actual.validate().is_err()); + } + + #[test] + fn test_de_duplicate_sink() { + let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: seriald\n seriald: seriald\n"; + let actual: Config = serde_yaml::from_str(yaml).unwrap(); + assert!(actual.validate().is_err()); } }