Skip to content

Commit

Permalink
test(webserver): create mock oauth client and write unit test for oau…
Browse files Browse the repository at this point in the history
…th login flow
  • Loading branch information
boxbeam committed May 28, 2024
1 parent 5aedf9c commit 0007c88
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 28 deletions.
30 changes: 30 additions & 0 deletions ee/tabby-webserver/src/oauth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,33 @@ pub fn new_oauth_client(
OAuthProvider::Github => Arc::new(GithubClient::new(auth)),
}
}

#[cfg(test)]
pub mod test_client {
use super::*;

pub struct TestOAuthClient {
pub access_token_response: fn() -> Result<String>,
pub user_email: String,
pub user_name: String,
}

#[async_trait]
impl OAuthClient for TestOAuthClient {
async fn exchange_code_for_token(&self, _code: String) -> Result<String> {
(self.access_token_response)()
}

async fn fetch_user_email(&self, _access_token: &str) -> Result<String> {
Ok(self.user_email.clone())
}

async fn fetch_user_full_name(&self, _access_token: &str) -> Result<String> {
Ok(self.user_name.clone())
}

async fn get_authorization_url(&self) -> Result<String> {
Ok("https://example.com".into())
}

Check warning on line 60 in ee/tabby-webserver/src/oauth/mod.rs

View check run for this annotation

Codecov / codecov/patch

ee/tabby-webserver/src/oauth/mod.rs#L58-L60

Added lines #L58 - L60 were not covered by tests
}
}
132 changes: 104 additions & 28 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use super::graphql_pagination_to_filter;
use crate::{
bail,
jwt::{generate_jwt, validate_jwt},
oauth,
oauth::{self, OAuthClient},
};

#[derive(Clone)]
Expand Down Expand Up @@ -421,27 +421,21 @@ impl AuthenticationService for AuthenticationServiceImpl {
provider: OAuthProvider,
) -> std::result::Result<OAuthResponse, OAuthError> {
let client = oauth::new_oauth_client(provider, Arc::new(self.clone()));
let access_token = client.exchange_code_for_token(code).await?;
let email = client.fetch_user_email(&access_token).await?;
let name = client.fetch_user_full_name(&access_token).await?;
let license = self
.license
.read()
.await
.context("Failed to read license info")?;
let user_id =
get_or_create_oauth_user(&license, &self.db, &self.setting, &self.mail, &email, &name)
.await?;

let refresh_token = self.db.create_refresh_token(user_id).await?;

let access_token = generate_jwt(user_id.as_id()).map_err(|_| OAuthError::Unknown)?;

let resp = OAuthResponse {
access_token,
refresh_token,
};
Ok(resp)
oauth_login(
client,
code,
&self.db,
&*self.setting,
&license,
&*self.mail,
)
.await

Check warning on line 438 in ee/tabby-webserver/src/service/auth.rs

View check run for this annotation

Codecov / codecov/patch

ee/tabby-webserver/src/service/auth.rs#L430-L438

Added lines #L430 - L438 were not covered by tests
}

async fn read_oauth_credential(
Expand Down Expand Up @@ -515,11 +509,35 @@ impl AuthenticationService for AuthenticationServiceImpl {
}
}

async fn oauth_login(
client: Arc<dyn OAuthClient>,
code: String,
db: &DbConn,
setting: &dyn SettingService,
license: &LicenseInfo,
mail: &dyn EmailService,
) -> Result<OAuthResponse, OAuthError> {
let access_token = client.exchange_code_for_token(code).await?;
let email = client.fetch_user_email(&access_token).await?;
let name = client.fetch_user_full_name(&access_token).await?;
let user_id = get_or_create_oauth_user(&license, &db, setting, mail, &email, &name).await?;

let refresh_token = db.create_refresh_token(user_id).await?;

let access_token = generate_jwt(user_id.as_id()).map_err(|_| OAuthError::Unknown)?;

let resp = OAuthResponse {
access_token,
refresh_token,
};
Ok(resp)
}

async fn get_or_create_oauth_user(
license: &LicenseInfo,
db: &DbConn,
setting: &Arc<dyn SettingService>,
mail: &Arc<dyn EmailService>,
setting: &dyn SettingService,
mail: &dyn EmailService,
email: &str,
name: &str,
) -> Result<i64, OAuthError> {
Expand Down Expand Up @@ -707,11 +725,14 @@ mod tests {
use serial_test::serial;
use tabby_schema::{
juniper::relay::{self, Connection},
license::{LicenseInfo, LicenseStatus},
license::{LicenseInfo, LicenseStatus, LicenseType},
};

use super::*;
use crate::service::email::{new_email_service, testutils::TestEmailServer};
use crate::{
oauth::test_client::TestOAuthClient,
service::email::{new_email_service, testutils::TestEmailServer},
};

#[test]
fn test_password_hash() {
Expand Down Expand Up @@ -941,8 +962,8 @@ mod tests {
let res = get_or_create_oauth_user(
&license,
&service.db,
&setting,
&service.mail,
&*setting,
&*service.mail,
"[email protected]",
"",
)
Expand All @@ -958,8 +979,8 @@ mod tests {
let res = get_or_create_oauth_user(
&license,
&service.db,
&setting,
&service.mail,
&*setting,
&*service.mail,
"[email protected]",
"Example User",
)
Expand All @@ -976,8 +997,8 @@ mod tests {
let res = get_or_create_oauth_user(
&license,
&service.db,
&setting,
&service.mail,
&*setting,
&*service.mail,
"[email protected]",
"",
)
Expand All @@ -993,8 +1014,8 @@ mod tests {
let res = get_or_create_oauth_user(
&license,
&service.db,
&setting,
&service.mail,
&*setting,
&*service.mail,
"[email protected]",
"User 3 by Invitation",
)
Expand Down Expand Up @@ -1460,4 +1481,59 @@ mod tests {
assert_eq!(cred.client_id, "id");
assert_eq!(cred.client_secret, "secret");
}

#[tokio::test]
async fn test_oauth_login() {
let service = test_authentication_service().await;
let license = LicenseInfo {
r#type: LicenseType::Enterprise,
status: LicenseStatus::Ok,
seats: 1000,
seats_used: 0,
issued_at: None,
expires_at: None,
};

let client = Arc::new(TestOAuthClient {
access_token_response: || Ok("faketoken".into()),
user_email: "[email protected]".into(),
user_name: "user".into(),
});

service
.create_invitation("[email protected]".into())
.await
.unwrap();

let response = oauth_login(
client,
"fakecode".into(),
&service.db,
&*service.setting,
&license,
&*service.mail,
)
.await
.unwrap();

assert!(!response.access_token.is_empty());

let client = Arc::new(TestOAuthClient {
access_token_response: || Err(anyhow!("bad auth")),
user_email: "[email protected]".into(),
user_name: "user".into(),
});

let response = oauth_login(
client,
"fakecode".into(),
&service.db,
&*service.setting,
&license,
&*service.mail,
)
.await;

assert!(response.is_err());
}
}

0 comments on commit 0007c88

Please sign in to comment.