-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test(webserver): create mock oauth client and write unit test for oau…
…th login flow
- Loading branch information
Showing
2 changed files
with
134 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ use super::graphql_pagination_to_filter; | |
use crate::{ | ||
bail, | ||
jwt::{generate_jwt, validate_jwt}, | ||
oauth, | ||
oauth::{self, OAuthClient}, | ||
}; | ||
|
||
#[derive(Clone)] | ||
|
@@ -421,27 +421,21 @@ impl AuthenticationService for AuthenticationServiceImpl { | |
provider: OAuthProvider, | ||
) -> std::result::Result<OAuthResponse, OAuthError> { | ||
let client = oauth::new_oauth_client(provider, Arc::new(self.clone())); | ||
let access_token = client.exchange_code_for_token(code).await?; | ||
let email = client.fetch_user_email(&access_token).await?; | ||
let name = client.fetch_user_full_name(&access_token).await?; | ||
let license = self | ||
.license | ||
.read() | ||
.await | ||
.context("Failed to read license info")?; | ||
let user_id = | ||
get_or_create_oauth_user(&license, &self.db, &self.setting, &self.mail, &email, &name) | ||
.await?; | ||
|
||
let refresh_token = self.db.create_refresh_token(user_id).await?; | ||
|
||
let access_token = generate_jwt(user_id.as_id()).map_err(|_| OAuthError::Unknown)?; | ||
|
||
let resp = OAuthResponse { | ||
access_token, | ||
refresh_token, | ||
}; | ||
Ok(resp) | ||
oauth_login( | ||
client, | ||
code, | ||
&self.db, | ||
&*self.setting, | ||
&license, | ||
&*self.mail, | ||
) | ||
.await | ||
} | ||
|
||
async fn read_oauth_credential( | ||
|
@@ -515,11 +509,35 @@ impl AuthenticationService for AuthenticationServiceImpl { | |
} | ||
} | ||
|
||
async fn oauth_login( | ||
client: Arc<dyn OAuthClient>, | ||
code: String, | ||
db: &DbConn, | ||
setting: &dyn SettingService, | ||
license: &LicenseInfo, | ||
mail: &dyn EmailService, | ||
) -> Result<OAuthResponse, OAuthError> { | ||
let access_token = client.exchange_code_for_token(code).await?; | ||
let email = client.fetch_user_email(&access_token).await?; | ||
let name = client.fetch_user_full_name(&access_token).await?; | ||
let user_id = get_or_create_oauth_user(&license, &db, setting, mail, &email, &name).await?; | ||
|
||
let refresh_token = db.create_refresh_token(user_id).await?; | ||
|
||
let access_token = generate_jwt(user_id.as_id()).map_err(|_| OAuthError::Unknown)?; | ||
|
||
let resp = OAuthResponse { | ||
access_token, | ||
refresh_token, | ||
}; | ||
Ok(resp) | ||
} | ||
|
||
async fn get_or_create_oauth_user( | ||
license: &LicenseInfo, | ||
db: &DbConn, | ||
setting: &Arc<dyn SettingService>, | ||
mail: &Arc<dyn EmailService>, | ||
setting: &dyn SettingService, | ||
mail: &dyn EmailService, | ||
email: &str, | ||
name: &str, | ||
) -> Result<i64, OAuthError> { | ||
|
@@ -707,11 +725,14 @@ mod tests { | |
use serial_test::serial; | ||
use tabby_schema::{ | ||
juniper::relay::{self, Connection}, | ||
license::{LicenseInfo, LicenseStatus}, | ||
license::{LicenseInfo, LicenseStatus, LicenseType}, | ||
}; | ||
|
||
use super::*; | ||
use crate::service::email::{new_email_service, testutils::TestEmailServer}; | ||
use crate::{ | ||
oauth::test_client::TestOAuthClient, | ||
service::email::{new_email_service, testutils::TestEmailServer}, | ||
}; | ||
|
||
#[test] | ||
fn test_password_hash() { | ||
|
@@ -941,8 +962,8 @@ mod tests { | |
let res = get_or_create_oauth_user( | ||
&license, | ||
&service.db, | ||
&setting, | ||
&service.mail, | ||
&*setting, | ||
&*service.mail, | ||
"[email protected]", | ||
"", | ||
) | ||
|
@@ -958,8 +979,8 @@ mod tests { | |
let res = get_or_create_oauth_user( | ||
&license, | ||
&service.db, | ||
&setting, | ||
&service.mail, | ||
&*setting, | ||
&*service.mail, | ||
"[email protected]", | ||
"Example User", | ||
) | ||
|
@@ -976,8 +997,8 @@ mod tests { | |
let res = get_or_create_oauth_user( | ||
&license, | ||
&service.db, | ||
&setting, | ||
&service.mail, | ||
&*setting, | ||
&*service.mail, | ||
"[email protected]", | ||
"", | ||
) | ||
|
@@ -993,8 +1014,8 @@ mod tests { | |
let res = get_or_create_oauth_user( | ||
&license, | ||
&service.db, | ||
&setting, | ||
&service.mail, | ||
&*setting, | ||
&*service.mail, | ||
"[email protected]", | ||
"User 3 by Invitation", | ||
) | ||
|
@@ -1460,4 +1481,59 @@ mod tests { | |
assert_eq!(cred.client_id, "id"); | ||
assert_eq!(cred.client_secret, "secret"); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_oauth_login() { | ||
let service = test_authentication_service().await; | ||
let license = LicenseInfo { | ||
r#type: LicenseType::Enterprise, | ||
status: LicenseStatus::Ok, | ||
seats: 1000, | ||
seats_used: 0, | ||
issued_at: None, | ||
expires_at: None, | ||
}; | ||
|
||
let client = Arc::new(TestOAuthClient { | ||
access_token_response: || Ok("faketoken".into()), | ||
user_email: "[email protected]".into(), | ||
user_name: "user".into(), | ||
}); | ||
|
||
service | ||
.create_invitation("[email protected]".into()) | ||
.await | ||
.unwrap(); | ||
|
||
let response = oauth_login( | ||
client, | ||
"fakecode".into(), | ||
&service.db, | ||
&*service.setting, | ||
&license, | ||
&*service.mail, | ||
) | ||
.await | ||
.unwrap(); | ||
|
||
assert!(!response.access_token.is_empty()); | ||
|
||
let client = Arc::new(TestOAuthClient { | ||
access_token_response: || Err(anyhow!("bad auth")), | ||
user_email: "[email protected]".into(), | ||
user_name: "user".into(), | ||
}); | ||
|
||
let response = oauth_login( | ||
client, | ||
"fakecode".into(), | ||
&service.db, | ||
&*service.setting, | ||
&license, | ||
&*service.mail, | ||
) | ||
.await; | ||
|
||
assert!(response.is_err()); | ||
} | ||
} |