Skip to content

Commit

Permalink
feat(app): use tolerations and affinities from CRC (#1626)
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski authored Sep 28, 2023
1 parent f0c9f20 commit fc52c56
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 55 deletions.
110 changes: 60 additions & 50 deletions renku_notebooks/api/amalthea_patches/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,70 @@
from renku_notebooks.api.classes.server import UserServer


def session_tolerations():
patches = []
tolerations = [
{
"key": f"{config.session_get_endpoint_annotations.renku_annotation_prefix}dedicated",
"operator": "Equal",
"value": "user",
"effect": "NoSchedule",
},
*config.sessions.tolerations,
]
patches.append(
{
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/tolerations",
"value": tolerations,
}
],
}
)
return patches
def session_tolerations(server: "UserServer"):
"""Patch for node taint tolerations, the static tolerations from the configuration are ignored
if the tolerations are set in the server options (coming from CRC)."""
if not server.server_options.tolerations:
key = f"{config.session_get_endpoint_annotations.renku_annotation_prefix}dedicated"
tolerations = [
{
"key": key,
"operator": "Equal",
"value": "user",
"effect": "NoSchedule",
},
] + config.sessions.tolerations
return [
{
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/tolerations",
"value": tolerations,
}
],
}
]
return [i.json_patch() for i in server.server_options.tolerations]


def session_affinity():
return [
{
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/affinity",
"value": config.sessions.affinity,
}
],
}
]
def session_affinity(server: "UserServer"):
"""Patch for session affinities, the static affinities from the configuration are ignored
if the affinities are set in the server options (coming from CRC)."""
if not server.server_options.node_affinities:
return [
{
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/affinity",
"value": config.sessions.affinity,
}
],
}
]
return [i.json_patch() for i in server.server_options.node_affinities]


def session_node_selector():
return [
{
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/nodeSelector",
"value": config.sessions.node_selector,
}
],
}
]
def session_node_selector(server: "UserServer"):
"""Patch for a node selector, if node affinities are specified in the server options
(coming from CRC) node selectors in the static configuration are ignored."""
if not server.server_options.node_affinities:
return [
{
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/nodeSelector",
"value": config.sessions.node_selector,
}
],
}
]
return []


def priority_class(server: "UserServer"):
Expand Down
6 changes: 3 additions & 3 deletions renku_notebooks/api/classes/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def _get_session_manifest(self):
patches = list(
chain(
general_patches.test(self),
general_patches.session_tolerations(),
general_patches.session_affinity(),
general_patches.session_node_selector(),
general_patches.session_tolerations(self),
general_patches.session_affinity(self),
general_patches.session_node_selector(self),
general_patches.priority_class(self),
jupyter_server_patches.args(),
jupyter_server_patches.env(self),
Expand Down
96 changes: 94 additions & 2 deletions renku_notebooks/api/schemas/server_options.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,75 @@
from dataclasses import dataclass
from typing import Optional, Callable, Dict, Any
from dataclasses import dataclass, field
from typing import Optional, Callable, Dict, Any, List

from marshmallow import Schema, fields, post_load

from ...config import config
from .custom_fields import ByteSizeField, CpuField, GpuField
from ...errors.programming import ProgrammingError


@dataclass
class NodeAffinity:
"""Node affinity used to schedule a session on specific nodes."""

key: str
required_during_scheduling: bool = False

def json_patch(self) -> Dict[str, Any]:
match_expressions = {
"matchExpressions": {
"key": self.key,
"operator": "Exists",
},
}
if self.required_during_scheduling:
return {
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/affinity/nodeAffinity"
"/requiredDuringSchedulingIgnoredDuringExecution/nodeSelectorTerms/-",
"value": match_expressions,
}
],
}
return {
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/affinity/nodeAffinity"
"/preferredDuringSchedulingIgnoredDuringExecution/-",
"value": {
"weight": 1,
"preference": match_expressions,
},
}
],
}


@dataclass
class Toleration:
"""Toleration used to schedule a session on tainted nodes."""

key: str

def json_patch(self) -> Dict[str, Any]:
return {
"type": "application/json-patch+json",
"patch": [
{
"op": "add",
"path": "/statefulset/spec/template/spec/tolerations/-",
"value": {
"key": self.key,
"operator": "Exists",
},
}
],
}


@dataclass
Expand All @@ -19,6 +84,8 @@ class ServerOptions:
lfs_auto_fetch: bool = False
gigabytes: bool = False
priority_class: Optional[str] = None
node_affinities: List[NodeAffinity] = field(default_factory=list)
tolerations: List[Toleration] = field(default_factory=list)

def __post_init__(self):
if self.default_url is None:
Expand All @@ -29,6 +96,27 @@ def __post_init__(self):
self.storage = 1
elif self.storage is None and not self.gigabytes:
self.storage = 1_000_000_000
if not all([isinstance(i, NodeAffinity) for i in self.node_affinities]):
raise ProgrammingError(
message="Cannot create a ServerOptions dataclass with node "
"affinities that are not of type NodeAffinity"
)
if not all([isinstance(i, Toleration) for i in self.tolerations]):
raise ProgrammingError(
message="Cannot create a ServerOptions dataclass with tolerations "
"that are not of type Toleration"
)
if self.node_affinities is None:
self.node_affinities = []
else:
self.node_affinities = sorted(
self.node_affinities,
key=lambda x: (x.key, x.required_during_scheduling),
)
if self.tolerations is None:
self.tolerations = []
else:
self.tolerations = sorted(self.tolerations, key=lambda x: x.key)

def __compare(
self,
Expand Down Expand Up @@ -107,6 +195,10 @@ def from_resource_class(cls, data: Dict[str, Any]) -> "ServerOptions":
memory=data["memory"] * 1000000000,
gpu=data["gpu"],
storage=data["default_storage"] * 1000000000,
node_affinities=[
NodeAffinity(**a) for a in data.get("node_affinities", [])
],
tolerations=[Toleration(t) for t in data.get("tolerations", [])],
)

@classmethod
Expand Down

0 comments on commit fc52c56

Please sign in to comment.