From 68e900b18a0b8ca40d62c620277bab9fb1453e13 Mon Sep 17 00:00:00 2001 From: Xander Bil <47951455+xerbalind@users.noreply.github.com> Date: Wed, 8 Nov 2023 22:15:18 +0000 Subject: [PATCH] Openid Connect implementation (#262) * Openidconnect implementation --- .gitignore | 1 + Cargo.lock | 79 ++++++ Cargo.toml | 5 + Rocket.toml | 3 + build.rs | 18 ++ keys/.gitkeep | 0 src/config.rs | 1 + src/controllers/oauth_controller.rs | 41 ++- src/errors.rs | 2 + src/jwt.rs | 119 +++++++++ src/lib.rs | 5 + src/models/session.rs | 4 + tests/common/mod.rs | 1 + tests/oauth.rs | 401 ++++++++++++++++------------ 14 files changed, 510 insertions(+), 170 deletions(-) create mode 100644 build.rs create mode 100644 keys/.gitkeep create mode 100644 src/jwt.rs diff --git a/.gitignore b/.gitignore index 0d7383c1..fb35ddba 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ db/* .envrc static/dist/ node_modules/ +keys/*.pem diff --git a/Cargo.lock b/Cargo.lock index 784d581f..93739d80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1044,6 +1044,20 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "9.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "155c4d7e39ad04c172c5e3a99c434ea3b4a7ba7960b38ecd562b270b097cce09" +dependencies = [ + "base64 0.21.5", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1275,6 +1289,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.17" @@ -1420,6 +1455,16 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "pem" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3163d2912b7c3b52d651a055f2c7eec9ba5cd22d26ef75b8dd3a59980b185923" +dependencies = [ + "base64 0.21.5", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.0" @@ -1676,6 +1721,20 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "ring" +version = "0.17.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.48.0", +] + [[package]] name = "rocket" version = "0.5.0-rc.3" @@ -1974,6 +2033,18 @@ dependencies = [ "libc", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "simple_logger" version = "4.2.0" @@ -2428,6 +2499,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.4.1" @@ -2770,9 +2847,11 @@ dependencies = [ "diesel", "diesel-derive-enum", "diesel_migrations", + "jsonwebtoken", "lazy_static", "lettre", "log", + "openssl", "parking_lot", "pwhash", "rand", diff --git a/Cargo.toml b/Cargo.toml index a2e5a6f6..b8090179 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,3 +32,8 @@ tempfile = "3.1" parking_lot = { version = "0.12" } thiserror = "1.0" validator = { version = "0.16", features = [ "derive" ] } +jsonwebtoken = "9.1" +openssl = "0.10" + +[build-dependencies] +openssl = "0.10" diff --git a/Rocket.toml b/Rocket.toml index 9cd17812..1c66465f 100644 --- a/Rocket.toml +++ b/Rocket.toml @@ -17,6 +17,7 @@ maximum_pending_users = 25 [debug] secret_key = "1vwCFFPSdQya895gNiO556SzmfShG6MokstgttLvwjw=" +ec_private_key = "keys/jwt_key.pem" bcrypt_cost = 4 seed_database = true @@ -29,6 +30,8 @@ port = 8000 # Values you want to fill in for production use # admin_email = # Email address to send admin notifications to (e.g. admin@zeus.gent) # secret_key = # used to encrypt cookies (generate a new one!) +# ec_private_key = # Path to ECDSA private key for signing jwt's. Key Algo needs to be ES384 in PKCS#8 form. +# generate by running: openssl ecparam -genkey -noout -name secp384r1 | openssl pkcs8 -topk8 -nocrypt -out ec-private.pem) # base_url = # URL where the application is hosten (e.g. https://auth.zeus.gent) # mail_from = # From header to set when sending emails (e.g. zauth@zeus.gent) # mail_server = # domain of the SMTP server used to send mail (e.g. smtp.zeus.gent) diff --git a/build.rs b/build.rs new file mode 100644 index 00000000..2ab44fc3 --- /dev/null +++ b/build.rs @@ -0,0 +1,18 @@ +use std::fs::File; +use std::io::Write; +use std::path::Path; + +use openssl::ec::{EcGroup, EcKey}; +use openssl::nid::Nid; +use openssl::pkey::PKey; + +fn main() { + let path = Path::new("keys/jwt_key.pem"); + if !path.exists() { + let group = EcGroup::from_curve_name(Nid::SECP384R1).unwrap(); + let pkey = PKey::from_ec_key(EcKey::generate(&group).unwrap()).unwrap(); + let mut f = File::create(path).unwrap(); + let pem = pkey.private_key_to_pem_pkcs8().unwrap(); + f.write_all(&pem).unwrap(); + } +} diff --git a/keys/.gitkeep b/keys/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/src/config.rs b/src/config.rs index be075c9b..713fa493 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,6 +14,7 @@ pub struct Config { pub secure_token_length: usize, pub bcrypt_cost: u32, pub base_url: String, + pub ec_private_key: String, pub mail_queue_size: usize, pub mail_queue_wait_seconds: u64, pub mail_from: String, diff --git a/src/controllers/oauth_controller.rs b/src/controllers/oauth_controller.rs index 87d15fd0..bcecf460 100644 --- a/src/controllers/oauth_controller.rs +++ b/src/controllers/oauth_controller.rs @@ -1,3 +1,4 @@ +use jsonwebtoken::jwk::JwkSet; use rocket::form::Form; use rocket::http::{Cookie, CookieJar}; use rocket::response::{Redirect, Responder}; @@ -10,6 +11,7 @@ use crate::ephemeral::session::UserSession; use crate::errors::Either::{Left, Right}; use crate::errors::*; use crate::http_authentication::BasicAuthentication; +use crate::jwt::JWTBuilder; use crate::models::client::*; use crate::models::session::*; use crate::models::user::*; @@ -161,6 +163,7 @@ pub struct UserToken { pub client_id: i32, pub client_name: String, pub redirect_uri: String, + pub scope: Option, } #[get("/oauth/grant")] @@ -215,6 +218,7 @@ async fn authorization_granted( let authorization_code = token_store .create_token(UserToken { user_id: user.id, + scope: state.scope.clone(), username: user.username.clone(), client_id: state.client_id.clone(), client_name: state.client_name.clone(), @@ -240,6 +244,8 @@ fn authorization_denied(state: AuthState) -> Redirect { pub struct TokenSuccess { access_token: String, token_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + id_token: Option, expires_in: i64, } @@ -258,6 +264,7 @@ pub async fn token( form: Form, config: &State, token_state: &State>, + jwt_builder: &State, db: DbConn, ) -> Result> { let data = form.into_inner(); @@ -306,13 +313,37 @@ pub async fn token( ))) } else { let user = User::find(token.user_id, &db).await?; - let session = - Session::create_client_session(&user, &client, &config, &db) - .await?; + let id_token = token + .scope + .as_ref() + .map(|scope| -> Option { + match scope.contains("openid") { + true => { + jwt_builder.encode_id_token(&client, &user, config).ok() + }, + false => None, + } + }) + .flatten(); + + let session = Session::create_client_session( + &user, + &client, + token.scope, + &config, + &db, + ) + .await?; Ok(Json(TokenSuccess { access_token: session.key.unwrap().clone(), - token_type: String::from("bearer"), - expires_in: config.client_session_seconds, + token_type: String::from("bearer"), + id_token, + expires_in: config.client_session_seconds, })) } } + +#[get("/oauth/jwks")] +pub async fn jwks(jwt_builder: &State) -> Json { + Json(jwt_builder.jwks.clone()) +} diff --git a/src/errors.rs b/src/errors.rs index 3401fab8..a7759371 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -214,6 +214,8 @@ pub enum InternalError { BincodeError(#[from] Box), #[error("B64 decode error")] Base64DecodeError(#[from] base64::DecodeError), + #[error("JWT error")] + JWTError(#[from] jsonwebtoken::errors::Error), } pub type InternalResult = std::result::Result; diff --git a/src/jwt.rs b/src/jwt.rs new file mode 100644 index 00000000..d026a650 --- /dev/null +++ b/src/jwt.rs @@ -0,0 +1,119 @@ +use crate::config::Config; +use crate::errors::{InternalError, LaunchError, Result}; +use crate::models::client::Client; +use crate::models::user::User; +use chrono::Utc; +use jsonwebtoken::jwk::{ + CommonParameters, EllipticCurveKeyParameters, Jwk, JwkSet, +}; +use jsonwebtoken::{encode, EncodingKey, Header}; +use openssl::bn::{BigNum, BigNumContext}; +use openssl::ec::EcKey; +use serde::Serialize; +use std::fs::File; +use std::io::Read; + +pub struct JWTBuilder { + pub key: EncodingKey, + pub header: Header, + pub jwks: JwkSet, +} + +#[derive(Serialize, Debug)] +struct IDToken { + sub: String, + iss: String, + aud: String, + exp: i64, + iat: i64, + preferred_username: String, + email: String, +} + +impl JWTBuilder { + pub fn new(config: &Config) -> Result { + let mut file = File::open(&config.ec_private_key) + .map_err(|err| LaunchError::BadConfigValueType(err.to_string()))?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer) + .map_err(|err| LaunchError::BadConfigValueType(err.to_string()))?; + + let key = EncodingKey::from_ec_pem(&buffer) + .map_err(|err| LaunchError::BadConfigValueType(err.to_string()))?; + let header = Header::new(jsonwebtoken::Algorithm::ES384); + + let private_key = EcKey::private_key_from_pem(&buffer) + .map_err(|err| LaunchError::BadConfigValueType(err.to_string()))?; + + let mut ctx: BigNumContext = BigNumContext::new().unwrap(); + let public_key = private_key.public_key(); + let mut x = BigNum::new().unwrap(); + let mut y = BigNum::new().unwrap(); + public_key + .affine_coordinates(private_key.group(), &mut x, &mut y, &mut ctx) + .expect("x,y coordinates"); + + let jwk = Jwk { + common: CommonParameters { + public_key_use: Some( + jsonwebtoken::jwk::PublicKeyUse::Signature, + ), + key_algorithm: Some( + jsonwebtoken::jwk::KeyAlgorithm::ES384, + ), + key_operations: None, + key_id: None, + x509_url: None, + x509_chain: None, + x509_sha1_fingerprint: None, + x509_sha256_fingerprint: None, + }, + algorithm: jsonwebtoken::jwk::AlgorithmParameters::EllipticCurve( + EllipticCurveKeyParameters { + key_type: jsonwebtoken::jwk::EllipticCurveKeyType::EC, + curve: jsonwebtoken::jwk::EllipticCurve::P384, + x: base64::encode_config( + x.to_vec(), + base64::URL_SAFE_NO_PAD, + ), + y: base64::encode_config( + y.to_vec(), + base64::URL_SAFE_NO_PAD, + ), + }, + ), + }; + + Ok(JWTBuilder { + key, + header, + jwks: JwkSet { + keys: Vec::from([jwk]), + }, + }) + } + + pub fn encode(&self, claims: &T) -> Result { + Ok(encode(&self.header, claims, &self.key) + .map_err(InternalError::from)?) + } + + pub fn encode_id_token( + &self, + client: &Client, + user: &User, + config: &Config, + ) -> Result { + let id_token = IDToken { + sub: user.id.to_string(), + iss: config.base_url().to_string(), + aud: client.name.clone(), + iat: Utc::now().timestamp(), + exp: Utc::now().timestamp() + + config.client_session_seconds, + preferred_username: user.username.clone(), + email: user.email.clone(), + }; + self.encode(&id_token) + } +} diff --git a/src/lib.rs b/src/lib.rs index b6940457..bd2fdd9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,12 +33,14 @@ pub mod db_seed; pub mod ephemeral; pub mod errors; pub mod http_authentication; +pub mod jwt; pub mod mailer; pub mod models; pub mod token_store; pub mod util; use diesel_migrations::MigrationHarness; +use jwt::JWTBuilder; use lettre::message::Mailbox; use rocket::fairing::AdHoc; use rocket::figment::Figment; @@ -90,6 +92,7 @@ fn assemble(rocket: Rocket) -> Rocket { ); let token_store = TokenStore::::new(&config); let mailer = Mailer::new(&config).unwrap(); + let jwt_builder = JWTBuilder::new(&config).expect("config"); let rocket = rocket .mount( @@ -108,6 +111,7 @@ fn assemble(rocket: Rocket) -> Rocket { oauth_controller::grant_get, oauth_controller::grant_post, oauth_controller::token, + oauth_controller::jwks, pages_controller::home_page, sessions_controller::create_session, sessions_controller::new_session, @@ -151,6 +155,7 @@ fn assemble(rocket: Rocket) -> Rocket { .manage(token_store) .manage(mailer) .manage(admin_email) + .manage(jwt_builder) .attach(DbConn::fairing()) .attach(AdHoc::config::()) .attach(AdHoc::on_ignite("Database preparation", prepare_database)); diff --git a/src/models/session.rs b/src/models/session.rs index 33c0fa81..22100a8c 100644 --- a/src/models/session.rs +++ b/src/models/session.rs @@ -48,6 +48,7 @@ pub struct NewSession { pub client_id: Option, pub created_at: NaiveDateTime, pub expires_at: NaiveDateTime, + pub scope: Option, } impl Session { @@ -64,6 +65,7 @@ impl Session { key: None, created_at, expires_at, + scope: None, }; db.run(move |conn| { conn.transaction(|conn| { @@ -82,6 +84,7 @@ impl Session { pub async fn create_client_session( user: &User, client: &Client, + scope: Option, conf: &Config, db: &DbConn, ) -> Result { @@ -94,6 +97,7 @@ impl Session { key: Some(key), created_at, expires_at, + scope, }; db.run(move |conn| { conn.transaction(|conn| { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 05f8ec4b..a8811803 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -44,6 +44,7 @@ pub fn config() -> Config { email_confirmation_token_seconds: 300, secure_token_length: 64, bcrypt_cost: BCRYPT_COST, + ec_private_key: "keys/jwt_key.pem".to_string(), base_url: "example.com".to_string(), mail_queue_size: 10, mail_queue_wait_seconds: 0, diff --git a/tests/oauth.rs b/tests/oauth.rs index 6d77814d..4dad6a29 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -8,6 +8,10 @@ extern crate urlencoding; extern crate zauth; use self::serde_json::Value; +use common::HttpClient; +use jsonwebtoken::jwk::JwkSet; +use jsonwebtoken::DecodingKey; +use jsonwebtoken::Validation; use regex::Regex; use rocket::http::Header; use rocket::http::Status; @@ -17,10 +21,18 @@ use zauth::controllers::oauth_controller::UserToken; use zauth::models::client::{Client, NewClient}; use zauth::models::user::{NewUser, User}; use zauth::token_store::TokenStore; +use zauth::DbConn; mod common; use crate::common::url; +const REDIRECT_URI: &str = "https://example.com/redirect/me/here"; +const CLIENT_ID: &str = "test"; +const CLIENT_STATE: &str = "anarchy (╯°□°)╯ ┻━┻"; +const USER_USERNAME: &str = "batman"; +const USER_PASSWORD: &str = "wolololo"; +const USER_EMAIL: &str = "test@test.com"; + fn get_param(param_name: &str, query: &String) -> Option { Regex::new(&format!("{}=([^&]+)", param_name)) .expect("valid regex") @@ -28,42 +40,176 @@ fn get_param(param_name: &str, query: &String) -> Option { .map(|c| c[1].to_string()) } +async fn create_user(db: &DbConn) -> User { + User::create( + NewUser { + username: String::from(USER_USERNAME), + password: String::from(USER_PASSWORD), + full_name: String::from("abc"), + email: String::from(USER_EMAIL), + ssh_key: Some(String::from("ssh-rsa pqrstuvwxyz")), + not_a_robot: true, + }, + common::BCRYPT_COST, + db, + ) + .await + .expect("user") +} + +async fn create_client(db: &DbConn) -> Client { + let mut client = Client::create( + NewClient { + name: String::from(CLIENT_ID), + }, + &db, + ) + .await + .expect("client created"); + + client.needs_grant = true; + client.redirect_uri_list = String::from(REDIRECT_URI); + client.update(db).await.expect("client updated") +} + +// Test all the usual oauth requests until `access_token/id_token` is retrieved. +async fn get_token( + authorize_url: String, + http_client: &HttpClient, + client: &Client, + user: &User, +) -> Value { + let response = http_client.get(authorize_url).dispatch().await; + assert_eq!(response.status(), Status::Ok); + + // 2. User accepts authorization to client + // Server should respond with login redirect. + let response = http_client + .post("/oauth/authorize") + .body("authorized=true") + .header(ContentType::Form) + .dispatch() + .await; + let login_location = response + .headers() + .get_one("Location") + .expect("Location header"); + + assert!(login_location.starts_with("/login")); + + // 3. User requests the login page + let response = http_client.get(login_location).dispatch().await; + + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.content_type(), Some(ContentType::HTML)); + + // 4. User posts it credentials to the login path + let login_url = "/login"; + let form_body = format!( + "username={}&password={}", + url(&user.username), + url(USER_PASSWORD), + ); + + let response = http_client + .post(login_url) + .body(form_body) + .header(ContentType::Form) + .dispatch() + .await; + + assert_eq!(response.status(), Status::SeeOther); + let grant_location = response + .headers() + .get_one("Location") + .expect("Location header"); + + assert!(grant_location.starts_with("/oauth/grant")); + + // 5. User requests grant page + let response = http_client.get(grant_location).dispatch().await; + + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.content_type(), Some(ContentType::HTML)); + + // 6. User posts to grant page + let grant_url = "/oauth/grant"; + let grant_form_body = String::from("grant=true"); + + let response = http_client + .post(grant_url) + .body(grant_form_body.clone()) + .header(ContentType::Form) + .dispatch() + .await; + + assert_eq!(response.status(), Status::SeeOther); + let redirect_location = response + .headers() + .get_one("Location") + .expect("Location header"); + + let redirect_uri_regex = Regex::new("^([^?]+)?(.*)$").unwrap(); + let (redirect_uri_base, redirect_uri_params) = redirect_uri_regex + .captures(&redirect_location) + .map(|c| (c[1].to_string(), c[2].to_string())) + .unwrap(); + + assert_eq!(redirect_uri_base, REDIRECT_URI); + + let authorization_code = + get_param("code", &redirect_uri_params).expect("authorization code"); + let state = get_param("state", &redirect_uri_params).expect("state"); + + // The client state we've sent in the beginning should be included in + // the redirect back to the OAuth client + assert_eq!( + CLIENT_STATE, + urlencoding::decode(&state).expect("state decoded") + ); + + // Log out user so we don't have their cookies anymore + let response = http_client.post("/logout").dispatch().await; + + assert_eq!(response.status(), Status::SeeOther); + + // 7a. Client requests access code while sending its credentials + // trough HTTP Auth. + let token_url = "/oauth/token"; + let form_body = format!( + "grant_type=authorization_code&code={}&redirect_uri={}", + authorization_code, REDIRECT_URI + ); + + let credentials = + base64::encode(&format!("{}:{}", CLIENT_ID, client.secret)); + + let req = http_client + .post(token_url) + .header(ContentType::Form) + .header(Header::new( + "Authorization", + format!("Basic {}", credentials), + )) + .body(form_body); + + let response = req.dispatch().await; + + assert_eq!(response.status(), Status::Ok); + assert_eq!( + response.content_type().expect("content type"), + ContentType::JSON + ); + + let response_body = response.into_string().await.expect("response body"); + serde_json::from_str(&response_body).expect("response json values") +} + #[rocket::async_test] async fn normal_flow() { common::as_visitor(async move |http_client, db| { - let redirect_uri = "https://example.com/redirect/me/here"; - let client_id = "test"; - let client_state = "anarchy (╯°□°)╯ ┻━┻"; - let user_username = "batman"; - let user_password = "wolololo"; - - let user = User::create( - NewUser { - username: String::from(user_username), - password: String::from(user_password), - full_name: String::from("abc"), - email: String::from("ghi@jkl.mno"), - ssh_key: Some(String::from("ssh-rsa pqrstuvwxyz")), - not_a_robot: true, - }, - common::BCRYPT_COST, - &db, - ) - .await - .expect("user"); - - let mut client = Client::create( - NewClient { - name: String::from(client_id), - }, - &db, - ) - .await - .expect("client created"); - - client.needs_grant = true; - client.redirect_uri_list = String::from(redirect_uri); - let client = client.update(&db).await.expect("client updated"); + let user = create_user(&db).await; + let client = create_client(&db).await; // 1. User is redirected to OAuth server with request params given by // the client @@ -71,141 +217,18 @@ async fn normal_flow() { let authorize_url = format!( "/oauth/authorize?response_type=code&redirect_uri={}&client_id={}&\ state={}", - url(redirect_uri), - url(client_id), - url(client_state) - ); - let response = http_client.get(authorize_url).dispatch().await; - - assert_eq!(response.status(), Status::Ok); - - // 2. User accepts authorization to client - // Server should respond with login redirect. - let response = http_client - .post("/oauth/authorize") - .body("authorized=true") - .header(ContentType::Form) - .dispatch() - .await; - let login_location = response - .headers() - .get_one("Location") - .expect("Location header"); - - assert!(login_location.starts_with("/login")); - - // 3. User requests the login page - let response = http_client.get(login_location).dispatch().await; - - assert_eq!(response.status(), Status::Ok); - assert_eq!(response.content_type(), Some(ContentType::HTML)); - - // 4. User posts it credentials to the login path - let login_url = "/login"; - let form_body = format!( - "username={}&password={}", - url(user_username), - url(user_password), - ); - - let response = http_client - .post(login_url) - .body(form_body) - .header(ContentType::Form) - .dispatch() - .await; - - assert_eq!(response.status(), Status::SeeOther); - let grant_location = response - .headers() - .get_one("Location") - .expect("Location header"); - - assert!(grant_location.starts_with("/oauth/grant")); - - // 5. User requests grant page - let response = http_client.get(grant_location).dispatch().await; - - assert_eq!(response.status(), Status::Ok); - assert_eq!(response.content_type(), Some(ContentType::HTML)); - - // 6. User posts to grant page - let grant_url = "/oauth/grant"; - let grant_form_body = String::from("grant=true"); - - let response = http_client - .post(grant_url) - .body(grant_form_body.clone()) - .header(ContentType::Form) - .dispatch() - .await; - - assert_eq!(response.status(), Status::SeeOther); - let redirect_location = response - .headers() - .get_one("Location") - .expect("Location header"); - - let redirect_uri_regex = Regex::new("^([^?]+)?(.*)$").unwrap(); - let (redirect_uri_base, redirect_uri_params) = redirect_uri_regex - .captures(&redirect_location) - .map(|c| (c[1].to_string(), c[2].to_string())) - .unwrap(); - - assert_eq!(redirect_uri_base, redirect_uri); - - let authorization_code = get_param("code", &redirect_uri_params) - .expect("authorization code"); - let state = get_param("state", &redirect_uri_params).expect("state"); - - // The client state we've sent in the beginning should be included in - // the redirect back to the OAuth client - assert_eq!( - client_state, - urlencoding::decode(&state).expect("state decoded") + url(REDIRECT_URI), + url(CLIENT_ID), + url(CLIENT_STATE) ); - // Log out user so we don't have their cookies anymore - let response = http_client.post("/logout").dispatch().await; - - assert_eq!(response.status(), Status::SeeOther); - - // 7a. Client requests access code while sending its credentials - // trough HTTP Auth. - let token_url = "/oauth/token"; - let form_body = format!( - "grant_type=authorization_code&code={}&redirect_uri={}", - authorization_code, redirect_uri - ); - - let credentials = - base64::encode(&format!("{}:{}", client_id, client.secret)); - - let req = http_client - .post(token_url) - .header(ContentType::Form) - .header(Header::new( - "Authorization", - format!("Basic {}", credentials), - )) - .body(form_body); - - let response = req.dispatch().await; - - assert_eq!(response.status(), Status::Ok); - assert_eq!( - response.content_type().expect("content type"), - ContentType::JSON - ); - - let response_body = - response.into_string().await.expect("response body"); - let data: Value = - serde_json::from_str(&response_body).expect("response json values"); + // Do all the requests until access_token is retrieved. + let data = get_token(authorize_url, &http_client, &client, &user).await; dbg!(&data); assert!(data["access_token"].is_string()); assert!(data["token_type"].is_string()); + assert_eq!(data.get("id_token"), None); assert_eq!(data["token_type"], "bearer"); // 7b. Client requests access code while sending its credentials @@ -219,11 +242,12 @@ async fn normal_flow() { let authorization_code = token_store .create_token(UserToken { + scope: None, user_id: user.id, username: user.username.clone(), client_id: client.id, client_name: client.name, - redirect_uri: String::from(redirect_uri), + redirect_uri: String::from(REDIRECT_URI), }) .await; @@ -231,7 +255,7 @@ async fn normal_flow() { let form_body = format!( "grant_type=authorization_code&code={}&redirect_uri={}&\ client_id={}&client_secret={}", - authorization_code, redirect_uri, client_id, client.secret + authorization_code, REDIRECT_URI, CLIENT_ID, client.secret ); let req = http_client @@ -275,7 +299,54 @@ async fn normal_flow() { serde_json::from_str(&response_body).expect("response json values"); assert!(data["id"].is_number()); - assert_eq!(data["username"], user_username); + assert_eq!(data["username"], USER_USERNAME); + }) + .await; +} + +#[rocket::async_test] +async fn openid_flow() { + common::as_visitor(async move |http_client, db| { + let user = create_user(&db).await; + let client = create_client(&db).await; + + let authorize_url = format!( + "/oauth/authorize?response_type=code&redirect_uri={}&client_id={}&\ + state={}&scope=openid", + url(REDIRECT_URI), + url(CLIENT_ID), + url(CLIENT_STATE) + ); + + let data = get_token(authorize_url, &http_client, &client, &user).await; + + assert!(data["access_token"].is_string()); + assert!(data["token_type"].is_string()); + assert_ne!(data.get("id_token"), None); + assert_eq!(data["token_type"], "bearer"); + + let url = "/oauth/jwks"; + let req = http_client.get(url); + let response = req.dispatch().await; + let response_body = + response.into_string().await.expect("response body"); + let jwk_set: JwkSet = + serde_json::from_str(&response_body).expect("response json values"); + assert_eq!(jwk_set.keys.len(), 1); + + let mut validation = Validation::new(jsonwebtoken::Algorithm::ES384); + validation.set_audience(&[CLIENT_ID]); + validation.set_issuer(&["http://localhost:8000"]); + + let id_token = jsonwebtoken::decode::( + data["id_token"].as_str().unwrap(), + &DecodingKey::from_jwk(&jwk_set.keys.get(0).unwrap()).unwrap(), + &validation, + ) + .expect("id token") + .claims; + assert_eq!(id_token["preferred_username"], USER_USERNAME); + assert_eq!(id_token["email"], USER_EMAIL); }) .await; }