diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml index 42b9e5363..8a91e5cdf 100644 --- a/.github/workflows/tag.yml +++ b/.github/workflows/tag.yml @@ -13,11 +13,6 @@ on: default: false required: false type: string - print-tag: - description: "Echo generated tag to console" - default: "true" - required: false - type: string release-branches: description: "Default branch (main, develop, etc)" default: 'main' @@ -30,6 +25,9 @@ on: new-tag: description: "The value of the newly created tag" value: ${{ jobs.tag-job.outputs.new-tag }} + app-version: + description: "The app version" + value: ${{ jobs.tag-job.outputs.app-version }} secrets: BROADBOT_TOKEN: required: true @@ -44,6 +42,7 @@ jobs: outputs: tag: ${{ steps.tag.outputs.tag }} new-tag: ${{ steps.tag.outputs.new_tag }} + app-version: ${{ steps.output-version.outputs.app-version }} steps: - name: Checkout current code uses: actions/checkout@v3 @@ -60,7 +59,14 @@ jobs: DRY_RUN: ${{ inputs.dry-run }} RELEASE_BRANCHES: ${{ inputs.release-branches }} WITH_V: true - - name: Echo generated tag to console - if: ${{ inputs.print-tag == 'true' }} + - name: Output app version + id: output-version run: | - echo "Newly created version tag: '${{ steps.tag.outputs.new_tag }}'" + # See https://broadworkbench.atlassian.net/browse/QA-2282 for context + if [[ -z "${{ steps.tag.outputs.new_tag }}" ]]; then + echo "App version tag for this commit has already been dispatched: '${{ steps.tag.outputs.tag }}'" + echo "app-version=${{ steps.tag.outputs.tag }}" >> $GITHUB_OUTPUT + else + echo "New app version tag: '${{ steps.tag.outputs.new_tag }}'" + echo "app-version=${{ steps.tag.outputs.new_tag }}" >> $GITHUB_OUTPUT + fi diff --git a/.github/workflows/verify_consumer_pacts.yml b/.github/workflows/verify_consumer_pacts.yml index c7c54b6ee..29935e8f3 100644 --- a/.github/workflows/verify_consumer_pacts.yml +++ b/.github/workflows/verify_consumer_pacts.yml @@ -151,7 +151,7 @@ jobs: # for publishing the results of provider verification. if [[ -z "${{ inputs.pb-event-type }}" ]]; then echo "PROVIDER_BRANCH=${{ env.CURRENT_BRANCH }}" >> $GITHUB_ENV - echo "PROVIDER_VERSION=${{ needs.regulated-tag-job.outputs.new-tag }}" >> $GITHUB_ENV + echo "PROVIDER_VERSION=${{ needs.regulated-tag-job.outputs.app-version }}" >> $GITHUB_ENV else echo "PROVIDER_VERSION=${{ env.PROVIDER_TAG }}" >> $GITHUB_ENV fi diff --git a/automation/project/Dependencies.scala b/automation/project/Dependencies.scala index 66f2cb8b4..304c365f3 100644 --- a/automation/project/Dependencies.scala +++ b/automation/project/Dependencies.scala @@ -7,9 +7,9 @@ object Dependencies { val akkaV = "2.6.19" val akkaHttpV = "10.2.2" - val workbenchLibV = "a0519cb" + val workbenchLibV = "d16cba9" val workbenchGoogleV = s"0.30-$workbenchLibV" - val workbenchGoogle2V = s"0.34-$workbenchLibV" + val workbenchGoogle2V = s"0.36-$workbenchLibV" val workbenchServiceTestV = "2.0-5863cbd" val excludeWorkbenchModel = ExclusionRule(organization = "org.broadinstitute.dsde.workbench", name = "workbench-model_" + scalaV) diff --git a/pact4s/src/test/resources/reference.conf b/pact4s/src/test/resources/reference.conf index 69bd71152..01b10882c 100644 --- a/pact4s/src/test/resources/reference.conf +++ b/pact4s/src/test/resources/reference.conf @@ -119,8 +119,6 @@ testStuff = { oidc { authorityEndpoint = "https://accounts.google.com" oidcClientId = "some-client" - oidcClientSecret = "some-secret" - legacyGoogleClientId = "another-client" } liquibase { diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 3dbe9cd01..c7a786ba2 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -11,14 +11,14 @@ object Dependencies { val postgresDriverVersion = "42.7.2" val sentryVersion = "6.15.0" - val workbenchLibV = "1c0cf92" // If updating this, make sure googleStorageLocal in test dependencies is up-to-date + val workbenchLibV = "d16cba9" // If updating this, make sure googleStorageLocal in test dependencies is up-to-date val workbenchUtilV = s"0.10-$workbenchLibV" val workbenchUtil2V = s"0.9-$workbenchLibV" val workbenchModelV = s"0.19-$workbenchLibV" val workbenchGoogleV = s"0.30-$workbenchLibV" val workbenchGoogle2V = s"0.36-$workbenchLibV" val workbenchNotificationsV = s"0.6-$workbenchLibV" - val workbenchOauth2V = s"0.5-$workbenchLibV" + val workbenchOauth2V = s"0.7-$workbenchLibV" val monocleVersion = "2.0.5" val crlVersion = "1.2.30-SNAPSHOT" val tclVersion = "1.0.5-SNAPSHOT" diff --git a/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changelog.xml b/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changelog.xml index 1d6741951..c9ba57fb0 100644 --- a/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changelog.xml +++ b/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changelog.xml @@ -28,6 +28,6 @@ + - diff --git a/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changesets/20240417_action_managed_identities.xml b/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changesets/20240417_action_managed_identities.xml new file mode 100644 index 000000000..756cdcbd0 --- /dev/null +++ b/src/main/resources/org/broadinstitute/dsde/sam/liquibase/changesets/20240417_action_managed_identities.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf index c78dcee3d..77b9bd2e6 100644 --- a/src/main/resources/reference.conf +++ b/src/main/resources/reference.conf @@ -1550,6 +1550,152 @@ resourceTypes = { allowLeaving = false reuseIds = true } + + resource-access-constraint = { + actionPatterns = { + use = { + description = "use this RAC for a SAM resource" + } + add_lock = { + description = "add a lock to this RAC" + } + delete = { + description = "delete this RAC" + } + read_policies = { + description = "view all policies and policy details for this RAC" + } + "share_policy::owner" = { + description = "change the membership of the owner policy for this RAC" + } + } + ownerRoleName = "owner" + roles = { + owner = { + roleActions = ["use", "add_lock", "delete", "read_policies", "share_policy::owner"] + } + } + allowLeaving = false + reuseIds = true + } + + lock = { + actionPatterns = { + use = { + description = "use this lock for a resource access constraint" + } + delete = { + description = "delete this lock" + } + read_policies = { + description = "view all policies and policy details for this lock" + } + "share_policy::owner" = { + description = "change the membership of the owner policy for this lock" + } + } + ownerRoleName = "owner" + roles = { + owner = { + roleActions = ["use", "delete", "read_policies", "share_policy::owner"] + } + } + allowLeaving = false + reuseIds = true + } + + private_azure_container_registry = { + actionPatterns = { + delete = { + description = "Delete this private acr" + } + read_policies = { + description = "view all policies and policy details for this private acr" + } + identify = { + description = "use the identity that has access to this private acr" + } + "share_policy::admin" = { + description = "change the membership of the admin policy for this private acr" + } + "share_policy::user" = { + description = "change the membership of the user policy for this private acr" + } + } + ownerRoleName = "admin" + roles = { + admin = { + roleActions = ["delete", "read_policies", "use", "share_policy::admin", "share_policy::user", "identify"] + } + user = { + roleActions = ["identify"] + } + } + allowLeaving = false + reuseIds = true + } + + private_azure_storage_account = { + actionPatterns = { + delete = { + description = "Delete this private azure storage account" + } + read_policies = { + description = "view all policies and policy details for this private azure storage account" + } + identify = { + description = "use the identity that has access to this private azure storage account" + } + "share_policy::admin" = { + description = "change the membership of the admin policy for this private azure storage account" + } + "share_policy::user" = { + description = "change the membership of the user policy for this private azure storage account" + } + } + ownerRoleName = "admin" + roles = { + admin = { + roleActions = ["delete", "read_policies", "use", "share_policy::admin", "share_policy::user", "identify"] + } + user = { + roleActions = ["identify"] + } + } + allowLeaving = false + reuseIds = true + } + + azure_managed_identity = { + actionPatterns = { + delete = { + description = "Delete this azure managed identity" + } + read_policies = { + description = "view all policies and policy details for this azure managed identity" + } + identify = { + description = "use the identity that has access to this azure managed identity" + } + "share_policy::admin" = { + description = "change the membership of the admin policy for this azure managed identity" + } + "share_policy::user" = { + description = "change the membership of the user policy for this azure managed identity" + } + } + ownerRoleName = "admin" + roles = { + admin = { + roleActions = ["delete", "read_policies", "use", "share_policy::admin", "share_policy::user", "identify"] + } + user = { + roleActions = ["identify"] + } + } + allowLeaving = false + reuseIds = true + } } diff --git a/src/main/resources/sam.conf b/src/main/resources/sam.conf index d4104ce91..d9bd1d758 100644 --- a/src/main/resources/sam.conf +++ b/src/main/resources/sam.conf @@ -37,8 +37,6 @@ termsOfService { oidc { authorityEndpoint = ${?OIDC_AUTHORITY_ENDPOINT} oidcClientId = ${?OIDC_CLIENT_ID} - oidcClientSecret = ${?OIDC_CLIENT_SECRET} - legacyGoogleClientId = ${?LEGACY_GOOGLE_CLIENT_ID} } schemaLock { diff --git a/src/main/resources/swagger/api-docs.yaml b/src/main/resources/swagger/api-docs.yaml index 183f8d991..dd17565fa 100755 --- a/src/main/resources/swagger/api-docs.yaml +++ b/src/main/resources/swagger/api-docs.yaml @@ -11,14 +11,8 @@ info: servers: - url: / security: - - googleoauth: - - openid - - email - - profile - oidc: - openid - - email - - profile paths: /api/admin/v1/user/{userId}: get: @@ -1479,7 +1473,7 @@ paths: content: 'application/json': schema: - $ref: '#/components/schemas/GetOrCreatePetManagedIdentityRequest' + $ref: '#/components/schemas/GetOrCreateManagedIdentityRequest' required: true responses: 200: @@ -1528,7 +1522,7 @@ paths: content: 'application/json': schema: - $ref: '#/components/schemas/GetOrCreatePetManagedIdentityRequest' + $ref: '#/components/schemas/GetOrCreateManagedIdentityRequest' required: true responses: 200: @@ -1559,6 +1553,125 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorReport' + /api/azure/v1/actionManagedIdentity/{billingProfileId}/{resourceTypeName}/{resourceId}/{action}: + post: + tags: + - Azure + summary: creates an action managed identity that the calling user has access to + operationId: createActionManagedIdentity + parameters: + - name: billingProfileId + in: path + description: Billing Profile if the Managed Resource Group to create the Action Managed Identity in + required: true + schema: + type: string + - name: resourceTypeName + in: path + description: Type of resource + required: true + schema: + type: string + - name: resourceId + in: path + description: Id of resource + required: true + schema: + type: string + - name: action + in: path + description: Action to create the managed identity for + required: true + schema: + type: string + responses: + 200: + description: Action managed identity already exists + content: + application/json: + schema: + $ref: '#/components/schemas/ActionManagedIdentityResponse' + 201: + description: Successfully created the action managed identity + content: + application/json: + schema: + $ref: '#/components/schemas/ActionManagedIdentityResponse' + 400: + description: Invalid or incomplete request body + content: { } + 403: + description: Caller does not have the action on the resource + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorReport' + 404: + description: Resource does not exist + content: { } + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorReport' + /api/azure/v1/actionManagedIdentity/{resourceTypeName}/{resourceId}/{action}: + get: + tags: + - Azure + summary: gets an action managed identity that the calling user has access to + operationId: getActionManagedIdentity + parameters: + - name: resourceTypeName + in: path + description: Type of resource + required: true + schema: + type: string + - name: resourceId + in: path + description: Id of resource + required: true + schema: + type: string + - name: action + in: path + description: Action to get the managed identity for + required: true + schema: + type: string + responses: + 200: + description: Successfully retrieved the action managed identity without creating it + content: + application/json: + schema: + $ref: '#/components/schemas/ActionManagedIdentityResponse' + 201: + description: Successfully created the action managed identity + content: + application/json: + schema: + type: string + 400: + description: Invalid or incomplete request body + content: { } + 403: + description: Caller does not have the action on the resource + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorReport' + 404: + description: Resource does not exist + content: { } + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorReport' + /api/resources/v2: get: tags: @@ -4210,7 +4323,7 @@ components: type: string description: the name of the Azure managed resource group example: my-managed-resource-group - GetOrCreatePetManagedIdentityRequest: + GetOrCreateManagedIdentityRequest: required: - tenantId - subscriptionId @@ -4229,7 +4342,31 @@ components: type: string description: the name of the Azure managed resource group example: my-managed-resource-group - description: specifies a request for a pet managed identity + description: specifies a request for a managed identity + ActionManagedIdentityId: + type: object + required: + - resourceId + - action + - billingProfileId + properties: + resourceId: + $ref: '#/components/schemas/FullyQualifiedResourceId' + action: + type: string + billingProfileId: + type: string + ActionManagedIdentityResponse: + type: object + properties: + id: + $ref: '#/components/schemas/ActionManagedIdentityId' + objectId: + type: string + displayName: + type: string + managedResourceGroupCoordinates: + $ref: '#/components/schemas/ManagedResourceGroupCoordinates' User: type: object required: @@ -4495,22 +4632,7 @@ components: items: type: string securitySchemes: - googleoauth: - type: oauth2 - flows: - implicit: - authorizationUrl: https://accounts.google.com/o/oauth2/auth - scopes: - openid: open id authorization - email: email authorization - profile: profile authorization oidc: - type: oauth2 - flows: - authorizationCode: - authorizationUrl: /oauth2/authorize - tokenUrl: /oauth2/token - scopes: - openid: open id authorization - email: email authorization - profile: profile authorization + type: openIdConnect + openIdConnectUrl: OPEN_ID_CONNECT_URL + x-tokenName: id_token \ No newline at end of file diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/Boot.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/Boot.scala index 11a60e555..2f9b647d2 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/Boot.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/Boot.scala @@ -34,7 +34,7 @@ import org.broadinstitute.dsde.workbench.google.{ } import org.broadinstitute.dsde.workbench.google2.{GoogleStorageInterpreter, GoogleStorageService} import org.broadinstitute.dsde.workbench.model.WorkbenchEmail -import org.broadinstitute.dsde.workbench.oauth2.{ClientId, ClientSecret, OpenIDConnectConfiguration} +import org.broadinstitute.dsde.workbench.oauth2.{ClientId, OpenIDConnectConfiguration} import org.broadinstitute.dsde.workbench.sam.api.{LivenessRoutes, SamRoutes, StandardSamUserDirectives} import org.broadinstitute.dsde.workbench.sam.azure.{AzureService, CrlService} import org.broadinstitute.dsde.workbench.sam.config.AppConfig.AdminConfig @@ -140,8 +140,6 @@ object Boot extends IOApp with LazyLogging { OpenIDConnectConfiguration[IO]( appConfig.oidcConfig.authorityEndpoint, ClientId(appConfig.oidcConfig.clientId), - oidcClientSecret = appConfig.oidcConfig.clientSecret.map(ClientSecret), - extraGoogleClientId = appConfig.oidcConfig.legacyGoogleClientId.map(ClientId), extraAuthParams = Some("prompt=login") ) ) @@ -408,6 +406,9 @@ object Boot extends IOApp with LazyLogging { )(implicit actorSystem: ActorSystem): AppDependencies = { val resourceTypeMap = config.resourceTypes.map(rt => rt.name -> rt).toMap val policyEvaluatorService = PolicyEvaluatorService(config.emailDomain, resourceTypeMap, accessPolicyDAO, directoryDAO) + val azureService = config.azureServicesConfig.map { azureConfig => + new AzureService(new CrlService(azureConfig, config.janitorConfig), directoryDAO, azureManagedResourceGroupDAO) + } val resourceService = new ResourceService( resourceTypeMap, policyEvaluatorService, @@ -415,11 +416,19 @@ object Boot extends IOApp with LazyLogging { directoryDAO, cloudExtensionsInitializer.cloudExtensions, config.emailDomain, - config.adminConfig.allowedEmailDomains + config.adminConfig.allowedEmailDomains, + azureService ) val tosService = new TosService(cloudExtensionsInitializer.cloudExtensions, directoryDAO, config.termsOfServiceConfig) val userService = - new UserService(directoryDAO, cloudExtensionsInitializer.cloudExtensions, config.blockedEmailDomains, tosService, config.azureServicesConfig) + new UserService( + directoryDAO, + cloudExtensionsInitializer.cloudExtensions, + config.blockedEmailDomains, + tosService, + config.azureServicesConfig, + Seq(config.emailDomain) + ) val statusService = new StatusService(directoryDAO, cloudExtensionsInitializer.cloudExtensions, 10 seconds, 1 minute) val managedGroupService = @@ -433,9 +442,7 @@ object Boot extends IOApp with LazyLogging { config.emailDomain ) val samApplication = SamApplication(userService, resourceService, statusService, tosService) - val azureService = config.azureServicesConfig.map { azureConfig => - new AzureService(new CrlService(azureConfig, config.janitorConfig), directoryDAO, azureManagedResourceGroupDAO) - } + cloudExtensionsInitializer match { case GoogleExtensionsInitializer(googleExt, synchronizer) => val routes = new SamRoutes( diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureModel.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureModel.scala index 176b0bd94..c7bc5b2d6 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureModel.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureModel.scala @@ -2,8 +2,9 @@ package org.broadinstitute.dsde.workbench.sam.azure import org.broadinstitute.dsde.workbench.model.WorkbenchIdentityJsonSupport._ import org.broadinstitute.dsde.workbench.model._ -import org.broadinstitute.dsde.workbench.sam.model.{ResourceAction, ResourceId} +import org.broadinstitute.dsde.workbench.sam.model.{FullyQualifiedResourceId, ResourceAction, ResourceId} import spray.json.DefaultJsonProtocol._ +import org.broadinstitute.dsde.workbench.sam.model.api.SamJsonSupport._ object AzureJsonSupport { implicit val tenantIdFormat = ValueObjectFormat(TenantId.apply) @@ -23,6 +24,12 @@ object AzureJsonSupport { implicit val petManagedIdentityFormat = jsonFormat3(PetManagedIdentity.apply) implicit val managedResourceGroupCoordinatesFormat = jsonFormat3(ManagedResourceGroupCoordinates.apply) + + implicit val billingProfileIdFormat = ValueObjectFormat(BillingProfileId.apply) + + implicit val actionManagedIdentityIdFormat = jsonFormat3(ActionManagedIdentityId.apply) + + implicit val actionManagedIdentityFormat = jsonFormat4(ActionManagedIdentity.apply) } final case class TenantId(value: String) extends ValueObject @@ -56,6 +63,14 @@ final case class PetManagedIdentityId( final case class PetManagedIdentity(id: PetManagedIdentityId, objectId: ManagedIdentityObjectId, displayName: ManagedIdentityDisplayName) +final case class ActionManagedIdentityId(resourceId: FullyQualifiedResourceId, action: ResourceAction, billingProfileId: BillingProfileId) + +final case class ActionManagedIdentity( + id: ActionManagedIdentityId, + objectId: ManagedIdentityObjectId, + displayName: ManagedIdentityDisplayName, + managedResourceGroupCoordinates: ManagedResourceGroupCoordinates +) object AzureExtensions { val resourceId = ResourceId("azure") val getPetManagedIdentityAction = ResourceAction("get_pet_managed_identity") diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureRoutes.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureRoutes.scala index 81a2604bb..7d15dfebf 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureRoutes.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureRoutes.scala @@ -17,7 +17,7 @@ import org.broadinstitute.dsde.workbench.sam.service.CloudExtensions import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext import spray.json.JsString -trait AzureRoutes extends SecurityDirectives with LazyLogging with SamRequestContextDirectives { +trait AzureRoutes extends SecurityDirectives with LazyLogging with SamRequestContextDirectives with SamModelDirectives { val azureService: Option[AzureService] def azureRoutes(samUser: SamUser, samRequestContext: SamRequestContext): Route = @@ -57,6 +57,67 @@ trait AzureRoutes extends SecurityDirectives with LazyLogging with SamRequestCon } } } ~ + pathPrefix("actionManagedIdentity") { + path(Segment / Segment / Segment / Segment) { (bpId, resourceTypeName, resourceId, action) => + val billingProfileId = BillingProfileId(bpId) + val resource = FullyQualifiedResourceId(ResourceTypeName(resourceTypeName), ResourceId(resourceId)) + val resourceAction = ResourceAction(action) + + withNonAdminResourceType(resource.resourceTypeName) { resourceType => + if (!resourceType.actionPatterns.map(ap => ResourceAction(ap.value)).contains(resourceAction)) { + throw new WorkbenchExceptionWithErrorReport(ErrorReport(StatusCodes.NotFound, s"action $action not found")) + } + pathEndOrSingleSlash { + postWithTelemetry( + samRequestContext, + "billingProfileId" -> billingProfileId, + "resourceType" -> resource.resourceTypeName, + "resource" -> resource.resourceId, + "action" -> resourceAction + ) { + requireAction(resource, resourceAction, samUser.id, samRequestContext) { + complete { + service.getOrCreateActionManagedIdentity(resource, resourceAction, billingProfileId, samRequestContext).map { case (ami, created) => + val status = if (created) StatusCodes.Created else StatusCodes.OK + status -> ami + } + } + } + } + } + } + } ~ + path(Segment / Segment / Segment) { (resourceTypeName, resourceId, action) => + val resource = FullyQualifiedResourceId(ResourceTypeName(resourceTypeName), ResourceId(resourceId)) + val resourceAction = ResourceAction(action) + + withNonAdminResourceType(resource.resourceTypeName) { resourceType => + if (!resourceType.actionPatterns.map(ap => ResourceAction(ap.value)).contains(resourceAction)) { + throw new WorkbenchExceptionWithErrorReport(ErrorReport(StatusCodes.NotFound, s"action $action not found")) + } + pathEndOrSingleSlash { + getWithTelemetry( + samRequestContext, + "resourceType" -> resource.resourceTypeName, + "resource" -> resource.resourceId, + "action" -> resourceAction + ) { + requireAction(resource, resourceAction, samUser.id, samRequestContext) { + complete { + service.getActionManagedIdentity(resource, resourceAction, samRequestContext).map { + case Some(actionManagedIdentity) => StatusCodes.OK -> actionManagedIdentity + case None => + throw new WorkbenchExceptionWithErrorReport( + ErrorReport(StatusCodes.NotFound, s"Action Managed identity for [$resourceAction] on [$resource] not found") + ) + } + } + } + } + } + } + } + } ~ path("billingProfile" / Segment / "managedResourceGroup") { billingProfileId => val billingProfileResourceId = ResourceId(billingProfileId) val billingProfileIdParam = "billingProfileId" -> billingProfileResourceId diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureService.scala index 58a29abd4..37cfae4bf 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureService.scala @@ -5,6 +5,7 @@ import bio.terra.cloudres.azure.resourcemanager.common.Defaults import bio.terra.cloudres.azure.resourcemanager.msi.data.CreateUserAssignedManagedIdentityRequestData import cats.data.OptionT import cats.effect.IO +import cats.implicits.toTraverseOps import com.azure.core.management.Region import com.azure.core.util.Context import com.azure.resourcemanager.managedapplications.models.Application @@ -14,13 +15,18 @@ import org.broadinstitute.dsde.workbench.model.{ErrorReport, WorkbenchEmail, Wor import org.broadinstitute.dsde.workbench.sam._ import org.broadinstitute.dsde.workbench.sam.config.ManagedAppPlan import org.broadinstitute.dsde.workbench.sam.dataAccess.{AzureManagedResourceGroupDAO, DirectoryDAO} +import org.broadinstitute.dsde.workbench.sam.model.{FullyQualifiedResourceId, ResourceAction} import org.broadinstitute.dsde.workbench.sam.model.api.SamUser import org.broadinstitute.dsde.workbench.sam.util.OpenTelemetryIOUtils._ import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext import scala.jdk.CollectionConverters._ -class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureManagedResourceGroupDAO: AzureManagedResourceGroupDAO) { +class AzureService( + crlService: CrlService, + directoryDAO: DirectoryDAO, + azureManagedResourceGroupDAO: AzureManagedResourceGroupDAO +) { // Tag on the MRG to specify the Sam billing-profile id private val billingProfileTag = "terra.billingProfileId" @@ -31,7 +37,7 @@ class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureMana private val managedAppValidationFailure = new WorkbenchExceptionWithErrorReport( ErrorReport( StatusCodes.Forbidden, - "Specified manged resource group invalid. Possible reasons include resource group does not exist, it is not " + + "Specified managed resource group invalid. Possible reasons include resource group does not exist, it is not " + "associated to an application, the application's plan is not supported or the user is not listed as authorized." ) ) @@ -71,6 +77,10 @@ class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureMana _ <- IO.raiseWhen(existing.isEmpty)( new WorkbenchExceptionWithErrorReport(ErrorReport(StatusCodes.NotFound, s"managed resource group for profile ${billingProfileId} not found")) ) + actionManagedIdentities <- directoryDAO.getAllActionManagedIdentitiesForBillingProfile(billingProfileId, samRequestContext) + _ <- actionManagedIdentities.toList.traverse { ami => + deleteActionManagedIdentity(ami.id, samRequestContext) + } _ <- azureManagedResourceGroupDAO.deleteManagedResourceGroup( billingProfileId, samRequestContext @@ -126,6 +136,79 @@ class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureMana createdPet <- directoryDAO.createPetManagedIdentity(petToCreate, samRequestContext) } yield (createdPet, true) + def getOrCreateActionManagedIdentity( + resource: FullyQualifiedResourceId, + resourceAction: ResourceAction, + billingProfileId: BillingProfileId, + samRequestContext: SamRequestContext + ): IO[(ActionManagedIdentity, Boolean)] = { + val id = ActionManagedIdentityId(resource, resourceAction, billingProfileId) + for { + existingAmiOpt <- directoryDAO.loadActionManagedIdentity(id, samRequestContext) + ami <- existingAmiOpt match { + // pet exists in Sam DB - return it + case Some(p) => IO.pure((p, false)) + // pet does not exist in Sam DB - create it + case None => createActionManagedIdentity(id, samRequestContext) + } + } yield ami + } + + def getActionManagedIdentity( + resource: FullyQualifiedResourceId, + resourceAction: ResourceAction, + samRequestContext: SamRequestContext + ): IO[Option[ActionManagedIdentity]] = directoryDAO.loadActionManagedIdentity(resource, resourceAction, samRequestContext) + + private def createActionManagedIdentity( + id: ActionManagedIdentityId, + samRequestContext: SamRequestContext + ): IO[(ActionManagedIdentity, Boolean)] = + for { + mrgOpt <- azureManagedResourceGroupDAO.getManagedResourceGroupByBillingProfileId(id.billingProfileId, samRequestContext) + mrg <- IO.fromOption(mrgOpt)( + new WorkbenchExceptionWithErrorReport( + ErrorReport(StatusCodes.NotFound, s"Managed Resource Group with Billing Profile ID [${id.billingProfileId}] does not exist") + ) + ) + mrgCoordinates = mrg.managedResourceGroupCoordinates + // mapping the result of the validate call to ensure that validation happens before anything is created in Azure + validatedMrgCoordinates <- validateManagedResourceGroup(mrgCoordinates, samRequestContext).map(_ => mrg.managedResourceGroupCoordinates) + msiManager <- crlService.buildMsiManager(validatedMrgCoordinates.tenantId, validatedMrgCoordinates.subscriptionId) + mrgManager <- crlService.buildResourceManager(validatedMrgCoordinates.tenantId, validatedMrgCoordinates.subscriptionId) + amiName = toManagedIdentityNameFromAmiId(id) + region <- getRegionFromMrg(validatedMrgCoordinates, mrgManager, samRequestContext) + context = managedIdentityContext(validatedMrgCoordinates, amiName, region) + azureUami <- traceIOWithContext("createUAMI", samRequestContext) { _ => + IO( + // note that this will not fail when the UAMI already exists + msiManager + .identities() + .define(amiName.value) + .withRegion(region) + .withExistingResourceGroup(validatedMrgCoordinates.managedResourceGroupName.value) + .withTags(managedIdentityTags(id).asJava) + .create(context) + ) + } + amiToCreate = ActionManagedIdentity(id, ManagedIdentityObjectId(azureUami.id()), ManagedIdentityDisplayName(azureUami.name()), validatedMrgCoordinates) + createdAmi <- directoryDAO.createActionManagedIdentity(amiToCreate, samRequestContext) + } yield (createdAmi, true) + + def deleteActionManagedIdentity(id: ActionManagedIdentityId, samRequestContext: SamRequestContext): IO[Unit] = + for { + existing <- directoryDAO.loadActionManagedIdentity(id, samRequestContext) + _ <- existing + .map { ami => + for { + msiManager <- crlService.buildMsiManager(ami.managedResourceGroupCoordinates.tenantId, ami.managedResourceGroupCoordinates.subscriptionId) + _ <- IO(msiManager.identities().deleteById(ami.objectId.value)) + _ <- directoryDAO.deleteActionManagedIdentity(ami.id, samRequestContext) + } yield {} + } + .getOrElse(IO.unit) + } yield {} + private def getRegionFromMrg(mrgCoords: ManagedResourceGroupCoordinates, mrgManager: ResourceManager, samRequestContext: SamRequestContext) = traceIOWithContext("getRegionFromMrg", samRequestContext) { _ => IO(mrgManager.resourceGroups().getByName(mrgCoords.managedResourceGroupName.value).region()) @@ -179,7 +262,7 @@ class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureMana resourceManager <- crlService.buildResourceManager(mrgCoords.tenantId, mrgCoords.subscriptionId) mrg <- lookupMrg(mrgCoords, resourceManager) appManager <- crlService.buildApplicationManager(mrgCoords.tenantId, mrgCoords.subscriptionId) - appsInSubscription <- IO(appManager.applications().list().asScala) + appsInSubscription <- IO(appManager.applications().list().asScala.toSeq) managedApp <- IO.fromOption(appsInSubscription.find(_.managedResourceGroupId() == mrg.id()))(managedAppValidationFailure) plan <- validatePlan(managedApp, crlService.getManagedAppPlans) _ <- if (validateUser) validateAuthorizedAppUser(managedApp, plan, samRequestContext) else IO.unit @@ -234,6 +317,9 @@ class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureMana private def managedIdentityTags(user: SamUser): Map[String, String] = Map("samUserId" -> user.id.value, "samUserEmail" -> user.email.value) + private def managedIdentityTags(amiId: ActionManagedIdentityId): Map[String, String] = + Map("resourceTypeName" -> amiId.resourceId.resourceTypeName.value, "resourceId" -> amiId.resourceId.resourceId.value, "action" -> amiId.action.value) + private def managedIdentityContext(mrgCoords: ManagedResourceGroupCoordinates, petName: ManagedIdentityDisplayName, region: Region): Context = Defaults.buildContext( CreateUserAssignedManagedIdentityRequestData @@ -249,4 +335,12 @@ class AzureService(crlService: CrlService, directoryDAO: DirectoryDAO, azureMana private def toManagedIdentityNameFromUser(user: SamUser): ManagedIdentityDisplayName = ManagedIdentityDisplayName(s"pet-${user.id.value}") + def toManagedIdentityNameFromAmiId(amiId: ActionManagedIdentityId): ManagedIdentityDisplayName = { + // Managed Identity Names are limited to 24 characters + val actionPart = if (amiId.action.value.length > 11) amiId.action.value.substring(0, 11) else amiId.action.value + val resourceIdPart = + if (amiId.resourceId.resourceId.value.length > 12) amiId.resourceId.resourceId.value.substring(0, 12) else amiId.resourceId.resourceId.value + ManagedIdentityDisplayName(s"$resourceIdPart-$actionPart") + } + } diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/AppConfig.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/AppConfig.scala index 5ff886fff..8af2a2859 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/AppConfig.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/AppConfig.scala @@ -37,9 +37,7 @@ object AppConfig { implicit val oidcReader: ValueReader[OidcConfig] = ValueReader.relative { config => OidcConfig( config.getString("authorityEndpoint"), - config.getString("oidcClientId"), - config.as[Option[String]]("oidcClientSecret"), - config.as[Option[String]]("legacyGoogleClientId") + config.getString("oidcClientId") ) } diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/OidcConfig.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/OidcConfig.scala index aa79e156f..aa9d7dc14 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/OidcConfig.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/config/OidcConfig.scala @@ -1,3 +1,3 @@ package org.broadinstitute.dsde.workbench.sam.config -case class OidcConfig(authorityEndpoint: String, clientId: String, clientSecret: Option[String], legacyGoogleClientId: Option[String]) +case class OidcConfig(authorityEndpoint: String, clientId: String) diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala index e25b6c20f..1e324c695 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/DirectoryDAO.scala @@ -3,9 +3,16 @@ package org.broadinstitute.dsde.workbench.sam.dataAccess import cats.effect.IO import org.broadinstitute.dsde.workbench.model._ import org.broadinstitute.dsde.workbench.model.google.ServiceAccountSubjectId -import org.broadinstitute.dsde.workbench.sam.azure.{ManagedIdentityObjectId, PetManagedIdentity, PetManagedIdentityId} +import org.broadinstitute.dsde.workbench.sam.azure.{ + ActionManagedIdentity, + ActionManagedIdentityId, + BillingProfileId, + ManagedIdentityObjectId, + PetManagedIdentity, + PetManagedIdentityId +} import org.broadinstitute.dsde.workbench.sam.model.api.{AdminUpdateUserRequest, SamUser, SamUserAttributes} -import org.broadinstitute.dsde.workbench.sam.model.{BasicWorkbenchGroup, SamUserTos} +import org.broadinstitute.dsde.workbench.sam.model.{BasicWorkbenchGroup, FullyQualifiedResourceId, ResourceAction, SamUserTos} import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext import java.time.Instant @@ -88,6 +95,34 @@ trait DirectoryDAO { def createPetManagedIdentity(petManagedIdentity: PetManagedIdentity, samRequestContext: SamRequestContext): IO[PetManagedIdentity] def loadPetManagedIdentity(petManagedIdentityId: PetManagedIdentityId, samRequestContext: SamRequestContext): IO[Option[PetManagedIdentity]] def getUserFromPetManagedIdentity(petManagedIdentityObjectId: ManagedIdentityObjectId, samRequestContext: SamRequestContext): IO[Option[SamUser]] + + def createActionManagedIdentity(actionManagedIdentity: ActionManagedIdentity, samRequestContext: SamRequestContext): IO[ActionManagedIdentity] + + def loadActionManagedIdentity(actionManagedIdentityId: ActionManagedIdentityId, samRequestContext: SamRequestContext): IO[Option[ActionManagedIdentity]] + + def loadActionManagedIdentity( + resource: FullyQualifiedResourceId, + action: ResourceAction, + samRequestContext: SamRequestContext + ): IO[Option[ActionManagedIdentity]] + + def updateActionManagedIdentity(actionManagedIdentity: ActionManagedIdentity, samRequestContext: SamRequestContext): IO[ActionManagedIdentity] + + def deleteActionManagedIdentity(actionManagedIdentityId: ActionManagedIdentityId, samRequestContext: SamRequestContext): IO[Unit] + + def getAllActionManagedIdentitiesForResource( + resourceId: FullyQualifiedResourceId, + samRequestContext: SamRequestContext + ): IO[Seq[ActionManagedIdentity]] + + def deleteAllActionManagedIdentitiesForResource(resourceId: FullyQualifiedResourceId, samRequestContext: SamRequestContext): IO[Unit] + + def getAllActionManagedIdentitiesForBillingProfile( + billingProfileId: BillingProfileId, + samRequestContext: SamRequestContext + ): IO[Seq[ActionManagedIdentity]] + def deleteAllActionManagedIdentitiesForBillingProfile(billingProfileId: BillingProfileId, samRequestContext: SamRequestContext): IO[Unit] + def setUserRegisteredAt(userId: WorkbenchUserId, registeredAt: Instant, samRequestContext: SamRequestContext): IO[Unit] def getUserAttributes(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[Option[SamUserAttributes]] diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala index 0b3e63e9c..f659eb72d 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAO.scala @@ -6,7 +6,19 @@ import com.typesafe.scalalogging.LazyLogging import org.broadinstitute.dsde.workbench.model._ import org.broadinstitute.dsde.workbench.model.google.{GoogleProject, ServiceAccount, ServiceAccountSubjectId} import org.broadinstitute.dsde.workbench.sam._ -import org.broadinstitute.dsde.workbench.sam.azure.{ManagedIdentityObjectId, PetManagedIdentity, PetManagedIdentityId} +import org.broadinstitute.dsde.workbench.sam.azure.{ + ActionManagedIdentity, + ActionManagedIdentityId, + BillingProfileId, + ManagedIdentityDisplayName, + ManagedIdentityObjectId, + ManagedResourceGroupCoordinates, + ManagedResourceGroupName, + PetManagedIdentity, + PetManagedIdentityId, + SubscriptionId, + TenantId +} import org.broadinstitute.dsde.workbench.sam.db.SamParameterBinderFactory._ import org.broadinstitute.dsde.workbench.sam.db.SamTypeBinders._ import org.broadinstitute.dsde.workbench.sam.db._ @@ -990,6 +1002,278 @@ class PostgresDirectoryDAO(protected val writeDbRef: DbReference, protected val userRecordOpt.map(UserTable.unmarshalUserRecord) } + override def createActionManagedIdentity(actionManagedIdentity: ActionManagedIdentity, samRequestContext: SamRequestContext): IO[ActionManagedIdentity] = + serializableWriteTransaction("createActionManagedIdentity", samRequestContext) { implicit session => + val actionManagedIdentityColumn = ActionManagedIdentityTable.column + val resourceTable = ResourceTable.syntax + val resourceTypeTable = ResourceTypeTable.syntax + val resourceActionTable = ResourceActionTable.syntax + val managedResourceGroupTable = AzureManagedResourceGroupTable.syntax + + samsql"""insert into ${ActionManagedIdentityTable.table} + ( + ${actionManagedIdentityColumn.resourceId}, + ${actionManagedIdentityColumn.resourceActionId}, + ${actionManagedIdentityColumn.managedResourceGroupId}, + ${actionManagedIdentityColumn.objectId}, + ${actionManagedIdentityColumn.displayName} + ) + values ( + (select ${resourceTable.result.id} from ${ResourceTable as resourceTable} left join ${ResourceTypeTable as resourceTypeTable} on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} where ${resourceTable.name} = ${actionManagedIdentity.id.resourceId.resourceId} and ${resourceTypeTable.name} = ${actionManagedIdentity.id.resourceId.resourceTypeName}), + (select ${resourceActionTable.result.id} from ${ResourceActionTable as resourceActionTable} left join ${ResourceTypeTable as resourceTypeTable} on ${resourceActionTable.resourceTypeId} = ${resourceTypeTable.id} where ${resourceActionTable.action} = ${actionManagedIdentity.id.action} and ${resourceTypeTable.name} = ${actionManagedIdentity.id.resourceId.resourceTypeName}), + (select ${managedResourceGroupTable.result.id} + from ${AzureManagedResourceGroupTable as managedResourceGroupTable} + where ${managedResourceGroupTable.billingProfileId} = ${actionManagedIdentity.id.billingProfileId}), + ${actionManagedIdentity.objectId}, + ${actionManagedIdentity.displayName} + )""" + .update() + .apply() + actionManagedIdentity + } + + type TableSyntax[A] = scalikejdbc.QuerySQLSyntaxProvider[scalikejdbc.SQLSyntaxSupport[A], A] + + override def loadActionManagedIdentity( + actionManagedIdentityId: ActionManagedIdentityId, + samRequestContext: SamRequestContext + ): IO[Option[ActionManagedIdentity]] = + readOnlyTransaction("loadActionManagedIdentity", samRequestContext) { implicit session => + implicit val actionManagedIdentityTable: TableSyntax[ActionManagedIdentityRecord] = ActionManagedIdentityTable.syntax + implicit val managedResourceGroupTable: TableSyntax[AzureManagedResourceGroupRecord] = AzureManagedResourceGroupTable.syntax + implicit val resourceActionTable: TableSyntax[ResourceActionRecord] = ResourceActionTable.syntax + implicit val resourceTable: TableSyntax[ResourceRecord] = ResourceTable.syntax + implicit val resourceTypeTable: TableSyntax[ResourceTypeRecord] = ResourceTypeTable.syntax + + val loadActionManagedIdentityQuery = + samsql"""select ${resourceTable.result.name}, + ${resourceTypeTable.result.name}, + ${resourceActionTable.result.action}, + ${managedResourceGroupTable.result.tenantId}, + ${managedResourceGroupTable.result.subscriptionId}, + ${managedResourceGroupTable.result.managedResourceGroupName}, + ${managedResourceGroupTable.result.billingProfileId}, + ${actionManagedIdentityTable.result.objectId}, + ${actionManagedIdentityTable.result.displayName} + from ${ActionManagedIdentityTable as actionManagedIdentityTable} + left join ${AzureManagedResourceGroupTable as managedResourceGroupTable} + on ${actionManagedIdentityTable.managedResourceGroupId} = ${managedResourceGroupTable.id} + left join ${ResourceActionTable as resourceActionTable} + on ${actionManagedIdentityTable.resourceActionId} = ${resourceActionTable.id} + left join ${ResourceTable as resourceTable} + on ${actionManagedIdentityTable.resourceId} = ${resourceTable.id} + left join ${ResourceTypeTable as resourceTypeTable} + on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} + where ${resourceTable.name} = ${actionManagedIdentityId.resourceId.resourceId} + and ${resourceTypeTable.name} = ${actionManagedIdentityId.resourceId.resourceTypeName} + and ${managedResourceGroupTable.id} = ${actionManagedIdentityTable.managedResourceGroupId} + and ${resourceActionTable.action} = ${actionManagedIdentityId.action}""" + + loadActionManagedIdentityQuery.map(unmarshalActionManagedIdentity).single().apply() + } + + def loadActionManagedIdentity( + resource: FullyQualifiedResourceId, + action: ResourceAction, + samRequestContext: SamRequestContext + ): IO[Option[ActionManagedIdentity]] = + readOnlyTransaction("loadActionManagedIdentityForResourceAction", samRequestContext) { implicit session => + implicit val actionManagedIdentityTable: TableSyntax[ActionManagedIdentityRecord] = ActionManagedIdentityTable.syntax + implicit val managedResourceGroupTable: TableSyntax[AzureManagedResourceGroupRecord] = AzureManagedResourceGroupTable.syntax + implicit val resourceActionTable: TableSyntax[ResourceActionRecord] = ResourceActionTable.syntax + implicit val resourceTable: TableSyntax[ResourceRecord] = ResourceTable.syntax + implicit val resourceTypeTable: TableSyntax[ResourceTypeRecord] = ResourceTypeTable.syntax + + val loadActionManagedIdentityQuery = + samsql"""select ${resourceTable.result.name}, + ${resourceTypeTable.result.name}, + ${resourceActionTable.result.action}, + ${managedResourceGroupTable.result.tenantId}, + ${managedResourceGroupTable.result.subscriptionId}, + ${managedResourceGroupTable.result.managedResourceGroupName}, + ${managedResourceGroupTable.result.billingProfileId}, + ${actionManagedIdentityTable.result.objectId}, + ${actionManagedIdentityTable.result.displayName} + from ${ActionManagedIdentityTable as actionManagedIdentityTable} + left join ${AzureManagedResourceGroupTable as managedResourceGroupTable} + on ${actionManagedIdentityTable.managedResourceGroupId} = ${managedResourceGroupTable.id} + left join ${ResourceActionTable as resourceActionTable} + on ${actionManagedIdentityTable.resourceActionId} = ${resourceActionTable.id} + left join ${ResourceTable as resourceTable} + on ${actionManagedIdentityTable.resourceId} = ${resourceTable.id} + left join ${ResourceTypeTable as resourceTypeTable} + on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} + where ${resourceTable.name} = ${resource.resourceId} + and ${resourceTypeTable.name} = ${resource.resourceTypeName} + and ${resourceActionTable.action} = $action""" + + loadActionManagedIdentityQuery.map(unmarshalActionManagedIdentity).single().apply() + } + + override def updateActionManagedIdentity(actionManagedIdentity: ActionManagedIdentity, samRequestContext: SamRequestContext): IO[ActionManagedIdentity] = + serializableWriteTransaction("updateActionManagedIdentity", samRequestContext) { implicit session => + val actionManagedIdentityColumn = ActionManagedIdentityTable.column + val resourceTable = ResourceTable.syntax + val resourceTypeTable = ResourceTypeTable.syntax + val resourceActionTable = ResourceActionTable.syntax + val managedResourceGroupTable = AzureManagedResourceGroupTable.syntax + + val updateAmiQuery = + samsql""" + update ${ActionManagedIdentityTable.table} + set + ${actionManagedIdentityColumn.objectId} = ${actionManagedIdentity.objectId}, + ${actionManagedIdentityColumn.displayName} = ${actionManagedIdentity.displayName} + where + ${actionManagedIdentityColumn.resourceId} = (select ${resourceTable.result.id} from ${ResourceTable as resourceTable} left join ${ResourceTypeTable as resourceTypeTable} on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} where ${resourceTable.name} = ${actionManagedIdentity.id.resourceId.resourceId} and ${resourceTypeTable.name} = ${actionManagedIdentity.id.resourceId.resourceTypeName}) + and ${actionManagedIdentityColumn.resourceActionId} = (select ${resourceActionTable.result.id} + from ${ResourceActionTable as resourceActionTable} + left join ${ResourceTypeTable as resourceTypeTable} on ${resourceActionTable.resourceTypeId} = ${resourceTypeTable.id} + where ${resourceActionTable.action} = ${actionManagedIdentity.id.action} + and ${resourceTypeTable.name} = ${actionManagedIdentity.id.resourceId.resourceTypeName}) + and ${actionManagedIdentityColumn.managedResourceGroupId} = (select ${managedResourceGroupTable.result.id} + from ${AzureManagedResourceGroupTable as managedResourceGroupTable} + where ${managedResourceGroupTable.billingProfileId} = ${actionManagedIdentity.id.billingProfileId}) + """ + val updated = updateAmiQuery.update().apply() + if (updated != 1) { + throw new WorkbenchException(s"Update cannot be applied because ${actionManagedIdentity.id} does not exist") + } + + actionManagedIdentity + } + + override def deleteActionManagedIdentity(actionManagedIdentityId: ActionManagedIdentityId, samRequestContext: SamRequestContext): IO[Unit] = + serializableWriteTransaction("deleteActionManagedIdentity", samRequestContext) { implicit session => + val actionManagedIdentityTable = ActionManagedIdentityTable.syntax + val resourceTable = ResourceTable.syntax + val resourceTypeTable = ResourceTypeTable.syntax + val resourceActionTable = ResourceActionTable.syntax + val managedResourceGroupTable = AzureManagedResourceGroupTable.syntax + + val deleteActionManagedIdentityQuery = + samsql"""delete from ${ActionManagedIdentityTable.table} + where ${actionManagedIdentityTable.resourceId} = (select ${resourceTable.result.id} + from ${ResourceTable as resourceTable} + left join ${ResourceTypeTable as resourceTypeTable} on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} + where ${resourceTable.name} = ${actionManagedIdentityId.resourceId.resourceId} + and ${resourceTypeTable.name} = ${actionManagedIdentityId.resourceId.resourceTypeName}) + and ${actionManagedIdentityTable.managedResourceGroupId} = (select ${managedResourceGroupTable.result.id} + from ${AzureManagedResourceGroupTable as managedResourceGroupTable} + where ${managedResourceGroupTable.billingProfileId} = ${actionManagedIdentityId.billingProfileId}) + and ${actionManagedIdentityTable.resourceActionId} = (select ${resourceActionTable.result.id} + from ${ResourceActionTable as resourceActionTable} + left join ${ResourceTypeTable as resourceTypeTable} on ${resourceActionTable.resourceTypeId} = ${resourceTypeTable.id} + where ${resourceActionTable.action} = ${actionManagedIdentityId.action} + and ${resourceTypeTable.name} = ${actionManagedIdentityId.resourceId.resourceTypeName}) + """ + if (deleteActionManagedIdentityQuery.update().apply() != 1) { + throw new WorkbenchException(s"${actionManagedIdentityId} cannot be deleted because it already does not exist") + } + } + + override def getAllActionManagedIdentitiesForResource( + resourceId: FullyQualifiedResourceId, + samRequestContext: SamRequestContext + ): IO[Seq[ActionManagedIdentity]] = + readOnlyTransaction("loadActionManagedIdentitiesForResource", samRequestContext) { implicit session => + implicit val actionManagedIdentityTable: TableSyntax[ActionManagedIdentityRecord] = ActionManagedIdentityTable.syntax + implicit val managedResourceGroupTable: TableSyntax[AzureManagedResourceGroupRecord] = AzureManagedResourceGroupTable.syntax + implicit val resourceActionTable: TableSyntax[ResourceActionRecord] = ResourceActionTable.syntax + implicit val resourceTable: TableSyntax[ResourceRecord] = ResourceTable.syntax + implicit val resourceTypeTable: TableSyntax[ResourceTypeRecord] = ResourceTypeTable.syntax + + val listActionManagedIdentitysQuery = + samsql"""select ${resourceTable.result.name}, ${resourceTypeTable.result.name}, ${resourceActionTable.result.action}, ${managedResourceGroupTable.result.tenantId}, ${managedResourceGroupTable.result.subscriptionId}, ${managedResourceGroupTable.result.managedResourceGroupName}, ${managedResourceGroupTable.result.billingProfileId}, ${actionManagedIdentityTable.result.objectId}, ${actionManagedIdentityTable.result.displayName} + from ${ActionManagedIdentityTable as actionManagedIdentityTable} + left join ${ResourceActionTable as resourceActionTable} + on ${actionManagedIdentityTable.resourceActionId} = ${resourceActionTable.id} + left join ${AzureManagedResourceGroupTable as managedResourceGroupTable} + on ${actionManagedIdentityTable.managedResourceGroupId} = ${managedResourceGroupTable.id} + left join ${ResourceTable as resourceTable} + on ${actionManagedIdentityTable.resourceId} = ${resourceTable.id} + left join ${ResourceTypeTable as resourceTypeTable} + on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} + where ${resourceTable.name} = ${resourceId.resourceId} + and ${resourceTypeTable.name} = ${resourceId.resourceTypeName} + """ + + listActionManagedIdentitysQuery.map(unmarshalActionManagedIdentity).list().apply() + } + + override def deleteAllActionManagedIdentitiesForResource(resourceId: FullyQualifiedResourceId, samRequestContext: SamRequestContext): IO[Unit] = + serializableWriteTransaction("deleteAllActionManagedIdentitiesForResource", samRequestContext) { implicit session => + val actionManagedIdentityTable = ActionManagedIdentityTable.syntax + val resourceTable = ResourceTable.syntax + val resourceTypeTable = ResourceTypeTable.syntax + val deleteActionManagedIdentityQuery = + samsql"""delete from ${ActionManagedIdentityTable.table} + where ${actionManagedIdentityTable.resourceId} = (select ${resourceTable.result.id} from ${ResourceTable as resourceTable} left join ${ResourceTypeTable as resourceTypeTable} on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} where ${resourceTable.name} = ${resourceId.resourceId} and ${resourceTypeTable.name} = ${resourceId.resourceTypeName})""" + deleteActionManagedIdentityQuery.update().apply() + } + + override def getAllActionManagedIdentitiesForBillingProfile( + billingProfileId: BillingProfileId, + samRequestContext: SamRequestContext + ): IO[Seq[ActionManagedIdentity]] = + readOnlyTransaction("loadActionManagedIdentitiesForResource", samRequestContext) { implicit session => + implicit val actionManagedIdentityTable: TableSyntax[ActionManagedIdentityRecord] = ActionManagedIdentityTable.syntax + implicit val managedResourceGroupTable: TableSyntax[AzureManagedResourceGroupRecord] = AzureManagedResourceGroupTable.syntax + implicit val resourceActionTable: TableSyntax[ResourceActionRecord] = ResourceActionTable.syntax + implicit val resourceTable: TableSyntax[ResourceRecord] = ResourceTable.syntax + implicit val resourceTypeTable: TableSyntax[ResourceTypeRecord] = ResourceTypeTable.syntax + + val listActionManagedIdentitysQuery = + samsql"""select ${resourceTable.result.name}, ${resourceTypeTable.result.name}, ${resourceActionTable.result.action}, ${managedResourceGroupTable.result.tenantId}, ${managedResourceGroupTable.result.subscriptionId}, ${managedResourceGroupTable.result.managedResourceGroupName}, ${managedResourceGroupTable.result.billingProfileId}, ${actionManagedIdentityTable.result.objectId}, ${actionManagedIdentityTable.result.displayName} + from ${ActionManagedIdentityTable as actionManagedIdentityTable} + left join ${ResourceActionTable as resourceActionTable} + on ${actionManagedIdentityTable.resourceActionId} = ${resourceActionTable.id} + left join ${AzureManagedResourceGroupTable as managedResourceGroupTable} + on ${actionManagedIdentityTable.managedResourceGroupId} = ${managedResourceGroupTable.id} + left join ${ResourceTable as resourceTable} + on ${actionManagedIdentityTable.resourceId} = ${resourceTable.id} + left join ${ResourceTypeTable as resourceTypeTable} + on ${resourceTable.resourceTypeId} = ${resourceTypeTable.id} + where ${managedResourceGroupTable.billingProfileId} = $billingProfileId + """ + + listActionManagedIdentitysQuery.map(unmarshalActionManagedIdentity).list().apply() + } + override def deleteAllActionManagedIdentitiesForBillingProfile(billingProfileId: BillingProfileId, samRequestContext: SamRequestContext): IO[Unit] = + serializableWriteTransaction("deleteAllActionManagedIdentitiesForManagedResourceGroup", samRequestContext) { implicit session => + val actionManagedIdentityTable = ActionManagedIdentityTable.syntax + val managedResourceGroupTable = AzureManagedResourceGroupTable.syntax + val deleteActionManagedIdentityQuery = + samsql"""delete from ${ActionManagedIdentityTable.table} + where ${actionManagedIdentityTable.managedResourceGroupId} = (select ${managedResourceGroupTable.result.id} + from ${AzureManagedResourceGroupTable as managedResourceGroupTable} + where ${managedResourceGroupTable.billingProfileId} = $billingProfileId) + """ + deleteActionManagedIdentityQuery.update().apply() + } + + private def unmarshalActionManagedIdentity(rs: WrappedResultSet)(implicit + resourceTable: TableSyntax[ResourceRecord], + resourceTypeTable: TableSyntax[ResourceTypeRecord], + resourceActionTable: TableSyntax[ResourceActionRecord], + actionManagedIdentityTable: TableSyntax[ActionManagedIdentityRecord], + managedResourceGroupTable: TableSyntax[AzureManagedResourceGroupRecord] + ) = + ActionManagedIdentity( + ActionManagedIdentityId( + FullyQualifiedResourceId(rs.get[ResourceTypeName](resourceTypeTable.resultName.name), rs.get[ResourceId](resourceTable.resultName.name)), + rs.get[ResourceAction](resourceActionTable.resultName.action), + rs.get[BillingProfileId](managedResourceGroupTable.resultName.billingProfileId) + ), + rs.get[ManagedIdentityObjectId](actionManagedIdentityTable.resultName.objectId), + rs.get[ManagedIdentityDisplayName](actionManagedIdentityTable.resultName.displayName), + ManagedResourceGroupCoordinates( + rs.get[TenantId](managedResourceGroupTable.resultName.tenantId), + rs.get[SubscriptionId](managedResourceGroupTable.resultName.subscriptionId), + rs.get[ManagedResourceGroupName](managedResourceGroupTable.resultName.managedResourceGroupName) + ) + ) + override def setUserRegisteredAt(userId: WorkbenchUserId, registeredAt: Instant, samRequestContext: SamRequestContext): IO[Unit] = serializableWriteTransaction("setUserRegisteredAt", samRequestContext) { implicit session => val u = UserTable.column diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/db/tables/ActionManagedIdentityTable.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/db/tables/ActionManagedIdentityTable.scala new file mode 100644 index 000000000..d91522998 --- /dev/null +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/db/tables/ActionManagedIdentityTable.scala @@ -0,0 +1,28 @@ +package org.broadinstitute.dsde.workbench.sam.db.tables + +import org.broadinstitute.dsde.workbench.sam.azure._ +import org.broadinstitute.dsde.workbench.sam.db.SamTypeBinders +import scalikejdbc._ + +final case class ActionManagedIdentityRecord( + resourceId: ResourcePK, + resourceActionId: ResourceActionPK, + managedResourceGroupId: ManagedResourceGroupPK, + objectId: ManagedIdentityObjectId, + displayName: ManagedIdentityDisplayName +) + +object ActionManagedIdentityTable extends SQLSyntaxSupportWithDefaultSamDB[ActionManagedIdentityRecord] { + override def tableName: String = "SAM_ACTION_MANAGED_IDENTITY" + + import SamTypeBinders._ + def apply(e: ResultName[ActionManagedIdentityRecord])(rs: WrappedResultSet): ActionManagedIdentityRecord = ActionManagedIdentityRecord( + rs.get(e.resourceId), + rs.get(e.resourceActionId), + rs.get(e.managedResourceGroupId), + rs.get(e.objectId), + rs.get(e.displayName) + ) + + def apply(p: SyntaxProvider[ActionManagedIdentityRecord])(rs: WrappedResultSet): ActionManagedIdentityRecord = apply(p.resultName)(rs) +} diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupService.scala index 716346f53..83117303b 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupService.scala @@ -50,7 +50,13 @@ class ManagedGroupService( ) validateGroupName(groupId.value) + val groupEmail = WorkbenchEmail(constructEmail(groupId.value)) for { + _ <- directoryDAO.loadSubjectFromEmail(groupEmail, samRequestContext).flatMap { + case Some(_) => + IO.raiseError(new WorkbenchExceptionWithErrorReport(ErrorReport(StatusCodes.Conflict, s"subject with email $groupEmail already exists"))) + case None => IO.pure(()) + } managedGroup <- resourceService.createResource( managedGroupType, groupId, diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala index 4a58bc621..8517b70fb 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceService.scala @@ -11,6 +11,7 @@ import org.broadinstitute.dsde.workbench.model._ import org.broadinstitute.dsde.workbench.sam._ import org.broadinstitute.dsde.workbench.sam.audit.SamAuditModelJsonSupport._ import org.broadinstitute.dsde.workbench.sam.audit._ +import org.broadinstitute.dsde.workbench.sam.azure.AzureService import org.broadinstitute.dsde.workbench.sam.dataAccess.{AccessPolicyDAO, DirectoryDAO, LoadResourceAuthDomainResult} import org.broadinstitute.dsde.workbench.sam.model._ import org.broadinstitute.dsde.workbench.sam.model.api.{ @@ -39,7 +40,8 @@ class ResourceService( private val directoryDAO: DirectoryDAO, private val cloudExtensions: CloudExtensions, val emailDomain: String, - private val allowedAdminEmailDomains: Set[String] + private val allowedAdminEmailDomains: Set[String], + private val azureService: Option[AzureService] = None )(implicit val executionContext: ExecutionContext) extends LazyLogging { @@ -333,6 +335,7 @@ class ResourceService( // remove from cloud first so a failure there does not leave sam in a bad state _ <- cloudDeletePolicies(resource, samRequestContext) + _ <- deleteActionManagedIdentitiesForResource(resource, samRequestContext) _ <- accessPolicyDAO.deleteAllResourcePolicies(resource, samRequestContext) _ <- maybeDeleteResource(resource, samRequestContext) @@ -340,6 +343,19 @@ class ResourceService( _ <- AuditLogger.logAuditEventIO(samRequestContext, ResourceEvent(ResourceDeleted, resource)) } yield () + private def deleteActionManagedIdentitiesForResource(resource: FullyQualifiedResourceId, samRequestContext: SamRequestContext): IO[Unit] = + azureService + .map { service => + for { + actionManagedIdentities <- directoryDAO.getAllActionManagedIdentitiesForResource(resource, samRequestContext) + _ <- actionManagedIdentities.toList.traverse { ami => + service.deleteActionManagedIdentity(ami.id, samRequestContext) + } + _ <- directoryDAO.deleteAllActionManagedIdentitiesForResource(resource, samRequestContext) + } yield () + } + .getOrElse(IO.unit) + /** Check if a resource has any children. If so, then throw a 400. */ def checkNoChildren(resource: FullyQualifiedResourceId, samRequestContext: SamRequestContext): IO[Unit] = listResourceChildren(resource, samRequestContext) map { list => diff --git a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala index d49ec08cc..8490a5175 100644 --- a/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala +++ b/src/main/scala/org/broadinstitute/dsde/workbench/sam/service/UserService.scala @@ -29,7 +29,8 @@ class UserService( val cloudExtensions: CloudExtensions, blockedEmailDomains: Seq[String], tosService: TosService, - azureConfig: Option[AzureServicesConfig] = None + azureConfig: Option[AzureServicesConfig] = None, + nonInvitableDomains: Seq[String] = Seq.empty )(implicit val executionContext: ExecutionContext ) extends LazyLogging { @@ -247,7 +248,7 @@ class UserService( def inviteUser(inviteeEmail: WorkbenchEmail, samRequestContext: SamRequestContext): IO[UserStatusDetails] = for { - _ <- validateEmailAddress(inviteeEmail, blockedEmailDomains) + _ <- validateEmailAddress(inviteeEmail, blockedEmailDomains, nonInvitableDomains) existingSubject <- directoryDAO.loadSubjectFromEmail(inviteeEmail, samRequestContext) createdUser <- existingSubject match { case None => createUserInternal(SamUser(genWorkbenchUserId(System.currentTimeMillis()), None, inviteeEmail, None, false), samRequestContext) @@ -436,14 +437,26 @@ class UserService( // moved this method from the UserService companion object into this class // because Mockito would not let us spy/mock the static method - def validateEmailAddress(email: WorkbenchEmail, blockedEmailDomains: Seq[String]): IO[Unit] = + def validateEmailAddress(email: WorkbenchEmail, blockedEmailDomains: Seq[String], nonInvitableDomain: Seq[String]): IO[Unit] = email.value match { - case emailString if blockedEmailDomains.exists(domain => emailString.endsWith("@" + domain) || emailString.endsWith("." + domain)) => + case emailString if matchesBadDomain(emailString, blockedEmailDomains) => IO.raiseError(new WorkbenchExceptionWithErrorReport(ErrorReport(StatusCodes.BadRequest, s"email domain not permitted [${email.value}]"))) + case emailString if matchesBadDomain(emailString, nonInvitableDomain) => + IO.raiseError( + new WorkbenchExceptionWithErrorReport( + ErrorReport( + StatusCodes.BadRequest, + s"Email domain cannot be invited [${email.value}]. If you are trying to invite a group, please make sure that group exists before adding it to a resource policy." + ) + ) + ) case UserService.emailRegex() => IO.unit case _ => IO.raiseError(new WorkbenchExceptionWithErrorReport(ErrorReport(StatusCodes.BadRequest, s"invalid email address [${email.value}]"))) } + private def matchesBadDomain(emailString: String, badDomains: Seq[String]): Boolean = + badDomains.exists(domain => emailString.endsWith("@" + domain) || emailString.endsWith("." + domain)) + def getUserAllowances(samUser: SamUser, samRequestContext: SamRequestContext): IO[SamUserAllowances] = for { tosStatus <- tosService.getTermsOfServiceComplianceStatus(samUser, samRequestContext) diff --git a/src/test/resources/reference.conf b/src/test/resources/reference.conf index 68371feb3..641f41cfb 100644 --- a/src/test/resources/reference.conf +++ b/src/test/resources/reference.conf @@ -103,7 +103,7 @@ testStuff = { azure { tenantId = "fad90753-2022-4456-9b0a-c7e5b934e408" subscriptionId = "f557c728-871d-408c-a28b-eb6b2141a087" - managedResourceGroupName = "e2e-8n6xqg" + managedResourceGroupName = "e2e-n6bgy8" } } @@ -111,8 +111,6 @@ testStuff = { oidc { authorityEndpoint = "https://accounts.google.com" oidcClientId = "some-client" - oidcClientSecret = "some-secret" - legacyGoogleClientId = "another-client" } liquibase { @@ -199,4 +197,4 @@ resourceTypes { landing-zone = { reuseIds = false } -} \ No newline at end of file +} diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/TestSupport.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/TestSupport.scala index 04f44ae44..aab7064d4 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/TestSupport.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/TestSupport.scala @@ -225,6 +225,7 @@ object TestSupport extends TestSupport { if (databaseEnabled) { dbRef.inLocalTransaction { implicit session => val tables = List( + ActionManagedIdentityTable, PolicyActionTable, PolicyRoleTable, PolicyTable, diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/api/TestSamRoutes.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/api/TestSamRoutes.scala index 392d6829a..c3a8cd542 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/api/TestSamRoutes.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/api/TestSamRoutes.scala @@ -203,7 +203,8 @@ object TestSamRoutes { mockResourceService.initResourceTypes(samRequestContext).unsafeRunSync() val mockStatusService = new StatusService(directoryDAO, cloudXtns) - val azureService = new AzureService(crlService.getOrElse(MockCrlService(Option(user))), directoryDAO, new MockAzureManagedResourceGroupDAO) + val azureService = + new AzureService(crlService.getOrElse(MockCrlService(Option(user))), directoryDAO, new MockAzureManagedResourceGroupDAO) new TestSamRoutes( mockResourceService, policyEvaluatorService, diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceSpec.scala index 785f58e11..19e5af2d4 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceSpec.scala @@ -7,19 +7,47 @@ import com.azure.resourcemanager.managedapplications.models.Plan import org.broadinstitute.dsde.workbench.model.{WorkbenchEmail, WorkbenchExceptionWithErrorReport} import org.broadinstitute.dsde.workbench.sam.Generator.genWorkbenchUserAzure import org.broadinstitute.dsde.workbench.sam.TestSupport._ -import org.broadinstitute.dsde.workbench.sam.dataAccess.{MockAzureManagedResourceGroupDAO, MockDirectoryDAO, PostgresDirectoryDAO} -import org.broadinstitute.dsde.workbench.sam.model.{UserStatus, UserStatusDetails} -import org.broadinstitute.dsde.workbench.sam.service.{NoExtensions, TosService, UserService} -import org.broadinstitute.dsde.workbench.sam.{ConnectedTest, Generator} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfterAll +import org.broadinstitute.dsde.workbench.sam.api.TestSamRoutes.SamResourceActionPatterns +import org.broadinstitute.dsde.workbench.sam.dataAccess.{ + AccessPolicyDAO, + DirectoryDAO, + MockAzureManagedResourceGroupDAO, + MockDirectoryDAO, + PostgresAccessPolicyDAO, + PostgresAzureManagedResourceGroupDAO, + PostgresDirectoryDAO +} +import org.broadinstitute.dsde.workbench.sam.model.{ + FullyQualifiedResourceId, + ResourceAction, + ResourceActionPattern, + ResourceId, + ResourceRole, + ResourceRoleName, + ResourceType, + ResourceTypeName, + UserStatus, + UserStatusDetails +} +import org.broadinstitute.dsde.workbench.sam.service.{NoExtensions, PolicyEvaluatorService, ResourceService, TosService, UserService} +import org.broadinstitute.dsde.workbench.sam.{ConnectedTest, Generator, TestSupport} +import org.mockito.scalatest.MockitoSugar +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.concurrent.ScalaFutures import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers +import java.util.UUID import scala.jdk.CollectionConverters._ -class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFlatSpecLike with Matchers with ScalaFutures with BeforeAndAfterAll { +class AzureServiceSpec(_system: ActorSystem) + extends TestKit(_system) + with AnyFlatSpecLike + with Matchers + with ScalaFutures + with BeforeAndAfterAll + with BeforeAndAfterEach + with MockitoSugar { implicit val ec = scala.concurrent.ExecutionContext.global implicit val ioRuntime = cats.effect.unsafe.IORuntime.global @@ -33,6 +61,9 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl super.afterAll() } + override def beforeEach(): Unit = + TestSupport.truncateAll + "AzureService" should "create a pet managed identity" taggedAs ConnectedTest in { val azureServicesConfig = appConfig.azureServicesConfig val janitorConfig = appConfig.janitorConfig @@ -45,6 +76,7 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl val tosService = new TosService(NoExtensions, directoryDAO, tosConfig) val userService = new UserService(directoryDAO, NoExtensions, Seq.empty, tosService) val azureTestConfig = config.getConfig("testStuff.azure") + setUpResources(directoryDAO) val azureService = new AzureService(crlService, directoryDAO, new MockAzureManagedResourceGroupDAO) // create user @@ -107,13 +139,154 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl // this is a best effort -- it will be deleted anyway by Janitor msiManager.identities().deleteById(azureRes.id()) } + def setUpResources(directoryDAO: DirectoryDAO): (ResourceService, ResourceType, ResourceAction) = { + lazy val policyDAO: AccessPolicyDAO = new PostgresAccessPolicyDAO(TestSupport.dbRef, TestSupport.dbRef) + val emailDomain = "example.com" + val ownerRoleName = ResourceRoleName("owner") + val viewAction = ResourceAction("view") + + val defaultResourceTypeActions = + Set(ResourceAction("alter_policies"), ResourceAction("delete"), ResourceAction("read_policies"), viewAction, ResourceAction("non_owner_action")) + val defaultResourceTypeActionPatterns = Set( + SamResourceActionPatterns.alterPolicies, + SamResourceActionPatterns.delete, + SamResourceActionPatterns.readPolicies, + ResourceActionPattern("view", "", false), + ResourceActionPattern("non_owner_action", "", false) + ) + val defaultResourceType = ResourceType( + ResourceTypeName(UUID.randomUUID().toString), + defaultResourceTypeActionPatterns, + Set( + ResourceRole(ownerRoleName, defaultResourceTypeActions - ResourceAction("non_owner_action")), + ResourceRole(ResourceRoleName("other"), Set(ResourceAction("view"), ResourceAction("non_owner_action"))) + ), + ownerRoleName + ) + + val resourceTypes = Map( + defaultResourceType.name -> defaultResourceType + ) + val policyEvaluatorService = PolicyEvaluatorService(emailDomain, resourceTypes, policyDAO, directoryDAO) + val resourceService = + new ResourceService(resourceTypes, policyEvaluatorService, policyDAO, directoryDAO, NoExtensions, emailDomain, Set("test.firecloud.org")) + (resourceService, defaultResourceType, viewAction) + } + + it should "create and delete an action managed identity" taggedAs ConnectedTest in { + val azureServicesConfig = appConfig.azureServicesConfig + val janitorConfig = appConfig.janitorConfig + + assume(azureServicesConfig.isDefined && janitorConfig.enabled, "-- skipping Azure test") + + // create dependencies + val directoryDAO = new PostgresDirectoryDAO(dbRef, dbRef) + val crlService = new CrlService(azureServicesConfig.get, janitorConfig) + val tosService = new TosService(NoExtensions, directoryDAO, tosConfig) + val userService = new UserService(directoryDAO, NoExtensions, Seq.empty, tosService) + val azureManagedResourceGroupDAO = new PostgresAzureManagedResourceGroupDAO(TestSupport.dbRef, TestSupport.dbRef) + val azureTestConfig = config.getConfig("testStuff.azure") + val (resourceService, defaultResourceType, viewAction) = setUpResources(directoryDAO) + val azureService = new AzureService(crlService, directoryDAO, azureManagedResourceGroupDAO) + + // create user + val defaultUser = Generator.genWorkbenchUserAzure.sample.map(_.copy(email = WorkbenchEmail("hermione.owner@test.firecloud.org"))).get + val userStatus = userService.createUser(defaultUser, samRequestContext).unsafeRunSync() + userStatus shouldBe UserStatus( + UserStatusDetails(defaultUser.id, defaultUser.email), + Map("tosAccepted" -> false, "adminEnabled" -> true, "ldap" -> true, "allUsersGroup" -> true, "google" -> true) + ) + + // user should exist in postgres + directoryDAO.loadUser(defaultUser.id, samRequestContext).unsafeRunSync() shouldBe Some(defaultUser.copy(enabled = true)) + + // Create the resource type and resource + val resourceName = ResourceId("resource") + val resource = FullyQualifiedResourceId(defaultResourceType.name, resourceName) + resourceService.createResourceType(defaultResourceType, samRequestContext).unsafeRunSync() + runAndWait(resourceService.createResource(defaultResourceType, resourceName, defaultUser, samRequestContext)) + + // Create the "billing profile" resource. There's no actual need for it to be a "billing profile", we just need a resource to attach the managed resource group to. + val billingProfileId = BillingProfileId("de38969d-f41b-4b80-99ba-db481e6db1cf") + val billingProfileResource = runAndWait(resourceService.createResource(defaultResourceType, billingProfileId.asResourceId, defaultUser, samRequestContext)) + + // action managed identity should not exist in postgres + val tenantId = TenantId(azureTestConfig.getString("tenantId")) + val subscriptionId = SubscriptionId(azureTestConfig.getString("subscriptionId")) + val managedResourceGroupName = ManagedResourceGroupName(azureTestConfig.getString("managedResourceGroupName")) + val mrgCoordinates = ManagedResourceGroupCoordinates(tenantId, subscriptionId, managedResourceGroupName) + val managedResourceGroup = ManagedResourceGroup(mrgCoordinates, billingProfileId) + runAndWait(azureService.createManagedResourceGroup(managedResourceGroup, samRequestContext.copy(samUser = Some(defaultUser)))) + + val actionManagedIdentityId = + ActionManagedIdentityId(resource, viewAction, billingProfileId) + directoryDAO.loadActionManagedIdentity(actionManagedIdentityId, samRequestContext).unsafeRunSync() shouldBe None + + // managed identity should not exist in Azure + val msiManager = crlService.buildMsiManager(tenantId, subscriptionId).unsafeRunSync() + msiManager.identities().listByResourceGroup(managedResourceGroupName.value).asScala.toList.exists { i => + i.name() === azureService.toManagedIdentityNameFromAmiId(actionManagedIdentityId) + } shouldBe false + + // create action managed identity + val (res, created) = azureService + .getOrCreateActionManagedIdentity(resource, viewAction, billingProfileId, samRequestContext.copy(samUser = Some(defaultUser))) + .unsafeRunSync() + created shouldBe true + res.id shouldBe actionManagedIdentityId + res.displayName shouldBe ManagedIdentityDisplayName(s"${resource.resourceId.value}-${viewAction.value}") + + // action managed identity should now exist in postgres + directoryDAO.loadActionManagedIdentity(actionManagedIdentityId, samRequestContext).unsafeRunSync() shouldBe Some(res) + + // managed identity should now exist in azure + val azureRes = msiManager.identities().getById(res.objectId.value) + azureRes should not be null + azureRes.tenantId() shouldBe tenantId.value + azureRes.resourceGroupName() shouldBe managedResourceGroupName.value + azureRes.id() shouldBe res.objectId.value + azureRes.name() shouldBe res.displayName.value + + // call getOrCreate again + val (res2, created2) = azureService.getOrCreateActionManagedIdentity(resource, viewAction, billingProfileId, samRequestContext).unsafeRunSync() + created2 shouldBe false + res2 shouldBe res + + // pet should still exist in postgres and azure + directoryDAO.loadActionManagedIdentity(actionManagedIdentityId, samRequestContext).unsafeRunSync() shouldBe Some(res2) + val azureRes2 = msiManager.identities().getById(res.objectId.value) + azureRes2 should not be null + azureRes2.tenantId() shouldBe tenantId.value + azureRes2.resourceGroupName() shouldBe managedResourceGroupName.value + azureRes2.id() shouldBe res2.objectId.value + azureRes2.name() shouldBe res2.displayName.value + + // delete action managed identity + azureService.deleteActionManagedIdentity(actionManagedIdentityId, samRequestContext).unsafeRunSync() + + // action managed identity should not exist in postgres + directoryDAO.loadActionManagedIdentity(actionManagedIdentityId, samRequestContext).unsafeRunSync() shouldBe None + + // managed identity should not exist in Azure + msiManager.identities().listByResourceGroup(managedResourceGroupName.value).asScala.toList.exists { i => + i.name() === azureService.toManagedIdentityNameFromAmiId(actionManagedIdentityId) + } shouldBe false + + // delete managed identity from Azure + // this is a best effort -- it will be deleted anyway by Janitor + msiManager.identities().deleteById(azureRes.id()) + } "createManagedResourceGroup" should "create a managed resource group" in { val user = Generator.genWorkbenchUserAzure.sample val managedResourceGroup = Generator.genManagedResourceGroup.sample.get val mockMrgDAO = new MockAzureManagedResourceGroupDAO val svc = - new AzureService(MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), new MockDirectoryDAO(), mockMrgDAO) + new AzureService( + MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), + new MockDirectoryDAO(), + mockMrgDAO + ) svc.createManagedResourceGroup(managedResourceGroup, samRequestContext.copy(samUser = user)).unsafeRunSync() mockMrgDAO.mrgs should contain(managedResourceGroup) @@ -141,7 +314,7 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl BillingProfileId("de38969d-f41b-4b80-99ba-db481e6db1cf") ) - val user = Generator.genWorkbenchUserAzure.sample.map(_.copy(email = WorkbenchEmail("rtitlefireclouddev@gmail.com"))) + val user = Generator.genWorkbenchUserAzure.sample.map(_.copy(email = WorkbenchEmail("hermione.owner@test.firecloud.org"))) azureService.createManagedResourceGroup(managedResourceGroup, samRequestContext.copy(samUser = user)).unsafeRunSync() mockMrgDAO.mrgs should contain(managedResourceGroup) @@ -152,7 +325,11 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl val managedResourceGroup = Generator.genManagedResourceGroup.sample.get val mockMrgDAO = new MockAzureManagedResourceGroupDAO val svc = - new AzureService(MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), new MockDirectoryDAO(), mockMrgDAO) + new AzureService( + MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), + new MockDirectoryDAO(), + mockMrgDAO + ) mockMrgDAO.insertManagedResourceGroup(managedResourceGroup.copy(billingProfileId = BillingProfileId("no the same")), samRequestContext).unsafeRunSync() val err = intercept[WorkbenchExceptionWithErrorReport] { @@ -167,7 +344,11 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl val managedResourceGroup = Generator.genManagedResourceGroup.sample.get val mockMrgDAO = new MockAzureManagedResourceGroupDAO val svc = - new AzureService(MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), new MockDirectoryDAO(), mockMrgDAO) + new AzureService( + MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), + new MockDirectoryDAO(), + mockMrgDAO + ) mockMrgDAO .insertManagedResourceGroup( @@ -308,7 +489,11 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl val managedResourceGroup = Generator.genManagedResourceGroup.sample.get val mockMrgDAO = new MockAzureManagedResourceGroupDAO val svc = - new AzureService(MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), new MockDirectoryDAO(), mockMrgDAO) + new AzureService( + MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), + new MockDirectoryDAO(), + mockMrgDAO + ) val err = intercept[WorkbenchExceptionWithErrorReport] { svc @@ -362,7 +547,11 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl val managedResourceGroup = Generator.genManagedResourceGroup.sample.get val mockMrgDAO = new MockAzureManagedResourceGroupDAO val svc = - new AzureService(MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), new MockDirectoryDAO(), mockMrgDAO) + new AzureService( + MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), + new MockDirectoryDAO(), + mockMrgDAO + ) mockMrgDAO.insertManagedResourceGroup(managedResourceGroup, samRequestContext).unsafeRunSync() svc.deleteManagedResourceGroup(managedResourceGroup.billingProfileId, samRequestContext.copy(samUser = user)).unsafeRunSync() @@ -374,7 +563,11 @@ class AzureServiceSpec(_system: ActorSystem) extends TestKit(_system) with AnyFl val mockMrgDAO = new MockAzureManagedResourceGroupDAO val managedResourceGroup = Generator.genManagedResourceGroup.sample.get val svc = - new AzureService(MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), new MockDirectoryDAO(), mockMrgDAO) + new AzureService( + MockCrlService(user, managedResourceGroup.managedResourceGroupCoordinates.managedResourceGroupName), + new MockDirectoryDAO(), + mockMrgDAO + ) val err = intercept[WorkbenchExceptionWithErrorReport] { svc.deleteManagedResourceGroup(managedResourceGroup.billingProfileId, samRequestContext.copy(samUser = user)).unsafeRunSync() diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceUnitSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceUnitSpec.scala new file mode 100644 index 000000000..ad8f1eae6 --- /dev/null +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/azure/AzureServiceUnitSpec.scala @@ -0,0 +1,225 @@ +package org.broadinstitute.dsde.workbench.sam.azure + +import cats.effect.IO +import cats.effect.unsafe.implicits.global +import com.azure.core.http.rest.PagedIterable +import com.azure.core.management.Region +import com.azure.core.util.Context +import com.azure.resourcemanager.managedapplications.ApplicationManager +import com.azure.resourcemanager.managedapplications.models.{Application, Applications, Plan} +import com.azure.resourcemanager.msi.MsiManager +import com.azure.resourcemanager.msi.models.Identity.DefinitionStages +import com.azure.resourcemanager.msi.models.{Identities, Identity} +import com.azure.resourcemanager.resources.ResourceManager +import com.azure.resourcemanager.resources.models.{ResourceGroup, ResourceGroups} +import org.broadinstitute.dsde.workbench.sam.config.ManagedAppPlan +import org.broadinstitute.dsde.workbench.sam.dataAccess.{AzureManagedResourceGroupDAO, DirectoryDAO} +import org.broadinstitute.dsde.workbench.sam.model.{FullyQualifiedResourceId, ResourceAction, ResourceId, ResourceTypeName} +import org.broadinstitute.dsde.workbench.sam.{Generator, PropertyBasedTesting, TestSupport} +import org.mockito.scalatest.MockitoSugar +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.should.Matchers + +import java.util.UUID + +class AzureServiceUnitSpec extends AnyFreeSpec with Matchers with ScalaFutures with TestSupport with MockitoSugar with PropertyBasedTesting { + + private val dummyUser = Generator.genWorkbenchUserBoth.sample.get + + "AzureService" - { + "Action Managed Identities" - { + "create an Action Managed Identity" in { + // Arrange + val mockCrlService = mock[CrlService] + val mockDirectoryDAO = mock[DirectoryDAO] + val mockAzureManagedResourceGroupDAO = mock[AzureManagedResourceGroupDAO] + val mockMsiManager = mock[MsiManager] + val mockApplicationManager = mock[ApplicationManager] + val mockApplications = mock[Applications] + val mockPagedResponse = mock[PagedIterable[Application]] + val mockApplication = mock[Application] + val mockPlan = mock[Plan] + val mockResourceManager = mock[ResourceManager] + val mockResourceGroups = mock[ResourceGroups] + val mockResourceGroup = mock[ResourceGroup] + val mockIdentities = mock[Identities] + val mockIdentitiesBlank = mock[DefinitionStages.Blank] + val mockIdentityWithGroup = mock[DefinitionStages.WithGroup] + val mockIdentityWithCreate = mock[DefinitionStages.WithCreate] + val mockIdentity = mock[Identity] + + val azureService = new AzureService(mockCrlService, mockDirectoryDAO, mockAzureManagedResourceGroupDAO) + + val testMrgCoordinates = ManagedResourceGroupCoordinates( + TenantId(UUID.randomUUID().toString), + SubscriptionId(UUID.randomUUID().toString), + ManagedResourceGroupName(UUID.randomUUID().toString) + ) + val testBillingProfileId = BillingProfileId(UUID.randomUUID().toString) + val testAction = ResourceAction("testAction") + val testResource = FullyQualifiedResourceId(ResourceTypeName("testResourceType"), ResourceId("testResource")) + val testActionManagedIdentityId = ActionManagedIdentityId(testResource, testAction, testBillingProfileId) + val testDisplayName = azureService.toManagedIdentityNameFromAmiId(testActionManagedIdentityId) + val testObjectId = ManagedIdentityObjectId(UUID.randomUUID().toString) + val testActionManagedIdentity = ActionManagedIdentity(testActionManagedIdentityId, testObjectId, testDisplayName, testMrgCoordinates) + val testManagedAppPlan = ManagedAppPlan("testPlan", "testPublisher", UUID.randomUUID().toString) + val testSamRequestContext = samRequestContext.copy(samUser = Some(dummyUser)) + val testMrgId = UUID.randomUUID().toString + + when(mockDirectoryDAO.loadActionManagedIdentity(testActionManagedIdentityId, testSamRequestContext)).thenReturn(IO.pure(None)) + when(mockAzureManagedResourceGroupDAO.getManagedResourceGroupByBillingProfileId(testBillingProfileId, testSamRequestContext)) + .thenReturn(IO.pure(Some(ManagedResourceGroup(testMrgCoordinates, testBillingProfileId)))) + when(mockCrlService.buildMsiManager(testMrgCoordinates.tenantId, testMrgCoordinates.subscriptionId)).thenReturn(IO.pure(mockMsiManager)) + when(mockCrlService.buildApplicationManager(testMrgCoordinates.tenantId, testMrgCoordinates.subscriptionId)).thenReturn(IO.pure(mockApplicationManager)) + when(mockCrlService.getManagedAppPlans).thenReturn(Seq(testManagedAppPlan)) + when(mockApplicationManager.applications()).thenReturn(mockApplications) + when(mockApplications.list()).thenReturn(mockPagedResponse) + when(mockPagedResponse.iterator()).thenReturn(java.util.List.of(mockApplication).iterator()) + when(mockApplication.plan()).thenReturn(mockPlan) + when(mockApplication.managedResourceGroupId()).thenReturn(testMrgId) + when(mockApplication.parameters()).thenReturn(java.util.Map.of(testManagedAppPlan.authorizedUserKey, java.util.Map.of("value", dummyUser.email.value))) + when(mockPlan.name()).thenReturn(testManagedAppPlan.name) + when(mockPlan.publisher()).thenReturn(testManagedAppPlan.publisher) + when(mockCrlService.buildResourceManager(testMrgCoordinates.tenantId, testMrgCoordinates.subscriptionId)).thenReturn(IO.pure(mockResourceManager)) + when(mockResourceManager.resourceGroups()).thenReturn(mockResourceGroups) + when(mockResourceGroups.getByName(testMrgCoordinates.managedResourceGroupName.value)).thenReturn(mockResourceGroup) + when(mockResourceGroup.region()).thenReturn(Region.US_EAST) + when(mockResourceGroup.id()).thenReturn(testMrgId) + when(mockMsiManager.identities()).thenReturn(mockIdentities) + when(mockIdentities.define(testDisplayName.value)).thenReturn(mockIdentitiesBlank) + when(mockIdentitiesBlank.withRegion(Region.US_EAST)).thenReturn(mockIdentityWithGroup) + when(mockIdentityWithGroup.withExistingResourceGroup(testMrgCoordinates.managedResourceGroupName.value)).thenReturn(mockIdentityWithCreate) + when(mockIdentityWithCreate.withTags(any[java.util.Map[String, String]])).thenReturn(mockIdentityWithCreate) + when(mockIdentityWithCreate.create(any[Context])).thenReturn(mockIdentity) + when(mockIdentity.id()).thenReturn(testObjectId.value) + when(mockIdentity.name()).thenReturn(testDisplayName.value) + when(mockDirectoryDAO.createActionManagedIdentity(testActionManagedIdentity, testSamRequestContext)).thenReturn(IO.pure(testActionManagedIdentity)) + + // Act + val (ami, created) = + azureService.getOrCreateActionManagedIdentity(testResource, testAction, testBillingProfileId, testSamRequestContext).unsafeRunSync() + + // Assert + ami should be(testActionManagedIdentity) + created should be(true) + } + + "retrieve an existing Action Managed Identity" in { + // Arrange + val mockCrlService = mock[CrlService] + val mockDirectoryDAO = mock[DirectoryDAO] + val mockAzureManagedResourceGroupDAO = mock[AzureManagedResourceGroupDAO] + + val azureService = new AzureService(mockCrlService, mockDirectoryDAO, mockAzureManagedResourceGroupDAO) + + val testMrgCoordinates = ManagedResourceGroupCoordinates( + TenantId(UUID.randomUUID().toString), + SubscriptionId(UUID.randomUUID().toString), + ManagedResourceGroupName(UUID.randomUUID().toString) + ) + val testBillingProfileId = BillingProfileId(UUID.randomUUID().toString) + val testAction = ResourceAction("testAction") + val testResource = FullyQualifiedResourceId(ResourceTypeName("testResourceType"), ResourceId("testResource")) + val testActionManagedIdentityId = ActionManagedIdentityId(testResource, testAction, testBillingProfileId) + val testDisplayName = azureService.toManagedIdentityNameFromAmiId(testActionManagedIdentityId) + val testObjectId = ManagedIdentityObjectId(UUID.randomUUID().toString) + val testActionManagedIdentity = ActionManagedIdentity(testActionManagedIdentityId, testObjectId, testDisplayName, testMrgCoordinates) + + when(mockDirectoryDAO.loadActionManagedIdentity(testActionManagedIdentityId, samRequestContext)).thenReturn(IO.pure(Option(testActionManagedIdentity))) + + // Act + val (ami, created) = + azureService.getOrCreateActionManagedIdentity(testResource, testAction, testBillingProfileId, samRequestContext).unsafeRunSync() + + // Assert + ami should be(testActionManagedIdentity) + created should be(false) + verify(mockCrlService, never).buildMsiManager(testMrgCoordinates.tenantId, testMrgCoordinates.subscriptionId) + } + + "delete an existing Action Managed Identity" in { + // Arrange + val mockCrlService = mock[CrlService] + val mockDirectoryDAO = mock[DirectoryDAO] + val mockAzureManagedResourceGroupDAO = mock[AzureManagedResourceGroupDAO] + val mockMsiManager = mock[MsiManager] + val mockIdentities = mock[Identities] + + val azureService = new AzureService(mockCrlService, mockDirectoryDAO, mockAzureManagedResourceGroupDAO) + + val testMrgCoordinates = ManagedResourceGroupCoordinates( + TenantId(UUID.randomUUID().toString), + SubscriptionId(UUID.randomUUID().toString), + ManagedResourceGroupName(UUID.randomUUID().toString) + ) + val testBillingProfileId = BillingProfileId(UUID.randomUUID().toString) + val testAction = ResourceAction("testAction") + val testResource = FullyQualifiedResourceId(ResourceTypeName("testResourceType"), ResourceId("testResource")) + val testActionManagedIdentityId = ActionManagedIdentityId(testResource, testAction, testBillingProfileId) + val testDisplayName = azureService.toManagedIdentityNameFromAmiId(testActionManagedIdentityId) + val testObjectId = ManagedIdentityObjectId(UUID.randomUUID().toString) + val testActionManagedIdentity = ActionManagedIdentity(testActionManagedIdentityId, testObjectId, testDisplayName, testMrgCoordinates) + + when(mockDirectoryDAO.loadActionManagedIdentity(testActionManagedIdentityId, samRequestContext)).thenReturn(IO.pure(Option(testActionManagedIdentity))) + when(mockCrlService.buildMsiManager(testMrgCoordinates.tenantId, testMrgCoordinates.subscriptionId)).thenReturn(IO.pure(mockMsiManager)) + when(mockMsiManager.identities()).thenReturn(mockIdentities) + doNothing.when(mockIdentities).deleteById(testObjectId.value) + when(mockDirectoryDAO.deleteActionManagedIdentity(testActionManagedIdentityId, samRequestContext)).thenReturn(IO.pure(())) + + // Act & Assert + azureService.deleteActionManagedIdentity(testActionManagedIdentityId, samRequestContext).unsafeRunSync() + } + } + + "Managed Resource Groups" - { + "delete action managed identities when deleting managed resource group" in { + // Arrange + val mockCrlService = mock[CrlService] + val mockDirectoryDAO = mock[DirectoryDAO] + val mockAzureManagedResourceGroupDAO = mock[AzureManagedResourceGroupDAO] + val mockMsiManager = mock[MsiManager] + val mockIdentities = mock[Identities] + + val azureService = new AzureService(mockCrlService, mockDirectoryDAO, mockAzureManagedResourceGroupDAO) + + val testMrgCoordinates = ManagedResourceGroupCoordinates( + TenantId(UUID.randomUUID().toString), + SubscriptionId(UUID.randomUUID().toString), + ManagedResourceGroupName(UUID.randomUUID().toString) + ) + val testBillingProfileId = BillingProfileId(UUID.randomUUID().toString) + val testManagedResourceGroup = ManagedResourceGroup(testMrgCoordinates, testBillingProfileId) + val testAction = ResourceAction("testAction") + val testResource = FullyQualifiedResourceId(ResourceTypeName("testResourceType"), ResourceId("testResource")) + val testActionManagedIdentityId = ActionManagedIdentityId(testResource, testAction, testBillingProfileId) + val testDisplayName = azureService.toManagedIdentityNameFromAmiId(testActionManagedIdentityId) + val testObjectId = ManagedIdentityObjectId(UUID.randomUUID().toString) + val testActionManagedIdentity = ActionManagedIdentity(testActionManagedIdentityId, testObjectId, testDisplayName, testMrgCoordinates) + + val testAction2 = ResourceAction("testAction2") + val testDisplayName2 = azureService.toManagedIdentityNameFromAmiId(testActionManagedIdentityId) + val testObjectId2 = ManagedIdentityObjectId(UUID.randomUUID().toString) + val testActionManagedIdentityId2 = ActionManagedIdentityId(testResource, testAction2, testBillingProfileId) + val testActionManagedIdentity2 = ActionManagedIdentity(testActionManagedIdentityId2, testObjectId2, testDisplayName2, testMrgCoordinates) + + when(mockDirectoryDAO.getAllActionManagedIdentitiesForBillingProfile(testActionManagedIdentityId.billingProfileId, samRequestContext)) + .thenReturn(IO.pure(Seq(testActionManagedIdentity, testActionManagedIdentity2))) + when(mockAzureManagedResourceGroupDAO.getManagedResourceGroupByBillingProfileId(testBillingProfileId, samRequestContext)) + .thenReturn(IO.pure(Some(testManagedResourceGroup))) + when(mockCrlService.buildMsiManager(testMrgCoordinates.tenantId, testMrgCoordinates.subscriptionId)).thenReturn(IO.pure(mockMsiManager)) + when(mockMsiManager.identities()).thenReturn(mockIdentities) + doNothing.when(mockIdentities).deleteById(testObjectId.value) + doNothing.when(mockIdentities).deleteById(testObjectId2.value) + when(mockDirectoryDAO.loadActionManagedIdentity(testActionManagedIdentityId, samRequestContext)).thenReturn(IO.pure(Some(testActionManagedIdentity))) + when(mockDirectoryDAO.loadActionManagedIdentity(testActionManagedIdentityId2, samRequestContext)).thenReturn(IO.pure(Some(testActionManagedIdentity2))) + when(mockDirectoryDAO.deleteActionManagedIdentity(testActionManagedIdentityId, samRequestContext)).thenReturn(IO.pure(())) + when(mockDirectoryDAO.deleteActionManagedIdentity(testActionManagedIdentityId2, samRequestContext)).thenReturn(IO.pure(())) + when(mockAzureManagedResourceGroupDAO.deleteManagedResourceGroup(testBillingProfileId, samRequestContext)).thenReturn(IO.pure(1)) + + // Act & Assert + azureService.deleteManagedResourceGroup(testBillingProfileId, samRequestContext).unsafeRunSync() + } + } + } +} diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala index 2ebef8165..58b4aa77c 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/MockDirectoryDAO.scala @@ -6,10 +6,17 @@ import cats.implicits._ import org.broadinstitute.dsde.workbench.model._ import org.broadinstitute.dsde.workbench.model.google.ServiceAccountSubjectId import org.broadinstitute.dsde.workbench.sam._ -import org.broadinstitute.dsde.workbench.sam.azure.{ManagedIdentityObjectId, PetManagedIdentity, PetManagedIdentityId} +import org.broadinstitute.dsde.workbench.sam.azure.{ + ActionManagedIdentity, + ActionManagedIdentityId, + BillingProfileId, + ManagedIdentityObjectId, + PetManagedIdentity, + PetManagedIdentityId +} import org.broadinstitute.dsde.workbench.sam.db.tables.TosTable import org.broadinstitute.dsde.workbench.sam.model.api.{AdminUpdateUserRequest, SamUser, SamUserAttributes} -import org.broadinstitute.dsde.workbench.sam.model.{AccessPolicy, BasicWorkbenchGroup, SamUserTos} +import org.broadinstitute.dsde.workbench.sam.model.{AccessPolicy, BasicWorkbenchGroup, FullyQualifiedResourceId, ResourceAction, SamUserTos} import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext import java.time.Instant @@ -397,6 +404,38 @@ class MockDirectoryDAO(val groups: mutable.Map[WorkbenchGroupIdentity, Workbench override def getUserFromPetManagedIdentity(petManagedIdentityObjectId: ManagedIdentityObjectId, samRequestContext: SamRequestContext): IO[Option[SamUser]] = IO.pure(None) + override def createActionManagedIdentity(actionManagedIdentity: ActionManagedIdentity, samRequestContext: SamRequestContext): IO[ActionManagedIdentity] = ??? + + override def loadActionManagedIdentity( + actionManagedIdentityId: ActionManagedIdentityId, + samRequestContext: SamRequestContext + ): IO[Option[ActionManagedIdentity]] = ??? + + override def loadActionManagedIdentity( + resource: FullyQualifiedResourceId, + action: ResourceAction, + samRequestContext: SamRequestContext + ): IO[Option[ActionManagedIdentity]] = ??? + + override def updateActionManagedIdentity(actionManagedIdentity: ActionManagedIdentity, samRequestContext: SamRequestContext): IO[ActionManagedIdentity] = ??? + + override def deleteActionManagedIdentity(actionManagedIdentityId: ActionManagedIdentityId, samRequestContext: SamRequestContext): IO[Unit] = ??? + + override def getAllActionManagedIdentitiesForResource( + resourceId: FullyQualifiedResourceId, + samRequestContext: SamRequestContext + ): IO[Seq[ActionManagedIdentity]] = IO.pure(Seq.empty) + + override def deleteAllActionManagedIdentitiesForResource(resourceId: FullyQualifiedResourceId, samRequestContext: SamRequestContext): IO[Unit] = + ??? + + override def getAllActionManagedIdentitiesForBillingProfile( + billingProfileId: BillingProfileId, + samRequestContext: SamRequestContext + ): IO[Seq[ActionManagedIdentity]] = IO.pure(Seq.empty) + + override def deleteAllActionManagedIdentitiesForBillingProfile(billingProfileId: BillingProfileId, samRequestContext: SamRequestContext): IO[Unit] = ??? + override def setUserRegisteredAt(userId: WorkbenchUserId, registeredAt: Instant, samRequestContext: SamRequestContext): IO[Unit] = loadUser(userId, samRequestContext).map { case None => diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala index 182f22dbc..f277f595f 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/dataAccess/PostgresDirectoryDAOSpec.scala @@ -26,6 +26,7 @@ import scala.concurrent.duration._ class PostgresDirectoryDAOSpec extends RetryableAnyFreeSpec with Matchers with BeforeAndAfterEach with TimeMatchers with OptionValues { val dao = new PostgresDirectoryDAO(TestSupport.dbRef, TestSupport.dbRef) val policyDAO = new PostgresAccessPolicyDAO(TestSupport.dbRef, TestSupport.dbRef) + val azureManagedResourceGroupDAO = new PostgresAzureManagedResourceGroupDAO(TestSupport.dbRef, TestSupport.dbRef) val defaultGroupName: WorkbenchGroupName = WorkbenchGroupName("group") val defaultGroup: BasicWorkbenchGroup = BasicWorkbenchGroup(defaultGroupName, Set.empty, WorkbenchEmail("foo@bar.com")) @@ -66,6 +67,27 @@ class PostgresDirectoryDAOSpec extends RetryableAnyFreeSpec with Matchers with B public = false ) + val defaultTenantId = TenantId("testTenant") + val defaultSubscriptionId = SubscriptionId(UUID.randomUUID().toString) + val defaultManagedResourceGroupName = ManagedResourceGroupName("mrg-test") + val defaultManagedResourceGroupCoordinates = ManagedResourceGroupCoordinates(defaultTenantId, defaultSubscriptionId, defaultManagedResourceGroupName) + val defaultBillingProfileId = BillingProfileId(UUID.randomUUID().toString) + val defaultBillingProfileResource = defaultResource.copy(resourceId = defaultBillingProfileId.asResourceId) + val defaultManagedResourceGroup = ManagedResourceGroup(defaultManagedResourceGroupCoordinates, defaultBillingProfileId) + + val defaultActionManagedIdentities: Set[ActionManagedIdentity] = Set(readAction, writeAction).map(action => + ActionManagedIdentity( + ActionManagedIdentityId( + FullyQualifiedResourceId(defaultResource.resourceTypeName, defaultResource.resourceId), + action, + defaultBillingProfileId + ), + ManagedIdentityObjectId(UUID.randomUUID().toString), + ManagedIdentityDisplayName(s"whoCares-$action"), + defaultManagedResourceGroupCoordinates + ) + ) + override protected def beforeEach(): Unit = TestSupport.truncateAll @@ -1886,5 +1908,77 @@ class PostgresDirectoryDAOSpec extends RetryableAnyFreeSpec with Matchers with B retrievedAttributes should be(Some(upsertedAttributes)) } } + + "Action Managed Identities" - { + "can be individually created, read, updated, and deleted" in { + assume(databaseEnabled, databaseEnabledClue) + policyDAO.createResourceType(resourceType, samRequestContext).unsafeRunSync() + policyDAO.createResource(defaultResource, samRequestContext).unsafeRunSync() + policyDAO.createResource(defaultBillingProfileResource, samRequestContext).unsafeRunSync() + azureManagedResourceGroupDAO.insertManagedResourceGroup(defaultManagedResourceGroup, samRequestContext).unsafeRunSync() + + defaultActionManagedIdentities.map(dao.createActionManagedIdentity(_, samRequestContext).unsafeRunSync()) + + val readActionManagedIdentity = defaultActionManagedIdentities.find(_.id.action == readAction) + val loadedReadActionManagedIdentity = dao.loadActionManagedIdentity(readActionManagedIdentity.get.id, samRequestContext).unsafeRunSync() + loadedReadActionManagedIdentity should be(readActionManagedIdentity) + + val writeActionManagedIdentity = defaultActionManagedIdentities.find(_.id.action == writeAction) + val loadedWriteActionManagedIdentity = dao.loadActionManagedIdentity(writeActionManagedIdentity.get.id, samRequestContext).unsafeRunSync() + loadedWriteActionManagedIdentity should be(writeActionManagedIdentity) + + val updatedActionManagedIdentity = writeActionManagedIdentity.get.copy( + objectId = ManagedIdentityObjectId(UUID.randomUUID().toString), + displayName = ManagedIdentityDisplayName("newDisplayName") + ) + + dao.updateActionManagedIdentity(updatedActionManagedIdentity, samRequestContext).unsafeRunSync() + + val loadedUpdatedActionManagedIdentity = dao.loadActionManagedIdentity(updatedActionManagedIdentity.id, samRequestContext).unsafeRunSync() + loadedUpdatedActionManagedIdentity should be(Some(updatedActionManagedIdentity)) + + dao.deleteActionManagedIdentity(readActionManagedIdentity.get.id, samRequestContext).unsafeRunSync() + dao.deleteActionManagedIdentity(writeActionManagedIdentity.get.id, samRequestContext).unsafeRunSync() + + dao.loadActionManagedIdentity(readActionManagedIdentity.get.id, samRequestContext).unsafeRunSync() should be(None) + dao.loadActionManagedIdentity(writeActionManagedIdentity.get.id, samRequestContext).unsafeRunSync() should be(None) + } + + "can be loaded for a resource and action" in { + assume(databaseEnabled, databaseEnabledClue) + policyDAO.createResourceType(resourceType, samRequestContext).unsafeRunSync() + policyDAO.createResource(defaultResource, samRequestContext).unsafeRunSync() + policyDAO.createResource(defaultBillingProfileResource, samRequestContext).unsafeRunSync() + azureManagedResourceGroupDAO.insertManagedResourceGroup(defaultManagedResourceGroup, samRequestContext).unsafeRunSync() + + defaultActionManagedIdentities.map(dao.createActionManagedIdentity(_, samRequestContext).unsafeRunSync()) + + val readActionManagedIdentity = defaultActionManagedIdentities.find(_.id.action == readAction) + val loadedReadActionManagedIdentity = dao.loadActionManagedIdentity(defaultResource.fullyQualifiedId, readAction, samRequestContext).unsafeRunSync() + loadedReadActionManagedIdentity should be(readActionManagedIdentity) + + val writeActionManagedIdentity = defaultActionManagedIdentities.find(_.id.action == writeAction) + val loadedWriteActionManagedIdentity = dao.loadActionManagedIdentity(defaultResource.fullyQualifiedId, writeAction, samRequestContext).unsafeRunSync() + loadedWriteActionManagedIdentity should be(writeActionManagedIdentity) + } + + "can be read, and deleted en mass for a resource" in { + assume(databaseEnabled, databaseEnabledClue) + policyDAO.createResourceType(resourceType, samRequestContext).unsafeRunSync() + policyDAO.createResource(defaultResource, samRequestContext).unsafeRunSync() + policyDAO.createResource(defaultBillingProfileResource, samRequestContext).unsafeRunSync() + azureManagedResourceGroupDAO.insertManagedResourceGroup(defaultManagedResourceGroup, samRequestContext).unsafeRunSync() + + defaultActionManagedIdentities.map(dao.createActionManagedIdentity(_, samRequestContext).unsafeRunSync()) + + val bothLoadedServiceAccounts = + dao.getAllActionManagedIdentitiesForResource(defaultResource.fullyQualifiedId, samRequestContext).unsafeRunSync().toSet + bothLoadedServiceAccounts should be(defaultActionManagedIdentities) + + dao.deleteAllActionManagedIdentitiesForResource(defaultResource.fullyQualifiedId, samRequestContext).unsafeRunSync() + + dao.getAllActionManagedIdentitiesForResource(defaultResource.fullyQualifiedId, samRequestContext).unsafeRunSync() should be(Seq.empty) + } + } } } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupServiceSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupServiceSpec.scala index 85b90173a..e734a510b 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupServiceSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ManagedGroupServiceSpec.scala @@ -152,7 +152,7 @@ class ManagedGroupServiceSpec val exception = intercept[WorkbenchExceptionWithErrorReport] { runAndWait(managedGroupService.createManagedGroup(ResourceId(groupName), dummyUser, samRequestContext = samRequestContext)) } - exception.getMessage should include("A resource of this type and name already exists") + exception.getMessage should include(s"subject with email $groupName@$testDomain already exists") managedGroupService.loadManagedGroup(resourceId, samRequestContext).unsafeRunSync() shouldEqual None } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala index 5c4801a9d..636116e8e 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/ResourceServiceSpec.scala @@ -13,8 +13,28 @@ import org.broadinstitute.dsde.workbench.sam import org.broadinstitute.dsde.workbench.sam.Generator._ import org.broadinstitute.dsde.workbench.sam.TestSupport.{databaseEnabled, databaseEnabledClue} import org.broadinstitute.dsde.workbench.sam.audit._ +import org.broadinstitute.dsde.workbench.sam.azure.{ + ActionManagedIdentity, + ActionManagedIdentityId, + AzureService, + BillingProfileId, + ManagedIdentityDisplayName, + ManagedIdentityObjectId, + ManagedResourceGroup, + ManagedResourceGroupCoordinates, + ManagedResourceGroupName, + SubscriptionId, + TenantId +} import org.broadinstitute.dsde.workbench.sam.config.AppConfig.resourceTypeReader -import org.broadinstitute.dsde.workbench.sam.dataAccess.{AccessPolicyDAO, DirectoryDAO, PostgresAccessPolicyDAO, PostgresDirectoryDAO} +import org.broadinstitute.dsde.workbench.sam.dataAccess.{ + AccessPolicyDAO, + AzureManagedResourceGroupDAO, + DirectoryDAO, + PostgresAccessPolicyDAO, + PostgresAzureManagedResourceGroupDAO, + PostgresDirectoryDAO +} import org.broadinstitute.dsde.workbench.sam.model._ import org.broadinstitute.dsde.workbench.sam.model.api._ import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext @@ -51,6 +71,7 @@ class ResourceServiceSpec lazy val dirDAO: DirectoryDAO = new PostgresDirectoryDAO(TestSupport.dbRef, TestSupport.dbRef) lazy val policyDAO: AccessPolicyDAO = new PostgresAccessPolicyDAO(TestSupport.dbRef, TestSupport.dbRef) + lazy val azureManagedResourceGroupDAO: AzureManagedResourceGroupDAO = new PostgresAzureManagedResourceGroupDAO(TestSupport.dbRef, TestSupport.dbRef) private val ownerRoleName = ResourceRoleName("owner") @@ -137,6 +158,10 @@ class ResourceServiceSpec private val constrainableService = new ResourceService(constrainableResourceTypes, constrainablePolicyEvaluatorService, policyDAO, dirDAO, NoExtensions, emailDomain, Set.empty) + val mockAzureService = mock[AzureService] + private val serviceWithAzure = + new ResourceService(resourceTypes, policyEvaluatorService, policyDAO, dirDAO, NoExtensions, emailDomain, Set("test.firecloud.org"), Some(mockAzureService)) + val managedGroupService = new ManagedGroupService(constrainableService, constrainablePolicyEvaluatorService, constrainableResourceTypes, policyDAO, dirDAO, NoExtensions, emailDomain) @@ -1779,6 +1804,54 @@ class ResourceServiceSpec testDeleteResource(managedGroupResourceType) } + it should "delete any action managed identites for the resource while it deletes the resource" in { + assume(databaseEnabled, databaseEnabledClue) + + val resource = FullyQualifiedResourceId(defaultResourceType.name, ResourceId("my-resource")) + // There's no actual need for it to be a real "billing profile", we just need a resource to attach the managed resource group to. + val billingProfileResource = FullyQualifiedResourceId(defaultResourceType.name, ResourceId(UUID.randomUUID().toString)) + val ownerRoleActions = defaultResourceType.roles.find(_.roleName == defaultResourceType.ownerRoleName).get.actions + + val managedResourceGroupCoordinates = ManagedResourceGroupCoordinates( + TenantId(UUID.randomUUID().toString), + SubscriptionId(UUID.randomUUID().toString), + ManagedResourceGroupName(UUID.randomUUID().toString) + ) + + val managedResourceGroup = ManagedResourceGroup(managedResourceGroupCoordinates, BillingProfileId(billingProfileResource.resourceId.value)) + + serviceWithAzure.createResourceType(defaultResourceType, samRequestContext).unsafeRunSync() + runAndWait(serviceWithAzure.createResource(defaultResourceType, resource.resourceId, dummyUser, samRequestContext)) + runAndWait(serviceWithAzure.createResource(defaultResourceType, billingProfileResource.resourceId, dummyUser, samRequestContext)) + runAndWait(azureManagedResourceGroupDAO.insertManagedResourceGroup(managedResourceGroup, samRequestContext)) + ownerRoleActions.foreach { action => + val ami = ActionManagedIdentity( + ActionManagedIdentityId( + resource, + action, + BillingProfileId(billingProfileResource.resourceId.value) + ), + ManagedIdentityObjectId(UUID.randomUUID().toString), + ManagedIdentityDisplayName(s"${resource.resourceId.value}-${action.value}"), + managedResourceGroupCoordinates + ) + runAndWait(dirDAO.createActionManagedIdentity(ami, samRequestContext)) + } + + assert(dirDAO.getAllActionManagedIdentitiesForResource(resource, samRequestContext).unsafeRunSync().nonEmpty) + + when(mockAzureService.deleteActionManagedIdentity(any[ActionManagedIdentityId], any[SamRequestContext])).thenReturn(IO.unit) + runAndWait(serviceWithAzure.deleteResource(resource, samRequestContext)) + + assert(dirDAO.getAllActionManagedIdentitiesForResource(resource, samRequestContext).unsafeRunSync().isEmpty) + ownerRoleActions.foreach { action => + verify(mockAzureService).deleteActionManagedIdentity( + argThat((arg: ActionManagedIdentityId) => arg.action.equals(action) && arg.resourceId.equals(resource)), + eqTo(samRequestContext) + ) + } + } + private def testDeleteResource(resourceType: ResourceType) = { val parentResource = FullyQualifiedResourceId(resourceType.name, ResourceId("my-resource-parent")) val childResource = FullyQualifiedResourceId(resourceType.name, ResourceId("my-resource-child")) diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala index 131f636d4..d00f6491e 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpec.scala @@ -8,7 +8,7 @@ import cats.effect.IO import cats.effect.unsafe.implicits.{global => globalEc} import org.broadinstitute.dsde.workbench.model._ import org.broadinstitute.dsde.workbench.sam.Generator.{arbNonPetEmail => _, _} -import org.broadinstitute.dsde.workbench.sam.TestSupport.{databaseEnabled, databaseEnabledClue} +import org.broadinstitute.dsde.workbench.sam.TestSupport.{databaseEnabled, databaseEnabledClue, truncateAll} import org.broadinstitute.dsde.workbench.sam.dataAccess.{DirectoryDAO, PostgresDirectoryDAO} import org.broadinstitute.dsde.workbench.sam.google.GoogleExtensions import org.broadinstitute.dsde.workbench.sam.matchers.BeSameUserMatcher.beSameUserAs @@ -35,7 +35,7 @@ import scala.concurrent.duration._ // TODO: continue breaking down old UserServiceSpec tests into nested suites // See: https://www.scalatest.org/scaladoc/3.2.3/org/scalatest/Suite.html -class UserServiceSpec(_system: ActorSystem) extends TestKit(_system) with Suite with BeforeAndAfterAll { +class UserServiceSpec(_system: ActorSystem) extends TestKit(_system) with Suite with BeforeAndAfterAll with BeforeAndAfterEach { override def nestedSuites: IndexedSeq[Suite] = IndexedSeq( new CreateUserSpec, @@ -54,6 +54,11 @@ class UserServiceSpec(_system: ActorSystem) extends TestKit(_system) with Suite TestKit.shutdownActorSystem(system) super.afterAll() } + + override def beforeEach(): Unit = { + truncateAll + super.beforeEach() + } } // This test suite is deprecated. It is still used and still has valid tests in it, but it should be broken out @@ -641,7 +646,7 @@ class OldUserServiceSpec(_system: ActorSystem) implicit val arbEmail: Arbitrary[WorkbenchEmail] = Arbitrary(genEmail) forAll { email: WorkbenchEmail => - assert(service.validateEmailAddress(email, Seq.empty).attempt.unsafeRunSync().isRight) + assert(service.validateEmailAddress(email, Seq.empty, Seq.empty).attempt.unsafeRunSync().isRight) } } @@ -662,7 +667,7 @@ class OldUserServiceSpec(_system: ActorSystem) implicit val arbEmail: Arbitrary[WorkbenchEmail] = Arbitrary(genEmail) forAll { email: WorkbenchEmail => - assert(service.validateEmailAddress(email, Seq.empty).attempt.unsafeRunSync().isLeft) + assert(service.validateEmailAddress(email, Seq.empty, Seq.empty).attempt.unsafeRunSync().isLeft) } } @@ -683,7 +688,7 @@ class OldUserServiceSpec(_system: ActorSystem) implicit val arbEmail: Arbitrary[WorkbenchEmail] = Arbitrary(genEmail) forAll { email: WorkbenchEmail => - assert(service.validateEmailAddress(email, Seq.empty).attempt.unsafeRunSync().isLeft) + assert(service.validateEmailAddress(email, Seq.empty, Seq.empty).attempt.unsafeRunSync().isLeft) } } @@ -697,12 +702,17 @@ class OldUserServiceSpec(_system: ActorSystem) implicit val arbEmail: Arbitrary[WorkbenchEmail] = Arbitrary(genEmail) forAll { email: WorkbenchEmail => - assert(service.validateEmailAddress(email, Seq.empty).attempt.unsafeRunSync().isLeft) + assert(service.validateEmailAddress(email, Seq.empty, Seq.empty).attempt.unsafeRunSync().isLeft) } } it should "reject blocked email domain" in { - assert(service.validateEmailAddress(WorkbenchEmail("foo@splat.bar.com"), Seq("bar.com")).attempt.unsafeRunSync().isLeft) - assert(service.validateEmailAddress(WorkbenchEmail("foo@bar.com"), Seq("bar.com")).attempt.unsafeRunSync().isLeft) + assert(service.validateEmailAddress(WorkbenchEmail("foo@splat.bar.com"), Seq("bar.com"), Seq.empty).attempt.unsafeRunSync().isLeft) + assert(service.validateEmailAddress(WorkbenchEmail("foo@bar.com"), Seq("bar.com"), Seq.empty).attempt.unsafeRunSync().isLeft) + } + + it should "reject an un-invitable email domain" in { + assert(service.validateEmailAddress(WorkbenchEmail("foo@splat.bar.com"), Seq.empty, Seq("bar.com")).attempt.unsafeRunSync().isLeft) + assert(service.validateEmailAddress(WorkbenchEmail("foo@bar.com"), Seq.empty, Seq("bar.com")).attempt.unsafeRunSync().isLeft) } } diff --git a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpecs/CreateUserSpec.scala b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpecs/CreateUserSpec.scala index f6953a893..3940a4823 100644 --- a/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpecs/CreateUserSpec.scala +++ b/src/test/scala/org/broadinstitute/dsde/workbench/sam/service/UserServiceSpecs/CreateUserSpec.scala @@ -383,6 +383,23 @@ class CreateUserSpec extends UserServiceTestTraits { describe("An invited User") { + it("should not be able to be invited with a non-invitable domain") { + // Arrange + val nonInvitableDomain = "non-invitable-domain.com" + val invitedGoogleUser = genWorkbenchUserGoogle.sample.get.copy(email = WorkbenchEmail(s"user@$nonInvitableDomain")) + val directoryDAO = MockDirectoryDaoBuilder(allUsersGroup).build + val cloudExtensions = MockCloudExtensionsBuilder(allUsersGroup).build + val userService = new UserService(directoryDAO, cloudExtensions, Seq.empty, defaultTosService, None, Seq(nonInvitableDomain)) + + // Act + val exception = intercept[WorkbenchExceptionWithErrorReport] { + runAndWait(userService.inviteUser(invitedGoogleUser.email, samRequestContext)) + } + + // Assert + exception.errorReport.message should include("Email domain cannot be invited") + } + it("should be able to be invited with no marketing consent") { // Arrange val invitedGoogleUser = genWorkbenchUserGoogle.sample.get