Skip to content

Commit

Permalink
feat: Update oauth_client (#254)
Browse files Browse the repository at this point in the history
* Update oauth_client

* Update Dpop client, add parameters for refresh token
  • Loading branch information
sugyan authored Nov 21, 2024
1 parent de18c4f commit 1a92b8a
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 54 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions atrium-oauth/oauth-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ trait-variant.workspace = true

[dev-dependencies]
hickory-resolver.workspace = true
p256 = { workspace = true, features = ["pem"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }

[features]
Expand Down
15 changes: 11 additions & 4 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_P
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::state::MemoryStateStore;
use atrium_oauth_client::{
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, OAuthClient,
OAuthClientConfig, OAuthResolverConfig,
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient,
OAuthClientConfig, OAuthResolverConfig, Scope,
};
use atrium_xrpc::http::Uri;
use hickory_resolver::TokioAsyncResolver;
Expand Down Expand Up @@ -37,7 +37,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let http_client = Arc::new(DefaultHttpClient::default());
let config = OAuthClientConfig {
client_metadata: AtprotoLocalhostClientMetadata {
redirect_uris: vec!["http://127.0.0.1".to_string()],
redirect_uris: Some(vec![String::from("http://127.0.0.1/callback")]),
scopes: Some(vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric),
]),
},
keys: None,
resolver: OAuthResolverConfig {
Expand All @@ -61,7 +65,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.authorize(
std::env::var("HANDLE").unwrap_or(String::from("https://bsky.social")),
AuthorizeOptions {
scopes: Some(vec![String::from("atproto")]),
scopes: vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric)
],
..Default::default()
}
)
Expand Down
266 changes: 250 additions & 16 deletions atrium-oauth/oauth-client/src/atproto.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::keyset::Keyset;
use crate::types::{OAuthClientMetadata, TryIntoOAuthClientMetadata};
use atrium_xrpc::http::Uri;
use atrium_xrpc::http::uri::{InvalidUri, Scheme, Uri};
use serde::{Deserialize, Serialize};
use thiserror::Error;

Expand All @@ -18,6 +18,22 @@ pub enum Error {
EmptyJwks,
#[error("`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided")]
AuthSigningAlg,
#[error(transparent)]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error(transparent)]
LocalhostClient(#[from] LocalhostClientError),
}

#[derive(Error, Debug)]
pub enum LocalhostClientError {
#[error("invalid redirect_uri: {0}")]
Invalid(#[from] InvalidUri),
#[error("loopback client_id must use `http:` redirect_uri")]
NotHttpScheme,
#[error("loopback client_id must not use `localhost` as redirect_uri hostname")]
Localhost,
#[error("loopback client_id must not use loopback addresses as redirect_uri")]
NotLoopbackHost,
}

pub type Result<T> = core::result::Result<T, Error>;
Expand Down Expand Up @@ -56,22 +72,37 @@ impl From<GrantType> for String {
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
pub enum Scope {
Known(KnownScope),
Unknown(String),
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum KnownScope {
#[serde(rename = "atproto")]
Atproto,
#[serde(rename = "transition:generic")]
TransitionGeneric,
#[serde(rename = "transition:chat.bsky")]
TransitionChatBsky,
}

impl From<Scope> for String {
fn from(value: Scope) -> Self {
match value {
Scope::Atproto => String::from("atproto"),
impl AsRef<str> for Scope {
fn as_ref(&self) -> &str {
match self {
Self::Known(KnownScope::Atproto) => "atproto",
Self::Known(KnownScope::TransitionGeneric) => "transition:generic",
Self::Known(KnownScope::TransitionChatBsky) => "transition:chat.bsky",
Self::Unknown(value) => value,
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AtprotoLocalhostClientMetadata {
pub redirect_uris: Vec<String>,
pub redirect_uris: Option<Vec<String>>,
pub scopes: Option<Vec<Scope>>,
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand All @@ -90,14 +121,46 @@ impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata {
type Error = Error;

fn try_into_client_metadata(self, _: &Option<Keyset>) -> Result<OAuthClientMetadata> {
if self.redirect_uris.is_empty() {
return Err(Error::EmptyRedirectUris);
// validate redirect_uris
if let Some(redirect_uris) = &self.redirect_uris {
for redirect_uri in redirect_uris {
let uri = redirect_uri.parse::<Uri>().map_err(LocalhostClientError::Invalid)?;
if uri.scheme() != Some(&Scheme::HTTP) {
return Err(Error::LocalhostClient(LocalhostClientError::NotHttpScheme));
}
if uri.host() == Some("localhost") {
return Err(Error::LocalhostClient(LocalhostClientError::Localhost));
}
if uri.host().map_or(true, |host| host != "127.0.0.1" && host != "[::1]") {
return Err(Error::LocalhostClient(LocalhostClientError::NotLoopbackHost));
}
}
}
// determine client_id
#[derive(serde::Serialize)]
struct Parameters {
#[serde(skip_serializing_if = "Option::is_none")]
redirect_uri: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
}
let query = serde_html_form::to_string(Parameters {
redirect_uri: self.redirect_uris.clone(),
scope: self
.scopes
.map(|scopes| scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
})?;
let mut client_id = String::from("http://localhost");
if !query.is_empty() {
client_id.push_str(&format!("?{query}"));
}
Ok(OAuthClientMetadata {
client_id: String::from("http://localhost"),
client_id,
client_uri: None,
redirect_uris: self.redirect_uris,
scope: None, // will be set to `atproto`
redirect_uris: self
.redirect_uris
.unwrap_or(vec![String::from("http://127.0.0.1/"), String::from("http://[::1]/")]),
scope: None,
grant_types: None, // will be set to `authorization_code` and `refresh_token`
token_endpoint_auth_method: Some(String::from("none")),
dpop_bound_access_tokens: None, // will be set to `true`
Expand All @@ -121,7 +184,7 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
if !self.grant_types.contains(&GrantType::AuthorizationCode) {
return Err(Error::InvalidGrantTypes);
}
if !self.scopes.contains(&Scope::Atproto) {
if !self.scopes.contains(&Scope::Known(KnownScope::Atproto)) {
return Err(Error::InvalidScope);
}
let (jwks_uri, mut jwks) = (self.jwks_uri, None);
Expand Down Expand Up @@ -150,13 +213,184 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
redirect_uris: self.redirect_uris,
token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()),
grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()),
scope: Some(
self.scopes.into_iter().map(|v| v.into()).collect::<Vec<String>>().join(" "),
),
scope: Some(self.scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
dpop_bound_access_tokens: Some(true),
jwks_uri,
jwks,
token_endpoint_auth_signing_alg: self.token_endpoint_auth_signing_alg,
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use elliptic_curve::SecretKey;
use jose_jwk::{Jwk, Key, Parameters};
use p256::pkcs8::DecodePrivateKey;

const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T
4i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P
gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3
-----END PRIVATE KEY-----"#;

#[test]
fn test_localhost_client_metadata_default() {
let metadata = AtprotoLocalhostClientMetadata::default();
assert_eq!(
metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: String::from("http://localhost"),
client_uri: None,
redirect_uris: vec![
String::from("http://127.0.0.1/"),
String::from("http://[::1]/"),
],
scope: None,
grant_types: None,
token_endpoint_auth_method: Some(AuthMethod::None.into()),
dpop_bound_access_tokens: None,
jwks_uri: None,
jwks: None,
token_endpoint_auth_signing_alg: None,
}
);
}

#[test]
fn test_localhost_client_metadata_custom() {
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![
String::from("http://127.0.0.1/callback"),
String::from("http://[::1]/callback"),
]),
scopes: Some(vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric),
Scope::Unknown(String::from("unknown")),
]),
};
assert_eq!(
metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: String::from("http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=atproto+transition%3Ageneric+unknown"),
client_uri: None,
redirect_uris: vec![
String::from("http://127.0.0.1/callback"),
String::from("http://[::1]/callback"),
],
scope: None,
grant_types: None,
token_endpoint_auth_method: Some(AuthMethod::None.into()),
dpop_bound_access_tokens: None,
jwks_uri: None,
jwks: None,
token_endpoint_auth_signing_alg: None,
}
);
}

#[test]
fn test_localhost_client_metadata_invalid() {
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("http://")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Invalid(_))));
}
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("https://127.0.0.1/")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotHttpScheme)));
}
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("http://localhost:8000/")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Localhost)));
}
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("http://192.168.0.0/")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)));
}
}

