Skip to content

Commit

Permalink
switch to object instead of trait
Browse files Browse the repository at this point in the history
  • Loading branch information
marctalbott committed Jan 13, 2025
1 parent 9a7e442 commit 814a783
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import org.broadinstitute.dsde.workbench.leonardo.{
}
import org.broadinstitute.dsde.workbench.model.{TraceId, UserInfo, WorkbenchEmail}

trait SamUtils[F[_]] {
val samService: SamService[F]

def checkRuntimeAction(userInfo: UserInfo,
cloudContext: CloudContext,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction,
userEmail: Option[WorkbenchEmail] = None
object SamUtils {
def checkRuntimeAction[F[_]](samService: SamService[F],
userInfo: UserInfo,
cloudContext: CloudContext,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction,
userEmail: Option[WorkbenchEmail] = None
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
checkActionInternal(
samService,
userInfo.accessToken,
userEmail.getOrElse(userInfo.userEmail),
samResourceId,
Expand All @@ -45,13 +45,15 @@ trait SamUtils[F[_]] {
RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database")
)

def checkRuntimeAction(userInfo: UserInfo,
workspaceId: WorkspaceId,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction
def checkRuntimeAction[F[_]](samService: SamService[F],
userInfo: UserInfo,
workspaceId: WorkspaceId,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
checkActionInternal(
samService,
userInfo.accessToken,
userInfo.userEmail,
samResourceId,
Expand All @@ -60,14 +62,16 @@ trait SamUtils[F[_]] {
RuntimeNotFoundByWorkspaceIdException(workspaceId, runtimeName, "Not found in database")
)

def checkDiskAction(userInfo: UserInfo,
cloudContext: CloudContext,
diskName: DiskName,
samResourceId: SamResourceId,
action: SamResourceAction,
traceId: TraceId
def checkDiskAction[F[_]](samService: SamService[F],
userInfo: UserInfo,
cloudContext: CloudContext,
diskName: DiskName,
samResourceId: SamResourceId,
action: SamResourceAction,
traceId: TraceId
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
checkActionInternal(
samService,
userInfo.accessToken,
userInfo.userEmail,
samResourceId,
Expand All @@ -76,13 +80,15 @@ trait SamUtils[F[_]] {
DiskNotFoundException(cloudContext, diskName, traceId)
)

def checkDiskAction(userInfo: UserInfo,
diskId: DiskId,
samResourceId: SamResourceId,
action: SamResourceAction,
traceId: TraceId
def checkDiskAction[F[_]](samService: SamService[F],
userInfo: UserInfo,
diskId: DiskId,
samResourceId: SamResourceId,
action: SamResourceAction,
traceId: TraceId
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
checkActionInternal(
samService,
userInfo.accessToken,
userInfo.userEmail,
samResourceId,
Expand All @@ -91,12 +97,13 @@ trait SamUtils[F[_]] {
DiskNotFoundByIdException(diskId, traceId)
)

private def checkActionInternal(userToken: OAuth2BearerToken,
userEmail: WorkbenchEmail,
samResourceId: SamResourceId,
actionToCheck: SamResourceAction,
resourceReadAction: SamResourceAction,
notFoundException: LeoException
private def checkActionInternal[F[_]](samService: SamService[F],
userToken: OAuth2BearerToken,
userEmail: WorkbenchEmail,
samResourceId: SamResourceId,
actionToCheck: SamResourceAction,
resourceReadAction: SamResourceAction,
notFoundException: LeoException
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
samService
.checkAuthorized(userToken.token, samResourceId, actionToCheck)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ class DiskServiceInterp[F[_]: Parallel](config: PersistentDiskConfig,
publisherQueue: Queue[F, LeoPubsubMessage],
googleDiskService: Option[GoogleDiskService[F]],
googleProjectDAO: Option[GoogleProjectDAO],
val samService: SamService[F]
samService: SamService[F]
)(implicit
F: Async[F],
dbReference: DbReference[F],
ec: ExecutionContext
) extends DiskService[F]
with SamUtils[F] {
) extends DiskService[F] {

override def createDisk(
userInfo: UserInfo,
Expand Down Expand Up @@ -178,12 +177,13 @@ class DiskServiceInterp[F[_]: Parallel](config: PersistentDiskConfig,
for {
ctx <- as.ask
resp <- DiskServiceDbQueries.getGetPersistentDiskResponse(cloudContext, diskName, ctx.traceId).transaction
_ <- checkDiskAction(userInfo,
cloudContext,
diskName,
resp.samResource,
PersistentDiskAction.ReadPersistentDisk,
ctx.traceId
_ <- SamUtils.checkDiskAction(samService,
userInfo,
cloudContext,
diskName,
resp.samResource,
PersistentDiskAction.ReadPersistentDisk,
ctx.traceId
)
} yield resp

Expand Down Expand Up @@ -233,12 +233,13 @@ class DiskServiceInterp[F[_]: Parallel](config: PersistentDiskConfig,
disk <- diskOpt.fold(F.raiseError[PersistentDisk](DiskNotFoundException(cloudContext, diskName, ctx.traceId)))(
F.pure
)
_ <- checkDiskAction(userInfo,
cloudContext,
diskName,
disk.samResource,
PersistentDiskAction.DeletePersistentDisk,
ctx.traceId
_ <- SamUtils.checkDiskAction(samService,
userInfo,
cloudContext,
diskName,
disk.samResource,
PersistentDiskAction.DeletePersistentDisk,
ctx.traceId
)
// throw 409 if the disk is not deletable
_ <-
Expand Down Expand Up @@ -297,12 +298,13 @@ class DiskServiceInterp[F[_]: Parallel](config: PersistentDiskConfig,
.getGetPersistentDiskResponse(cloudContext, disk.name, ctx.traceId)
.transaction

_ <- checkDiskAction(userInfo,
cloudContext,
dbdisk.name,
dbdisk.samResource,
PersistentDiskAction.DeletePersistentDisk,
ctx.traceId
_ <- SamUtils.checkDiskAction(samService,
userInfo,
cloudContext,
dbdisk.name,
dbdisk.samResource,
PersistentDiskAction.DeletePersistentDisk,
ctx.traceId
)

// Mark the resource as deleted in Leo's DB
Expand Down Expand Up @@ -339,12 +341,13 @@ class DiskServiceInterp[F[_]: Parallel](config: PersistentDiskConfig,
disk <- diskOpt.fold(F.raiseError[PersistentDisk](DiskNotFoundException(cloudContext, diskName, ctx.traceId)))(
F.pure
)
_ <- checkDiskAction(userInfo,
cloudContext,
diskName,
disk.samResource,
PersistentDiskAction.ModifyPersistentDisk,
ctx.traceId
_ <- SamUtils.checkDiskAction(samService,
userInfo,
cloudContext,
diskName,
disk.samResource,
PersistentDiskAction.ModifyPersistentDisk,
ctx.traceId
)
// throw 400 if UpdateDiskRequest new size is smaller than disk's current size
_ <-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ import scala.concurrent.ExecutionContext
class DiskV2ServiceInterp[F[_]: Parallel](
publisherQueue: Queue[F, LeoPubsubMessage],
wsmClientProvider: WsmApiClientProvider[F],
val samService: SamService[F]
samService: SamService[F]
)(implicit
F: Async[F],
dbReference: DbReference[F],
ec: ExecutionContext,
log: StructuredLogger[F]
) extends DiskV2Service[F]
with SamUtils[F] {
) extends DiskV2Service[F] {

// backwards compatible with v1 getDisk route
override def getDisk(userInfo: UserInfo, diskId: DiskId)(implicit
Expand All @@ -45,7 +44,13 @@ class DiskV2ServiceInterp[F[_]: Parallel](
_ <- F.fromOption(diskResp.workspaceId, DiskWithoutWorkspaceException(diskId, ctx.traceId))

// check that user has read action on disk
_ <- checkDiskAction(userInfo, diskId, diskResp.samResource, PersistentDiskAction.ReadPersistentDisk, ctx.traceId)
_ <- SamUtils.checkDiskAction(samService,
userInfo,
diskId,
diskResp.samResource,
PersistentDiskAction.ReadPersistentDisk,
ctx.traceId
)
} yield diskResp

override def deleteDisk(userInfo: UserInfo, diskId: DiskId)(implicit
Expand All @@ -57,7 +62,13 @@ class DiskV2ServiceInterp[F[_]: Parallel](

disk <- F.fromOption(diskOpt, DiskNotFoundByIdException(diskId, ctx.traceId))

_ <- checkDiskAction(userInfo, diskId, disk.samResource, PersistentDiskAction.DeletePersistentDisk, ctx.traceId)
_ <- SamUtils.checkDiskAction(samService,
userInfo,
diskId,
disk.samResource,
PersistentDiskAction.DeletePersistentDisk,
ctx.traceId
)
_ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done auth call for delete azure disk permission")))

// check that workspaceId is not null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,14 @@ class ProxyService(
samDAO: SamDAO[IO],
googleTokenCache: Cache[IO, String, (UserInfo, Instant)],
samResourceCache: Cache[IO, SamResourceCacheKey, (Option[String], Option[AppAccessScope])],
val samService: SamService[IO]
samService: SamService[IO]
)(implicit
val system: ActorSystem,
executionContext: ExecutionContext,
dbRef: DbReference[IO],
loggerIO: StructuredLogger[IO],
metrics: OpenTelemetryMetrics[IO]
) extends LazyLogging
with SamUtils[IO] {
) extends LazyLogging {
val httpsConnectionContext = ConnectionContext.httpsClient(sslContext)
val clientConnectionSettings =
ClientConnectionSettings(system).withTransport(ClientTransport.withCustomResolver(proxyResolver.resolveAkka))
Expand Down Expand Up @@ -271,7 +270,13 @@ class ProxyService(
ctx <- ev.ask[AppContext]

samResource <- getCachedRuntimeSamResource(RuntimeCacheKey(cloudContext, runtimeName))
_ <- checkRuntimeAction(userInfo, cloudContext, runtimeName, samResource, RuntimeAction.ConnectToRuntime)
_ <- SamUtils.checkRuntimeAction(samService,
userInfo,
cloudContext,
runtimeName,
samResource,
RuntimeAction.ConnectToRuntime
)

hostStatus <- getRuntimeTargetHost(cloudContext, runtimeName)
_ <- hostStatus match {
Expand Down
Loading

0 comments on commit 814a783

Please sign in to comment.