diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala index 4532e9ecd..5664bd166 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala @@ -76,8 +76,13 @@ trait DirectoryDAO { def acceptTermsOfService(userId: WorkbenchUserId, tosVersion: String, samRequestContext: SamRequestContext): IO[Boolean] def rejectTermsOfService(userId: WorkbenchUserId, tosVersion: String, samRequestContext: SamRequestContext): IO[Boolean] - def getUserTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[Option[SamUserTos]] - def getUserTermsOfServiceVersion(userId: WorkbenchUserId, tosVersion: Option[String], samRequestContext: SamRequestContext): IO[Option[SamUserTos]] + def getUserTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext, action: Option[String] = None): IO[Option[SamUserTos]] + def getUserTermsOfServiceVersion( + userId: WorkbenchUserId, + tosVersion: Option[String], + samRequestContext: SamRequestContext, + action: Option[String] = None + ): IO[Option[SamUserTos]] def getUserTermsOfServiceHistory(userId: WorkbenchUserId, samRequestContext: SamRequestContext, limit: Integer): IO[List[SamUserTos]] def createPetManagedIdentity(petManagedIdentity: PetManagedIdentity, samRequestContext: SamRequestContext): IO[PetManagedIdentity] def loadPetManagedIdentity(petManagedIdentityId: PetManagedIdentityId, samRequestContext: SamRequestContext): IO[Option[PetManagedIdentity]] diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala index fd9ea6d00..502cd6e6f 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala @@ -649,20 +649,32 @@ class PostgresDirectoryDAO(protected val writeDbRef: DbReference, protected val } // When no tosVersion is specified, return the latest TosRecord for the user - override def getUserTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[Option[SamUserTos]] = - getUserTermsOfServiceVersion(userId, None, samRequestContext) - - override def getUserTermsOfServiceVersion(userId: WorkbenchUserId, tosVersion: Option[String], samRequestContext: SamRequestContext): IO[Option[SamUserTos]] = + override def getUserTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext, action: Option[String] = None): IO[Option[SamUserTos]] = + getUserTermsOfServiceVersion(userId, None, samRequestContext, action) + + override def getUserTermsOfServiceVersion( + userId: WorkbenchUserId, + tosVersion: Option[String], + samRequestContext: SamRequestContext, + action: Option[String] = None + ): IO[Option[SamUserTos]] = readOnlyTransaction("getUserTermsOfService", samRequestContext) { implicit session => val tosTable = TosTable.syntax val column = TosTable.column val versionConstraint = if (tosVersion.isDefined) samsqls"and ${column.version} = ${tosVersion.get}" else samsqls"" + val actionConstraint = action match { + case Some(a) => samsqls"and ${column.action} = ${a}" + case None => samsqls"" + } + val loadUserTosQuery = samsql"""select ${tosTable.resultAll} from ${TosTable as tosTable} - where ${column.samUserId} = $userId $versionConstraint + where ${column.samUserId} = $userId + $versionConstraint + $actionConstraint order by ${column.createdAt} desc limit 1""" diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/TosService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/TosService.scala index b8fd53c8d..63171de0d 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/TosService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/TosService.scala @@ -107,7 +107,7 @@ class TosService( } private def ensureLatestTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[SamUserTos] = for { - maybeTermsOfServiceRecord <- directoryDao.getUserTermsOfService(userId, samRequestContext) + maybeTermsOfServiceRecord <- directoryDao.getUserTermsOfService(userId, samRequestContext, Option(TosTable.ACCEPT)) latestUserTermsOfService <- maybeTermsOfServiceRecord .map(IO.pure) .getOrElse( diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala index 82e951611..e94a28668 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala @@ -330,14 +330,22 @@ class MockDirectoryDAO(val groups: mutable.Map[WorkbenchGroupIdentity, Workbench true } - override def getUserTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[Option[SamUserTos]] = + override def getUserTermsOfService(userId: WorkbenchUserId, samRequestContext: SamRequestContext, action: Option[String]): IO[Option[SamUserTos]] = loadUser(userId, samRequestContext).map { case None => None case Some(_) => - userTermsOfService.get(userId) + if (action.isDefined) { + userTermsOfService.get(userId).filter(_.action == action.get) + } else + userTermsOfService.get(userId) } - override def getUserTermsOfServiceVersion(userId: WorkbenchUserId, tosVersion: Option[String], samRequestContext: SamRequestContext): IO[Option[SamUserTos]] = + override def getUserTermsOfServiceVersion( + userId: WorkbenchUserId, + tosVersion: Option[String], + samRequestContext: SamRequestContext, + action: Option[String] + ): IO[Option[SamUserTos]] = loadUser(userId, samRequestContext).map { case None => None case Some(_) => diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDaoBuilder.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDaoBuilder.scala index 6a622e5a3..ad8a09e40 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDaoBuilder.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDaoBuilder.scala @@ -116,8 +116,10 @@ case class MockDirectoryDaoBuilder() extends IdiomaticMockito { def withAcceptedTermsOfServiceForUser(samUser: SamUser, tosVersion: String): MockDirectoryDaoBuilder = { makeUserExist(samUser) val samUserTos = SamUserTos(samUser.id, tosVersion, TosTable.ACCEPT, Instant.now) - mockedDirectoryDAO.getUserTermsOfService(eqTo(samUser.id), any[SamRequestContext]) returns IO(Option(samUserTos)) - mockedDirectoryDAO.getUserTermsOfServiceVersion(eqTo(samUser.id), eqTo(Some(tosVersion)), any[SamRequestContext]) returns IO(Option(samUserTos)) + mockedDirectoryDAO.getUserTermsOfService(eqTo(samUser.id), any[SamRequestContext], any[Option[String]]) returns IO(Option(samUserTos)) + mockedDirectoryDAO.getUserTermsOfServiceVersion(eqTo(samUser.id), eqTo(Some(tosVersion)), any[SamRequestContext], any[Option[String]]) returns IO( + Option(samUserTos) + ) this } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala index 27624ef36..5c5933625 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala @@ -1664,6 +1664,18 @@ class PostgresDirectoryDAOSpec extends RetryableAnyFreeSpec with Matchers with B val userTos = dao.getUserTermsOfService(user.id, samRequestContext).unsafeRunSync() userTos should be(None) } + "returns acceptances" in { + assume(databaseEnabled, databaseEnabledClue) + val user = Generator.genWorkbenchUserGoogle.sample.get + dao.createUser(user, samRequestContext).unsafeRunSync() + + dao.acceptTermsOfService(user.id, tosConfig.version, samRequestContext).unsafeRunSync() shouldBe true + dao.rejectTermsOfService(user.id, tosConfig.version, samRequestContext).unsafeRunSync() shouldBe true + + // Assert + val userTos = dao.getUserTermsOfService(user.id, samRequestContext, action = Option(TosTable.ACCEPT)).unsafeRunSync() + userTos should be(Some(SamUserTos(user.id, tosConfig.version, TosTable.ACCEPT, Instant.now()))) + } } "checkStatus" - {