diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala index 0ad88610c..0b3e63e9c 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala @@ -470,46 +470,41 @@ class PostgresDirectoryDAO(protected val writeDbRef: DbReference, protected val override def updateUserEmail(userId: WorkbenchUserId, email: WorkbenchEmail, samRequestContext: SamRequestContext): IO[Unit] = IO.unit override def updateUser(samUser: SamUser, userUpdate: AdminUpdateUserRequest, samRequestContext: SamRequestContext): IO[Option[SamUser]] = + // NOTE updating emails and 'enabled' status is currently not supported by this method serializableWriteTransaction("updateUser", samRequestContext) { implicit session => val u = UserTable.column - // NOTE updating emails and 'enabled' status is currently not supported by this method - val setColumnsClause = userUpdate match { - case AdminUpdateUserRequest(None, None) => throw new WorkbenchException("Cannot update user with no values.") - case AdminUpdateUserRequest(Some(newAzureB2CId), None) => - samsqls"(${u.azureB2cId}, ${u.updatedAt})" - case AdminUpdateUserRequest(None, Some(newGoogleSubjectId)) => - samsqls"(${u.googleSubjectId}, ${u.updatedAt})" - case AdminUpdateUserRequest(Some(newGoogleSubjectId), Some(newAzureB2CId)) => - samsqls"(${u.googleSubjectId}, ${u.azureB2cId}, ${u.updatedAt})" + if (userUpdate.googleSubjectId.isEmpty && userUpdate.azureB2CId.isEmpty) { + throw new WorkbenchException("Cannot update user with no values.") } - var updatedUser = - samUser.copy( - googleSubjectId = userUpdate.googleSubjectId.orElse(samUser.googleSubjectId), - azureB2CId = userUpdate.azureB2CId.orElse(samUser.azureB2CId), - updatedAt = Instant.now() - ) - val setColumnValuesClause = userUpdate match { - case AdminUpdateUserRequest(None, None) => throw new WorkbenchException("Cannot update user with no values.") - case AdminUpdateUserRequest(Some(AzureB2CId("null")), None) => - updatedUser = updatedUser.copy(azureB2CId = None) - // string interpolation adds quotes around null so we have to special case it here - samsqls"(null, ${Instant.now()})" - case AdminUpdateUserRequest(None, Some(GoogleSubjectId("null"))) => - updatedUser = updatedUser.copy(googleSubjectId = None) - // string interpolation adds quotes around null so we have to special case it here - samsqls"(null, ${Instant.now()})" - case AdminUpdateUserRequest(Some(newGoogleSubjectId), None) => - samsqls"($newGoogleSubjectId, ${Instant.now()})" - case AdminUpdateUserRequest(None, Some(newAzureB2CId)) => - samsqls"($newAzureB2CId, ${Instant.now()})" - case AdminUpdateUserRequest(Some(newGoogleSubjectId), Some(newAzureB2CId)) => - samsqls"($newGoogleSubjectId, $newAzureB2CId, ${Instant.now()})" + val (updateGoogleColumn, updateGoogleValue, returnGoogleValue) = userUpdate.googleSubjectId match { + case None => (None, None, samUser.googleSubjectId) + case Some(GoogleSubjectId("null")) => + (Some(samsqls"${u.googleSubjectId}"), Some(samsqls"null"), None) + case Some(newGoogleSubjectId: GoogleSubjectId) => + (Some(samsqls"${u.googleSubjectId}"), Some(samsqls"$newGoogleSubjectId"), Some(newGoogleSubjectId)) + } + + val (updateAzureB2CColumn, updateAzureB2CValue, returnAzureB2CValue) = userUpdate.azureB2CId match { + case None => (None, None, samUser.azureB2CId) + case Some(AzureB2CId("null")) => + (Some(samsqls"${u.azureB2cId}"), Some(samsqls"null"), None) + case Some(newAzureB2CId: AzureB2CId) => + (Some(samsqls"${u.azureB2cId}"), Some(samsqls"$newAzureB2CId"), Some(newAzureB2CId)) } + // This is a little hacky, but is needed because SQLSyntax's `flatten`, `substring`, and other string-manipulation + // methods transform the SQLSyntax into a String. Thankfully, since we always have an `updatedAt` value, + // we can use it as a base for foldLeft, and then concatenate the rest of the existing values to it + // within the `samsqls` interpolation, preserving the SQLSyntax functionality. + val updateColumns = List(updateGoogleColumn, updateAzureB2CColumn).flatten + .foldLeft(samsqls"${u.updatedAt}")((acc, col) => samsqls"$acc, $col") + val updateValues = List(updateGoogleValue, updateAzureB2CValue).flatten + .foldLeft(samsqls"${Instant.now()}")((acc, col) => samsqls"$acc, $col") + val results = samsql"""update ${UserTable.table} - set ${setColumnsClause} = ${setColumnValuesClause} + set ($updateColumns) = ($updateValues) where ${u.id} = ${samUser.id}""" .update() .apply() @@ -517,7 +512,13 @@ class PostgresDirectoryDAO(protected val writeDbRef: DbReference, protected val if (results != 1) { None } else { - Option(updatedUser) + Option( + samUser.copy( + googleSubjectId = returnGoogleValue, + azureB2CId = returnAzureB2CValue, + updatedAt = Instant.now() + ) + ) } } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala index 54eb15711..182f22dbc 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala @@ -1343,30 +1343,31 @@ class PostgresDirectoryDAOSpec extends RetryableAnyFreeSpec with Matchers with B "update the googleSubjectId for a user" in { assume(databaseEnabled, databaseEnabledClue) val newGoogleSubjectId = GoogleSubjectId("newGoogleSubjectId") - dao.createUser(defaultUser.copy(googleSubjectId = None), samRequestContext).unsafeRunSync() + val user = Generator.genWorkbenchUserAzure.sample.get + dao.createUser(user, samRequestContext).unsafeRunSync() - dao.loadUser(defaultUser.id, samRequestContext).unsafeRunSync().flatMap(_.googleSubjectId) shouldBe None - dao.updateUser(defaultUser, AdminUpdateUserRequest(None, Option(newGoogleSubjectId)), samRequestContext).unsafeRunSync() + dao.loadUser(user.id, samRequestContext).unsafeRunSync().flatMap(_.googleSubjectId) shouldBe None + dao.updateUser(user, AdminUpdateUserRequest(None, Option(newGoogleSubjectId)), samRequestContext).unsafeRunSync() - dao.loadUser(defaultUser.id, samRequestContext).unsafeRunSync().flatMap(_.googleSubjectId) shouldBe Option(newGoogleSubjectId) + dao.loadUser(user.id, samRequestContext).unsafeRunSync().flatMap(_.googleSubjectId) shouldBe Option(newGoogleSubjectId) } "update the azureB2CId for a user" in { assume(databaseEnabled, databaseEnabledClue) val newB2CId = AzureB2CId(UUID.randomUUID().toString) - dao.createUser(defaultUser.copy(azureB2CId = None), samRequestContext).unsafeRunSync() + val user = Generator.genWorkbenchUserGoogle.sample.get + dao.createUser(user, samRequestContext).unsafeRunSync() - dao.loadUser(defaultUser.id, samRequestContext).unsafeRunSync().flatMap(_.azureB2CId) shouldBe None - dao.updateUser(defaultUser, AdminUpdateUserRequest(Option(newB2CId), None), samRequestContext).unsafeRunSync() + dao.loadUser(user.id, samRequestContext).unsafeRunSync().flatMap(_.azureB2CId) shouldBe None + dao.updateUser(user, AdminUpdateUserRequest(Option(newB2CId), None), samRequestContext).unsafeRunSync() - dao.loadUser(defaultUser.id, samRequestContext).unsafeRunSync().flatMap(_.azureB2CId) shouldBe Option(newB2CId) + dao.loadUser(user.id, samRequestContext).unsafeRunSync().flatMap(_.azureB2CId) shouldBe Option(newB2CId) } "sets the updatedAt datetime to the current datetime" in { assume(databaseEnabled, databaseEnabledClue) // Arrange val user = Generator.genWorkbenchUserGoogle.sample.get.copy( - googleSubjectId = None, updatedAt = Instant.parse("2020-02-02T20:20:20Z") ) dao.createUser(user, samRequestContext).unsafeRunSync() @@ -1379,6 +1380,20 @@ class PostgresDirectoryDAOSpec extends RetryableAnyFreeSpec with Matchers with B val loadedUser = dao.loadUser(user.id, samRequestContext).unsafeRunSync() loadedUser.value.updatedAt should beAround(Instant.now()) } + + "will update the googleSubjectId and azureB2CId for a user" in { + assume(databaseEnabled, databaseEnabledClue) + val newGoogleSubjectId = GoogleSubjectId("234567890123456789012") + val newB2CId = AzureB2CId(UUID.randomUUID().toString) + val user = Generator.genWorkbenchUserBoth.sample.get + dao.createUser(user, samRequestContext).unsafeRunSync() + + dao.updateUser(user, AdminUpdateUserRequest(Option(newB2CId), Option(newGoogleSubjectId)), samRequestContext).unsafeRunSync() + + val updatedUser = dao.loadUser(user.id, samRequestContext).unsafeRunSync() + updatedUser.flatMap(_.googleSubjectId) shouldBe Option(newGoogleSubjectId) + updatedUser.flatMap(_.azureB2CId) shouldBe Option(newB2CId) + } } "setGoogleSubjectId" - {