-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
re-written tests for the new Executors and have tested all of them wi…
…th new Mocks
- Loading branch information
Showing
2 changed files
with
76 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,90 @@ | ||
from unittest.mock import patch, Mock | ||
import cgatcore.pipeline | ||
from cgatcore.pipeline import GridExecutor | ||
import cgatcore.pipeline as P | ||
import pytest | ||
from cgatcore.pipeline.executors import ( | ||
SlurmExecutor, | ||
SGEExecutor, | ||
LocalExecutor, | ||
TorqueExecutor | ||
) | ||
from cgatcore.pipeline.kubernetes import KubernetesExecutor | ||
import cgatcore.pipeline.execution | ||
|
||
mock = Mock() | ||
cgatcore.pipeline.execution.GLOBAL_SESSION = mock() | ||
|
||
# python 3.7 vs 3.8 difference | ||
cgatcore.pipeline.execution.GLOBAL_SESSION = mock | ||
|
||
|
||
def get_options(obj): | ||
args = obj.call_args.args | ||
if isinstance(args, dict): | ||
return args | ||
else: | ||
return list(obj.call_args)[0][0] | ||
if isinstance(args[0], dict): | ||
return args[0] | ||
elif isinstance(args[0], list) and len(args[0]) > 0 and isinstance(args[0][0], dict): | ||
return args[0][0] | ||
return {} | ||
|
||
|
||
@patch.object(LocalExecutor, "run", return_value=[{"task": "local_task", "total_t": 5}]) | ||
def test_local_executor_runs_correctly(local_run_patch): | ||
executor = LocalExecutor() | ||
benchmark_data = executor.run(["echo 'Running local task'"]) | ||
local_run_patch.assert_called_once_with(["echo 'Running local task'"]) | ||
assert benchmark_data[0]["task"] == "local_task" | ||
|
||
|
||
@patch.object(SGEExecutor, "run", return_value=[{"task": "sge_task", "total_t": 8}]) | ||
def test_sge_executor_runs_correctly(sge_run_patch): | ||
executor = SGEExecutor() | ||
benchmark_data = executor.run(["echo 'Running SGE task'"]) | ||
sge_run_patch.assert_called_once_with(["echo 'Running SGE task'"]) | ||
assert benchmark_data[0]["task"] == "sge_task" | ||
|
||
@patch.object(GridExecutor, "setup_job") | ||
def test_default_queue_arguments(grid_run_patch): | ||
P.initialize(argv=["mytool"]) | ||
with patch("cgatcore.pipeline.execution.will_run_on_cluster", return_value=True): | ||
# fails with NameError if drmaa not configured | ||
# and import drmaa has failed | ||
with pytest.raises(NameError): | ||
P.run("echo here") | ||
grid_run_patch.assert_called_once() | ||
options = get_options(grid_run_patch) | ||
assert options["queue"] == "all.q" | ||
assert options["queue_manager"] == "sge" | ||
|
||
@patch.object(SlurmExecutor, "run", return_value=[{"task": "slurm_task", "total_t": 10}]) | ||
def test_slurm_executor_runs_correctly(slurm_run_patch): | ||
executor = SlurmExecutor() | ||
benchmark_data = executor.run(["echo 'Running Slurm task'"]) | ||
slurm_run_patch.assert_called_once_with(["echo 'Running Slurm task'"]) | ||
assert benchmark_data[0]["task"] == "slurm_task" | ||
|
||
@patch.object(GridExecutor, "setup_job") | ||
def test_default_queue_can_be_overridden(grid_run_patch): | ||
P.initialize(argv=["mytool", "--cluster-queue=test.q"]) | ||
with patch("cgatcore.pipeline.execution.will_run_on_cluster", return_value=True): | ||
# fails with NameError if drmaa not configured | ||
# and import drmaa has failed | ||
with pytest.raises(NameError): | ||
P.run("echo here") | ||
grid_run_patch.assert_called_once() | ||
options = get_options(grid_run_patch) | ||
assert options["queue"] == "test.q" | ||
assert options["queue_manager"] == "sge" | ||
|
||
@patch.object(TorqueExecutor, "run", return_value=[{"task": "torque_task", "total_t": 7}]) | ||
def test_torque_executor_runs_correctly(torque_run_patch): | ||
executor = TorqueExecutor() | ||
benchmark_data = executor.run(["echo 'Running Torque task'"]) | ||
torque_run_patch.assert_called_once_with(["echo 'Running Torque task'"]) | ||
assert benchmark_data[0]["task"] == "torque_task" | ||
|
||
|
||
@patch.object(KubernetesExecutor, "run", return_value=[{"task": "kubernetes_task", "total_t": 15}]) | ||
def test_kubernetes_executor_runs_correctly(kubernetes_run_patch): | ||
with patch("cgatcore.pipeline.kubernetes.config.load_kube_config") as mock_kube_config: | ||
mock_kube_config.return_value = None # Mock kube config loading if necessary | ||
executor = KubernetesExecutor() | ||
benchmark_data = executor.run(["echo 'Running Kubernetes task'"]) | ||
kubernetes_run_patch.assert_called_once_with(["echo 'Running Kubernetes task'"]) | ||
assert benchmark_data[0]["task"] == "kubernetes_task" | ||
|
||
|
||
@patch.object(GridExecutor, "setup_job") | ||
@pytest.mark.parametrize( | ||
"option,field,value", | ||
[("--cluster-queue-manager", "queue_manager", "slurm"), | ||
("--cluster-queue", "queue", "test.q"), | ||
("--cluster-num-jobs", "num_jobs", 4), | ||
("--cluster-priority", "priority", -100), | ||
("--cluster-parallel-environment", "parallel_environment", "smp"), | ||
("--cluster-memory-resource", "memory_resource", "vmem"), | ||
("--cluster-options", "options", "-n test.name")]) | ||
def test_all_cluster_parameters_can_be_set(grid_run_patch, option, field, value): | ||
P.initialize(argv=["mytool", "{}={}".format(option, value)]) | ||
with patch("cgatcore.pipeline.execution.will_run_on_cluster", return_value=True): | ||
# fails with NameError if drmaa not configured | ||
# and import drmaa has failed | ||
with pytest.raises(NameError): | ||
P.run("echo here") | ||
grid_run_patch.assert_called_once() | ||
options = get_options(grid_run_patch) | ||
assert options[field] == value | ||
"executor_class, command, expected_task", | ||
[ | ||
(LocalExecutor, "echo 'Local job'", "local_task"), | ||
(SGEExecutor, "echo 'SGE job'", "sge_task"), | ||
(SlurmExecutor, "echo 'Slurm job'", "slurm_task"), | ||
(TorqueExecutor, "echo 'Torque job'", "torque_task"), | ||
(KubernetesExecutor, "echo 'Kubernetes job'", "kubernetes_task") | ||
] | ||
) | ||
@patch.object(LocalExecutor, "run", return_value=[{"task": "local_task", "total_t": 5}]) | ||
@patch.object(SGEExecutor, "run", return_value=[{"task": "sge_task", "total_t": 8}]) | ||
@patch.object(SlurmExecutor, "run", return_value=[{"task": "slurm_task", "total_t": 10}]) | ||
@patch.object(TorqueExecutor, "run", return_value=[{"task": "torque_task", "total_t": 7}]) | ||
@patch.object(KubernetesExecutor, "run", return_value=[{"task": "kubernetes_task", "total_t": 15}]) | ||
def test_all_executors_run_correctly(local_run_patch, sge_run_patch, slurm_run_patch, torque_run_patch, kubernetes_run_patch, executor_class, command, expected_task): | ||
if executor_class == KubernetesExecutor: | ||
with patch("cgatcore.pipeline.kubernetes.config.load_kube_config") as mock_kube_config: | ||
mock_kube_config.return_value = None # Mock kube config loading | ||
executor = executor_class() | ||
else: | ||
executor = executor_class() | ||
benchmark_data = executor.run([command]) | ||
assert benchmark_data[0]["task"] == expected_task |