From dc918395ff12c50ff891846303217f6f2e7e74ea Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Dec 2024 16:39:58 +0500 Subject: [PATCH] Allow specifying `vm_service_account` in GCP config (#2110) * Support vm_service_account for GCP instances * Support Shared VPC and vm_service_account for TPUs * Support vm_service_account for GCP gateways * Fix extra dots --- docs/docs/reference/server/config.yml.md | 16 +++++++++++----- .../_internal/core/backends/gcp/compute.py | 5 ++++- .../_internal/core/backends/gcp/resources.py | 9 ++++++++- src/dstack/_internal/core/models/backends/gcp.py | 2 ++ src/dstack/_internal/server/services/config.py | 6 ++++++ 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/docs/docs/reference/server/config.yml.md b/docs/docs/reference/server/config.yml.md index f9d46a149..d1f51168f 100644 --- a/docs/docs/reference/server/config.yml.md +++ b/docs/docs/reference/server/config.yml.md @@ -438,6 +438,7 @@ gcloud projects list --format="json(projectId)" compute.instances.get compute.instances.setLabels compute.instances.setMetadata + compute.instances.setServiceAccount compute.instances.setTags compute.networks.get compute.networks.updatePolicy @@ -511,11 +512,16 @@ gcloud projects list --format="json(projectId)" ``` - - To use a shared VPC, that VPC has to be configured with two additional firewall rules: - - * Allow `INGRESS` traffic on port `22`, with the target tag `dstack-runner-instance` - * Allow `INGRESS` traffic on ports `22`, `80`, `443`, with the target tag `dstack-gateway-instance` + + When using a Shared VPC, ensure there is a firewall rule allowing `INGRESS` traffic on port `22`. + You can limit this rule to `dstack` instances using the `dstack-runner-instance` target tag. + + When using GCP gateways with a Shared VPC, also ensure there is a firewall rule allowing `INGRESS` traffic on ports `22`, `80`, `443`. + You can limit this rule to `dstack` gateway instances using the `dstack-gateway-instance` target tag. + + To use TPUs with a Shared VPC, you need to grant the TPU Service Account in your service project permissions + to manage resources in the host project by granting the "TPU Shared VPC Agent" (roles/tpu.xpnAgent) role + ([more in the GCP docs](https://cloud.google.com/tpu/docs/shared-vpc-networks#vpc-shared-vpc)). ??? info "Private subnets" By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index a30037cbd..fb252de9a 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -195,8 +195,10 @@ def create_instance( authorized_keys=authorized_keys, spot=instance_offer.instance.resources.spot, labels=labels, + network=self.config.vpc_resource_name, subnetwork=subnetwork, allocate_public_ip=allocate_public_ip, + service_account=self.config.vm_service_account, ) create_node_request = tpu_v2.CreateNodeRequest( parent=f"projects/{self.config.project_id}/locations/{zone}", @@ -257,6 +259,7 @@ def create_instance( tags=[gcp_resources.DSTACK_INSTANCE_TAG], instance_name=instance_name, zone=zone, + service_account=self.config.vm_service_account, network=self.config.vpc_resource_name, subnetwork=subnetwork, allocate_public_ip=allocate_public_ip, @@ -425,7 +428,7 @@ def create_gateway( tags=[gcp_resources.DSTACK_GATEWAY_TAG], instance_name=configuration.instance_name, zone=zone, - service_account=None, + service_account=self.config.vm_service_account, network=self.config.vpc_resource_name, subnetwork=subnetwork, ) diff --git a/src/dstack/_internal/core/backends/gcp/resources.py b/src/dstack/_internal/core/backends/gcp/resources.py index 97ef4f040..de50aa023 100644 --- a/src/dstack/_internal/core/backends/gcp/resources.py +++ b/src/dstack/_internal/core/backends/gcp/resources.py @@ -365,22 +365,29 @@ def create_tpu_node_struct( authorized_keys: List[str], spot: bool, labels: Dict[str, str], + network: str = "global/networks/default", subnetwork: Optional[str] = None, allocate_public_ip: bool = True, + service_account: Optional[str] = None, ) -> tpu_v2.Node: node = tpu_v2.Node() if spot: node.scheduling_config = tpu_v2.SchedulingConfig(preemptible=True) node.accelerator_type = instance_name node.runtime_version = "tpu-ubuntu2204-base" - # subnetwork determines the network, so network shouldn't be specified node.network_config = tpu_v2.NetworkConfig( enable_external_ips=allocate_public_ip, + network=network, subnetwork=subnetwork, ) ssh_keys = "\n".join(f"ubuntu:{key}" for key in authorized_keys) node.metadata = {"ssh-keys": ssh_keys, "startup-script": startup_script} node.labels = labels + if service_account is not None: + node.service_account = tpu_v2.ServiceAccount( + email=service_account, + scope=["https://www.googleapis.com/auth/cloud-platform"], + ) return node diff --git a/src/dstack/_internal/core/models/backends/gcp.py b/src/dstack/_internal/core/models/backends/gcp.py index 2e196abcf..9557aa87f 100644 --- a/src/dstack/_internal/core/models/backends/gcp.py +++ b/src/dstack/_internal/core/models/backends/gcp.py @@ -15,6 +15,7 @@ class GCPConfigInfo(CoreModel): vpc_project_id: Optional[str] = None public_ips: Optional[bool] = None nat_check: Optional[bool] = None + vm_service_account: Optional[str] = None tags: Optional[Dict[str, str]] = None @@ -51,6 +52,7 @@ class GCPConfigInfoWithCredsPartial(CoreModel): vpc_project_id: Optional[str] = None public_ips: Optional[bool] nat_check: Optional[bool] = None + vm_service_account: Optional[str] = None tags: Optional[Dict[str, str]] = None diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index 95ebe1a44..965854cc8 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -239,6 +239,9 @@ class GCPConfig(CoreModel): ) ), ] = None + vm_service_account: Annotated[ + Optional[str], Field(description="The service account associated with provisioned VMs") + ] = None tags: Annotated[ Optional[Dict[str, str]], Field( @@ -276,6 +279,9 @@ class GCPAPIConfig(CoreModel): ) ), ] = None + vm_service_account: Annotated[ + Optional[str], Field(description="The service account associated with provisioned VMs") + ] = None tags: Annotated[ Optional[Dict[str, str]], Field(