Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sttp-oauth2-cache-cats module with CE3 support #173

Merged
merged 3 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- run: sbt ++${{ matrix.scala }} test docs/mdoc mimaReportBinaryIssues

- name: Compress target directories
run: tar cf targets.tar oauth2-cache-ce2/target target mdoc/target oauth2-cache-future/target oauth2-cache/target oauth2/target project/target
run: tar cf targets.tar oauth2-cache-ce2/target target mdoc/target oauth2-cache-cats/target oauth2-cache-future/target oauth2-cache/target oauth2/target project/target

- name: Upload target directories
uses: actions/upload-artifact@v2
Expand Down
3 changes: 2 additions & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
version = "3.2.0"
runner.dialect = scala213
maxColumn = 140
align.preset = some
align.tokens.add = [
align.tokens."+" = [
{code = "<-", owner = Enumerator.Generator}
]
align.multiline = true
Expand Down
22 changes: 18 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ ThisBuild / githubWorkflowEnv ++= List("PGP_PASSPHRASE", "PGP_SECRET", "SONATYPE

val Versions = new {
val catsCore = "2.6.1"
val catsEffect = "2.3.3"
val catsEffect = "3.3.0"
val catsEffect2 = "2.3.3"
val circe = "0.14.1"
val kindProjector = "0.13.2"
val monix = "3.4.0"
Expand Down Expand Up @@ -117,12 +118,25 @@ lazy val `oauth2-cache` = project
)
.dependsOn(oauth2)

lazy val `oauth2-cache-cats` = project
.settings(
name := "sttp-oauth2-cache-cats",
libraryDependencies ++= Seq(
"org.typelevel" %% "cats-effect-kernel" % Versions.catsEffect,
"org.typelevel" %% "cats-effect-std" % Versions.catsEffect,
"org.typelevel" %% "cats-effect" % Versions.catsEffect % Test,
"org.typelevel" %% "cats-effect-testkit" % Versions.catsEffect % Test
) ++ plugins ++ testDependencies,
mimaPreviousArtifacts := Set.empty
)
.dependsOn(`oauth2-cache`)

lazy val `oauth2-cache-ce2` = project
.settings(
name := "sttp-oauth2-cache-ce2",
libraryDependencies ++= Seq(
"org.typelevel" %% "cats-effect" % Versions.catsEffect,
"org.typelevel" %% "cats-effect-laws" % Versions.catsEffect % Test
"org.typelevel" %% "cats-effect" % Versions.catsEffect2,
"org.typelevel" %% "cats-effect-laws" % Versions.catsEffect2 % Test
) ++ plugins ++ testDependencies,
mimaSettings
)
Expand All @@ -145,4 +159,4 @@ val root = project
mimaPreviousArtifacts := Set.empty
)
// after adding a module remember to regenerate ci.yml using `sbt githubWorkflowGenerate`
.aggregate(oauth2, `oauth2-cache`, `oauth2-cache-ce2`, `oauth2-cache-future`)
.aggregate(oauth2, `oauth2-cache`, `oauth2-cache-cats`, `oauth2-cache-ce2`, `oauth2-cache-future`)
9 changes: 5 additions & 4 deletions docs/client-credentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ Caching modules provide cached `AccessTokenProvider`, which can:
- fetch a new token if the previous one expires


| module name | class name | default cache implementation | semaphore | notes |
|------------------------------|--------------------------------------------|---------------------------------|--------------------------------------|-------------------------------------------------|
| `sttp-oauth2-cache-ce2` | `SttpOauth2ClientCredentialsCatsBackend` | `cats-effect2`'s `Ref` | `cats-effect2`'s `Semaphore` | |
| `sttp-oauth2-cache-future` | `SttpOauth2ClientCredentialsFutureBackend` | `monix-execution`'s `AtomicAny` | `monix-execution`'s `AsyncSemaphore` | It only uses submodule of whole `monix` project |
| module name | class name | provided cache implementation | semaphore | notes |
|----------------------------|------------------------------------|---------------------------------|--------------------------------------|-------------------------------------------------|
| `sttp-oauth2-cache-cats` | `CachingAccessTokenProvider` | `cats-effect3`'s `Ref` | `cats-effect2`'s `Semaphore` | |
| `sttp-oauth2-cache-ce2` | `CachingAccessTokenProvider` | `cats-effect2`'s `Ref` | `cats-effect2`'s `Semaphore` | |
| `sttp-oauth2-cache-future` | `FutureCachingAccessTokenProvider` | `monix-execution`'s `AtomicAny` | `monix-execution`'s `AsyncSemaphore` | It only uses submodule of whole `monix` project |

### Cats example

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.ocadotechnology.sttp.oauth2.cache.cats

import cats.data.OptionT
import cats.effect.kernel.Clock
import cats.effect.kernel.Concurrent
import cats.effect.kernel.MonadCancelThrow
import cats.effect.std.Semaphore
import cats.syntax.all._
import com.ocadotechnology.sttp.oauth2.AccessTokenProvider
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken
import com.ocadotechnology.sttp.oauth2.Secret
import com.ocadotechnology.sttp.oauth2.cache.ExpiringCache
import com.ocadotechnology.sttp.oauth2.cache.cats.CachingAccessTokenProvider.TokenWithExpirationTime
import com.ocadotechnology.sttp.oauth2.common.Scope

import java.time.Instant
import scala.concurrent.duration.Duration

final class CachingAccessTokenProvider[F[_]: MonadCancelThrow: Clock](
delegate: AccessTokenProvider[F],
semaphore: Semaphore[F],
tokenCache: ExpiringCache[F, Scope, TokenWithExpirationTime]
) extends AccessTokenProvider[F] {

override def requestToken(scope: Scope): F[ClientCredentialsToken.AccessTokenResponse] =
getFromCache(scope)
.getOrElseF(semaphore.permit.surround(acquireToken(scope))) // semaphore prevents concurrent token fetch from external service

private def acquireToken(scope: Scope) =
getFromCache(scope) // duplicate cache check, to verify if any other thread filled the cache during wait for semaphore permit
.getOrElseF(fetchAndSaveToken(scope))

private def getFromCache(scope: Scope) =
(OptionT(tokenCache.get(scope)), OptionT.liftF(Clock[F].realTimeInstant))
.mapN(_.toAccessTokenResponse(_))

private def fetchAndSaveToken(scope: Scope) =
for {
token <- delegate.requestToken(scope)
tokenWithExpiry <- calculateExpiryInstant(token)
_ <- tokenCache.put(scope, tokenWithExpiry, tokenWithExpiry.expirationTime)
} yield token

private def calculateExpiryInstant(response: ClientCredentialsToken.AccessTokenResponse): F[TokenWithExpirationTime] =
Clock[F].realTimeInstant.map(TokenWithExpirationTime.from(response, _))

}

object CachingAccessTokenProvider {

def apply[F[_]: Concurrent: Clock](
delegate: AccessTokenProvider[F],
tokenCache: ExpiringCache[F, Scope, TokenWithExpirationTime]
): F[CachingAccessTokenProvider[F]] = Semaphore[F](n = 1).map(new CachingAccessTokenProvider[F](delegate, _, tokenCache))

def refCacheInstance[F[_]: Concurrent: Clock](delegate: AccessTokenProvider[F]): F[CachingAccessTokenProvider[F]] =
CatsRefExpiringCache[F, Scope, TokenWithExpirationTime].flatMap(CachingAccessTokenProvider(delegate, _))

final case class TokenWithExpirationTime(
accessToken: Secret[String],
domain: Option[String],
expirationTime: Instant,
scope: Scope
) {

def toAccessTokenResponse(now: Instant): ClientCredentialsToken.AccessTokenResponse = {
val newExpiresIn = Duration.fromNanos(java.time.Duration.between(now, expirationTime).toNanos)
ClientCredentialsToken.AccessTokenResponse(accessToken, domain, newExpiresIn, scope)
}

}

object TokenWithExpirationTime {

def from(token: ClientCredentialsToken.AccessTokenResponse, now: Instant): TokenWithExpirationTime = {
val expirationTime = now.plusNanos(token.expiresIn.toNanos)
TokenWithExpirationTime(token.accessToken, token.domain, expirationTime, token.scope)
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.ocadotechnology.sttp.oauth2.cache.cats

import cats.Monad
import cats.data.OptionT
import cats.effect.kernel.Clock
import cats.effect.kernel.Ref
import cats.implicits._
import com.ocadotechnology.sttp.oauth2.cache.ExpiringCache
import com.ocadotechnology.sttp.oauth2.cache.cats.CatsRefExpiringCache.Entry

import java.time.Instant

final class CatsRefExpiringCache[F[_]: Monad: Clock, K, V] private[cats] (ref: Ref[F, Map[K, Entry[V]]]) extends ExpiringCache[F, K, V] {

override def get(key: K): F[Option[V]] =
OptionT(ref.get.map(_.get(key)))
.product(OptionT.liftF(Clock[F].realTimeInstant))
.flatMapF { case (Entry(value, expiryInstant), now) =>
if (now.isBefore(expiryInstant))
value.some.pure[F]
else
remove(key) *> none[V].pure[F] // cleaning up to save memory
}
.value

override def put(key: K, value: V, expirationTime: Instant): F[Unit] = ref.update(_ + (key -> Entry(value, expirationTime)))

override def remove(key: K): F[Unit] = ref.update(_ - key)
}

object CatsRefExpiringCache {
private[cats] final case class Entry[V](value: V, expirationTime: Instant)

def apply[F[_]: Ref.Make: Monad: Clock, K, V]: F[ExpiringCache[F, K, V]] =
Ref[F].of(Map.empty[K, Entry[V]]).map(new CatsRefExpiringCache(_))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package com.ocadotechnology.sttp.oauth2.cache.cats

import cats.effect.IO
import cats.effect.Ref
import cats.effect.Temporal
import cats.effect.unsafe.implicits.global
import cats.syntax.all._
import com.ocadotechnology.sttp.oauth2.AccessTokenProvider
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse
import com.ocadotechnology.sttp.oauth2.Secret
import com.ocadotechnology.sttp.oauth2.cache.ExpiringCache
import com.ocadotechnology.sttp.oauth2.cache.cats.CachingAccessTokenProvider.TokenWithExpirationTime
import com.ocadotechnology.sttp.oauth2.common.Scope
import eu.timepit.refined.auto._
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import java.time.Instant
import scala.concurrent.duration._

class CachingAccessTokenProviderParallelSpec extends AnyWordSpec with Matchers {
private val testScope: Scope = "test-scope"
private val token = AccessTokenResponse(Secret("secret"), None, 10.seconds, testScope)

private val sleepDuration: FiniteDuration = 1.second

"CachingAccessTokenProvider" should {
"block multiple parallel" in runTest { case (delegate, cachingProvider) =>
for {
_ <- delegate.setToken(testScope, token)
(result1, result2) <- (cachingProvider.requestToken(testScope), cachingProvider.requestToken(testScope)).parTupled
} yield {
result1 shouldBe token.copy(expiresIn = result1.expiresIn)
result2 shouldBe token.copy(expiresIn = result2.expiresIn)
// if both calls would be made in parallel, both would get the same expiresIn from TestAccessTokenProvider.
// When blocking is in place, the second call would be delayed by sleepDuration and would hit the cache,
// which has Instant on top of which new expiresIn would be calculated
diffInExpirations(result1, result2) shouldBe >=(sleepDuration)
}
}

"not block multiple parallel access if its already in cache" in runTest { case (delegate, cachingProvider) =>
for {
_ <- delegate.setToken(testScope, token)
_ <- cachingProvider.requestToken(testScope)
(result1, result2) <- (cachingProvider.requestToken(testScope), cachingProvider.requestToken(testScope)).parTupled
} yield {
result1 shouldBe token.copy(expiresIn = result1.expiresIn)
result2 shouldBe token.copy(expiresIn = result2.expiresIn)
// second call should not be forced to wait sleepDuration, because some active token is already in cache
diffInExpirations(result1, result2) shouldBe <(sleepDuration)
}
}
}

private def diffInExpirations(result1: AccessTokenResponse, result2: AccessTokenResponse): FiniteDuration =
if (result1.expiresIn > result2.expiresIn) result1.expiresIn - result2.expiresIn else result2.expiresIn - result1.expiresIn

def runTest(test: ((TestAccessTokenProvider[IO], AccessTokenProvider[IO])) => IO[Assertion]): Assertion =
prepareTest.flatMap(test).unsafeRunSync()

class DelayingCache[F[_]: Temporal, K, V](delegate: ExpiringCache[F, K, V]) extends ExpiringCache[F, K, V] {
override def get(key: K): F[Option[V]] = delegate.get(key)

override def put(key: K, value: V, expirationTime: Instant): F[Unit] =
Temporal[F].sleep(sleepDuration) *> delegate.put(key, value, expirationTime)

override def remove(key: K): F[Unit] = delegate.remove(key)
}

private def prepareTest: IO[(TestAccessTokenProvider[IO], CachingAccessTokenProvider[IO])] =
for {
state <- Ref.of[IO, TestAccessTokenProvider.State](TestAccessTokenProvider.State.empty)
delegate = TestAccessTokenProvider[IO](state)
cache <- CatsRefExpiringCache[IO, Scope, TokenWithExpirationTime]
delayingCache = new DelayingCache(cache)
cachingProvider <- CachingAccessTokenProvider[IO](delegate, delayingCache)
} yield (delegate, cachingProvider)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.ocadotechnology.sttp.oauth2.cache.cats

import cats.effect.IO
import cats.effect.Ref
import cats.effect.kernel.Outcome.Succeeded
import cats.effect.testkit.TestContext
import cats.effect.testkit.TestInstances
import com.ocadotechnology.sttp.oauth2.AccessTokenProvider
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse
import com.ocadotechnology.sttp.oauth2.Secret
import com.ocadotechnology.sttp.oauth2.cache.cats.CachingAccessTokenProvider.TokenWithExpirationTime
import com.ocadotechnology.sttp.oauth2.common.Scope
import eu.timepit.refined.auto._
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.concurrent.duration._

class CachingAccessTokenProviderSpec extends AnyWordSpec with Matchers with TestInstances {
private implicit val ticker: Ticker = Ticker(TestContext())

private val testScope: Scope = "test-scope"
private val token = AccessTokenResponse(Secret("secret"), None, 10.seconds, testScope)
private val newToken = AccessTokenResponse(Secret("secret2"), None, 20.seconds, testScope)

"CachingAccessTokenProvider" should {
"delegate token retrieval on first call" in runTest { case (delegate, cachingProvider) =>
for {
_ <- delegate.setToken(testScope, token)
result <- cachingProvider.requestToken(testScope)
} yield result shouldBe token
}

"decrease expiresIn in second read" in runTest { case (delegate, cachingProvider) =>
for {
_ <- delegate.setToken(testScope, token)
_ <- cachingProvider.requestToken(testScope)
_ <- IO.sleep(3.seconds)
result <- cachingProvider.requestToken(testScope)
} yield result shouldBe token.copy(expiresIn = 7.seconds)
}

"not refresh token before expiration" in runTest { case (delegate, cachingProvider) =>
for {
_ <- delegate.setToken(testScope, token)
_ <- cachingProvider.requestToken(testScope)
_ <- delegate.setToken(testScope, newToken)
_ <- IO.sleep(10.seconds - 1.milli)
result <- cachingProvider.requestToken(testScope)
} yield result shouldBe token.copy(expiresIn = 1.milli)
}

"ask for token again after expiration" in runTest { case (delegate, cachingProvider) =>
for {
_ <- delegate.setToken(testScope, token)
_ <- cachingProvider.requestToken(testScope)
_ <- delegate.setToken(testScope, newToken)
_ <- IO.sleep(11.seconds)
result <- cachingProvider.requestToken(testScope)
} yield result shouldBe newToken
}

}

def runTest(test: ((TestAccessTokenProvider[IO], AccessTokenProvider[IO])) => IO[Assertion]): Assertion =
unsafeRun(prepareTest.flatMap(test)) match {
case Succeeded(Some(assertion)) => assertion
case wrongResult => fail(s"Test should finish successfully. Instead ended with $wrongResult")
}

private def prepareTest: IO[(TestAccessTokenProvider[IO], CachingAccessTokenProvider[IO])] =
for {
state <- Ref.of[IO, TestAccessTokenProvider.State](TestAccessTokenProvider.State.empty)
delegate = TestAccessTokenProvider[IO](state)
cache <- CatsRefExpiringCache[IO, Scope, TokenWithExpirationTime]
cachingProvider <- CachingAccessTokenProvider[IO](delegate, cache)
} yield (delegate, cachingProvider)

}
Loading