#[test]
fn test_client_metadata() {
let metadata = AtprotoClientMetadata {
client_id: String::from("https://example.com/client_metadata.json"),
client_uri: String::from("https://example.com"),
redirect_uris: vec![String::from("https://example.com/callback")],
token_endpoint_auth_method: AuthMethod::PrivateKeyJwt,
grant_types: vec![GrantType::AuthorizationCode],
scopes: vec![Scope::Known(KnownScope::Atproto)],
jwks_uri: None,
token_endpoint_auth_signing_alg: Some(String::from("ES256")),
};
{
let metadata = metadata.clone();
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::EmptyJwks));
}
{
let metadata = metadata.clone();
let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY)
.expect("failed to parse private key");
let keys = vec![Jwk {
key: Key::from(&secret_key.into()),
prm: Parameters { kid: Some(String::from("kid00")), ..Default::default() },
}];
let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset");
assert_eq!(
metadata
.try_into_client_metadata(&Some(keyset.clone()))
.expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: String::from("https://example.com/client_metadata.json"),
client_uri: Some(String::from("https://example.com")),
redirect_uris: vec![String::from("https://example.com/callback"),],
scope: Some(String::from("atproto")),
grant_types: Some(vec![String::from("authorization_code")]),
token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()),
dpop_bound_access_tokens: Some(true),
jwks_uri: None,
jwks: Some(keyset.public_jwks()),
token_endpoint_auth_signing_alg: Some(String::from("ES256")),
}
);
}
}

#[test]
fn test_scope_serde() {
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Scopes {
scopes: Vec<Scope>,
}

let scopes = Scopes {
scopes: vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric),
Scope::Unknown(String::from("unknown")),
],
};
let json = serde_json::to_string(&scopes).expect("failed to serialize scopes");
assert_eq!(json, r#"{"scopes":["atproto","transition:generic","unknown"]}"#);
let deserialized =
serde_json::from_str::<Scopes>(&json).expect("failed to deserialize scopes");
assert_eq!(deserialized, scopes);
}
}
Loading

0 comments on commit 1a92b8a

Please sign in to comment.