Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add taint to user and worker nodes #2605

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5000f06
save progress
Adam-D-Lewis Jun 26, 2024
7ce8555
Merge branch 'develop' into node-taint
Adam-D-Lewis Aug 16, 2024
a661514
fix node taint check
Adam-D-Lewis Aug 16, 2024
fb55fab
Merge branch 'develop' into node-taint
Adam-D-Lewis Aug 19, 2024
7f1800d
fix node taints on gcp
Adam-D-Lewis Aug 19, 2024
40940f6
add latest changes
Adam-D-Lewis Aug 19, 2024
cdac5c6
merge develop
Adam-D-Lewis Aug 21, 2024
6382c7b
allow daemonsets to run on user node group
Adam-D-Lewis Aug 21, 2024
e9d9dd9
recreate node groups when taints change
Adam-D-Lewis Aug 21, 2024
c55cd5f
quick attempt to get scheduler running on tanted worker node group
Adam-D-Lewis Aug 21, 2024
57e6e09
Merge branch 'main' into node-taint
Adam-D-Lewis Oct 25, 2024
a1370c9
add default options to options_handler
Adam-D-Lewis Oct 25, 2024
0e7e11c
add comments
Adam-D-Lewis Oct 28, 2024
adb9d74
rename variable
Adam-D-Lewis Oct 31, 2024
7944071
add comment
Adam-D-Lewis Oct 31, 2024
fa81fb9
make work for all providers
Adam-D-Lewis Oct 31, 2024
da9fd82
move var back
Adam-D-Lewis Oct 31, 2024
6a1f81d
move var back
Adam-D-Lewis Oct 31, 2024
b4c08f3
move var back
Adam-D-Lewis Oct 31, 2024
9bae2a1
move var back
Adam-D-Lewis Oct 31, 2024
b3dbeda
add reference
Adam-D-Lewis Oct 31, 2024
97858d0
refactor
Adam-D-Lewis Nov 1, 2024
4ac7b9c
various fixes for aws and azure providers
Adam-D-Lewis Nov 1, 2024
480647b
Merge branch 'main' into node-taint
Adam-D-Lewis Nov 1, 2024
f6b9a4f
add taint conversion for AWS
Adam-D-Lewis Nov 4, 2024
e752a3a
add DEFAULT_.*_TAINT vars
Adam-D-Lewis Nov 4, 2024
59daa0c
clean up fixed TODOs
Adam-D-Lewis Nov 4, 2024
e05f143
more clean up
Adam-D-Lewis Nov 4, 2024
3a4ae6b
Merge branch 'main' into node-taint
Adam-D-Lewis Nov 4, 2024
f3cb2e9
fix test
Adam-D-Lewis Nov 4, 2024
b125e8c
fix test error
Adam-D-Lewis Nov 4, 2024
8f9f846
Merge branch 'main' into node-taint
dcmcand Nov 15, 2024
2264558
merge main
Adam-D-Lewis Dec 30, 2024
747a293
add test
Adam-D-Lewis Dec 30, 2024
964f360
Merge branch 'main' into node-taint
Adam-D-Lewis Dec 30, 2024
4f48462
Merge branch 'main' into node-taint
Adam-D-Lewis Jan 7, 2025
459ac01
small cleanup
Adam-D-Lewis Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 158 additions & 35 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator

from _nebari import constants
from _nebari.provider import opentofu
Expand Down Expand Up @@ -39,15 +39,75 @@ class ExistingInputVars(schema.Base):
kube_context: str


class NodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
taints: Optional[List[schema.Taint]] = None

@field_validator("taints", mode="before")
def validate_taint_strings(cls, taints: list[Any]):
if taints is None:
return taints

TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)")
return_value = []
for taint in taints:
if not isinstance(taint, str):
return_value.append(taint)
else:
match = TAINT_STR_REGEX.match(taint)
if not match:
raise ValueError(f"Invalid taint string: {taint}")
key, taints, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=taints, effect=effect)
return_value.append(parsed_taint)

return return_value


DEFAULT_GENERAL_NODE_GROUP_TAINTS = []
DEFAULT_NODE_GROUP_TAINTS = [
schema.Taint(key="dedicated", value="nebari", effect="NoSchedule")
]


def set_missing_taints_to_default_taints(node_groups: NodeGroup) -> NodeGroup:

for node_group_name, node_group in node_groups.items():
if node_group.taints is None:
if node_group_name == "general":
node_group.taints = DEFAULT_GENERAL_NODE_GROUP_TAINTS
else:
node_group.taints = DEFAULT_NODE_GROUP_TAINTS
return node_groups


