Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom token response in AuthorizationCodeGrant #104

Merged
merged 8 commits into from
Jun 19, 2021
Merged
10 changes: 8 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,14 @@ val testDependencies = Seq(
"io.circe" %% "circe-literal" % Versions.circe
).map(_ % Test)

val mimaSettings =
mimaPreviousArtifacts := previousStableVersion.value.map(organization.value %% moduleName.value % _).toSet
val mimaSettings =
mimaPreviousArtifacts := {
val onlyPatchChanged = previousStableVersion.value.flatMap(CrossVersion.partialVersion) == CrossVersion.partialVersion(version.value)
if(onlyPatchChanged)
previousStableVersion.value.map(organization.value %% moduleName.value % _).toSet
else
Set.empty
}

lazy val oauth2 = project.settings(
name := "sttp-oauth2",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.ocadotechnology.sttp.oauth2

import cats.syntax.all._
import cats.implicits._
import com.ocadotechnology.sttp.oauth2.common._
import io.circe.parser.decode
import sttp.client3._
Expand All @@ -10,6 +10,7 @@ import sttp.monad.syntax._

import AuthorizationCodeProvider.Config
import sttp.model.HeaderNames
import io.circe.Decoder

object AuthorizationCode {

Expand All @@ -35,16 +36,16 @@ object AuthorizationCode {
.addParam("client_id", clientId)
.addParam("redirect_uri", redirectUri)

private def convertAuthCodeToUser[F[_], UriType](
private def convertAuthCodeToUser[F[_], UriType, RT <: OAuth2TokenResponse.Basic: Decoder](
tokenUri: Uri,
authCode: String,
redirectUri: String,
clientId: String,
clientSecret: Secret[String]
)(
implicit backend: SttpBackend[F, Any]
): F[Oauth2TokenResponse] = {
implicit val F: MonadError[F] = backend.responseMonad
): F[RT] = {
implicit val ME: MonadError[F] = backend.responseMonad
backend
.send {
basicRequest
Expand All @@ -53,8 +54,15 @@ object AuthorizationCode {
.response(asString)
.header(HeaderNames.Accept, "application/json")
}
.map(_.body.leftMap(new RuntimeException(_)).flatMap(decode[Oauth2TokenResponse]).toTry)
.flatMap(backend.responseMonad.fromTry)
.flatMap{ response =>
ME.fromTry(
response
.body
.leftMap(new RuntimeException(_))
.flatMap(decode[RT])
.toTry
)
}
}

private def tokenRequestParams(authCode: String, redirectUri: String, clientId: String, clientSecret: String) =
Expand All @@ -66,15 +74,15 @@ object AuthorizationCode {
"code" -> authCode
)

private def performTokenRefresh[F[_], UriType](
private def performTokenRefresh[F[_], UriType, RT <: OAuth2TokenResponse.Basic: Decoder](
tokenUri: Uri,
refreshToken: String,
clientId: String,
clientSecret: Secret[String],
scopeOverride: ScopeSelection
)(
implicit backend: SttpBackend[F, Any]
): F[Oauth2TokenResponse] = {
): F[RT] = {
implicit val F: MonadError[F] = backend.responseMonad
backend
.send {
Expand All @@ -83,8 +91,7 @@ object AuthorizationCode {
.body(refreshTokenRequestParams(refreshToken, clientId, clientSecret.value, scopeOverride.toRequestMap))
.response(asString)
}
.map(_.body.leftMap(new RuntimeException(_)).flatMap(decode[RefreshTokenResponse]).toTry)
.map(_.map(_.toOauth2Token(refreshToken)))
.map(_.body.leftMap(new RuntimeException(_)).flatMap(decode[RT]).toTry)
.flatMap(backend.responseMonad.fromTry)
}

Expand All @@ -106,16 +113,16 @@ object AuthorizationCode {
): Uri =
prepareLoginLink(baseUrl, clientId, redirectUri.toString, state.getOrElse(""), scopes, path.values)

def authCodeToToken[F[_]](
def authCodeToToken[F[_], RT <: OAuth2TokenResponse.Basic: Decoder](
tokenUri: Uri,
redirectUri: Uri,
clientId: String,
clientSecret: Secret[String],
authCode: String
)(
implicit backend: SttpBackend[F, Any]
): F[Oauth2TokenResponse] =
convertAuthCodeToUser(tokenUri, authCode, redirectUri.toString, clientId, clientSecret)
): F[RT] =
convertAuthCodeToUser[F, Uri, RT](tokenUri, authCode, redirectUri.toString, clientId, clientSecret)

def logoutLink[F[_]](
baseUrl: Uri,
Expand All @@ -126,15 +133,15 @@ object AuthorizationCode {
): Uri =
prepareLogoutLink(baseUrl, clientId, postLogoutRedirect.getOrElse(redirectUri).toString(), path.values)

def refreshAccessToken[F[_]](
def refreshAccessToken[F[_], RT <: OAuth2TokenResponse.Basic: Decoder](
tokenUri: Uri,
clientId: String,
clientSecret: Secret[String],
refreshToken: String,
scopeOverride: ScopeSelection = ScopeSelection.KeepExisting
)(
implicit backend: SttpBackend[F, Any]
): F[Oauth2TokenResponse] =
performTokenRefresh(tokenUri, refreshToken, clientId, clientSecret, scopeOverride)
): F[RT] =
performTokenRefresh[F, Uri, RT](tokenUri, refreshToken, clientId, clientSecret, scopeOverride)

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import eu.timepit.refined.refineV
import eu.timepit.refined.string.Url
import sttp.client3._
import sttp.model.Uri
import io.circe.Decoder

/** Provides set of functions to simplify oauth2 identity provider integration.
* Use the `instance` companion object method to create instances.
Expand Down Expand Up @@ -36,21 +37,25 @@ trait AuthorizationCodeProvider[UriType, F[_]] {

/** Returns token details wrapped in effect
*
* @tparam TokenType type that models token response. It must implement MinimalStructurem, and have io.circe.Decoder instance.
* Predefined implementations: OAuth2TokenResponse and ExtendedOAuth2TokenResponse
* @param authCode code provided by oauth2 provider redirect,
* after user is authenticated correctly
* @return Oauth2TokenResponse details containing user info and additional information
* @return TokenType details containing user info and additional information
*/
def authCodeToToken(authCode: String): F[Oauth2TokenResponse]
def authCodeToToken[TokenType <: OAuth2TokenResponse.Basic: Decoder](authCode: String): F[TokenType]

/** Performs the token refresh on oauth2 provider nad returns new token details wrapped in effect
*
* @tparam TokenType type that models token response. It must implement MinimalStructurem, and have io.circe.Decoder instance.
* Predefined implementations: OAuth2TokenResponse and ExtendedOAuth2TokenResponse
* @param refreshToken value from refresh_token field of previous access token
* @param scope optional parameter for overriding token scope, useful to narrow down the scope
* when not provided or ScopeSelection.KeepExisting passed,
* the new token will be issued for the same scope as the previous one
* @return Oauth2TokenResponse details containing user info and additional information
* @return TokenType details containing user info and additional information
*/
def refreshAccessToken(refreshToken: String, scope: ScopeSelection = ScopeSelection.KeepExisting): F[Oauth2TokenResponse]
def refreshAccessToken[TokenType <: OAuth2TokenResponse.Basic: Decoder](refreshToken: String, scope: ScopeSelection = ScopeSelection.KeepExisting): F[TokenType]
}

object AuthorizationCodeProvider {
Expand Down Expand Up @@ -81,6 +86,12 @@ object AuthorizationCodeProvider {
tokenPath = Path(List(Segment("oauth2"), Segment("token")))
)

val GitHub = Config(
loginPath = Path(List(Segment("login"), Segment("oauth"), Segment("authorize"))),
logoutPath = Path(List(Segment("logout"))),
tokenPath = Path(List(Segment("login"), Segment("oauth"), Segment("access_token")))
)

// Other predefined configurations for well-known oauth2 providers could be placed here
}

Expand All @@ -106,9 +117,9 @@ object AuthorizationCodeProvider {
.toString
)

override def authCodeToToken(authCode: String): F[Oauth2TokenResponse] =
override def authCodeToToken[TT <: OAuth2TokenResponse.Basic: Decoder](authCode: String): F[TT] =
AuthorizationCode
.authCodeToToken(tokenUri, redirectUri, clientId, clientSecret, authCode)
.authCodeToToken[F, TT](tokenUri, redirectUri, clientId, clientSecret, authCode)

override def logoutLink(postLogoutRedirect: Option[Refined[String, Url]]): Refined[String, Url] =
refineV[Url].unsafeFrom[String](
Expand All @@ -117,10 +128,10 @@ object AuthorizationCodeProvider {
.toString
)

override def refreshAccessToken(
override def refreshAccessToken[TT <: OAuth2TokenResponse.Basic: Decoder](
refreshToken: String,
scopeOverride: ScopeSelection = ScopeSelection.KeepExisting
): F[Oauth2TokenResponse] =
): F[TT] =
AuthorizationCode
.refreshAccessToken(tokenUri, clientId, clientSecret, refreshToken, scopeOverride)

Expand All @@ -142,18 +153,18 @@ object AuthorizationCodeProvider {
AuthorizationCode
.loginLink(baseUrl, redirectUri, clientId, state, scope, pathsConfig.loginPath)

override def authCodeToToken(authCode: String): F[Oauth2TokenResponse] =
override def authCodeToToken[TT <: OAuth2TokenResponse.Basic: Decoder](authCode: String): F[TT] =
AuthorizationCode
.authCodeToToken(tokenUri, redirectUri, clientId, clientSecret, authCode)

override def logoutLink(postLogoutRedirect: Option[Uri]): Uri =
AuthorizationCode
.logoutLink(baseUrl, redirectUri, clientId, postLogoutRedirect, pathsConfig.logoutPath)

override def refreshAccessToken(
override def refreshAccessToken[TT <: OAuth2TokenResponse.Basic: Decoder](
refreshToken: String,
scopeOverride: ScopeSelection = ScopeSelection.KeepExisting
): F[Oauth2TokenResponse] =
): F[TT] =
AuthorizationCode
.refreshAccessToken(tokenUri, clientId, clientSecret, refreshToken, scopeOverride)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2Error

object OAuth2Token {

type Response = Either[Error, Oauth2TokenResponse]
// TODO: should be changed to Response[A] and allow custom responses, like in AuthorizationCodeGrant
type Response = Either[Error, ExtendedOAuth2TokenResponse]

private implicit val bearerTokenResponseDecoder: Decoder[Either[OAuth2Error, Oauth2TokenResponse]] =
circe.eitherOrFirstError[Oauth2TokenResponse, OAuth2Error](
Decoder[Oauth2TokenResponse],
private implicit val bearerTokenResponseDecoder: Decoder[Either[OAuth2Error, ExtendedOAuth2TokenResponse]] =
circe.eitherOrFirstError[ExtendedOAuth2TokenResponse, OAuth2Error](
Decoder[ExtendedOAuth2TokenResponse],
Decoder[OAuth2Error]
)

val response: ResponseAs[Response, Any] =
common.responseWithCommonError[Oauth2TokenResponse]
common.responseWithCommonError[ExtendedOAuth2TokenResponse]

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,61 @@ import io.circe.Decoder

import scala.concurrent.duration.FiniteDuration

case class Oauth2TokenResponse(
case class OAuth2TokenResponse(
accessToken: Secret[String],
scope: String,
tokenType: String,
expiresIn: Option[FiniteDuration],
refreshToken: Option[String]
) extends OAuth2TokenResponse.Basic

object OAuth2TokenResponse {
import com.ocadotechnology.sttp.oauth2.circe._

/** Miminal structure as required by RFC https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
* Token response is described in https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 as follows:
* access_token
* REQUIRED. The access token issued by the authorization server.
*
*token_type
* REQUIRED. The type of the token issued as described in
* Section 7.1. Value is case insensitive.
*
*expires_in
* RECOMMENDED. The lifetime in seconds of the access token. For
* example, the value "3600" denotes that the access token will
* expire in one hour from the time the response was generated.
* If omitted, the authorization server SHOULD provide the
* expiration time via other means or document the default value.
*
*refresh_token
* OPTIONAL. The refresh token, which can be used to obtain new
* access tokens using the same authorization grant as described
* in Section 6.
*
*scope
* OPTIONAL, if identical to the scope requested by the client;
* otherwise, REQUIRED. The scope of the access token as
* described by Section 3.3.
*/
trait Basic {
def accessToken: Secret[String]
def tokenType: String
}

implicit val decoder: Decoder[OAuth2TokenResponse] =
Decoder.forProduct5(
"access_token",
"scope",
"token_type",
"expires_in",
"refresh_token"
)(OAuth2TokenResponse.apply)

}

// @deprecated("This model will be removed in next release", "0.10.0")
case class ExtendedOAuth2TokenResponse(
accessToken: Secret[String],
refreshToken: String,
expiresIn: FiniteDuration,
Expand All @@ -16,12 +70,12 @@ case class Oauth2TokenResponse(
securityLevel: Long,
userId: String,
tokenType: String
)
) extends OAuth2TokenResponse.Basic

object Oauth2TokenResponse {
object ExtendedOAuth2TokenResponse {
import com.ocadotechnology.sttp.oauth2.circe._

implicit val decoder: Decoder[Oauth2TokenResponse] =
implicit val decoder: Decoder[ExtendedOAuth2TokenResponse] =
Decoder.forProduct11(
"access_token",
"refresh_token",
Expand All @@ -34,6 +88,6 @@ object Oauth2TokenResponse {
"security_level",
"user_id",
"token_type"
)(Oauth2TokenResponse.apply)
)(ExtendedOAuth2TokenResponse.apply)

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import sttp.model.Uri
import cats.syntax.all._

trait PasswordGrantProvider[F[_]] {
def requestToken(user: User, scope: Scope): F[Oauth2TokenResponse]
def requestToken(user: User, scope: Scope): F[ExtendedOAuth2TokenResponse]
}

object PasswordGrantProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ private[oauth2] final case class RefreshTokenResponse(
) {

def toOauth2Token(oldRefreshToken: String) =
Oauth2TokenResponse(
ExtendedOAuth2TokenResponse(
accessToken,
refreshToken.getOrElse(oldRefreshToken),
expiresIn,
Expand Down
Loading