Skip to content

Commit

Permalink
feat: Update XrpcClient, add AuthorizationToken (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan authored Nov 15, 2024
1 parent 4896bb7 commit c892ece
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
22 changes: 13 additions & 9 deletions atrium-api/src/agent/inner.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use super::{Session, SessionStore};
use crate::did_doc::DidDocument;
use crate::types::string::Did;
use crate::types::TryFromUnknown;
use atrium_xrpc::error::{Error, Result, XrpcErrorKind};
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
use crate::types::{string::Did, TryFromUnknown};
use atrium_xrpc::{
error::{Error, Result, XrpcErrorKind},
types::AuthorizationToken,
HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
};
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
use std::{
fmt::Debug,
sync::{Arc, RwLock},
};
use tokio::sync::{Mutex, Notify};

struct WrapperClient<S, T> {
Expand Down Expand Up @@ -72,13 +76,13 @@ where
fn base_uri(&self) -> String {
self.store.get_endpoint()
}
async fn authentication_token(&self, is_refresh: bool) -> Option<String> {
async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
self.store.get_session().await.map(|session| {
if is_refresh {
AuthorizationToken::Bearer(if is_refresh {
session.data.refresh_jwt
} else {
session.data.access_jwt
}
})
})
}
async fn atproto_proxy_header(&self) -> Option<String> {
Expand Down
21 changes: 10 additions & 11 deletions atrium-xrpc/src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use crate::error::Error;
use crate::error::{XrpcError, XrpcErrorKind};
use crate::types::{Header, NSID_REFRESH_SESSION};
use crate::error::{Error, XrpcError, XrpcErrorKind};
use crate::types::{AuthorizationToken, Header, NSID_REFRESH_SESSION};
use crate::{InputDataOrBytes, OutputDataOrBytes, XrpcRequest};
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::fmt::Debug;
use std::future::Future;
use std::{fmt::Debug, future::Future};

/// An abstract HTTP client.
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
Expand All @@ -32,9 +30,12 @@ type XrpcResult<O, E> = core::result::Result<OutputDataOrBytes<O>, self::Error<E
pub trait XrpcClient: HttpClient {
/// The base URI of the XRPC server.
fn base_uri(&self) -> String;
/// Get the authentication token to use `Authorization` header.
/// Get the authorization token to use `Authorization` header.
#[allow(unused_variables)]
fn authentication_token(&self, is_refresh: bool) -> impl Future<Output = Option<String>> {
fn authorization_token(
&self,
is_refresh: bool,
) -> impl Future<Output = Option<AuthorizationToken>> {
async { None }
}
/// Get the `atproto-proxy` header.
Expand Down Expand Up @@ -102,12 +103,10 @@ where
builder = builder.header(Header::ContentType, encoding);
}
if let Some(token) = client
.authentication_token(
request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION,
)
.authorization_token(request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION)
.await
{
builder = builder.header(Header::Authorization, format!("Bearer {}", token));
builder = builder.header(Header::Authorization, token);
}
if let Some(proxy) = client.atproto_proxy_header().await {
builder = builder.header(Header::AtprotoProxy, proxy);
Expand Down
20 changes: 18 additions & 2 deletions atrium-xrpc/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
use http::header::{AUTHORIZATION, CONTENT_TYPE};
use http::{HeaderName, Method};
use http::header::{HeaderName, HeaderValue, InvalidHeaderValue, AUTHORIZATION, CONTENT_TYPE};
use http::Method;
use serde::{de::DeserializeOwned, Serialize};

pub(crate) const NSID_REFRESH_SESSION: &str = "com.atproto.server.refreshSession";

pub enum AuthorizationToken {
Bearer(String),
Dpop(String),
}

impl TryFrom<AuthorizationToken> for HeaderValue {
type Error = InvalidHeaderValue;

fn try_from(token: AuthorizationToken) -> Result<Self, Self::Error> {
HeaderValue::from_str(&match token {
AuthorizationToken::Bearer(t) => format!("Bearer {t}"),
AuthorizationToken::Dpop(t) => format!("DPoP {t}"),
})
}
}

/// HTTP headers which can be used in XPRC requests.
pub enum Header {
ContentType,
Expand Down

0 comments on commit c892ece

Please sign in to comment.