From f3d166ab4ce1f726d5c0bd41e096d9a6344cb59b Mon Sep 17 00:00:00 2001 From: Francesco Cogno Date: Fri, 26 Jun 2020 16:07:07 +0200 Subject: [PATCH] Implemented refresh token code --- .../examples/device_code_flow.rs | 5 +- azure_sdk_auth_aad/src/device_code_flow.rs | 23 +++++-- azure_sdk_auth_aad/src/lib.rs | 3 + azure_sdk_auth_aad/src/prelude.rs | 1 + azure_sdk_auth_aad/src/refresh_token.rs | 12 ++-- azure_sdk_auth_aad/src/responses/mod.rs | 2 + .../src/responses/refresh_token_response.rs | 67 +++++++++++++++++++ azure_sdk_auth_aad/src/traits.rs | 16 +++++ 8 files changed, 113 insertions(+), 16 deletions(-) create mode 100644 azure_sdk_auth_aad/src/prelude.rs create mode 100644 azure_sdk_auth_aad/src/responses/mod.rs create mode 100644 azure_sdk_auth_aad/src/responses/refresh_token_response.rs create mode 100644 azure_sdk_auth_aad/src/traits.rs diff --git a/azure_sdk_auth_aad/examples/device_code_flow.rs b/azure_sdk_auth_aad/examples/device_code_flow.rs index 8b3669fa7..d19943e73 100644 --- a/azure_sdk_auth_aad/examples/device_code_flow.rs +++ b/azure_sdk_auth_aad/examples/device_code_flow.rs @@ -5,7 +5,6 @@ use futures::stream::StreamExt; use oauth2::ClientId; use std::env; use std::error::Error; -use std::sync::Arc; #[tokio::main] async fn main() -> Result<(), Box> { @@ -31,7 +30,7 @@ async fn main() -> Result<(), Box> { &client_id, &[ &format!( - "https://{}.blob.core.windows.net/.default", + "https://{}.blob.core.windows.net/user_impersonation", storage_account_name ), "offline_access", @@ -49,7 +48,7 @@ async fn main() -> Result<(), Box> { // return, besides errors, a success meaning either // Success or Pending. The loop will continue until we // get either a Success or an error. - let mut stream = Box::pin(device_code_flow.stream(&client)); + let mut stream = Box::pin(device_code_flow.stream()); let mut authorization = None; while let Some(resp) = stream.next().await { println!("{:?}", resp); diff --git a/azure_sdk_auth_aad/src/device_code_flow.rs b/azure_sdk_auth_aad/src/device_code_flow.rs index 30513285f..123ae9e5c 100644 --- a/azure_sdk_auth_aad/src/device_code_flow.rs +++ b/azure_sdk_auth_aad/src/device_code_flow.rs @@ -4,8 +4,8 @@ use azure_sdk_core::errors::AzureError; use futures::stream::unfold; use log::debug; pub use oauth2::{ClientId, ClientSecret}; +use std::borrow::Cow; use std::convert::TryInto; -use std::sync::Arc; use std::time::Duration; use url::form_urlencoded; @@ -21,7 +21,9 @@ pub struct DeviceCodePhaseOneResponse<'a> { // from the Azure answer. They will be added // manually after deserialization #[serde(skip)] - tenant_id: &'a str, + client: Option<&'a reqwest::Client>, + #[serde(skip)] + tenant_id: Cow<'a, str>, // we store the ClientId as string instead of // the original type because it does not // implement Default and it's in another @@ -30,17 +32,22 @@ pub struct DeviceCodePhaseOneResponse<'a> { client_id: String, } -pub async fn begin_authorize_device_code_flow<'a, 'b>( +pub async fn begin_authorize_device_code_flow<'a, 'b, T>( client: &'a reqwest::Client, - tenant_id: &'a str, + tenant_id: T, client_id: &'a ClientId, scopes: &'b [&'b str], -) -> Result, AzureError> { +) -> Result, AzureError> +where + T: Into>, +{ let mut encoded = form_urlencoded::Serializer::new(String::new()); let encoded = encoded.append_pair("client_id", client_id.as_str()); let encoded = encoded.append_pair("scope", &scopes.join(" ")); let encoded = encoded.finish(); + let tenant_id = tenant_id.into(); + debug!("encoded ==> {}", encoded); let url = url::Url::parse(&format!( @@ -69,6 +76,7 @@ pub async fn begin_authorize_device_code_flow<'a, 'b>( expires_in: device_code_reponse.expires_in, interval: device_code_reponse.interval, message: device_code_reponse.message, + client: Some(client), tenant_id, client_id: client_id.as_str().to_string(), }) @@ -92,7 +100,6 @@ impl<'a> DeviceCodePhaseOneResponse<'a> { pub fn stream<'b>( &'b self, - client: &'b reqwest::Client, ) -> impl futures::Stream> + 'b + '_ { #[derive(Debug, Clone, PartialEq)] enum NextState { @@ -123,7 +130,9 @@ impl<'a> DeviceCodePhaseOneResponse<'a> { let encoded = encoded.append_pair("device_code", &self.device_code); let encoded = encoded.finish(); - let result = match client + let result = match self + .client + .unwrap() .post(&uri) .header("ContentType", "application/x-www-form-urlencoded") .body(encoded) diff --git a/azure_sdk_auth_aad/src/lib.rs b/azure_sdk_auth_aad/src/lib.rs index baf7d37b7..213859bb2 100644 --- a/azure_sdk_auth_aad/src/lib.rs +++ b/azure_sdk_auth_aad/src/lib.rs @@ -23,10 +23,13 @@ pub mod errors; mod refresh_token; pub use refresh_token::*; mod naive_server; +mod traits; pub use crate::device_code_flow::*; pub use crate::device_code_responses::*; use futures::TryFutureExt; +mod responses; pub use naive_server::naive_server; +mod prelude; #[derive(Debug)] pub struct AuthObj { diff --git a/azure_sdk_auth_aad/src/prelude.rs b/azure_sdk_auth_aad/src/prelude.rs new file mode 100644 index 000000000..6a613054d --- /dev/null +++ b/azure_sdk_auth_aad/src/prelude.rs @@ -0,0 +1 @@ +pub use crate::traits::*; diff --git a/azure_sdk_auth_aad/src/refresh_token.rs b/azure_sdk_auth_aad/src/refresh_token.rs index ef8701285..23bff6f92 100644 --- a/azure_sdk_auth_aad/src/refresh_token.rs +++ b/azure_sdk_auth_aad/src/refresh_token.rs @@ -1,7 +1,8 @@ +use crate::responses::RefreshTokenResponse; use azure_sdk_core::errors::AzureError; use log::debug; use oauth2::{AccessToken, ClientId, ClientSecret}; -use std::sync::Arc; +use std::convert::TryInto; use url::form_urlencoded; pub async fn exchange_refresh_token( @@ -10,7 +11,7 @@ pub async fn exchange_refresh_token( client_id: &ClientId, client_secret: Option<&ClientSecret>, refresh_token: &AccessToken, -) -> Result<(), AzureError> { +) -> Result { let mut encoded = form_urlencoded::Serializer::new(String::new()); let encoded = encoded.append_pair("grant_type", "refresh_token"); let encoded = encoded.append_pair("client_id", client_id.as_str()); @@ -23,7 +24,7 @@ pub async fn exchange_refresh_token( let encoded = encoded.append_pair("refresh_token", refresh_token.secret()); let encoded = encoded.finish(); - println!("encoded ==> {}", encoded); + debug!("encoded ==> {}", encoded); let url = url::Url::parse(&format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", @@ -40,8 +41,7 @@ pub async fn exchange_refresh_token( .text() .await .map_err(|e| AzureError::GenericErrorWithText(e.to_string()))?; + debug!("{}", ret); - println!("{}", ret); - - Ok(()) + Ok(ret.try_into()?) } diff --git a/azure_sdk_auth_aad/src/responses/mod.rs b/azure_sdk_auth_aad/src/responses/mod.rs new file mode 100644 index 000000000..69f0ae0c3 --- /dev/null +++ b/azure_sdk_auth_aad/src/responses/mod.rs @@ -0,0 +1,2 @@ +mod refresh_token_response; +pub use refresh_token_response::RefreshTokenResponse; diff --git a/azure_sdk_auth_aad/src/responses/refresh_token_response.rs b/azure_sdk_auth_aad/src/responses/refresh_token_response.rs new file mode 100644 index 000000000..6d8c3e0d6 --- /dev/null +++ b/azure_sdk_auth_aad/src/responses/refresh_token_response.rs @@ -0,0 +1,67 @@ +use crate::prelude::*; +use oauth2::AccessToken; +use std::convert::TryInto; + +#[derive(Debug, Clone)] +pub struct RefreshTokenResponse { + token_type: String, + scopes: Vec, + expires_in: u64, + ext_expires_in: u64, + access_token: AccessToken, + refresh_token: AccessToken, +} + +impl TryInto for String { + type Error = serde_json::Error; + + fn try_into(self) -> Result { + // we use a temp struct to deserialize the scope into + // the scopes vec at later time + #[derive(Debug, Clone, Deserialize)] + pub struct _RefreshTokenResponse<'a> { + token_type: String, + scope: &'a str, + expires_in: u64, + ext_expires_in: u64, + access_token: AccessToken, + refresh_token: AccessToken, + } + + serde_json::from_str::<_RefreshTokenResponse>(&self).map(|rtr| RefreshTokenResponse { + token_type: rtr.token_type, + scopes: rtr.scope.split(' ').map(|s| s.to_owned()).collect(), + expires_in: rtr.expires_in, + ext_expires_in: rtr.ext_expires_in, + access_token: rtr.access_token, + refresh_token: rtr.refresh_token, + }) + } +} + +impl BearerToken for RefreshTokenResponse { + fn token_type(&self) -> &str { + &self.token_type + } + fn scopes(&self) -> &[String] { + &self.scopes + } + fn expires_in(&self) -> u64 { + self.expires_in + } + fn access_token(&self) -> &AccessToken { + &self.access_token + } +} + +impl RefreshToken for RefreshTokenResponse { + fn refresh_token(&self) -> &AccessToken { + &self.refresh_token + } +} + +impl ExtExpiresIn for RefreshTokenResponse { + fn ext_expires_in(&self) -> u64 { + self.ext_expires_in + } +} diff --git a/azure_sdk_auth_aad/src/traits.rs b/azure_sdk_auth_aad/src/traits.rs new file mode 100644 index 000000000..876cc0b7e --- /dev/null +++ b/azure_sdk_auth_aad/src/traits.rs @@ -0,0 +1,16 @@ +use oauth2::AccessToken; + +pub trait BearerToken { + fn token_type(&self) -> &str; + fn scopes(&self) -> &[String]; + fn expires_in(&self) -> u64; + fn access_token(&self) -> &AccessToken; +} + +pub trait RefreshToken { + fn refresh_token(&self) -> &AccessToken; +} + +pub trait ExtExpiresIn { + fn ext_expires_in(&self) -> u64; +}