From 4c5148c062942ed367d78866d4ab5ca1bf21cdff Mon Sep 17 00:00:00 2001 From: dvoet Date: Wed, 3 Jul 2024 10:02:17 -0400 Subject: [PATCH] PROD-972 Limit google group syncs (#1475) --- .../sam/dataAccess/AccessPolicyDAO.scala | 8 ++ .../dataAccess/PostgresAccessPolicyDAO.scala | 29 ++++++- .../sam/google/GoogleExtensions.scala | 20 +++-- .../dsde/workbench/sam/model/SamModel.scala | 4 +- .../sam/service/CloudExtensions.scala | 14 +++- .../sam/service/ResourceService.scala | 8 +- .../workbench/sam/service/UserService.scala | 2 +- .../sam/dataAccess/MockAccessPolicyDAO.scala | 1 + .../PostgresAccessPolicyDAOSpec.scala | 68 +++++++++++++++-- .../sam/google/GoogleExtensionSpec.scala | 12 +-- .../service/MockCloudExtensionsBuilder.scala | 4 +- .../sam/service/ResourceServiceSpec.scala | 75 ++++++++++++------- .../StatefulMockCloudExtensionsBuilder.scala | 4 +- .../sam/service/UserServiceSpec.scala | 4 +- 14 files changed, 194 insertions(+), 59 deletions(-) diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/AccessPolicyDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/AccessPolicyDAO.scala index e43b566bb..bb3397830 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/AccessPolicyDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/AccessPolicyDAO.scala @@ -31,8 +31,16 @@ trait AccessPolicyDAO { samRequestContext: SamRequestContext ): IO[Unit] + /** Lists policies on resources that are constrained by the given group. If relevantMembers is provided, only policies that contain any of the relevantMembers + * (directly or indirectly) will be returned. + * @param groupId + * the group to constrain by + * @param relevantMembers + * if provided, only policies that contain any of the relevantMembers (directly or indirectly) will be returned, if empty, all policies will be returned + */ def listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup( groupId: WorkbenchGroupIdentity, + relevantMembers: Set[WorkbenchSubject], samRequestContext: SamRequestContext ): IO[Set[FullyQualifiedPolicyId]] diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAO.scala index 74ec4df86..23aa83d48 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAO.scala @@ -641,6 +641,7 @@ class PostgresAccessPolicyDAO( override def listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup( groupId: WorkbenchGroupIdentity, + relevantMembers: Set[WorkbenchSubject], samRequestContext: SamRequestContext ): IO[Set[FullyQualifiedPolicyId]] = readOnlyTransaction("listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup", samRequestContext) { implicit session => @@ -667,14 +668,40 @@ class PostgresAccessPolicyDAO( and ${policy.resource.resourceTypeName} = ${rt.name}""" } + // if relevantMembers is empty, assume all members are relevant and don't join on GroupMemberFlatTable + // otherwise, only include policies where the member is in the relevantMembers set + // need to account for both member groups and users + val pu = GroupMemberFlatTable.syntax("pu") + val (relevantMembersJoin, relevantMembersCondition) = if (relevantMembers.isEmpty) { + (samsqls"", samsqls"") + } else { + val groupPKs = queryForGroupPKs(relevantMembers) + val groupCondition = if (groupPKs.isEmpty) { + samsqls"false" + } else { + samsqls"${pu.memberGroupId} in (${groupPKs})" + } + + val userIds = collectUserIds(relevantMembers) + val userCondition = if (userIds.isEmpty) { + samsqls"false" + } else { + samsqls"${pu.memberUserId} in (${userIds})" + } + + (samsqls"""join ${GroupMemberFlatTable as pu} on ${pu.groupId} = ${p.groupId}""", samsqls"""and (${userCondition} or ${groupCondition})""") + } + samsql""" select ${rt.result.name}, ${r.result.name}, ${p.result.name} from ${ResourceTable as r} join ${ResourceTypeTable as rt} on ${r.resourceTypeId} = ${rt.id} join ${PolicyTable as p} on ${r.id} = ${p.resourceId} join ${GroupTable as g} on ${p.groupId} = ${g.id} + ${relevantMembersJoin} where ${r.id} in (${constrainedResourcesPKs}) - and ${g.synchronizedDate} is not null""" + and ${g.synchronizedDate} is not null + ${relevantMembersCondition}""" .map(rs => FullyQualifiedPolicyId( FullyQualifiedResourceId(rs.get[ResourceTypeName](rt.resultName.name), rs.get[ResourceId](r.resultName.name)), diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensions.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensions.scala index 4113114a8..b3eda9b72 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensions.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensions.scala @@ -188,7 +188,7 @@ class GoogleExtensions( /* - managed groups and access policies are both "groups" - - You can have a bunch of resources constrained an auth domain (a collection of managed groups). + - You can have a bunch of resources constrained by an auth domain (a collection of managed groups). - A user must be a member of the auth domain in order to access some actions on the resources in that auth domain. - The user must be a member of all groups in an auth domain in order to access a resource - An access policy is specific to a single resource @@ -207,7 +207,11 @@ class GoogleExtensions( see GoogleGroupSynchronizer for the background process that does the group synchronization */ - override def onGroupUpdate(groupIdentities: Seq[WorkbenchGroupIdentity], samRequestContext: SamRequestContext): IO[Unit] = + override def onGroupUpdate( + groupIdentities: Seq[WorkbenchGroupIdentity], + relevantMembers: Set[WorkbenchSubject], + samRequestContext: SamRequestContext + ): IO[Unit] = for { start <- clock.monotonic // only sync groups that have been synchronized in the past @@ -219,14 +223,14 @@ class GoogleExtensions( messages <- previouslySyncedIds.traverse { // it is a group that isn't an access policy, could be a managed group case groupName: WorkbenchGroupName => - makeConstrainedResourceAccessPolicyMessages(groupName, samRequestContext).map(_ :+ groupName.toJson.compactPrint) + makeConstrainedResourceAccessPolicyMessages(groupName, relevantMembers, samRequestContext).map(_ :+ groupName.toJson.compactPrint) // it is the admin or member access policy of a managed group case accessPolicyId @ FullyQualifiedPolicyId( FullyQualifiedResourceId(ManagedGroupService.managedGroupTypeName, id), ManagedGroupService.adminPolicyName | ManagedGroupService.memberPolicyName ) => - makeConstrainedResourceAccessPolicyMessages(accessPolicyId, samRequestContext).map(_ :+ accessPolicyId.toJson.compactPrint) + makeConstrainedResourceAccessPolicyMessages(accessPolicyId, relevantMembers, samRequestContext).map(_ :+ accessPolicyId.toJson.compactPrint) // it is an access policy on a resource that's not a managed group case accessPolicyId: FullyQualifiedPolicyId => IO.pure(List(accessPolicyId.toJson.compactPrint)) @@ -245,7 +249,11 @@ class GoogleExtensions( ) } - private def makeConstrainedResourceAccessPolicyMessages(groupIdentity: WorkbenchGroupIdentity, samRequestContext: SamRequestContext): IO[List[String]] = + private def makeConstrainedResourceAccessPolicyMessages( + groupIdentity: WorkbenchGroupIdentity, + relevantMembers: Set[WorkbenchSubject], + samRequestContext: SamRequestContext + ): IO[List[String]] = // start with a group for { // get all the ancestors of that group @@ -262,7 +270,7 @@ class GoogleExtensions( // get all access policies on any resource that is constrained by the groups constrainedResourceAccessPolicyIds <- managedGroupIds.toList.traverse( - accessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(_, samRequestContext) + accessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(_, relevantMembers, samRequestContext) ) // return messages for all the affected access policies and the original group we started with diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/model/SamModel.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/model/SamModel.scala index de3449b87..b101ee2bd 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/model/SamModel.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/model/SamModel.scala @@ -192,7 +192,9 @@ object RolesAndActions { policyName: AccessPolicyName, resourceTypeName: ResourceTypeName, resourceId: ResourceId -) +) { + def toFullyQualifiedPolicyId: FullyQualifiedPolicyId = FullyQualifiedPolicyId(FullyQualifiedResourceId(resourceTypeName, resourceId), policyName) +} @Lenses case class AccessPolicyName(value: String) extends ValueObject @Lenses final case class CreateResourceRequest( diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/CloudExtensions.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/CloudExtensions.scala index 938224b50..ce96e83d8 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/CloudExtensions.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/CloudExtensions.scala @@ -31,7 +31,13 @@ trait CloudExtensions { def publishGroup(id: WorkbenchGroupName): Future[Unit] - def onGroupUpdate(groupIdentities: Seq[WorkbenchGroupIdentity], samRequestContext: SamRequestContext): IO[Unit] + /** This method is called when a group is updated. + * @param groupIdentities + * the identities of the groups that were updated + * @param relevantMembers + * the members of the groups that were added or removed or empty if the members are not known + */ + def onGroupUpdate(groupIdentities: Seq[WorkbenchGroupIdentity], relevantMembers: Set[WorkbenchSubject], samRequestContext: SamRequestContext): IO[Unit] def onGroupDelete(groupEmail: WorkbenchEmail): IO[Unit] @@ -74,7 +80,11 @@ trait NoExtensions extends CloudExtensions { override def publishGroup(id: WorkbenchGroupName): Future[Unit] = Future.successful(()) - override def onGroupUpdate(groupIdentities: Seq[WorkbenchGroupIdentity], samRequestContext: SamRequestContext): IO[Unit] = IO.unit + override def onGroupUpdate( + groupIdentities: Seq[WorkbenchGroupIdentity], + relevantMembers: Set[WorkbenchSubject], + samRequestContext: SamRequestContext + ): IO[Unit] = IO.unit override def onGroupDelete(groupEmail: WorkbenchEmail): IO[Unit] = IO.unit diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala index dafd42dbc..039f5732c 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala @@ -288,7 +288,7 @@ class ResourceService( policies <- listResourcePolicies(resource, samRequestContext) _ <- accessPolicyDAO.addResourceAuthDomain(resource, authDomains, samRequestContext) _ <- policies.traverse(p => directoryDAO.updateGroupUpdatedDateAndVersionWithSession(FullyQualifiedPolicyId(resource, p.policyName), samRequestContext)) - _ <- cloudExtensions.onGroupUpdate(policies.map(p => FullyQualifiedPolicyId(resource, p.policyName)), samRequestContext) + _ <- cloudExtensions.onGroupUpdate(policies.map(p => FullyQualifiedPolicyId(resource, p.policyName)), Set.empty, samRequestContext) authDomains <- loadResourceAuthDomain(resource, samRequestContext) } yield authDomains @@ -678,11 +678,13 @@ class ResourceService( private def onPolicyUpdate(policyId: FullyQualifiedPolicyId, originalPolicies: Iterable[AccessPolicy], samRequestContext: SamRequestContext): IO[Unit] = for { updatedPolicies <- accessPolicyDAO.listAccessPolicies(policyId.resource, samRequestContext) - changeEvents = createAccessChangeEvents(policyId.resource, originalPolicies, updatedPolicies) + removedMembers = originalPolicies.flatMap(_.members).toSet -- updatedPolicies.flatMap(_.members).toSet + addedMembers = updatedPolicies.flatMap(_.members).toSet -- originalPolicies.flatMap(_.members).toSet + changeEvents = createAccessChangeEvents(policyId.resource, originalPolicies, updatedPolicies) _ <- AuditLogger.logAuditEventIO(samRequestContext, changeEvents.toSeq: _*) - _ <- cloudExtensions.onGroupUpdate(Seq(policyId), samRequestContext).attempt.flatMap { + _ <- cloudExtensions.onGroupUpdate(Seq(policyId), removedMembers ++ addedMembers, samRequestContext).attempt.flatMap { case Left(regrets) => IO(logger.error(s"error calling cloudExtensions.onGroupUpdate for $policyId", regrets)) case Right(_) => IO.unit } diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala index 8490a5175..98d29c576 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala @@ -227,7 +227,7 @@ class UserService( for { _ <- updateInvitedUser(userToRegister, samRequestContext) groups <- directoryDAO.listUserDirectMemberships(userToRegister.id, samRequestContext) - _ <- cloudExtensions.onGroupUpdate(groups, samRequestContext) + _ <- cloudExtensions.onGroupUpdate(groups, Set(invitedUserId), samRequestContext) } yield userToRegister } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockAccessPolicyDAO.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockAccessPolicyDAO.scala index a3afb1b5d..41c1f7a38 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockAccessPolicyDAO.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockAccessPolicyDAO.scala @@ -95,6 +95,7 @@ class MockAccessPolicyDAO(private val resourceTypes: mutable.Map[ResourceTypeNam override def listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup( groupId: WorkbenchGroupIdentity, + relevantMembers: Set[WorkbenchSubject], samRequestContext: SamRequestContext ): IO[Set[FullyQualifiedPolicyId]] = IO { val groupName: WorkbenchGroupName = groupId match { diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAOSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAOSpec.scala index ee723aceb..178c99b22 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAOSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresAccessPolicyDAOSpec.scala @@ -668,6 +668,14 @@ class PostgresAccessPolicyDAOSpec extends AnyFreeSpec with Matchers with BeforeA "can find all synced policies for resources with the group in its auth domain" in { assume(databaseEnabled, databaseEnabledClue) + val user = Generator.genWorkbenchUserBoth.sample.get + dirDao.createUser(user, samRequestContext).unsafeRunSync() + val groupUser = Generator.genWorkbenchUserBoth.sample.get + dirDao.createUser(groupUser, samRequestContext).unsafeRunSync() + + val group = BasicWorkbenchGroup(Generator.genWorkbenchGroupName.sample.get, Set(groupUser.id), Generator.genNonPetEmail.sample.get) + dirDao.createGroup(group, samRequestContext = samRequestContext).unsafeRunSync() + dao.createResourceType(resourceType, samRequestContext).unsafeRunSync() val secondResourceType = resourceType.copy(name = ResourceTypeName("superAwesomeResourceType")) dao.createResourceType(secondResourceType, samRequestContext).unsafeRunSync() @@ -681,7 +689,7 @@ class PostgresAccessPolicyDAOSpec extends AnyFreeSpec with Matchers with BeforeA val resource2FullyQualifiedId = FullyQualifiedResourceId(secondResourceType.name, ResourceId("resource2")) val policy1 = AccessPolicy( FullyQualifiedPolicyId(resource1FullyQualifiedId, AccessPolicyName("policyName1")), - Set.empty, + Set(user.id), WorkbenchEmail("policy1@email.com"), resourceType.roles.map(_.roleName), Set(readAction, writeAction), @@ -690,7 +698,7 @@ class PostgresAccessPolicyDAOSpec extends AnyFreeSpec with Matchers with BeforeA ) val policy2 = AccessPolicy( FullyQualifiedPolicyId(resource1FullyQualifiedId, AccessPolicyName("policyName2")), - Set.empty, + Set(group.id, user.id), WorkbenchEmail("policy2@email.com"), resourceType.roles.map(_.roleName), Set(readAction, writeAction), @@ -699,25 +707,69 @@ class PostgresAccessPolicyDAOSpec extends AnyFreeSpec with Matchers with BeforeA ) val policy3 = AccessPolicy( FullyQualifiedPolicyId(resource2FullyQualifiedId, AccessPolicyName("policyName3")), - Set.empty, + Set(group.id), WorkbenchEmail("policy3@email.com"), secondResourceType.roles.map(_.roleName), Set(readAction, writeAction), Set.empty, false ) + val policy4 = AccessPolicy( + FullyQualifiedPolicyId(resource2FullyQualifiedId, AccessPolicyName("policyName4")), + Set.empty, + WorkbenchEmail("policy4@email.com"), + secondResourceType.roles.map(_.roleName), + Set(readAction, writeAction), + Set.empty, + false + ) val resource1 = Resource(resource1FullyQualifiedId.resourceTypeName, resource1FullyQualifiedId.resourceId, Set(sharedAuthDomain.id), Set(policy1, policy2)) val resource2 = - Resource(resource2FullyQualifiedId.resourceTypeName, resource2FullyQualifiedId.resourceId, Set(sharedAuthDomain.id, otherGroup.id), Set(policy3)) + Resource( + resource2FullyQualifiedId.resourceTypeName, + resource2FullyQualifiedId.resourceId, + Set(sharedAuthDomain.id, otherGroup.id), + Set(policy3, policy4) + ) dao.createResource(resource1, samRequestContext).unsafeRunSync() dao.createResource(resource2, samRequestContext).unsafeRunSync() dirDao.updateSynchronizedDateAndVersion(policy1, samRequestContext).unsafeRunSync() dirDao.updateSynchronizedDateAndVersion(policy3, samRequestContext).unsafeRunSync() + dirDao.updateSynchronizedDateAndVersion(policy4, samRequestContext).unsafeRunSync() - dao.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(sharedAuthDomain.id, samRequestContext).unsafeRunSync() should contain theSameElementsAs Set( + // finds all synced policies when no members specified + dao + .listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(sharedAuthDomain.id, Set.empty, samRequestContext) + .unsafeRunSync() should contain theSameElementsAs Set( policy1.id, + policy3.id, + policy4.id + ) + // finds only relevant synced policies when user and group specified + dao + .listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(sharedAuthDomain.id, Set(user.id, group.id), samRequestContext) + .unsafeRunSync() should contain theSameElementsAs Set( + policy1.id, + policy3.id + ) + // finds only relevant synced policies when user specified + dao + .listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(sharedAuthDomain.id, Set(user.id), samRequestContext) + .unsafeRunSync() should contain theSameElementsAs Set( + policy1.id + ) + // finds only relevant synced policies when group specified + dao + .listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(sharedAuthDomain.id, Set(group.id), samRequestContext) + .unsafeRunSync() should contain theSameElementsAs Set( + policy3.id + ) + // finds only relevant synced policies when user in group specified + dao + .listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(sharedAuthDomain.id, Set(groupUser.id), samRequestContext) + .unsafeRunSync() should contain theSameElementsAs Set( policy3.id ) } @@ -728,13 +780,15 @@ class PostgresAccessPolicyDAOSpec extends AnyFreeSpec with Matchers with BeforeA val group = BasicWorkbenchGroup(WorkbenchGroupName("boringGroup"), Set.empty, WorkbenchEmail("notAnAuthDomain@insecure.biz")) dirDao.createGroup(group, samRequestContext = samRequestContext).unsafeRunSync() - dao.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(group.id, samRequestContext).unsafeRunSync() shouldEqual Set.empty + dao.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(group.id, Set.empty, samRequestContext).unsafeRunSync() shouldEqual Set.empty } "returns an empty list if group doesn't exist" in { assume(databaseEnabled, databaseEnabledClue) - dao.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName("notEvenReal"), samRequestContext).unsafeRunSync() shouldEqual Set.empty + dao + .listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName("notEvenReal"), Set.empty, samRequestContext) + .unsafeRunSync() shouldEqual Set.empty } } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensionSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensionSpec.scala index 94130b099..9b53ae49f 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensionSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/google/GoogleExtensionSpec.scala @@ -1082,10 +1082,10 @@ class GoogleExtensionSpec(_system: ActorSystem) when(mockGoogleGroupSyncPubSubDAO.publishMessages(any[String], any[Seq[MessageRequest]])).thenReturn(Future.successful(())) // mock responses for onManagedGroupUpdate - when(mockAccessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName(managedGroupId), samRequestContext)) + when(mockAccessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName(managedGroupId), Set.empty, samRequestContext)) .thenReturn(IO.pure(Set(ownerRPN, readerRPN))) - runAndWait(googleExtensions.onGroupUpdate(Seq(managedGroupRPN), samRequestContext)) + runAndWait(googleExtensions.onGroupUpdate(Seq(managedGroupRPN), Set.empty, samRequestContext)) verify(mockGoogleGroupSyncPubSubDAO, times(1)).publishMessages(any[String], any[Seq[MessageRequest]]) } @@ -1136,10 +1136,10 @@ class GoogleExtensionSpec(_system: ActorSystem) .thenReturn(IO.pure(Set(managedGroupRPN).asInstanceOf[Set[WorkbenchGroupIdentity]])) // mock responses for onManagedGroupUpdate - when(mockAccessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName(managedGroupId), samRequestContext)) + when(mockAccessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName(managedGroupId), Set.empty, samRequestContext)) .thenReturn(IO.pure(Set(ownerRPN, readerRPN))) - runAndWait(googleExtensions.onGroupUpdate(Seq(WorkbenchGroupName(subGroupId)), samRequestContext)) + runAndWait(googleExtensions.onGroupUpdate(Seq(WorkbenchGroupName(subGroupId)), Set.empty, samRequestContext)) verify(mockGoogleGroupSyncPubSubDAO, times(1)).publishMessages(any[String], any[Seq[MessageRequest]]) } @@ -1190,10 +1190,10 @@ class GoogleExtensionSpec(_system: ActorSystem) .thenReturn(IO.pure(Set(managedGroupRPN).asInstanceOf[Set[WorkbenchGroupIdentity]])) // mock responses for onManagedGroupUpdate - when(mockAccessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName(managedGroupId), samRequestContext)) + when(mockAccessPolicyDAO.listSyncedAccessPolicyIdsOnResourcesConstrainedByGroup(WorkbenchGroupName(managedGroupId), Set.empty, samRequestContext)) .thenReturn(IO.pure(Set(ownerRPN, readerRPN))) - runAndWait(googleExtensions.onGroupUpdate(Seq(WorkbenchGroupName(subGroupId)), samRequestContext)) + runAndWait(googleExtensions.onGroupUpdate(Seq(WorkbenchGroupName(subGroupId)), Set.empty, samRequestContext)) verify(mockGoogleGroupSyncPubSubDAO, times(1)).publishMessages(any[String], any[Seq[MessageRequest]]) } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/MockCloudExtensionsBuilder.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/MockCloudExtensionsBuilder.scala index dbba80809..f8ca060d4 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/MockCloudExtensionsBuilder.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/MockCloudExtensionsBuilder.scala @@ -2,7 +2,7 @@ package org.broadinstitute.dsde.workbench.sam.service import cats.effect.IO import org.broadinstitute.dsde.workbench.model.google.GoogleProject -import org.broadinstitute.dsde.workbench.model.{WorkbenchEmail, WorkbenchGroup, WorkbenchGroupIdentity, WorkbenchUserId} +import org.broadinstitute.dsde.workbench.model.{WorkbenchEmail, WorkbenchGroup, WorkbenchGroupIdentity, WorkbenchSubject, WorkbenchUserId} import org.broadinstitute.dsde.workbench.sam.dataAccess.DirectoryDAO import org.broadinstitute.dsde.workbench.sam.model.api.SamUser import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext @@ -22,7 +22,7 @@ case class MockCloudExtensionsBuilder(allUsersGroup: WorkbenchGroup) extends Idi mockedCloudExtensions.getUserStatus(any[SamUser]) returns IO(false) mockedCloudExtensions.onUserEnable(any[SamUser], any[SamRequestContext]) returns IO.unit mockedCloudExtensions.onUserDisable(any[SamUser], any[SamRequestContext]) returns IO.unit - mockedCloudExtensions.onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[SamRequestContext]) returns IO.unit + mockedCloudExtensions.onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[Set[WorkbenchSubject]], any[SamRequestContext]) returns IO.unit mockedCloudExtensions.onUserDelete(any[WorkbenchUserId], any[SamRequestContext]) returns IO.unit mockedCloudExtensions.allSubSystems returns Set.empty mockedCloudExtensions.checkStatus returns Map.empty diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala index 8eadd0331..ef8f251ec 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala @@ -1264,19 +1264,27 @@ class ResourceServiceSpec val policyId = FullyQualifiedPolicyId(FullyQualifiedResourceId(defaultResourceType.name, ResourceId("testR")), AccessPolicyName("testA")) val accessPolicy = AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false) - // setup existing policy with no members - when(mockAccessPolicyDAO.listAccessPolicies(ArgumentMatchers.eq(policyId.resource), any[SamRequestContext])).thenReturn(IO.pure(LazyList(accessPolicy))) - - // function calls that should pass but what they return does not matter - when(mockAccessPolicyDAO.overwritePolicy(any[AccessPolicy], any[SamRequestContext])).thenReturn(IO.pure(accessPolicy)) - when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext])).thenReturn(IO.unit) - // overwrite policy with members in memberPolicy val memberPolicy = FullyQualifiedPolicyId(FullyQualifiedResourceId(defaultResourceType.name, ResourceId("testMemberR")), AccessPolicyName("testB")) val memberPolicyIdSet = Set( PolicyIdentifiers(memberPolicy.accessPolicyName, memberPolicy.resource.resourceTypeName, memberPolicy.resource.resourceId) ) + + val updatedPolicy = accessPolicy.copy(members = memberPolicyIdSet.map(_.toFullyQualifiedPolicyId)) + when(mockAccessPolicyDAO.listAccessPolicies(ArgumentMatchers.eq(policyId.resource), any[SamRequestContext])).thenReturn( + IO.pure(LazyList(accessPolicy)), // first call with empty membership + IO.pure(LazyList(updatedPolicy)) // second call with updated membership + ) + when(mockAccessPolicyDAO.overwritePolicy(any[AccessPolicy], any[SamRequestContext])).thenReturn(IO.pure(updatedPolicy)) + when( + mockCloudExtensions.onGroupUpdate( + ArgumentMatchers.eq(Seq(policyId)), + ArgumentMatchers.eq(memberPolicyIdSet.map(_.toFullyQualifiedPolicyId)), + any[SamRequestContext] + ) + ).thenReturn(IO.unit) + runAndWait( resourceService.overwritePolicy( defaultResourceType, @@ -1287,7 +1295,8 @@ class ResourceServiceSpec ) ) - verify(mockCloudExtensions, Mockito.timeout(500)).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.timeout(500)) + .onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(memberPolicyIdSet.map(_.toFullyQualifiedPolicyId)), any[SamRequestContext]) } it should "call CloudExtensions.onGroupUpdate when members change" in { @@ -1308,13 +1317,14 @@ class ResourceServiceSpec val member = WorkbenchUserId("testU") val accessPolicy = AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false) - // setup existing policy with a member + // setup existing policy with a member and second call without that member + val originalAccessPolicy = AccessPolicy.members.set(Set(member))(accessPolicy) when(mockAccessPolicyDAO.listAccessPolicies(ArgumentMatchers.eq(policyId.resource), any[SamRequestContext])) - .thenReturn(IO.pure(LazyList(AccessPolicy.members.set(Set(member))(accessPolicy)))) + .thenReturn(IO.pure(LazyList(originalAccessPolicy)), IO.pure(LazyList(accessPolicy))) // function calls that should pass but what they return does not matter when(mockAccessPolicyDAO.overwritePolicy(ArgumentMatchers.eq(accessPolicy), any[SamRequestContext])).thenReturn(IO.pure(accessPolicy)) - when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext])).thenReturn(IO.unit) + when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext])).thenReturn(IO.unit) // overwrite policy with no members runAndWait( @@ -1327,7 +1337,8 @@ class ResourceServiceSpec ) ) - verify(mockCloudExtensions, Mockito.timeout(500)).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.timeout(500)) + .onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext]) } it should "not call CloudExtensions.onGroupUpdate when members don't change" in { @@ -1361,7 +1372,7 @@ class ResourceServiceSpec ) ) - verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[Set[WorkbenchSubject]], any[SamRequestContext]) } "overwriteAdminPolicy" should "succeed with a valid request" in { @@ -1495,18 +1506,22 @@ class ResourceServiceSpec val policyId = FullyQualifiedPolicyId(FullyQualifiedResourceId(defaultResourceType.name, ResourceId("testR")), AccessPolicyName("testA")) val member = WorkbenchUserId("testU") - // setup existing policy with a member + // setup existing policy with a member and second call without that member when(mockAccessPolicyDAO.listAccessPolicies(ArgumentMatchers.eq(policyId.resource), any[SamRequestContext])) - .thenReturn(IO.pure(LazyList(AccessPolicy(policyId, Set(member), WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false)))) + .thenReturn( + IO.pure(LazyList(AccessPolicy(policyId, Set(member), WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false))), + IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false))) + ) // function calls that should pass but what they return does not matter when(mockAccessPolicyDAO.overwritePolicyMembers(ArgumentMatchers.eq(policyId), ArgumentMatchers.eq(Set.empty), any[SamRequestContext])).thenReturn(IO.unit) - when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext])).thenReturn(IO.unit) + when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext])).thenReturn(IO.unit) // overwrite policy members with empty set runAndWait(resourceService.overwritePolicyMembers(policyId, Set.empty, samRequestContext)) - verify(mockCloudExtensions, Mockito.timeout(500)).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.timeout(500)) + .onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext]) } it should "not call CloudExtensions.onGroupUpdate when members don't change" in { @@ -1532,7 +1547,7 @@ class ResourceServiceSpec // overwrite policy members with empty set runAndWait(resourceService.overwritePolicyMembers(policyId, Set.empty, samRequestContext)) - verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[Set[WorkbenchSubject]], any[SamRequestContext]) } it should "succeed with a regex action" in { @@ -2086,12 +2101,16 @@ class ResourceServiceSpec // return value true at the end indicates group changed when(mockDirectoryDAO.addGroupMember(ArgumentMatchers.eq(policyId), ArgumentMatchers.eq(member), any[SamRequestContext])).thenReturn(IO.pure(true)) - when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext])).thenReturn(IO.unit) + when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext])).thenReturn(IO.unit) when(mockAccessPolicyDAO.listAccessPolicies(ArgumentMatchers.eq(policyId.resource), any[SamRequestContext])) - .thenReturn(IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false)))) + .thenReturn( + IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false))), + IO.pure(LazyList(AccessPolicy(policyId, Set(member), WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false))) + ) runAndWait(resourceService.addSubjectToPolicy(policyId, member, samRequestContext)) - verify(mockCloudExtensions, Mockito.timeout(500)).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.timeout(500)) + .onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext]) } it should "not call CloudExtensions.onGroupUpdate when member added but is already there" in { @@ -2117,7 +2136,7 @@ class ResourceServiceSpec .thenReturn(IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false)))) runAndWait(resourceService.addSubjectToPolicy(policyId, member, samRequestContext)) - verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[Set[WorkbenchSubject]], any[SamRequestContext]) } "removeSubjectFromPolicy" should "call CloudExtensions.onGroupUpdate when member removed" in { @@ -2139,12 +2158,16 @@ class ResourceServiceSpec // return value true at the end indicates group changed when(mockDirectoryDAO.removeGroupMember(ArgumentMatchers.eq(policyId), ArgumentMatchers.eq(member), any[SamRequestContext])).thenReturn(IO.pure(true)) - when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext])).thenReturn(IO.unit) + when(mockCloudExtensions.onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext])).thenReturn(IO.unit) when(mockAccessPolicyDAO.listAccessPolicies(ArgumentMatchers.eq(policyId.resource), any[SamRequestContext])) - .thenReturn(IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false)))) + .thenReturn( + IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false))), + IO.pure(LazyList(AccessPolicy(policyId, Set(member), WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false))) + ) runAndWait(resourceService.removeSubjectFromPolicy(policyId, member, samRequestContext)) - verify(mockCloudExtensions, Mockito.timeout(1000)).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.timeout(1000)) + .onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), ArgumentMatchers.eq(Set(member)), any[SamRequestContext]) } it should "not call CloudExtensions.onGroupUpdate when member removed but wasn't there to start with" in { @@ -2170,7 +2193,7 @@ class ResourceServiceSpec .thenReturn(IO.pure(LazyList(AccessPolicy(policyId, Set.empty, WorkbenchEmail(""), Set.empty, Set.empty, Set.empty, false)))) runAndWait(resourceService.removeSubjectFromPolicy(policyId, member, samRequestContext)) - verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[SamRequestContext]) + verify(mockCloudExtensions, Mockito.after(500).never).onGroupUpdate(ArgumentMatchers.eq(Seq(policyId)), any[Set[WorkbenchSubject]], any[SamRequestContext]) } "initResourceTypes" should "do the happy path" in { diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/StatefulMockCloudExtensionsBuilder.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/StatefulMockCloudExtensionsBuilder.scala index 8cbe8d59b..1819cd459 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/StatefulMockCloudExtensionsBuilder.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/StatefulMockCloudExtensionsBuilder.scala @@ -2,7 +2,7 @@ package org.broadinstitute.dsde.workbench.sam.service import cats.effect.IO import cats.effect.unsafe.implicits.global -import org.broadinstitute.dsde.workbench.model.{WorkbenchGroup, WorkbenchGroupIdentity, WorkbenchUserId} +import org.broadinstitute.dsde.workbench.model.{WorkbenchGroup, WorkbenchGroupIdentity, WorkbenchSubject, WorkbenchUserId} import org.broadinstitute.dsde.workbench.sam.dataAccess.DirectoryDAO import org.broadinstitute.dsde.workbench.sam.model.api.SamUser import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext @@ -67,7 +67,7 @@ case class StatefulMockCloudExtensionsBuilder(directoryDAO: DirectoryDAO) extend lenient() .doReturn(IO.unit) .when(mockedCloudExtensions) - .onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[SamRequestContext]) + .onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[Set[WorkbenchSubject]], any[SamRequestContext]) lenient() .doReturn(IO.unit) diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala index d00f6491e..33f62462b 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala @@ -139,7 +139,7 @@ class OldUserServiceMockSpec(_system: ActorSystem) when(googleExtensions.getUserStatus(any[SamUser])).thenReturn(IO(true)) when(googleExtensions.onUserDisable(any[SamUser], any[SamRequestContext])).thenReturn(IO.unit) when(googleExtensions.onUserEnable(any[SamUser], any[SamRequestContext])).thenReturn(IO.unit) - when(googleExtensions.onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[SamRequestContext])).thenReturn(IO.unit) + when(googleExtensions.onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[Set[WorkbenchSubject]], any[SamRequestContext])).thenReturn(IO.unit) mockTosService = mock[TosService](RETURNS_SMART_NULLS) when(mockTosService.getTermsOfServiceComplianceStatus(any[SamUser], any[SamRequestContext])) @@ -324,7 +324,7 @@ class OldUserServiceSpec(_system: ActorSystem) when(googleExtensions.getUserStatus(any[SamUser])).thenReturn(IO.pure(true)) when(googleExtensions.onUserDisable(any[SamUser], any[SamRequestContext])).thenReturn(IO.unit) when(googleExtensions.onUserEnable(any[SamUser], any[SamRequestContext])).thenReturn(IO.unit) - when(googleExtensions.onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[SamRequestContext])).thenReturn(IO.unit) + when(googleExtensions.onGroupUpdate(any[Seq[WorkbenchGroupIdentity]], any[Set[WorkbenchSubject]], any[SamRequestContext])).thenReturn(IO.unit) tos = new TosService(googleExtensions, dirDAO, TestSupport.tosConfig) service = new UserService(dirDAO, googleExtensions, Seq(blockedDomain), tos)