class GCPNodeGroupInputVars(schema.Base):
name: str
instance_type: str
min_size: int
max_size: int
node_taints: List[dict]
labels: Dict[str, str]
preemptible: bool
guest_accelerators: List["GCPGuestAccelerator"]

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


class GCPPrivateClusterConfig(schema.Base):
enable_private_nodes: bool
Expand Down Expand Up @@ -89,6 +149,11 @@ class AzureNodeGroupInputVars(schema.Base):
instance: str
min_nodes: int
max_nodes: int
node_taints: list[str]

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [f"{taint.key}={taint.value}:{taint.effect.value}" for taint in value]


class AzureInputVars(schema.Base):
Expand Down Expand Up @@ -131,6 +196,22 @@ class AWSNodeGroupInputVars(schema.Base):
permissions_boundary: Optional[str] = None
ami_type: Optional[AWSAmiTypes] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None
node_taints: list[dict]

@field_validator("node_taints", mode="before")
Copy link
Member Author

@Adam-D-Lewis Adam-D-Lewis Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is repeated (see line 233 in this file) for GCP and AWS NodeGroupInputVars classes, but that's b/c the format expected by GCP and AWS terraform modules for taints happens to be the same. I think the required formats for the different modules could evolve separately and so I chose to duplicate the code in this case.

def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


def construct_aws_ami_type(
Expand All @@ -157,6 +238,21 @@ def construct_aws_ami_type(

return "AL2_x86_64"

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


class AWSInputVars(schema.Base):
name: str
Expand Down Expand Up @@ -270,19 +366,28 @@ class GCPGuestAccelerator(schema.Base):
count: Annotated[int, Field(ge=1)] = 1


class GCPNodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
class GCPNodeGroup(NodeGroup):
preemptible: bool = False
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


DEFAULT_GCP_NODE_GROUPS = {
"general": GCPNodeGroup(instance="e2-standard-8", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"general": GCPNodeGroup(
instance="e2-standard-8",
min_nodes=1,
max_nodes=1,
),
"user": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
),
"worker": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
),
}


Expand All @@ -295,7 +400,9 @@ class GoogleCloudPlatformProvider(schema.Base):
kubernetes_version: str
availability_zones: Optional[List[str]] = []
release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL
node_groups: Dict[str, GCPNodeGroup] = DEFAULT_GCP_NODE_GROUPS
node_groups: Annotated[
Dict[str, GCPNodeGroup], AfterValidator(set_missing_taints_to_default_taints)
] = Field(DEFAULT_GCP_NODE_GROUPS, validate_default=True)
tags: Optional[List[str]] = []
networking_mode: str = "ROUTE"
network: str = "default"
Expand Down Expand Up @@ -345,16 +452,26 @@ def _check_input(cls, data: Any) -> Any:
return data


class AzureNodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
class AzureNodeGroup(NodeGroup):
pass


DEFAULT_AZURE_NODE_GROUPS = {
"general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1),
"user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5),
"worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5),
"general": AzureNodeGroup(
instance="Standard_D8_v3",
min_nodes=1,
max_nodes=1,
),
"user": AzureNodeGroup(
instance="Standard_D4_v3",
min_nodes=0,
max_nodes=5,
),
"worker": AzureNodeGroup(
instance="Standard_D4_v3",
min_nodes=0,
max_nodes=5,
),
}


Expand All @@ -363,7 +480,9 @@ class AzureProvider(schema.Base):
kubernetes_version: Optional[str] = None
storage_account_postfix: str
resource_group_name: Optional[str] = None
node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS
node_groups: Annotated[
Dict[str, AzureNodeGroup], AfterValidator(set_missing_taints_to_default_taints)
] = Field(DEFAULT_AZURE_NODE_GROUPS, validate_default=True)
storage_account_postfix: str
vnet_subnet_id: Optional[str] = None
private_cluster_enabled: bool = False
Expand Down Expand Up @@ -416,10 +535,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]:
return value if value is None else azure_cloud.validate_tags(value)


class AWSNodeGroup(schema.Base):
instance: str
min_nodes: int = 0
max_nodes: int
class AWSNodeGroup(NodeGroup):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
Expand All @@ -436,12 +552,22 @@ def check_launch_template(cls, values):


