Skip to content

Commit

Permalink
#40 add oauth2-backend-cats module with SttpOauth2ClientCredentialsCa…
Browse files Browse the repository at this point in the history
…tsBackend
  • Loading branch information
Bartłomiej Wierciński committed Mar 9, 2021
1 parent 9c6a66d commit c263a0c
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- run: sbt ++${{ matrix.scala }} test mimaReportBinaryIssues

- name: Compress target directories
run: tar cf targets.tar target oauth2/target project/target
run: tar cf targets.tar target oauth2/target oauth2-backend-cats/target project/target

- name: Upload target directories
uses: actions/upload-artifact@v2
Expand Down
17 changes: 15 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ ThisBuild / githubWorkflowEnv ++= List("PGP_PASSPHRASE", "PGP_SECRET", "SONATYPE

val Versions = new {
val catsCore = "2.4.2"
val catsEffect = "2.3.1"
val circe = "0.13.0"
val kindProjector = "0.11.3"
val scalaTest = "3.2.5"
Expand All @@ -77,7 +78,8 @@ val commonDependencies = {
)

val plugins = Seq(
compilerPlugin("org.typelevel" % "kind-projector" % Versions.kindProjector cross CrossVersion.full)
compilerPlugin("org.typelevel" % "kind-projector" % Versions.kindProjector cross CrossVersion.full),
compilerPlugin("com.olegpy" %% "better-monadic-for" % "0.3.1")
)

val sttp = Seq(
Expand Down Expand Up @@ -111,10 +113,21 @@ lazy val oauth2 = project.settings(
mimaSettings
)

lazy val `oauth2-backend-cats` = project
.settings(
name := "sttp-oauth2-backend-cats",
libraryDependencies ++= oauth2Dependencies ++ Seq(
"org.typelevel" %% "cats-effect" % Versions.catsEffect,
"com.softwaremill.sttp.client3" %% "async-http-client-backend-cats" % Versions.sttp % Test
),
mimaSettings
)
.dependsOn(oauth2)

val root = project
.in(file("."))
.settings(
skip in publish := true,
mimaPreviousArtifacts := Set.empty
)
.aggregate(oauth2)
.aggregate(oauth2, `oauth2-backend-cats`)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.ocadotechnology.sttp.oauth2.backend

import cats.effect.Sync
import cats.implicits._
import cats.effect.concurrent.Ref

trait Cache[F[_], A] {
def get: F[Option[A]]
def set(a: A): F[Unit]
}

object Cache {

def refCache[F[_]: Sync, A]: F[Cache[F, A]] = Ref[F].of(Option.empty[A]).map { ref =>
new Cache[F, A] {
override def get: F[Option[A]] = ref.get
override def set(a: A): F[Unit] = ref.set(Some(a))
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.ocadotechnology.sttp.oauth2.backend

import cats.Monad
import cats.data.OptionT
import cats.effect.Clock
import cats.effect.Concurrent
import cats.effect.concurrent.Semaphore
import cats.implicits._
import com.ocadotechnology.sttp.oauth2.ClientCredentialsProvider
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse
import com.ocadotechnology.sttp.oauth2.Secret
import com.ocadotechnology.sttp.oauth2.backend.SttpOauth2ClientCredentialsCatsBackend.TokenWithExpiryInstant
import com.ocadotechnology.sttp.oauth2.common.Scope
import eu.timepit.refined.types.string.NonEmptyString
import sttp.capabilities.Effect
import sttp.client3.DelegateSttpBackend
import sttp.client3.Request
import sttp.client3.Response
import sttp.client3.SttpBackend
import sttp.model.Uri

import java.time.Instant
import scala.concurrent.duration.MILLISECONDS

class SttpOauth2ClientCredentialsCatsBackend[F[_]: Monad: Clock, P] private (
delegate: SttpBackend[F, P],
clientCredentialsProvider: ClientCredentialsProvider[F],
cache: Cache[F, TokenWithExpiryInstant],
semaphore: Semaphore[F],
val scope: Scope
) extends DelegateSttpBackend(delegate) {

override def send[T, R >: P with Effect[F]](request: Request[T, R]): F[Response[T]] = for {
token <- semaphore.withPermit(resolveToken)
response <- delegate.send(request.auth.bearer(token.value))
} yield response

private val resolveToken: F[Secret[String]] =
OptionT(cache.get)
.product(OptionT.liftF(getCurrentInstant))
.filter { case (TokenWithExpiryInstant(_, expiryInstant), currentInstant) => currentInstant isBefore expiryInstant }
.map(_._1)
.getOrElseF(requestAndSaveToken)
.map(_.token)

private def requestAndSaveToken: F[TokenWithExpiryInstant] =
clientCredentialsProvider.requestToken(scope).flatMap(calculateExpiryTime).flatTap(cache.set)

private def calculateExpiryTime(response: AccessTokenResponse): F[TokenWithExpiryInstant] =
getCurrentInstant.map(_ plusMillis response.expiresIn.toMillis).map(TokenWithExpiryInstant(response.accessToken, _))

private def getCurrentInstant: F[Instant] = Clock[F].realTime(MILLISECONDS).map(Instant.ofEpochMilli)

}

object SttpOauth2ClientCredentialsCatsBackend {
final case class TokenWithExpiryInstant(token: Secret[String], expiryInstant: Instant)

def apply[F[_]: Concurrent: Clock, P](
tokenUrl: Uri,
tokenIntrospectionUrl: Uri,
clientId: NonEmptyString,
clientSecret: Secret[String]
)(
scope: Scope
)(
implicit backend: SttpBackend[F, P]
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] = {
val clientCredentialsProvider = ClientCredentialsProvider.instance(tokenUrl, tokenIntrospectionUrl, clientId, clientSecret)
usingClientCredentialsProvider(clientCredentialsProvider)(scope)
}

/** Keep in mind that the given implicit `backend` may be different than this one used in `clientCredentialsProvider`
*/
def usingClientCredentialsProvider[F[_]: Concurrent: Clock, P](
clientCredentialsProvider: ClientCredentialsProvider[F]
)(
scope: Scope
)(
implicit backend: SttpBackend[F, P]
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] =
Cache.refCache[F, TokenWithExpiryInstant].flatMap(usingClientCredentialsProviderAndCache(clientCredentialsProvider, _)(scope))

def usingCache[F[_]: Concurrent: Clock, P](
cache: Cache[F, TokenWithExpiryInstant]
)(
tokenUrl: Uri,
tokenIntrospectionUrl: Uri,
clientId: NonEmptyString,
clientSecret: Secret[String]
)(
scope: Scope
)(
implicit backend: SttpBackend[F, P]
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] = {
val clientCredentialsProvider = ClientCredentialsProvider.instance(tokenUrl, tokenIntrospectionUrl, clientId, clientSecret)
usingClientCredentialsProviderAndCache(clientCredentialsProvider, cache)(scope)
}

/** Keep in mind that the given implicit `backend` may be different than this one used in `clientCredentialsProvider`
*/
def usingClientCredentialsProviderAndCache[F[_]: Concurrent: Clock, P](
clientCredentialsProvider: ClientCredentialsProvider[F],
cache: Cache[F, TokenWithExpiryInstant]
)(
scope: Scope
)(
implicit backend: SttpBackend[F, P]
): F[SttpOauth2ClientCredentialsCatsBackend[F, P]] =
Semaphore(n = 1).map(new SttpOauth2ClientCredentialsCatsBackend(backend, clientCredentialsProvider, cache, _, scope))

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package com.ocadotechnology.sttp.oauth2.backend

import cats.effect.ContextShift
import cats.effect.IO
import cats.effect.Timer
import cats.implicits._
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse
import com.ocadotechnology.sttp.oauth2.Secret
import com.ocadotechnology.sttp.oauth2.common.Scope
import eu.timepit.refined.collection.NonEmpty
import eu.timepit.refined.refineMV
import eu.timepit.refined.types.string.NonEmptyString
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import sttp.client3._
import sttp.client3.asynchttpclient.cats.AsyncHttpClientCatsBackend
import sttp.client3.testing.SttpBackendStub
import sttp.client3.testing._
import sttp.model.HeaderNames.Authorization
import sttp.model._

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._

class SttpOauth2ClientCredentialsCatsBackendSpec extends AsyncWordSpec with Matchers {
implicit override val executionContext: ExecutionContext = ExecutionContext.global
implicit val contextShift: ContextShift[IO] = IO.contextShift(executionContext)
implicit val timer: Timer[IO] = IO.timer(executionContext)

"SttpOauth2ClientCredentialsBackend" when {
val tokenUrl: Uri = uri"https://authserver.org/oauth2/token"
val clientId: NonEmptyString = refineMV[NonEmpty]("clientid")
val clientSecret: Secret[String] = Secret("secret")
val scope: Scope = refineMV("scope")

val testAppUrl: Uri = uri"https://testapp.org/test"

"TestApp is invoked once" should {
"request a token. add the token to the TestApp request" in {
val accessToken: Secret[String] = Secret("token")
implicit val mockBackend: SttpBackendStub[IO, Any] = AsyncHttpClientCatsBackend
.stub[IO]
.whenTokenIsRequested()
.thenRespond(Right(AccessTokenResponse(accessToken, "domain", 100.seconds, scope)))
.whenTestAppIsRequestedWithToken(accessToken)
.thenRespondOk()

for {
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope)
response <- backend.send(basicRequest.get(testAppUrl).response(asStringAlways))
} yield response.code shouldBe StatusCode.Ok
}.unsafeToFuture()
}

"TestApp is invoked twice sequentially" should {
"first invocation is requesting a token, second invocation is getting the token from the cache. add the token to the both TestApp requests" in {
val accessToken: Secret[String] = Secret("token")
implicit val recordingMockBackend: RecordingSttpBackend[IO, Any] = new RecordingSttpBackend(
AsyncHttpClientCatsBackend
.stub[IO]
.whenTokenIsRequested()
.thenRespond(Right(AccessTokenResponse(accessToken, "domain", 100.seconds, scope)))
.whenTestAppIsRequestedWithToken(accessToken)
.thenRespondOk()
)

for {
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope)
invokeTestApp = backend.send(basicRequest.get(testAppUrl).response(asStringAlways))
response1 <- invokeTestApp
response2 <- invokeTestApp
} yield {
response1.code shouldBe StatusCode.Ok
response2.code shouldBe StatusCode.Ok
recordingMockBackend.invocationCountByUri shouldBe Map(tokenUrl -> 1, testAppUrl -> 2)
}
}.unsafeToFuture()
}

"TestApp is invoked twice in parallel" should {
"first invocation is requesting a token, second invocation is waiting for token response and getting the token from the cache. add the token to the both TestApp requests" in {
val accessToken: Secret[String] = Secret("token")
implicit val recordingMockBackend: RecordingSttpBackend[IO, Any] = new RecordingSttpBackend(
AsyncHttpClientCatsBackend
.stub[IO]
.whenTokenIsRequested()
.thenRespondF(IO.sleep(200.millis).as(Response.ok(Right(AccessTokenResponse(accessToken, "domain", 100.seconds, scope)))))
.whenTestAppIsRequestedWithToken(accessToken)
.thenRespondOk()
)

for {
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope)
invokeTestApp = backend.send(basicRequest.get(testAppUrl).response(asStringAlways))
(response1, response2) <- (invokeTestApp, invokeTestApp).parTupled
} yield {
response1.code shouldBe StatusCode.Ok
response2.code shouldBe StatusCode.Ok
recordingMockBackend.invocationCountByUri shouldBe Map(tokenUrl -> 1, testAppUrl -> 2)
}
}.unsafeToFuture()
}

"TestApp is invoked after token expires" should {
"first invocation is requesting a token, second invocation is requesting a token, because the previous token is expired. add the token to the both TestApp requests" in {
val accessToken1: Secret[String] = Secret("token1")
val accessToken2: Secret[String] = Secret("token2")
implicit val recordingMockBackend: RecordingSttpBackend[IO, Any] = new RecordingSttpBackend(
AsyncHttpClientCatsBackend
.stub[IO]
.whenTokenIsRequested()
.thenRespondCyclic(
Right(AccessTokenResponse(accessToken1, "domain", 100.millis, scope)),
Right(AccessTokenResponse(accessToken2, "domain", 100.millis, scope))
)
.whenTestAppIsRequestedWithToken(accessToken1)
.thenRespond("body1")
.whenTestAppIsRequestedWithToken(accessToken2)
.thenRespond("body2")
)

for {
backend <- SttpOauth2ClientCredentialsCatsBackend[IO, Any](tokenUrl, uri"https://unused", clientId, clientSecret)(scope)
invokeTestApp = backend.send(basicRequest.get(testAppUrl).response(asStringAlways))
response1 <- invokeTestApp
_ <- IO.sleep(200.millis)
response2 <- invokeTestApp
} yield {
response1.code shouldBe StatusCode.Ok
response1.body shouldBe "body1"
response2.code shouldBe StatusCode.Ok
response2.body shouldBe "body2"
recordingMockBackend.invocationCountByUri shouldBe Map(tokenUrl -> 2, testAppUrl -> 2)
}
}.unsafeToFuture()
}

implicit class SttpBackendStubOps[F[_], P](val backend: SttpBackendStub[F, P]) {
import backend.WhenRequest

def whenTokenIsRequested(): WhenRequest = backend.whenRequestMatches { request =>
request.method == Method.POST &&
request.uri == tokenUrl &&
request.forceBodyAsString == "grant_type=client_credentials&" +
s"client_id=${clientId.value}&" +
s"client_secret=${clientSecret.value}&" +
s"scope=${scope.value}"
}

def whenTestAppIsRequestedWithToken(accessToken: Secret[String]): WhenRequest = backend.whenRequestMatches { request =>
request.method == Method.GET &&
request.uri == testAppUrl &&
request.headers.contains(Header(Authorization, s"Bearer ${accessToken.value}"))
}
}

implicit class RecordingSttpBackendOps[F[_], P](backend: RecordingSttpBackend[F, P]) {
def invocationCountByUri: Map[Uri, Int] = backend.allInteractions.groupBy(_._1.uri).fmap(_.size)
}
}

}

0 comments on commit c263a0c

Please sign in to comment.