Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract common abstractions #244

Merged
merged 18 commits into from
Nov 21, 2024
19 changes: 19 additions & 0 deletions .github/workflows/common.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Common
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
env:
CARGO_TERM_COLOR: always
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build
run: |
cargo build -p atrium-common --verbose
- name: Run tests
run: |
cargo test -p atrium-common --lib
1 change: 1 addition & 0 deletions .github/workflows/wasm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ jobs:
- run: wasm-pack test --node atrium-xrpc
- run: wasm-pack test --node atrium-xrpc-client
- run: wasm-pack test --node atrium-oauth/identity
- run: wasm-pack test --node atrium-common
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = [
"atrium-api",
"atrium-common",
"atrium-crypto",
"atrium-xrpc",
"atrium-xrpc-client",
Expand All @@ -26,6 +27,7 @@ keywords = ["atproto", "bluesky"]
[workspace.dependencies]
# Intra-workspace dependencies
atrium-api = { version = "0.24.8", path = "atrium-api", default-features = false }
atrium-common = { version = "0.1.0", path = "atrium-common" }
atrium-identity = { version = "0.1.0", path = "atrium-oauth/identity" }
atrium-xrpc = { version = "0.12.0", path = "atrium-xrpc" }
atrium-xrpc-client = { version = "0.5.10", path = "atrium-xrpc-client" }
Expand Down
36 changes: 36 additions & 0 deletions atrium-common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[package]
name = "atrium-common"
version = "0.1.0"
authors = ["sugyan <[email protected]>", "avdb13 <[email protected]>"]
edition.workspace = true
rust-version.workspace = true
description = "Utility library for common abstractions in atproto"
documentation = "https://docs.rs/atrium-common"
readme = "README.md"
repository.workspace = true
license.workspace = true
keywords = ["atproto", "bluesky"]

[dependencies]
dashmap.workspace = true
thiserror.workspace = true
tokio = { workspace = true, default-features = false, features = ["sync"] }
trait-variant.workspace = true

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
moka = { workspace = true, features = ["future"] }

[target.'cfg(target_arch = "wasm32")'.dependencies]
lru.workspace = true
web-time.workspace = true

[dev-dependencies]
futures.workspace = true

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
gloo-timers.workspace = true
tokio = { workspace = true, features = ["time"] }
wasm-bindgen-test.workspace = true
3 changes: 3 additions & 0 deletions atrium-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod resolver;
pub mod store;
pub mod types;
222 changes: 222 additions & 0 deletions atrium-common/src/resolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
mod cached;
mod throttled;

pub use self::cached::CachedResolver;
pub use self::throttled::ThrottledResolver;
use std::future::Future;

#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait Resolver {
type Input: ?Sized;
type Output;
type Error;

fn resolve(
&self,
input: &Self::Input,
) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>;
}

#[cfg(test)]
mod tests {
use super::*;
use crate::types::cached::r#impl::{Cache, CacheImpl};
use crate::types::cached::{CacheConfig, Cacheable};
use crate::types::throttled::Throttleable;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;

#[cfg(not(target_arch = "wasm32"))]
async fn sleep(duration: Duration) {
tokio::time::sleep(duration).await;
}

#[cfg(target_arch = "wasm32")]
async fn sleep(duration: Duration) {
gloo_timers::future::sleep(duration).await;
}

#[derive(Debug, PartialEq)]
struct Error;

type Result<T> = core::result::Result<T, Error>;

struct MockResolver {
data: HashMap<String, String>,
counts: Arc<RwLock<HashMap<String, usize>>>,
}

impl Resolver for MockResolver {
type Input = String;
type Output = String;
type Error = Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
sleep(Duration::from_millis(10)).await;
*self.counts.write().await.entry(input.clone()).or_default() += 1;
if let Some(value) = self.data.get(input) {
Ok(value.clone())
} else {
Err(Error)
}
}
}

