Skip to content

Commit

Permalink
Allow specifying vm_service_account in GCP config (#2110)
Browse files Browse the repository at this point in the history
* Support vm_service_account for GCP instances

* Support Shared VPC and vm_service_account for TPUs

* Support vm_service_account for GCP gateways

* Fix extra dots
  • Loading branch information
r4victor authored Dec 17, 2024
1 parent b43074e commit dc91839
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
16 changes: 11 additions & 5 deletions docs/docs/reference/server/config.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ gcloud projects list --format="json(projectId)"
compute.instances.get
compute.instances.setLabels
compute.instances.setMetadata
compute.instances.setServiceAccount
compute.instances.setTags
compute.networks.get
compute.networks.updatePolicy
Expand Down Expand Up @@ -511,11 +512,16 @@ gcloud projects list --format="json(projectId)"
```

</div>

To use a shared VPC, that VPC has to be configured with two additional firewall rules:

* Allow `INGRESS` traffic on port `22`, with the target tag `dstack-runner-instance`
* Allow `INGRESS` traffic on ports `22`, `80`, `443`, with the target tag `dstack-gateway-instance`

When using a Shared VPC, ensure there is a firewall rule allowing `INGRESS` traffic on port `22`.
You can limit this rule to `dstack` instances using the `dstack-runner-instance` target tag.

When using GCP gateways with a Shared VPC, also ensure there is a firewall rule allowing `INGRESS` traffic on ports `22`, `80`, `443`.
You can limit this rule to `dstack` gateway instances using the `dstack-gateway-instance` target tag.

To use TPUs with a Shared VPC, you need to grant the TPU Service Account in your service project permissions
to manage resources in the host project by granting the "TPU Shared VPC Agent" (roles/tpu.xpnAgent) role
([more in the GCP docs](https://cloud.google.com/tpu/docs/shared-vpc-networks#vpc-shared-vpc)).

??? info "Private subnets"
By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic.
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def create_instance(
authorized_keys=authorized_keys,
spot=instance_offer.instance.resources.spot,
labels=labels,
network=self.config.vpc_resource_name,
subnetwork=subnetwork,
allocate_public_ip=allocate_public_ip,
service_account=self.config.vm_service_account,
)
create_node_request = tpu_v2.CreateNodeRequest(
parent=f"projects/{self.config.project_id}/locations/{zone}",
Expand Down Expand Up @@ -257,6 +259,7 @@ def create_instance(
tags=[gcp_resources.DSTACK_INSTANCE_TAG],
instance_name=instance_name,
zone=zone,
service_account=self.config.vm_service_account,
network=self.config.vpc_resource_name,
subnetwork=subnetwork,
allocate_public_ip=allocate_public_ip,
Expand Down Expand Up @@ -425,7 +428,7 @@ def create_gateway(
tags=[gcp_resources.DSTACK_GATEWAY_TAG],
instance_name=configuration.instance_name,
zone=zone,
service_account=None,
service_account=self.config.vm_service_account,
network=self.config.vpc_resource_name,
subnetwork=subnetwork,
)
Expand Down
9 changes: 8 additions & 1 deletion src/dstack/_internal/core/backends/gcp/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,22 +365,29 @@ def create_tpu_node_struct(
authorized_keys: List[str],
spot: bool,
labels: Dict[str, str],
network: str = "global/networks/default",
subnetwork: Optional[str] = None,
allocate_public_ip: bool = True,
service_account: Optional[str] = None,
) -> tpu_v2.Node:
node = tpu_v2.Node()
if spot:
node.scheduling_config = tpu_v2.SchedulingConfig(preemptible=True)
node.accelerator_type = instance_name
node.runtime_version = "tpu-ubuntu2204-base"
# subnetwork determines the network, so network shouldn't be specified
node.network_config = tpu_v2.NetworkConfig(
enable_external_ips=allocate_public_ip,
network=network,
subnetwork=subnetwork,
)
ssh_keys = "\n".join(f"ubuntu:{key}" for key in authorized_keys)
node.metadata = {"ssh-keys": ssh_keys, "startup-script": startup_script}
node.labels = labels
if service_account is not None:
node.service_account = tpu_v2.ServiceAccount(
email=service_account,
scope=["https://www.googleapis.com/auth/cloud-platform"],
)
return node


Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/backends/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class GCPConfigInfo(CoreModel):
vpc_project_id: Optional[str] = None
public_ips: Optional[bool] = None
nat_check: Optional[bool] = None
vm_service_account: Optional[str] = None
tags: Optional[Dict[str, str]] = None


Expand Down Expand Up @@ -51,6 +52,7 @@ class GCPConfigInfoWithCredsPartial(CoreModel):
vpc_project_id: Optional[str] = None
public_ips: Optional[bool]
nat_check: Optional[bool] = None
vm_service_account: Optional[str] = None
tags: Optional[Dict[str, str]] = None


Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/server/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ class GCPConfig(CoreModel):
)
),
] = None
vm_service_account: Annotated[
Optional[str], Field(description="The service account associated with provisioned VMs")
] = None
tags: Annotated[
Optional[Dict[str, str]],
Field(
Expand Down Expand Up @@ -276,6 +279,9 @@ class GCPAPIConfig(CoreModel):
)
),
] = None
vm_service_account: Annotated[
Optional[str], Field(description="The service account associated with provisioned VMs")
] = None
tags: Annotated[
Optional[Dict[str, str]],
Field(
Expand Down

0 comments on commit dc91839

Please sign in to comment.