Skip to content

Commit

Permalink
Fix runtime version for TPU v6e (#2149)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Dec 27, 2024
1 parent 2df74b8 commit aca56fe
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
16 changes: 14 additions & 2 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def create_instance(
authorized_keys=authorized_keys,
spot=instance_offer.instance.resources.spot,
labels=labels,
runtime_version=_get_tpu_runtime_version(instance_offer.instance.name),
network=self.config.vpc_resource_name,
subnetwork=subnetwork,
allocate_public_ip=allocate_public_ip,
Expand Down Expand Up @@ -777,15 +778,26 @@ def _get_tpu_startup_script(authorized_keys: List[str]) -> str:
return startup_script


def _is_tpu(name: str) -> bool:
parts = name.split("-")
def _is_tpu(instance_name: str) -> bool:
parts = instance_name.split("-")
if len(parts) == 2:
version, cores = parts
if version in TPU_VERSIONS and cores.isdigit():
return True
return False


def _get_tpu_runtime_version(instance_name: str) -> str:
tpu_version = _get_tpu_version(instance_name)
if tpu_version == "v6e":
return "v2-alpha-tpuv6e"
return "tpu-ubuntu2204-base"


def _get_tpu_version(instance_name: str) -> str:
return instance_name.split("-")[0]


def _is_single_host_tpu(instance_name: str) -> bool:
parts = instance_name.split("-")
if len(parts) != 2:
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/core/backends/gcp/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def create_tpu_node_struct(
authorized_keys: List[str],
spot: bool,
labels: Dict[str, str],
runtime_version: str = "tpu-ubuntu2204-base",
network: str = "global/networks/default",
subnetwork: Optional[str] = None,
allocate_public_ip: bool = True,
Expand All @@ -375,7 +376,7 @@ def create_tpu_node_struct(
if spot:
node.scheduling_config = tpu_v2.SchedulingConfig(preemptible=True)
node.accelerator_type = instance_name
node.runtime_version = "tpu-ubuntu2204-base"
node.runtime_version = runtime_version
node.network_config = tpu_v2.NetworkConfig(
enable_external_ips=allocate_public_ip,
network=network,
Expand Down

0 comments on commit aca56fe

Please sign in to comment.