DEFAULT_AWS_NODE_GROUPS = {
"general": AWSNodeGroup(instance="m5.2xlarge", min_nodes=1, max_nodes=1),
"general": AWSNodeGroup(
instance="m5.2xlarge",
min_nodes=1,
max_nodes=1,
),
"user": AWSNodeGroup(
instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False
instance="m5.xlarge",
min_nodes=0,
max_nodes=5,
single_subnet=False,
),
"worker": AWSNodeGroup(
instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False
instance="m5.xlarge",
min_nodes=0,
max_nodes=5,
single_subnet=False,
),
}

Expand All @@ -450,7 +576,9 @@ class AmazonWebServicesProvider(schema.Base):
region: str
kubernetes_version: str
availability_zones: Optional[List[str]]
node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS
node_groups: Annotated[
Dict[str, AWSNodeGroup], AfterValidator(set_missing_taints_to_default_taints)
] = Field(DEFAULT_AWS_NODE_GROUPS, validate_default=True)
eks_endpoint_access: Optional[
Literal["private", "public", "public_and_private"]
] = "public"
Expand Down Expand Up @@ -576,16 +704,8 @@ class ExistingProvider(schema.Base):
schema.ProviderEnum.azure: AzureProvider,
}

provider_enum_name_map: Dict[schema.ProviderEnum, str] = {
schema.ProviderEnum.local: "local",
schema.ProviderEnum.existing: "existing",
schema.ProviderEnum.gcp: "google_cloud_platform",
schema.ProviderEnum.aws: "amazon_web_services",
schema.ProviderEnum.azure: "azure",
}

provider_name_abbreviation_map: Dict[str, str] = {
value: key.value for key, value in provider_enum_name_map.items()
value: key.value for key, value in schema.provider_enum_name_map.items()
}

provider_enum_default_node_groups_map: Dict[schema.ProviderEnum, Any] = {
Expand Down Expand Up @@ -625,7 +745,7 @@ def check_provider(cls, data: Any) -> Any:
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
expected_provider_config = schema.provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
Expand Down Expand Up @@ -773,6 +893,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
instance_type=node_group.instance,
min_size=node_group.min_nodes,
max_size=node_group.max_nodes,
node_taints=node_group.taints,
preemptible=node_group.preemptible,
guest_accelerators=node_group.guest_accelerators,
)
Expand Down Expand Up @@ -804,6 +925,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
instance=node_group.instance,
min_nodes=node_group.min_nodes,
max_nodes=node_group.max_nodes,
node_taints=node_group.taints,
)
for name, node_group in self.config.azure.node_groups.items()
},
Expand Down Expand Up @@ -847,6 +969,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
single_subnet=node_group.single_subnet,
permissions_boundary=node_group.permissions_boundary,
launch_template=None,
node_taints=node_group.taints,
ami_type=construct_aws_ami_type(
gpu_enabled=node_group.gpu,
launch_template=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ resource "aws_eks_node_group" "main" {
max_size = var.node_groups[count.index].max_size
}

dynamic "taint" {
for_each = var.node_groups[count.index].node_taints
content {
key = taint.value.key
value = taint.value.value
effect = taint.value.effect
}
}

# Only set launch_template if its node_group counterpart parameter is not null
dynamic "launch_template" {
for_each = var.node_groups[count.index].launch_template != null ? [0] : []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ variable "node_groups" {
single_subnet = bool
launch_template = map(any)
ami_type = string
node_taints = list(object({
key = string
value = string
effect = string
}))
}))
}

Expand Down
5 changes: 5 additions & 0 deletions src/_nebari/stages/infrastructure/template/aws/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ variable "node_groups" {
single_subnet = bool
launch_template = map(any)
ami_type = string
node_taints = list(object({
key = string
value = string
effect = string
}))
}))
}

Expand Down
1 change: 1 addition & 0 deletions src/_nebari/stages/infrastructure/template/azure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module "kubernetes" {
instance_type = config.instance
min_size = config.min_nodes
max_size = config.max_nodes
node_taints = config.node_taints
}
]
vnet_subnet_id = var.vnet_subnet_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ resource "azurerm_kubernetes_cluster" "main" {
min_count = var.node_groups[0].min_size
max_count = var.node_groups[0].max_size
max_pods = var.max_pods
# It's not possible to add node_taints to the default node pool. See https://github.com/hashicorp/terraform-provider-azurerm/issues/9183 for more info

orchestrator_version = var.kubernetes_version
node_labels = {
Expand Down Expand Up @@ -84,4 +85,5 @@ resource "azurerm_kubernetes_cluster_node_pool" "node_group" {
orchestrator_version = var.kubernetes_version
tags = var.tags
vnet_subnet_id = var.vnet_subnet_id
node_taints = each.value.node_taints
}
Loading
Loading