-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#40 add oauth2-backend-cats module with SttpOauth2ClientCredentialsCa…
…tsBackend
- Loading branch information
Bartłomiej Wierciński
committed
Mar 9, 2021
1 parent
9c6a66d
commit c263a0c
Showing
5 changed files
with
311 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
oauth2-backend-cats/src/main/scala/com/ocadotechnology/sttp/oauth2/backend/Cache.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package com.ocadotechnology.sttp.oauth2.backend | ||
|
||
import cats.effect.Sync | ||
import cats.implicits._ | ||
import cats.effect.concurrent.Ref | ||
|
||
trait Cache[F[_], A] { | ||
def get: F[Option[A]] | ||
def set(a: A): F[Unit] | ||
} | ||
|
||
object Cache { | ||
|
||
def refCache[F[_]: Sync, A]: F[Cache[F, A]] = Ref[F].of(Option.empty[A]).map { ref => | ||
new Cache[F, A] { | ||
override def get: F[Option[A]] = ref.get | ||
override def set(a: A): F[Unit] = ref.set(Some(a)) | ||
} | ||
} | ||
|
||
} |
112 changes: 112 additions & 0 deletions
112
...cala/com/ocadotechnology/sttp/oauth2/backend/SttpOauth2ClientCredentialsCatsBackend.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
package com.ocadotechnology.sttp.oauth2.backend | ||
|
||
import cats.Monad | ||
import cats.data.OptionT | ||
import cats.effect.Clock | ||
import cats.effect.Concurrent | ||
import cats.effect.concurrent.Semaphore | ||
import cats.implicits._ | ||
import com.ocadotechnology.sttp.oauth2.ClientCredentialsProvider | ||
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse | ||
import com.ocadotechnology.sttp.oauth2.Secret | ||
import com.ocadotechnology.sttp.oauth2.backend.SttpOauth2ClientCredentialsCatsBackend.TokenWithExpiryInstant | ||
import com.ocadotechnology.sttp.oauth2.common.Scope | ||
import eu.timepit.refined.types.string.NonEmptyString | ||
import sttp.capabilities.Effect | ||
import sttp.client3.DelegateSttpBackend | ||
import sttp.client3.Request | ||
import sttp.client3.Response | ||
import sttp.client3.SttpBackend | ||
import sttp.model.Uri | ||
|
||
import java.time.Instant | ||
import scala.concurrent.duration.MILLISECONDS | ||
|
||
class SttpOauth2ClientCredentialsCatsBackend[F[_]: Monad: Clock, P] private ( | ||
delegate: SttpBackend[F, P], | ||
clientCredentialsProvider: ClientCredentialsProvider[F], | ||
cache: Cache[F, TokenWithExpiryInstant], | ||
semaphore: Semaphore[F], | ||
val scope: Scope | ||
) extends DelegateSttpBackend(delegate) { | ||
|
||
override def send[T, R >: P with Effect[F]](request: Request[T, R]): F[Response[T]] = for { | ||
token <- semaphore.withPermit(resolveToken) | ||
response <- delegate.send(request.auth.bearer(token.value)) | ||
} yield response | ||
|
||
private val resolveToken: F[Secret[String]] = | ||
OptionT(cache.get) | ||
.product(OptionT.liftF(getCurrentInstant)) | ||
.filter { case (TokenWithExpiryInstant(_, expiryInstant), currentInstant) => currentInstant isBefore expiryInstant } | ||
.map(_._1) | ||
.getOrElseF(requestAndSaveToken) | ||
.map(_.token) | ||
|
||
private def requestAndSaveToken: F[TokenWithExpiryInstant] = | ||
clientCredentialsProvider.requestToken(scope).flatMap(calculateExpiryTime).flatTap(cache.set) | ||
|
||
private def calculateExpiryTime(response: AccessTokenResponse): F[TokenWithExpiryInstant] = | ||
getCurrentInstant.map(_ plusMillis response.expiresIn.toMillis).map(TokenWithExpiryInstant(response.accessToken, _)) | ||
|
||
private def getCurrentInstant: F[Instant] = Clock[F].realTime(MILLISECONDS).map(Instant.ofEpochMilli) | ||
|
||
} | ||
|
||
object SttpOauth2ClientCredentialsCatsBackend { | ||
final case class TokenWithExpiryInstant(token: Secret[String], expiryInstant: Instant) | ||
|
||
def apply[F[_]: Concurrent: Clock, P]( | ||
tokenUrl: Uri, | ||
tokenIntrospectionUrl: Uri, | ||
clientId: NonEmptyString, | ||
clientSecret: Secret[String] | ||
)( | ||
scope: Scope | ||
)( | ||
implicit backend: SttpBackend[F, P] | ||
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] = { | ||
val clientCredentialsProvider = ClientCredentialsProvider.instance(tokenUrl, tokenIntrospectionUrl, clientId, clientSecret) | ||
usingClientCredentialsProvider(clientCredentialsProvider)(scope) | ||
} | ||
|
||
/** Keep in mind that the given implicit `backend` may be different than this one used in `clientCredentialsProvider` | ||
*/ | ||
def usingClientCredentialsProvider[F[_]: Concurrent: Clock, P]( | ||
clientCredentialsProvider: ClientCredentialsProvider[F] | ||
)( | ||
scope: Scope | ||
)( | ||
implicit backend: SttpBackend[F, P] | ||
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] = | ||
Cache.refCache[F, TokenWithExpiryInstant].flatMap(usingClientCredentialsProviderAndCache(clientCredentialsProvider, _)(scope)) | ||
|
||
def usingCache[F[_]: Concurrent: Clock, P]( | ||
cache: Cache[F, TokenWithExpiryInstant] | ||
)( | ||
tokenUrl: Uri, | ||
tokenIntrospectionUrl: Uri, | ||
clientId: NonEmptyString, | ||
clientSecret: Secret[String] | ||
)( | ||
scope: Scope | ||
)( | ||
implicit backend: SttpBackend[F, P] | ||
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] = { | ||
val clientCredentialsProvider = ClientCredentialsProvider.instance(tokenUrl, tokenIntrospectionUrl, clientId, clientSecret) | ||
usingClientCredentialsProviderAndCache(clientCredentialsProvider, cache)(scope) | ||
} | ||
|
||
/** Keep in mind that the given implicit `backend` may be different than this one used in `clientCredentialsProvider` | ||
*/ | ||
def usingClientCredentialsProviderAndCache[F[_]: Concurrent: Clock, P]( | ||
clientCredentialsProvider: ClientCredentialsProvider[F], | ||
cache: Cache[F, TokenWithExpiryInstant] | ||
)( | ||
scope: Scope | ||
)( | ||
implicit backend: SttpBackend[F, P] | ||
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] = | ||
Semaphore(n = 1).map(new SttpOauth2ClientCredentialsCatsBackend(backend, clientCredentialsProvider, cache, _, scope)) | ||
|
||
} |
162 changes: 162 additions & 0 deletions
162
.../com/ocadotechnology/sttp/oauth2/backend/SttpOauth2ClientCredentialsCatsBackendSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
package com.ocadotechnology.sttp.oauth2.backend | ||
|
||
import cats.effect.ContextShift | ||
import cats.effect.IO | ||
import cats.effect.Timer | ||
import cats.implicits._ | ||
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse | ||
import com.ocadotechnology.sttp.oauth2.Secret | ||
import com.ocadotechnology.sttp.oauth2.common.Scope | ||
import eu.timepit.refined.collection.NonEmpty | ||
import eu.timepit.refined.refineMV | ||
import eu.timepit.refined.types.string.NonEmptyString | ||
import org.scalatest.matchers.should.Matchers | ||
import org.scalatest.wordspec.AsyncWordSpec | ||
import sttp.client3._ | ||
import sttp.client3.asynchttpclient.cats.AsyncHttpClientCatsBackend | ||
import sttp.client3.testing.SttpBackendStub | ||
import sttp.client3.testing._ | ||
import sttp.model.HeaderNames.Authorization | ||
import sttp.model._ | ||
|
||
import scala.concurrent.ExecutionContext | ||
import scala.concurrent.duration._ | ||
|
||
class SttpOauth2ClientCredentialsCatsBackendSpec extends AsyncWordSpec with Matchers { | ||
implicit override val executionContext: ExecutionContext = ExecutionContext.global | ||
implicit val contextShift: ContextShift[IO] = IO.contextShift(executionContext) | ||
implicit val timer: Timer[IO] = IO.timer(executionContext) | ||
|
||
"SttpOauth2ClientCredentialsBackend" when { | ||
val tokenUrl: Uri = uri"https://authserver.org/oauth2/token" | ||
val clientId: NonEmptyString = refineMV[NonEmpty]("clientid") | ||
val clientSecret: Secret[String] = Secret("secret") | ||
val scope: Scope = refineMV("scope") | ||
|
||
val testAppUrl: Uri = uri"https://testapp.org/test" | ||
|
||
"TestApp is invoked once" should { | ||
"request a token. add the token to the TestApp request" in { | ||
val accessToken: Secret[String] = Secret("token") | ||
implicit val mockBackend: SttpBackendStub[IO, Any] = AsyncHttpClientCatsBackend | ||
.stub[IO] | ||
.whenTokenIsRequested() | ||
.thenRespond(Right(AccessTokenResponse(accessToken, "domain", 100.seconds, scope))) | ||
.whenTestAppIsRequestedWithToken(accessToken) | ||
.thenRespondOk() | ||
|
||
for { | ||
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope) | ||
response <- backend.send(basicRequest.get(testAppUrl).response(asStringAlways)) | ||
} yield response.code shouldBe StatusCode.Ok | ||
}.unsafeToFuture() | ||
} | ||
|
||
"TestApp is invoked twice sequentially" should { | ||
"first invocation is requesting a token, second invocation is getting the token from the cache. add the token to the both TestApp requests" in { | ||
val accessToken: Secret[String] = Secret("token") | ||
implicit val recordingMockBackend: RecordingSttpBackend[IO, Any] = new RecordingSttpBackend( | ||
AsyncHttpClientCatsBackend | ||
.stub[IO] | ||
.whenTokenIsRequested() | ||
.thenRespond(Right(AccessTokenResponse(accessToken, "domain", 100.seconds, scope))) | ||
.whenTestAppIsRequestedWithToken(accessToken) | ||
.thenRespondOk() | ||
) | ||
|
||
for { | ||
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope) | ||
invokeTestApp = backend.send(basicRequest.get(testAppUrl).response(asStringAlways)) | ||
response1 <- invokeTestApp | ||
response2 <- invokeTestApp | ||
} yield { | ||
response1.code shouldBe StatusCode.Ok | ||
response2.code shouldBe StatusCode.Ok | ||
recordingMockBackend.invocationCountByUri shouldBe Map(tokenUrl -> 1, testAppUrl -> 2) | ||
} | ||
}.unsafeToFuture() | ||
} | ||
|
||
"TestApp is invoked twice in parallel" should { | ||
"first invocation is requesting a token, second invocation is waiting for token response and getting the token from the cache. add the token to the both TestApp requests" in { | ||
val accessToken: Secret[String] = Secret("token") | ||
implicit val recordingMockBackend: RecordingSttpBackend[IO, Any] = new RecordingSttpBackend( | ||
AsyncHttpClientCatsBackend | ||
.stub[IO] | ||
.whenTokenIsRequested() | ||
.thenRespondF(IO.sleep(200.millis).as(Response.ok(Right(AccessTokenResponse(accessToken, "domain", 100.seconds, scope))))) | ||
.whenTestAppIsRequestedWithToken(accessToken) | ||
.thenRespondOk() | ||
) | ||
|
||
for { | ||
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope) | ||
invokeTestApp = backend.send(basicRequest.get(testAppUrl).response(asStringAlways)) | ||
(response1, response2) <- (invokeTestApp, invokeTestApp).parTupled | ||
} yield { | ||
response1.code shouldBe StatusCode.Ok | ||
response2.code shouldBe StatusCode.Ok | ||
recordingMockBackend.invocationCountByUri shouldBe Map(tokenUrl -> 1, testAppUrl -> 2) | ||
} | ||
}.unsafeToFuture() | ||
} | ||
|
||
"TestApp is invoked after token expires" should { | ||
"first invocation is requesting a token, second invocation is requesting a token, because the previous token is expired. add the token to the both TestApp requests" in { | ||
val accessToken1: Secret[String] = Secret("token1") | ||
val accessToken2: Secret[String] = Secret("token2") | ||
implicit val recordingMockBackend: RecordingSttpBackend[IO, Any] = new RecordingSttpBackend( | ||
AsyncHttpClientCatsBackend | ||
.stub[IO] | ||
.whenTokenIsRequested() | ||
.thenRespondCyclic( | ||
Right(AccessTokenResponse(accessToken1, "domain", 100.millis, scope)), | ||
Right(AccessTokenResponse(accessToken2, "domain", 100.millis, scope)) | ||
) | ||
.whenTestAppIsRequestedWithToken(accessToken1) | ||
.thenRespond("body1") | ||
.whenTestAppIsRequestedWithToken(accessToken2) | ||
.thenRespond("body2") | ||
) | ||
|
||
for { | ||
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope) | ||
invokeTestApp = backend.send(basicRequest.get(testAppUrl).response(asStringAlways)) | ||
response1 <- invokeTestApp | ||
_ <- IO.sleep(200.millis) | ||
response2 <- invokeTestApp | ||
} yield { | ||
response1.code shouldBe StatusCode.Ok | ||
response1.body shouldBe "body1" | ||
response2.code shouldBe StatusCode.Ok | ||
response2.body shouldBe "body2" | ||
recordingMockBackend.invocationCountByUri shouldBe Map(tokenUrl -> 2, testAppUrl -> 2) | ||
} | ||
}.unsafeToFuture() | ||
} | ||
|
||
implicit class SttpBackendStubOps[F[_], P](val backend: SttpBackendStub[F, P]) { | ||
import backend.WhenRequest | ||
|
||
def whenTokenIsRequested(): WhenRequest = backend.whenRequestMatches { request => | ||
request.method == Method.POST && | ||
request.uri == tokenUrl && | ||
request.forceBodyAsString == "grant_type=client_credentials&" + | ||
s"client_id=${clientId.value}&" + | ||
s"client_secret=${clientSecret.value}&" + | ||
s"scope=${scope.value}" | ||
} | ||
|
||
def whenTestAppIsRequestedWithToken(accessToken: Secret[String]): WhenRequest = backend.whenRequestMatches { request => | ||
request.method == Method.GET && | ||
request.uri == testAppUrl && | ||
request.headers.contains(Header(Authorization, s"Bearer ${accessToken.value}")) | ||
} | ||
} | ||
|
||
implicit class RecordingSttpBackendOps[F[_], P](backend: RecordingSttpBackend[F, P]) { | ||
def invocationCountByUri: Map[Uri, Int] = backend.allInteractions.groupBy(_._1.uri).fmap(_.size) | ||
} | ||
} | ||
|
||
} |