Skip to content

Commit

Permalink
adds tests and small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fabricebrito committed Dec 13, 2024
1 parent dc65c7f commit 540f28a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
4 changes: 2 additions & 2 deletions calrissian/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def __init__(self, kwargs=None):
self.pod_serviceaccount = None
self.tool_logs_basepath = None
self.max_gpus = None
self.no_network_access_pod_label = None
self.network_access_pod_label = None
self.no_network_access_pod_labels = None
self.network_access_pod_labels = None
return super(CalrissianRuntimeContext, self).__init__(kwargs)
35 changes: 17 additions & 18 deletions calrissian/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def add_emptydir_volume_binding(self, name, target):

class KubernetesPodBuilder(object):

def __init__(self, name, builder, container_image, environment, volume_mounts, volumes, command_line, stdout, stderr, stdin, labels, nodeselectors, security_context, serviceaccount, no_network_access_pod_label={}, network_access_pod_label={}):
def __init__(self, name, builder, container_image, environment, volume_mounts, volumes, command_line, stdout, stderr, stdin, labels, nodeselectors, security_context, serviceaccount, no_network_access_pod_labels=None, network_access_pod_labels=None):
self.name = name
self.builder = builder
self.cwl_version = self.builder.cwlVersion
Expand All @@ -215,8 +215,8 @@ def __init__(self, name, builder, container_image, environment, volume_mounts, v
self.nodeselectors = nodeselectors
self.security_context = security_context
self.serviceaccount = serviceaccount
self.no_network_access_pod_label = no_network_access_pod_label
self.network_access_pod_label = network_access_pod_label
self.no_network_access_pod_labels = no_network_access_pod_labels
self.network_access_pod_labels = network_access_pod_labels
self.requirements = {} if self.builder.requirements is None else self.builder.requirements
self.hints = [] if self.builder.hints is None else self.builder.hints

Expand Down Expand Up @@ -350,15 +350,14 @@ def pod_labels(self):
network_access = False

for requirement in self.requirements:
if requirement["class"] in ["NetworkAccess"]:
network_access = requirement.get("networkAccess")
if "class" in requirement.keys() and requirement["class"] in ["NetworkAccess"]:
network_access = True if requirement.get("networkAccess") == "true" else False
break

if not network_access and self.no_network_access_pod_label:
self.labels = {**self.labels, **self.no_network_access_pod_label}
if not network_access and self.no_network_access_pod_labels:
self.labels = {**self.labels, **self.no_network_access_pod_labels}

if network_access and self.network_access_pod_label:
self.labels = {**self.labels, **self.network_access_pod_label}
if network_access and self.network_access_pod_labels:
self.labels = {**self.labels, **self.network_access_pod_labels}

return {str(k): str(v) for k, v in self.labels.items()}

Expand Down Expand Up @@ -536,15 +535,15 @@ def get_pod_labels(self, runtimeContext):
else:
return {}

def get_network_access_pod_label(self, runtimeContext):
if runtimeContext.network_access_pod_label:
return read_yaml(runtimeContext.network_access_pod_label)
def get_network_access_pod_labels(self, runtimeContext):
if runtimeContext.network_access_pod_labels:
return read_yaml(runtimeContext.network_access_pod_labels)
else:
return {}

def get_no_network_access_pod_label(self, runtimeContext):
if runtimeContext.no_network_access_pod_label:
return read_yaml(runtimeContext.no_network_access_pod_label)
def get_no_network_access_pod_labels(self, runtimeContext):
if runtimeContext.no_network_access_pod_labels:
return read_yaml(runtimeContext.no_network_access_pod_labels)
else:
return {}

Expand Down Expand Up @@ -621,8 +620,8 @@ def create_kubernetes_runtime(self, runtimeContext):
self.get_pod_nodeselectors(runtimeContext),
self.get_security_context(runtimeContext),
self.get_pod_serviceaccount(runtimeContext),
self.get_no_network_access_pod_label(runtimeContext),
self.get_network_access_pod_label(runtimeContext),
self.get_no_network_access_pod_labels(runtimeContext),
self.get_network_access_pod_labels(runtimeContext),
)
built = k8s_builder.build()
log.debug('{}\n{}{}\n'.format('-' * 80, yaml.dump(built), '-' * 80))
Expand Down
34 changes: 29 additions & 5 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class KubernetesPodBuilderTestCase(TestCase):

