Skip to content

Commit

Permalink
identity: implemented google id flow
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Sep 24, 2024
1 parent fda7a8c commit 01c8ced
Show file tree
Hide file tree
Showing 12 changed files with 891 additions and 50 deletions.
494 changes: 462 additions & 32 deletions Cargo.lock

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ members = [
"crates/did-chain",
"crates/did-simple",
"crates/egui-picking",
"crates/header-parsing",
"crates/picking-xr",
"crates/replicate/client",
"crates/replicate/common",
Expand All @@ -29,8 +30,11 @@ edition = "2021"
rust-version = "1.78.0"

[workspace.dependencies]
arc-swap = "1.7.1"
async-compat = "0.2.4"
axum = "0.7.5"
axum-extra = "0.9.3"
axum-macros = "0.4.1"
base64 = "0.21.7"
bevy = { version = "0.13", features = ["serialize"] }
bevy-inspector-egui = "0.23.4"
Expand All @@ -54,7 +58,9 @@ egui = "0.26"
egui-picking = { path = "crates/egui-picking" }
eyre = "0.6"
futures = "0.3.30"
header-parsing.path = "crates/header-parsing"
hex-literal = "0.4.1"
http = "1.1.0"
http-body-util = "0.1.2"
jose-jwk = { version = "0.1.2", default-features = false }
lightyear = "0.12"
Expand All @@ -67,6 +73,7 @@ rand_xoshiro = "0.6.0"
random-number = "0.1.8"
replicate-client.path = "crates/replicate/client"
replicate-server.path = "crates/replicate/server"
reqwest = { version = "0.12.7", features = ["rustls-tls"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.114"
slotmap = "1.0.7"
Expand All @@ -85,15 +92,8 @@ uuid = "1.7.0"
wtransport = "0.1.13"

[workspace.dependencies.derive_more]
version = "0.99"
version = "1.0.0"
default-features = false
features = [
"add",
"deref",
"deref_mut",
"mul",
"from",
]

[workspace.dependencies.opus]
git = "https://github.com/Schmarni-Dev/opus-rs"
Expand Down
11 changes: 9 additions & 2 deletions apps/identity_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,26 @@ description = "Self-custodial identity using did:web"
publish = false

[dependencies]
axum.workspace = true
arc-swap.workspace = true
axum = { workspace = true, features = [] }
axum-extra = { workspace = true, features = ["cookie"] }
axum-macros.workspace = true
clap = { workspace = true, features = ["derive", "env"] }
color-eyre.workspace = true
derive_more = { workspace = true, features = ["debug"] }
did-simple.workspace = true
header-parsing.workspace = true
http-body-util.workspace = true
jose-jwk = { workspace = true, default-features = false }
jsonwebtoken = { version = "9.3.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
serde.workspace = true
serde_json.workspace = true
sqlx = { version = "0.8.0", features = ["runtime-tokio", "tls-rustls", "sqlite", "uuid", "migrate"] }
thiserror.workspace = true
tokio = { workspace = true, features = ["full"] }
tower-http = { workspace = true, features = ["trace"] }
tower-http = { workspace = true, features = ["trace", "fs"] }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tracing.workspace = true
uuid = { workspace = true, features = ["std", "v4", "serde"] }
Expand Down
141 changes: 141 additions & 0 deletions apps/identity_server/src/google_jwks_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use std::{sync::Arc, time::Duration};

use arc_swap::ArcSwap;
use axum::async_trait;
use color_eyre::{eyre::WrapErr as _, Result, Section};
use jsonwebtoken::jwk::JwkSet;
use reqwest::Url;
use tracing::info;

/// Retrieves the latest JWKs for an external service.
///
/// Example: This can be used to get the JWKs from google, located at
/// <https://www.googleapis.com/oauth2/v3/certs>
///
/// This provider exists to support mocking of the external interface, for the purposes
/// of testing.
#[derive(Debug)]
pub struct JwksProvider {
#[cfg(not(test))]
provider: HttpProvider,
#[cfg(test)]
provider: Box<dyn JwksProviderT>,
}

impl JwksProvider {
pub fn google(client: reqwest::Client) -> Self {
Self {
#[cfg(not(test))]
provider: HttpProvider::google(client),
#[cfg(test)]
provider: Box::new(HttpProvider::google(client)),
}
}
pub async fn get(&self) -> Result<Arc<CachedJwks>> {
self.provider.get().await
}
}

#[async_trait]
trait JwksProviderT: std::fmt::Debug + Send + Sync + 'static {
/// Gets the latest JWKS for google.
async fn get(&self) -> Result<Arc<CachedJwks>>;
}

#[derive(Debug, Eq, PartialEq)]
pub struct CachedJwks {
jwks: JwkSet,
expires_at: std::time::Instant,
}

impl CachedJwks {
/// Creates an empty set of JWKs, which is already expired.
fn new_expired() -> Self {
let now = std::time::Instant::now();
let expires_at = now.checked_sub(Duration::from_secs(1)).unwrap_or(now);
Self {
jwks: JwkSet { keys: vec![] },
expires_at,
}
}

pub fn jwks(&self) -> &JwkSet {
&self.jwks
}

fn is_expired(&self) -> bool {
self.expires_at <= std::time::Instant::now()
}
}

/// Uses http to retrieve the JWKs.
#[derive(Debug)]
struct HttpProvider {
url: Url,
client: reqwest::Client,
cached_jwks: ArcSwap<CachedJwks>,
}

impl HttpProvider {
/// Creates a provider that requests the JWKS over HTTP from google's url.
pub fn google(client: reqwest::Client) -> Self {
// Creates immediately expired empty keyset
Self {
client,
url: "https://www.googleapis.com/oauth2/v3/certs"
.try_into()
.unwrap(),
cached_jwks: ArcSwap::new(Arc::new(CachedJwks::new_expired())),
}
}
}

#[async_trait]
impl JwksProviderT for HttpProvider {
/// Usually this is instantly ready with the JWKS, but if the cached value doesn't
/// exist
/// or is out of date, it will await on the new value.
async fn get(&self) -> Result<Arc<CachedJwks>> {
let cached_jwks = self.cached_jwks.load();
if !cached_jwks.is_expired() {
return Ok(cached_jwks.to_owned());
}
let response = self
.client
.get(self.url.clone())
.send()
.await
.wrap_err("failed to initiate get request for certs")
.with_note(|| format!("url was {}", self.url))?;
let expires_at = {
if let Some(duration) =
header_parsing::time_until_max_age(response.headers())
{
std::time::Instant::now() + duration
} else {
std::time::Instant::now()
}
};
let serialized_keys = response
.bytes()
.await
.wrap_err("failed to get response body")?;
let jwks: JwkSet = serde_json::from_slice(&serialized_keys)
.wrap_err("unexpected response, expected a JWKS")?;
let cached_jwks = Arc::new(CachedJwks { jwks, expires_at });
self.cached_jwks.store(Arc::clone(&cached_jwks));
info!("cached JWKs: {cached_jwks:?}");
Ok(cached_jwks)
}
}

/// Always provides the same JWKs.
#[derive(Debug, Clone)]
struct StaticProvider(Arc<CachedJwks>);

#[async_trait]
impl JwksProviderT for StaticProvider {
async fn get(&self) -> Result<Arc<CachedJwks>> {
Ok(Arc::clone(&self.0))
}
}
11 changes: 11 additions & 0 deletions apps/identity_server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod google_jwks_provider;
pub mod jwk;
pub mod oauth;
pub mod v1;

mod uuid;
Expand Down Expand Up @@ -28,6 +30,7 @@ impl MigratedDbPool {
#[derive(Debug)]
pub struct RouterConfig {
pub v1: crate::v1::RouterConfig,
pub oauth: crate::oauth::OAuthConfig,
}

impl RouterConfig {
Expand All @@ -37,9 +40,17 @@ impl RouterConfig {
.build()
.await
.wrap_err("failed to build v1 router")?;

let oauth = self
.oauth
.build()
.await
.wrap_err("failed to build oauth router")?;

Ok(axum::Router::new()
.route("/", get(root))
.nest("/api/v1", v1)
.nest("/oauth2", oauth)
.layer(TraceLayer::new_for_http()))
}
}
Expand Down
22 changes: 17 additions & 5 deletions apps/identity_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::net::{Ipv6Addr, SocketAddr};

use clap::Parser as _;
use color_eyre::eyre::Context as _;
use identity_server::MigratedDbPool;
use identity_server::{google_jwks_provider::JwksProvider, MigratedDbPool};
use std::path::PathBuf;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
Expand All @@ -13,6 +13,10 @@ struct Cli {
port: u16,
#[clap(long, env, default_value = "identities.db")]
db_path: PathBuf,
/// The Google API OAuth2 Client ID.
/// See https://developers.google.com/identity/gsi/web/guides/get-google-api-clientid
#[clap(long, env)]
google_client_id: String,
}

#[tokio::main]
Expand Down Expand Up @@ -43,15 +47,23 @@ async fn main() -> color_eyre::Result<()> {
.await
.wrap_err("failed to migrate db pool")?
};
let reqwest_client = reqwest::Client::new();

let v1_cfg = identity_server::v1::RouterConfig {
uuid_provider: Default::default(),
db_pool,
};
let router = identity_server::RouterConfig { v1: v1_cfg }
.build()
.await
.wrap_err("failed to build router")?;
let oauth_cfg = identity_server::oauth::OAuthConfig {
google_client_id: cli.google_client_id,
google_jwks_provider: JwksProvider::google(reqwest_client.clone()),
};
let router = identity_server::RouterConfig {
v1: v1_cfg,
oauth: oauth_cfg,
}
.build()
.await
.wrap_err("failed to build router")?;

let listener = tokio::net::TcpListener::bind(SocketAddr::new(
Ipv6Addr::UNSPECIFIED.into(),
Expand Down
Loading

0 comments on commit 01c8ced

Please sign in to comment.