fn mock_resolver(counts: Arc<RwLock<HashMap<String, usize>>>) -> MockResolver {
MockResolver {
data: [
(String::from("k1"), String::from("v1")),
(String::from("k2"), String::from("v2")),
]
.into_iter()
.collect(),
counts,
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_no_cached() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone());
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),]
.into_iter()
.collect()
);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_cached() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default()));
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),]
.into_iter()
.collect()
);
}

#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_cached_with_max_capacity() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone())
.cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() }));
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),]
.into_iter()
.collect()
);
}

#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_cached_with_time_to_live() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig {
time_to_live: Some(Duration::from_millis(10)),
..Default::default()
}));
for _ in 0..10 {
let result = resolver.resolve(&String::from("k1")).await;
assert_eq!(result.expect("failed to resolve"), "v1");
}
sleep(Duration::from_millis(10)).await;
for _ in 0..10 {
let result = resolver.resolve(&String::from("k1")).await;
assert_eq!(result.expect("failed to resolve"), "v1");
}
assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect());
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_throttled() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = Arc::new(mock_resolver(counts.clone()).throttled());

let mut handles = Vec::new();
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let resolver = resolver.clone();
handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) });
}
for (result, expected) in futures::future::join_all(handles).await {
let result = result.and_then(|opt| opt.ok_or(Error));

match expected {
Some(value) => {
assert_eq!(result.expect("failed to resolve"), value)
}
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),]
.into_iter()
.collect()
);
}
}
31 changes: 31 additions & 0 deletions atrium-common/src/resolver/cached.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::hash::Hash;

use crate::types::cached::r#impl::{Cache, CacheImpl};
use crate::types::cached::Cached;

use super::Resolver;

pub type CachedResolver<R> = Cached<R, CacheImpl<<R as Resolver>::Input, <R as Resolver>::Output>>;

impl<R, C> Resolver for Cached<R, C>
where
R: Resolver + Send + Sync + 'static,
R::Input: Clone + Hash + Eq + Send + Sync + 'static,
R::Output: Clone + Send + Sync + 'static,
C: Cache<Input = R::Input, Output = R::Output> + Send + Sync + 'static,
C::Input: Clone + Hash + Eq + Send + Sync + 'static,
C::Output: Clone + Send + Sync + 'static,
{
type Input = R::Input;
type Output = R::Output;
type Error = R::Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
if let Some(output) = self.cache.get(input).await {
return Ok(output);
}
let output = self.inner.resolve(input).await?;
self.cache.set(input.clone(), output.clone()).await;
Ok(output)
}
}
43 changes: 43 additions & 0 deletions atrium-common/src/resolver/throttled.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::{hash::Hash, sync::Arc};

use dashmap::{DashMap, Entry};
use tokio::sync::broadcast::{channel, Sender};
use tokio::sync::Mutex;

use crate::types::throttled::Throttled;

use super::Resolver;

pub type SenderMap<R> =
DashMap<<R as Resolver>::Input, Arc<Mutex<Sender<Option<<R as Resolver>::Output>>>>>;

pub type ThrottledResolver<R> = Throttled<R, SenderMap<R>>;

impl<R> Resolver for Throttled<R, SenderMap<R>>
where
R: Resolver + Send + Sync + 'static,
R::Input: Clone + Hash + Eq + Send + Sync + 'static,
R::Output: Clone + Send + Sync + 'static,
{
type Input = R::Input;
type Output = Option<R::Output>;
type Error = R::Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
match self.pending.entry(input.clone()) {
Entry::Occupied(occupied) => {
let tx = occupied.get().lock().await.clone();
drop(occupied);
Ok(tx.subscribe().recv().await.expect("recv"))
}
Entry::Vacant(vacant) => {
let (tx, _) = channel(1);
vacant.insert(Arc::new(Mutex::new(tx.clone())));
let result = self.inner.resolve(input).await;
tx.send(result.as_ref().ok().cloned()).ok();
self.pending.remove(input);
result.map(Some)
}
}
}
}
Loading
Loading