def setUp(self):
builder = Mock()
builder.cwlVersion = "v1.0"
builder.cwlVersion = "v1.2"
builder.requirements = []
builder.resources = {'cores': 1, 'ram': 1024}
self.name = 'PodName'
Expand All @@ -291,12 +291,12 @@ def setUp(self):
self.nodeselectors = {'disktype': 'ssd', 'cachelevel': 2}
self.security_context = { 'runAsUser': os.getuid(),'runAsGroup': os.getgid() }
self.pod_serviceaccount = "podmanager"
self.no_network_access_pod_label = {}
self.network_access_pod_label = {}
self.no_network_access_pod_labels = {"calrissian-network": "disabled"}
self.network_access_pod_labels = {"calrissian-network": "enabled"}
self.pod_builder = KubernetesPodBuilder(self.name, self.builder, self.container_image, self.environment, self.volume_mounts,
self.volumes, self.command_line, self.stdout, self.stderr, self.stdin,
self.labels, self.nodeselectors, self.security_context, self.pod_serviceaccount,
self.no_network_access_pod_label, self.network_access_pod_label)
self.no_network_access_pod_labels, self.network_access_pod_labels)

@patch('calrissian.job.random_tag')
def test_safe_pod_name(self, mock_random_tag):
Expand Down Expand Up @@ -382,14 +382,37 @@ def test_gpu_hints(self):
}
self.assertEqual(expected, resources)

def test_network_access_1_2(self):
self.pod_builder.cwl_version = "v1.2"
self.pod_builder.requirements = [OrderedDict([("class", "NetworkAccess"), ("networkAccess", "true")])]
self.assertEqual(self.pod_builder.pod_labels(), {"calrissian-network": "enabled", 'key1':'val1', 'key2':'123'})

def test_no_network_access_1_2(self):
self.pod_builder.cwl_version = "v1.2"
self.pod_builder.requirements = [OrderedDict([("class", "NetworkAccess"), ("networkAccess", "false")])]
self.assertEqual(self.pod_builder.pod_labels(), {"calrissian-network": "disabled", 'key1':'val1', 'key2':'123'})

def test_network_access_1_0(self):
self.pod_builder.cwl_version = "v1.0"
self.pod_builder.requirements = [OrderedDict([])]
self.assertEqual(self.pod_builder.pod_labels(), {"calrissian-network": "enabled", 'key1':'val1', 'key2':'123'})

def test_string_labels(self):
self.pod_builder.labels = {'key1': 123}
self.assertEqual(self.pod_builder.pod_labels(), {'key1':'123'})
self.assertEqual(self.pod_builder.pod_labels(), {"calrissian-network": "disabled", 'key1':'123'})

def test_string_nodeselectors(self):
self.pod_builder.nodeselectors = {'cachelevel': 2}
self.assertEqual(self.pod_builder.pod_nodeselectors(), {'cachelevel':'2'})

def test_string_no_network_access_pod_label(self):
self.pod_builder.no_network_access_pod_labels = {"calrissian-network": "disabled"}
self.assertEqual(self.pod_builder.pod_labels(), {"calrissian-network": "disabled", 'key1': 'val1', 'key2': '123'})

def test_string_network_access_pod_label(self):
self.pod_builder.network_access_pod_labels = {"calrissian-network": "enabled"}
self.assertEqual(self.pod_builder.pod_labels(), {"calrissian-network": "disabled", 'key1': 'val1', 'key2': '123'})

def test_init_containers_empty_when_no_stdout_or_stderr(self):
self.pod_builder.stdout = None
self.pod_builder.stderr = None
Expand Down Expand Up @@ -435,6 +458,7 @@ def test_build(self, mock_random_tag):
'labels': {
'key1': 'val1',
'key2': '123',
'calrissian-network': 'disabled',
}
},
'apiVersion': 'v1',
Expand Down

0 comments on commit 540f28a

Please sign in to comment.