From e684bf4963d3f6099f7ff73564e9d7d89d21bf39 Mon Sep 17 00:00:00 2001 From: Xander Date: Thu, 26 Oct 2023 17:52:30 +0200 Subject: [PATCH] Openidconnect implementation --- src/controllers/oauth_controller.rs | 60 ++++++++++++++++++++++++++--- src/models/session.rs | 4 ++ tests/oauth.rs | 10 +++-- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/controllers/oauth_controller.rs b/src/controllers/oauth_controller.rs index 87d15fd0..a790b4d3 100644 --- a/src/controllers/oauth_controller.rs +++ b/src/controllers/oauth_controller.rs @@ -1,3 +1,4 @@ +use chrono::Utc; use rocket::form::Form; use rocket::http::{Cookie, CookieJar}; use rocket::response::{Redirect, Responder}; @@ -161,6 +162,7 @@ pub struct UserToken { pub client_id: i32, pub client_name: String, pub redirect_uri: String, + pub scope: Option, } #[get("/oauth/grant")] @@ -215,6 +217,7 @@ async fn authorization_granted( let authorization_code = token_store .create_token(UserToken { user_id: user.id, + scope: state.scope.clone(), username: user.username.clone(), client_id: state.client_id.clone(), client_name: state.client_name.clone(), @@ -236,10 +239,22 @@ fn authorization_denied(state: AuthState) -> Redirect { )) } +#[derive(Serialize, Debug)] +pub struct IDToken { + sub: String, + iss: String, + aud: String, + exp: i64, + iat: i64, + nickname: String, + email: String, +} + #[derive(Serialize, Debug)] pub struct TokenSuccess { access_token: String, token_type: String, + id_token: Option, expires_in: i64, } @@ -252,6 +267,15 @@ pub struct TokenFormData { client_secret: Option, } +fn create_jwt(id_token: IDToken) -> String { + let header = base64::encode("{\"alg\": \"none\"}"); + let payload = base64::encode_config( + serde_json::to_string(&id_token).unwrap(), + base64::URL_SAFE_NO_PAD, + ); + format!("{}.{}.", header, payload) +} + #[post("/oauth/token", data = "
")] pub async fn token( auth: Option, @@ -306,13 +330,39 @@ pub async fn token( ))) } else { let user = User::find(token.user_id, &db).await?; - let session = - Session::create_client_session(&user, &client, &config, &db) - .await?; + let id_token = token + .scope + .as_ref() + .map(|scope| -> Option { + match scope.contains("openid") { + true => Some(create_jwt(IDToken { + sub: user.id.to_string(), + iss: config.base_url().to_string(), + aud: client.name.clone(), + iat: Utc::now().timestamp(), + exp: Utc::now().timestamp() + + config.client_session_seconds, + nickname: user.username.clone(), + email: user.email.clone(), + })), + false => None, + } + }) + .flatten(); + + let session = Session::create_client_session( + &user, + &client, + token.scope, + &config, + &db, + ) + .await?; Ok(Json(TokenSuccess { access_token: session.key.unwrap().clone(), - token_type: String::from("bearer"), - expires_in: config.client_session_seconds, + token_type: String::from("bearer"), + id_token, + expires_in: config.client_session_seconds, })) } } diff --git a/src/models/session.rs b/src/models/session.rs index 33c0fa81..22100a8c 100644 --- a/src/models/session.rs +++ b/src/models/session.rs @@ -48,6 +48,7 @@ pub struct NewSession { pub client_id: Option, pub created_at: NaiveDateTime, pub expires_at: NaiveDateTime, + pub scope: Option, } impl Session { @@ -64,6 +65,7 @@ impl Session { key: None, created_at, expires_at, + scope: None, }; db.run(move |conn| { conn.transaction(|conn| { @@ -82,6 +84,7 @@ impl Session { pub async fn create_client_session( user: &User, client: &Client, + scope: Option, conf: &Config, db: &DbConn, ) -> Result { @@ -94,6 +97,7 @@ impl Session { key: Some(key), created_at, expires_at, + scope, }; db.run(move |conn| { conn.transaction(|conn| { diff --git a/tests/oauth.rs b/tests/oauth.rs index 6d77814d..f7056964 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -32,6 +32,7 @@ fn get_param(param_name: &str, query: &String) -> Option { async fn normal_flow() { common::as_visitor(async move |http_client, db| { let redirect_uri = "https://example.com/redirect/me/here"; + let scope = None; let client_id = "test"; let client_state = "anarchy (╯°□°)╯ ┻━┻"; let user_username = "batman"; @@ -219,10 +220,11 @@ async fn normal_flow() { let authorization_code = token_store .create_token(UserToken { - user_id: user.id, - username: user.username.clone(), - client_id: client.id, - client_name: client.name, + scope, + user_id: user.id, + username: user.username.clone(), + client_id: client.id, + client_name: client.name, redirect_uri: String::from(redirect_uri), }) .await;