From 1a92b8a9bc1edde00ade3e381fa4f07fae5a84a4 Mon Sep 17 00:00:00 2001 From: Yoshihiro Sugi Date: Thu, 21 Nov 2024 16:14:35 +0900 Subject: [PATCH] feat: Update oauth_client (#254) * Update oauth_client * Update Dpop client, add parameters for refresh token --- Cargo.lock | 11 + atrium-oauth/oauth-client/Cargo.toml | 1 + atrium-oauth/oauth-client/examples/main.rs | 15 +- atrium-oauth/oauth-client/src/atproto.rs | 266 ++++++++++++++++-- .../oauth-client/src/http_client/dpop.rs | 69 +++-- atrium-oauth/oauth-client/src/lib.rs | 2 +- atrium-oauth/oauth-client/src/oauth_client.rs | 6 +- atrium-oauth/oauth-client/src/server_agent.rs | 14 +- atrium-oauth/oauth-client/src/types.rs | 14 +- .../oauth-client/src/types/request.rs | 10 + 10 files changed, 354 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5dd1b42c..0fde4d43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -717,6 +717,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" dependencies = [ "const-oid", + "pem-rfc7468", "zeroize", ] @@ -779,6 +780,7 @@ dependencies = [ "ff", "generic-array", "group", + "pem-rfc7468", "pkcs8", "rand_core", "sec1", @@ -1840,6 +1842,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" diff --git a/atrium-oauth/oauth-client/Cargo.toml b/atrium-oauth/oauth-client/Cargo.toml index 02596f59..8920ccfc 100644 --- a/atrium-oauth/oauth-client/Cargo.toml +++ b/atrium-oauth/oauth-client/Cargo.toml @@ -36,6 +36,7 @@ trait-variant.workspace = true [dev-dependencies] hickory-resolver.workspace = true +p256 = { workspace = true, features = ["pem"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [features] diff --git a/atrium-oauth/oauth-client/examples/main.rs b/atrium-oauth/oauth-client/examples/main.rs index 40a91a9e..ee211fc4 100644 --- a/atrium-oauth/oauth-client/examples/main.rs +++ b/atrium-oauth/oauth-client/examples/main.rs @@ -2,8 +2,8 @@ use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_P use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}; use atrium_oauth_client::store::state::MemoryStateStore; use atrium_oauth_client::{ - AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, OAuthClient, - OAuthClientConfig, OAuthResolverConfig, + AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient, + OAuthClientConfig, OAuthResolverConfig, Scope, }; use atrium_xrpc::http::Uri; use hickory_resolver::TokioAsyncResolver; @@ -37,7 +37,11 @@ async fn main() -> Result<(), Box> { let http_client = Arc::new(DefaultHttpClient::default()); let config = OAuthClientConfig { client_metadata: AtprotoLocalhostClientMetadata { - redirect_uris: vec!["http://127.0.0.1".to_string()], + redirect_uris: Some(vec![String::from("http://127.0.0.1/callback")]), + scopes: Some(vec![ + Scope::Known(KnownScope::Atproto), + Scope::Known(KnownScope::TransitionGeneric), + ]), }, keys: None, resolver: OAuthResolverConfig { @@ -61,7 +65,10 @@ async fn main() -> Result<(), Box> { .authorize( std::env::var("HANDLE").unwrap_or(String::from("https://bsky.social")), AuthorizeOptions { - scopes: Some(vec![String::from("atproto")]), + scopes: vec![ + Scope::Known(KnownScope::Atproto), + Scope::Known(KnownScope::TransitionGeneric) + ], ..Default::default() } ) diff --git a/atrium-oauth/oauth-client/src/atproto.rs b/atrium-oauth/oauth-client/src/atproto.rs index 94bf4c56..ae23170f 100644 --- a/atrium-oauth/oauth-client/src/atproto.rs +++ b/atrium-oauth/oauth-client/src/atproto.rs @@ -1,6 +1,6 @@ use crate::keyset::Keyset; use crate::types::{OAuthClientMetadata, TryIntoOAuthClientMetadata}; -use atrium_xrpc::http::Uri; +use atrium_xrpc::http::uri::{InvalidUri, Scheme, Uri}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -18,6 +18,22 @@ pub enum Error { EmptyJwks, #[error("`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided")] AuthSigningAlg, + #[error(transparent)] + SerdeHtmlForm(#[from] serde_html_form::ser::Error), + #[error(transparent)] + LocalhostClient(#[from] LocalhostClientError), +} + +#[derive(Error, Debug)] +pub enum LocalhostClientError { + #[error("invalid redirect_uri: {0}")] + Invalid(#[from] InvalidUri), + #[error("loopback client_id must use `http:` redirect_uri")] + NotHttpScheme, + #[error("loopback client_id must not use `localhost` as redirect_uri hostname")] + Localhost, + #[error("loopback client_id must not use loopback addresses as redirect_uri")] + NotLoopbackHost, } pub type Result = core::result::Result; @@ -56,22 +72,37 @@ impl From for String { } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] +#[serde(untagged)] pub enum Scope { + Known(KnownScope), + Unknown(String), +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum KnownScope { + #[serde(rename = "atproto")] Atproto, + #[serde(rename = "transition:generic")] + TransitionGeneric, + #[serde(rename = "transition:chat.bsky")] + TransitionChatBsky, } -impl From for String { - fn from(value: Scope) -> Self { - match value { - Scope::Atproto => String::from("atproto"), +impl AsRef for Scope { + fn as_ref(&self) -> &str { + match self { + Self::Known(KnownScope::Atproto) => "atproto", + Self::Known(KnownScope::TransitionGeneric) => "transition:generic", + Self::Known(KnownScope::TransitionChatBsky) => "transition:chat.bsky", + Self::Unknown(value) => value, } } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct AtprotoLocalhostClientMetadata { - pub redirect_uris: Vec, + pub redirect_uris: Option>, + pub scopes: Option>, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -90,14 +121,46 @@ impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata { type Error = Error; fn try_into_client_metadata(self, _: &Option) -> Result { - if self.redirect_uris.is_empty() { - return Err(Error::EmptyRedirectUris); + // validate redirect_uris + if let Some(redirect_uris) = &self.redirect_uris { + for redirect_uri in redirect_uris { + let uri = redirect_uri.parse::().map_err(LocalhostClientError::Invalid)?; + if uri.scheme() != Some(&Scheme::HTTP) { + return Err(Error::LocalhostClient(LocalhostClientError::NotHttpScheme)); + } + if uri.host() == Some("localhost") { + return Err(Error::LocalhostClient(LocalhostClientError::Localhost)); + } + if uri.host().map_or(true, |host| host != "127.0.0.1" && host != "[::1]") { + return Err(Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)); + } + } + } + // determine client_id + #[derive(serde::Serialize)] + struct Parameters { + #[serde(skip_serializing_if = "Option::is_none")] + redirect_uri: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + scope: Option, + } + let query = serde_html_form::to_string(Parameters { + redirect_uri: self.redirect_uris.clone(), + scope: self + .scopes + .map(|scopes| scopes.iter().map(AsRef::as_ref).collect::>().join(" ")), + })?; + let mut client_id = String::from("http://localhost"); + if !query.is_empty() { + client_id.push_str(&format!("?{query}")); } Ok(OAuthClientMetadata { - client_id: String::from("http://localhost"), + client_id, client_uri: None, - redirect_uris: self.redirect_uris, - scope: None, // will be set to `atproto` + redirect_uris: self + .redirect_uris + .unwrap_or(vec![String::from("http://127.0.0.1/"), String::from("http://[::1]/")]), + scope: None, grant_types: None, // will be set to `authorization_code` and `refresh_token` token_endpoint_auth_method: Some(String::from("none")), dpop_bound_access_tokens: None, // will be set to `true` @@ -121,7 +184,7 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata { if !self.grant_types.contains(&GrantType::AuthorizationCode) { return Err(Error::InvalidGrantTypes); } - if !self.scopes.contains(&Scope::Atproto) { + if !self.scopes.contains(&Scope::Known(KnownScope::Atproto)) { return Err(Error::InvalidScope); } let (jwks_uri, mut jwks) = (self.jwks_uri, None); @@ -150,9 +213,7 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata { redirect_uris: self.redirect_uris, token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()), grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()), - scope: Some( - self.scopes.into_iter().map(|v| v.into()).collect::>().join(" "), - ), + scope: Some(self.scopes.iter().map(AsRef::as_ref).collect::>().join(" ")), dpop_bound_access_tokens: Some(true), jwks_uri, jwks, @@ -160,3 +221,176 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use elliptic_curve::SecretKey; + use jose_jwk::{Jwk, Key, Parameters}; + use p256::pkcs8::DecodePrivateKey; + + const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T +4i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P +gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3 +-----END PRIVATE KEY-----"#; + + #[test] + fn test_localhost_client_metadata_default() { + let metadata = AtprotoLocalhostClientMetadata::default(); + assert_eq!( + metadata.try_into_client_metadata(&None).expect("failed to convert metadata"), + OAuthClientMetadata { + client_id: String::from("http://localhost"), + client_uri: None, + redirect_uris: vec![ + String::from("http://127.0.0.1/"), + String::from("http://[::1]/"), + ], + scope: None, + grant_types: None, + token_endpoint_auth_method: Some(AuthMethod::None.into()), + dpop_bound_access_tokens: None, + jwks_uri: None, + jwks: None, + token_endpoint_auth_signing_alg: None, + } + ); + } + + #[test] + fn test_localhost_client_metadata_custom() { + let metadata = AtprotoLocalhostClientMetadata { + redirect_uris: Some(vec![ + String::from("http://127.0.0.1/callback"), + String::from("http://[::1]/callback"), + ]), + scopes: Some(vec![ + Scope::Known(KnownScope::Atproto), + Scope::Known(KnownScope::TransitionGeneric), + Scope::Unknown(String::from("unknown")), + ]), + }; + assert_eq!( + metadata.try_into_client_metadata(&None).expect("failed to convert metadata"), + OAuthClientMetadata { + client_id: String::from("http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=atproto+transition%3Ageneric+unknown"), + client_uri: None, + redirect_uris: vec![ + String::from("http://127.0.0.1/callback"), + String::from("http://[::1]/callback"), + ], + scope: None, + grant_types: None, + token_endpoint_auth_method: Some(AuthMethod::None.into()), + dpop_bound_access_tokens: None, + jwks_uri: None, + jwks: None, + token_endpoint_auth_signing_alg: None, + } + ); + } + + #[test] + fn test_localhost_client_metadata_invalid() { + { + let metadata = AtprotoLocalhostClientMetadata { + redirect_uris: Some(vec![String::from("http://")]), + ..Default::default() + }; + let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail"); + assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Invalid(_)))); + } + { + let metadata = AtprotoLocalhostClientMetadata { + redirect_uris: Some(vec![String::from("https://127.0.0.1/")]), + ..Default::default() + }; + let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail"); + assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotHttpScheme))); + } + { + let metadata = AtprotoLocalhostClientMetadata { + redirect_uris: Some(vec![String::from("http://localhost:8000/")]), + ..Default::default() + }; + let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail"); + assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Localhost))); + } + { + let metadata = AtprotoLocalhostClientMetadata { + redirect_uris: Some(vec![String::from("http://192.168.0.0/")]), + ..Default::default() + }; + let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail"); + assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotLoopbackHost))); + } + } + + #[test] + fn test_client_metadata() { + let metadata = AtprotoClientMetadata { + client_id: String::from("https://example.com/client_metadata.json"), + client_uri: String::from("https://example.com"), + redirect_uris: vec![String::from("https://example.com/callback")], + token_endpoint_auth_method: AuthMethod::PrivateKeyJwt, + grant_types: vec![GrantType::AuthorizationCode], + scopes: vec![Scope::Known(KnownScope::Atproto)], + jwks_uri: None, + token_endpoint_auth_signing_alg: Some(String::from("ES256")), + }; + { + let metadata = metadata.clone(); + let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail"); + assert!(matches!(err, Error::EmptyJwks)); + } + { + let metadata = metadata.clone(); + let secret_key = SecretKey::::from_pkcs8_pem(PRIVATE_KEY) + .expect("failed to parse private key"); + let keys = vec![Jwk { + key: Key::from(&secret_key.into()), + prm: Parameters { kid: Some(String::from("kid00")), ..Default::default() }, + }]; + let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset"); + assert_eq!( + metadata + .try_into_client_metadata(&Some(keyset.clone())) + .expect("failed to convert metadata"), + OAuthClientMetadata { + client_id: String::from("https://example.com/client_metadata.json"), + client_uri: Some(String::from("https://example.com")), + redirect_uris: vec![String::from("https://example.com/callback"),], + scope: Some(String::from("atproto")), + grant_types: Some(vec![String::from("authorization_code")]), + token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()), + dpop_bound_access_tokens: Some(true), + jwks_uri: None, + jwks: Some(keyset.public_jwks()), + token_endpoint_auth_signing_alg: Some(String::from("ES256")), + } + ); + } + } + + #[test] + fn test_scope_serde() { + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] + struct Scopes { + scopes: Vec, + } + + let scopes = Scopes { + scopes: vec![ + Scope::Known(KnownScope::Atproto), + Scope::Known(KnownScope::TransitionGeneric), + Scope::Unknown(String::from("unknown")), + ], + }; + let json = serde_json::to_string(&scopes).expect("failed to serialize scopes"); + assert_eq!(json, r#"{"scopes":["atproto","transition:generic","unknown"]}"#); + let deserialized = + serde_json::from_str::(&json).expect("failed to deserialize scopes"); + assert_eq!(deserialized, scopes); + } +} diff --git a/atrium-oauth/oauth-client/src/http_client/dpop.rs b/atrium-oauth/oauth-client/src/http_client/dpop.rs index 489fc3e8..b92fd621 100644 --- a/atrium-oauth/oauth-client/src/http_client/dpop.rs +++ b/atrium-oauth/oauth-client/src/http_client/dpop.rs @@ -13,6 +13,7 @@ use jose_jwk::{crypto, EcCurves, Jwk, Key}; use rand::rngs::SmallRng; use rand::{RngCore, SeedableRng}; use serde::Deserialize; +use sha2::{Digest, Sha256}; use std::sync::Arc; use thiserror::Error; @@ -40,17 +41,16 @@ where S: SimpleStore, { inner: Arc, - key: Key, - #[allow(dead_code)] - iss: String, + pub(crate) key: Key, nonces: S, + is_auth_server: bool, } impl DpopClient { pub fn new( key: Key, - iss: String, http_client: Arc, + is_auth_server: bool, supported_algs: &Option>, ) -> Result { if let Some(algs) = supported_algs { @@ -66,9 +66,21 @@ impl DpopClient { } } let nonces = MemorySimpleStore::::default(); - Ok(Self { inner: http_client, key, iss, nonces }) + Ok(Self { inner: http_client, key, nonces, is_auth_server }) } - fn build_proof(&self, htm: String, htu: String, nonce: Option) -> Result { +} + +impl DpopClient +where + S: SimpleStore, +{ + fn build_proof( + &self, + htm: String, + htu: String, + ath: Option, + nonce: Option, + ) -> Result { match crypto::Key::try_from(&self.key).map_err(Error::JwkCrypto)? { crypto::Key::P256(crypto::Kind::Secret(secret_key)) => { let mut header = RegisteredHeader::from(Algorithm::Signing(Signing::Es256)); @@ -83,12 +95,7 @@ impl DpopClient { iat: Some(Utc::now().timestamp()), ..Default::default() }, - public: PublicClaims { - htm: Some(htm), - htu: Some(htu), - nonce, - ..Default::default() - }, + public: PublicClaims { htm: Some(htm), htu: Some(htu), ath, nonce }, }; Ok(create_signed_jwt(secret_key.into(), header.into(), claims)?) } @@ -96,14 +103,24 @@ impl DpopClient { } } fn is_use_dpop_nonce_error(&self, response: &Response>) -> bool { - // is auth server? - if response.status() == 400 { - if let Ok(res) = serde_json::from_slice::(response.body()) { - return res.error == "use_dpop_nonce"; - }; + // https://datatracker.ietf.org/doc/html/rfc9449#name-authorization-server-provid + if self.is_auth_server { + if response.status() == 400 { + if let Ok(res) = serde_json::from_slice::(response.body()) { + return res.error == "use_dpop_nonce"; + }; + } + } + // https://datatracker.ietf.org/doc/html/rfc6750#section-3 + // https://datatracker.ietf.org/doc/html/rfc9449#name-resource-server-provided-no + else if response.status() == 401 { + if let Some(www_auth) = + response.headers().get("WWW-Authenticate").and_then(|v| v.to_str().ok()) + { + return www_auth.starts_with("DPoP") + && www_auth.contains(r#"error="use_dpop_nonce""#); + } } - // is resource server? - false } // https://datatracker.ietf.org/doc/html/rfc9449#section-4.2 @@ -115,9 +132,10 @@ impl DpopClient { } } -impl HttpClient for DpopClient +impl HttpClient for DpopClient where T: HttpClient + Send + Sync + 'static, + S: SimpleStore + Send + Sync + 'static, { async fn send_http( &self, @@ -128,9 +146,16 @@ where let nonce_key = uri.authority().unwrap().to_string(); let htm = request.method().to_string(); let htu = uri.to_string(); + // https://datatracker.ietf.org/doc/html/rfc9449#section-4.2 + let ath = request + .headers() + .get("Authorization") + .filter(|v| v.to_str().map_or(false, |s| s.starts_with("DPoP "))) + .map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..]))); let init_nonce = self.nonces.get(&nonce_key).await?; - let init_proof = self.build_proof(htm.clone(), htu.clone(), init_nonce.clone())?; + let init_proof = + self.build_proof(htm.clone(), htu.clone(), ath.clone(), init_nonce.clone())?; request.headers_mut().insert("DPoP", init_proof.parse()?); let response = self.inner.send_http(request.clone()).await?; @@ -151,7 +176,7 @@ where if !self.is_use_dpop_nonce_error(&response) { return Ok(response); } - let next_proof = self.build_proof(htm, htu, next_nonce)?; + let next_proof = self.build_proof(htm, htu, ath, next_nonce)?; request.headers_mut().insert("DPoP", next_proof.parse()?); let response = self.inner.send_http(request).await?; Ok(response) diff --git a/atrium-oauth/oauth-client/src/lib.rs b/atrium-oauth/oauth-client/src/lib.rs index d9a7f071..06071dc7 100644 --- a/atrium-oauth/oauth-client/src/lib.rs +++ b/atrium-oauth/oauth-client/src/lib.rs @@ -12,7 +12,7 @@ mod types; mod utils; pub use atproto::{ - AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, GrantType, Scope, + AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, GrantType, KnownScope, Scope, }; pub use error::{Error, Result}; #[cfg(feature = "default-client")] diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index 25e21b43..e844f00a 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -166,7 +166,7 @@ where response_type: AuthorizationResponseType::Code, redirect_uri, state, - scope: options.scopes.map(|v| v.join(" ")), + scope: Some(options.scopes.iter().map(AsRef::as_ref).collect::>().join(" ")), response_mode: None, code_challenge, code_challenge_method: AuthorizationCodeChallengeMethod::S256, @@ -256,8 +256,6 @@ where // https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 let verifier = URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default())); - let mut hasher = Sha256::new(); - hasher.update(verifier.as_bytes()); - (URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes())), verifier) + (URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier) } } diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index 2a05beff..c9d556f3 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -5,7 +5,8 @@ use crate::keyset::Keyset; use crate::resolver::OAuthResolver; use crate::types::{ OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse, - PushedAuthorizationRequestParameters, TokenGrantType, TokenRequestParameters, TokenSet, + PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, + TokenRequestParameters, TokenSet, }; use crate::utils::{compare_algos, generate_nonce}; use atrium_api::types::string::Datetime; @@ -56,6 +57,7 @@ pub type Result = core::result::Result; #[allow(dead_code)] pub enum OAuthRequest { Token(TokenRequestParameters), + Refresh(RefreshRequestParameters), Revocation, Introspection, PushedAuthorizationRequest(PushedAuthorizationRequestParameters), @@ -65,6 +67,7 @@ impl OAuthRequest { fn name(&self) -> String { String::from(match self { Self::Token(_) => "token", + Self::Refresh(_) => "refresh", Self::Revocation => "revocation", Self::Introspection => "introspection", Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", @@ -72,7 +75,7 @@ impl OAuthRequest { } fn expected_status(&self) -> StatusCode { match self { - Self::Token(_) => StatusCode::OK, + Self::Token(_) | Self::Refresh(_) => StatusCode::OK, Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, _ => unimplemented!(), } @@ -120,8 +123,8 @@ where ) -> Result { let dpop_client = DpopClient::new( dpop_key, - client_metadata.client_id.clone(), http_client, + true, &server_metadata.token_endpoint_auth_signing_alg_values_supported, )?; Ok(Self { server_metadata, client_metadata, dpop_client, resolver, keyset }) @@ -181,6 +184,7 @@ where }; let body = match &request { OAuthRequest::Token(params) => self.build_body(params)?, + OAuthRequest::Refresh(params) => self.build_body(params)?, OAuthRequest::PushedAuthorizationRequest(params) => self.build_body(params)?, _ => unimplemented!(), }; @@ -266,7 +270,9 @@ where } fn endpoint(&self, request: &OAuthRequest) -> Option<&String> { match request { - OAuthRequest::Token(_) => Some(&self.server_metadata.token_endpoint), + OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => { + Some(&self.server_metadata.token_endpoint) + } OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(), OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(), OAuthRequest::PushedAuthorizationRequest(_) => { diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index 45ef9bdb..a5712674 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -4,11 +4,13 @@ mod request; mod response; mod token; +use crate::atproto::{KnownScope, Scope}; pub use client_metadata::{OAuthClientMetadata, TryIntoOAuthClientMetadata}; pub use metadata::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; pub use request::{ AuthorizationCodeChallengeMethod, AuthorizationResponseType, - PushedAuthorizationRequestParameters, TokenGrantType, TokenRequestParameters, + PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, + TokenRequestParameters, }; pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse}; use serde::Deserialize; @@ -36,13 +38,19 @@ impl From for String { #[derive(Debug, Deserialize)] pub struct AuthorizeOptions { pub redirect_uri: Option, - pub scopes: Option>, // TODO: enum? + pub scopes: Vec, pub prompt: Option, + pub state: Option, } impl Default for AuthorizeOptions { fn default() -> Self { - Self { redirect_uri: None, scopes: Some(vec![String::from("atproto")]), prompt: None } + Self { + redirect_uri: None, + scopes: vec![Scope::Known(KnownScope::Atproto)], + prompt: None, + state: None, + } } } diff --git a/atrium-oauth/oauth-client/src/types/request.rs b/atrium-oauth/oauth-client/src/types/request.rs index a5b71474..d8d352e6 100644 --- a/atrium-oauth/oauth-client/src/types/request.rs +++ b/atrium-oauth/oauth-client/src/types/request.rs @@ -49,6 +49,8 @@ pub struct PushedAuthorizationRequestParameters { #[serde(rename_all = "snake_case")] pub enum TokenGrantType { AuthorizationCode, + #[allow(dead_code)] + RefreshToken, } #[derive(Serialize)] @@ -60,3 +62,11 @@ pub struct TokenRequestParameters { // https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 pub code_verifier: String, } + +#[derive(Serialize)] +pub struct RefreshRequestParameters { + // https://datatracker.ietf.org/doc/html/rfc6749#section-6 + pub grant_type: TokenGrantType, + pub refresh_token: String, + pub scope: Option, +}