Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

簡単なリファクタ #74

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions kble/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ clap.workspace = true
serde.workspace = true
serde_yaml = "0.9"
serde_with = "3.7"
tracing-subscriber.workspace = true
tracing.workspace = true
notalawyer.workspace = true
notalawyer-clap.workspace = true
28 changes: 28 additions & 0 deletions kble/src/app.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::{plug, spaghetti::Config};
use anyhow::Result;
use futures::future;
use futures::StreamExt;
use std::collections::HashMap;

pub async fn run(config: &Config) -> Result<()> {
let mut sinks = HashMap::new();
let mut streams = HashMap::new();
for (name, url) in config.plugs().iter() {
let (sink, stream) = plug::connect(url).await?;
sinks.insert(name.as_str(), sink);
streams.insert(name.as_str(), stream);
}
let mut edges = vec![];
for (stream_name, sink_name) in config.links().iter() {
let Some(stream) = streams.remove(stream_name.as_str()) else {
unreachable!("No such plug: {stream_name}");
};
let Some(sink) = sinks.remove(sink_name.as_str()) else {
unreachable!("No such plug or already used: {sink_name}");
};
let edge = stream.forward(sink);
edges.push(edge);
}
future::try_join_all(edges).await?;
Ok(())
}
21 changes: 18 additions & 3 deletions kble/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ use std::path::PathBuf;
use anyhow::{Context, Result};
use clap::Parser;
use notalawyer_clap::*;
use tracing_subscriber::{prelude::*, EnvFilter};

mod app;
mod plug;
mod spaghetti;

use spaghetti::{Config, Raw};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Expand All @@ -21,15 +25,26 @@ impl Args {
.open(&self.spaghetti)
.with_context(|| format!("Failed to open {:?}", &self.spaghetti))?;
let spagetthi_rdr = std::io::BufReader::new(spaghetti_file);
serde_yaml::from_reader(spagetthi_rdr)
.with_context(|| format!("Unable to parse {:?}", self.spaghetti))
let raw: Config<Raw> = serde_yaml::from_reader(spagetthi_rdr)
.with_context(|| format!("Unable to parse {:?}", self.spaghetti))?;
raw.validate()
.with_context(|| format!("Invalid configuration in {:?}", self.spaghetti))
}
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(std::io::stderr),
)
.with(EnvFilter::from_default_env())
.init();

let args = Args::parse_with_license_notice(include_notice!());
let config = args.load_spaghetti_config()?;
config.run().await?;
app::run(&config).await?;
Ok(())
}
110 changes: 85 additions & 25 deletions kble/src/spaghetti.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,74 @@
use anyhow::{anyhow, Result};
use std::collections::HashMap;

use anyhow::{anyhow, Result};
use futures::{future, StreamExt};
use serde::{Deserialize, Serialize};
use url::Url;

use crate::plug;

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Config {
pub struct Inner {
plugs: HashMap<String, Url>,
links: HashMap<String, String>,
}

impl Config {
pub async fn run(&self) -> Result<()> {
let mut sinks = HashMap::new();
let mut streams = HashMap::new();
for (name, url) in self.plugs.iter() {
let (sink, stream) = plug::connect(url).await?;
sinks.insert(name.as_str(), sink);
streams.insert(name.as_str(), stream);
#[derive(PartialEq, Debug)]
pub enum Raw {}
pub enum Validated {}

#[derive(Serialize, Debug, Clone, PartialEq, Eq)]
pub struct Config<State = Validated> {
#[serde(flatten)]
inner: Inner,
state: std::marker::PhantomData<State>,
}

impl<'de> serde::Deserialize<'de> for Config<Raw> {
fn deserialize<D>(deserializer: D) -> Result<Config<Raw>, D::Error>
where
D: serde::Deserializer<'de>,
{
let inner = Inner::deserialize(deserializer)?;
Ok(Config::new(inner))
}
}

impl<State> Config<State> {
fn new(inner: Inner) -> Self {
Config {
inner,
state: std::marker::PhantomData,
}
let mut edges = vec![];
for (stream_name, sink_name) in self.links.iter() {
let Some(stream) = streams.remove(stream_name.as_str()) else {
}
}

impl Config<Raw> {
pub fn validate(self) -> Result<Config<Validated>> {
use std::collections::HashSet;
let mut seen_sinks = HashSet::new();

for (stream_name, sink_name) in self.inner.links.iter() {
if !self.inner.plugs.contains_key(stream_name) {
return Err(anyhow!("No such plug: {stream_name}"));
};
let Some(sink) = sinks.remove(sink_name.as_str()) else {
return Err(anyhow!("No such plug or already used: {sink_name}"));
};
let edge = stream.forward(sink);
edges.push(edge);
}
if !self.inner.plugs.contains_key(sink_name) {
return Err(anyhow!("No such plug: {sink_name}"));
}

if seen_sinks.contains(sink_name) {
return Err(anyhow!("Sink {sink_name} used more than once"));
}
seen_sinks.insert(sink_name);
}
future::try_join_all(edges).await?;
Ok(())
Ok(Config::new(self.inner))
}
}

impl Config<Validated> {
pub fn plugs(&self) -> &HashMap<String, Url> {
&self.inner.plugs
}

pub fn links(&self) -> &HashMap<String, String> {
&self.inner.links
}
}

Expand All @@ -47,7 +81,7 @@ mod tests {
#[test]
fn test_de() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: seriald\n";
let expected = Config {
let inner = Inner {
plugs: HashMap::from_iter([
("tfsync".to_string(), Url::parse("exec:tfsync foo").unwrap()),
(
Expand All @@ -57,7 +91,33 @@ mod tests {
]),
links: HashMap::from_iter([("tfsync".to_string(), "seriald".to_string())]),
};
let expected = Config {
inner,
state: std::marker::PhantomData,
};
let actual = serde_yaml::from_str(yaml).unwrap();
assert_eq!(expected, actual);
actual.validate().unwrap();
}

#[test]
fn test_de_invalid_dest() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: serialdxxxx\n";
let actual: Config<Raw> = serde_yaml::from_str(yaml).unwrap();
assert!(actual.validate().is_err());
}

#[test]
fn test_de_invalid_source() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsyncxxxx: seriald\n";
let actual: Config<Raw> = serde_yaml::from_str(yaml).unwrap();
assert!(actual.validate().is_err());
}

#[test]
fn test_de_duplicate_sink() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: seriald\n seriald: seriald\n";
let actual: Config<Raw> = serde_yaml::from_str(yaml).unwrap();
assert!(actual.validate().is_err());
}
}