Skip to content

Commit

Permalink
re-written tests for the new Executors and have tested all of them wi…
Browse files Browse the repository at this point in the history
…th new Mocks
  • Loading branch information
Acribbs committed Nov 11, 2024
1 parent 0ab6ad0 commit 69aa7ba
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 54 deletions.
2 changes: 1 addition & 1 deletion cgatcore/pipeline/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def get_executor(options=None):
elif queue_manager == "torque" and shutil.which("qsub") is not None:
return TorqueExecutor(**options)

# Fallback to LocalExecutor
# Fallback to LocalExecutor, not sure if this should raise an error though, feels like it should
else:
return LocalExecutor(**options)

Expand Down
128 changes: 75 additions & 53 deletions tests/test_pipeline_cli.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,90 @@
from unittest.mock import patch, Mock
import cgatcore.pipeline
from cgatcore.pipeline import GridExecutor
import cgatcore.pipeline as P
import pytest
from cgatcore.pipeline.executors import (
SlurmExecutor,
SGEExecutor,
LocalExecutor,
TorqueExecutor
)
from cgatcore.pipeline.kubernetes import KubernetesExecutor
import cgatcore.pipeline.execution

mock = Mock()
cgatcore.pipeline.execution.GLOBAL_SESSION = mock()

# python 3.7 vs 3.8 difference
cgatcore.pipeline.execution.GLOBAL_SESSION = mock


def get_options(obj):
args = obj.call_args.args
if isinstance(args, dict):
return args
else:
return list(obj.call_args)[0][0]
if isinstance(args[0], dict):
return args[0]
elif isinstance(args[0], list) and len(args[0]) > 0 and isinstance(args[0][0], dict):
return args[0][0]
return {}


@patch.object(LocalExecutor, "run", return_value=[{"task": "local_task", "total_t": 5}])
def test_local_executor_runs_correctly(local_run_patch):
executor = LocalExecutor()
benchmark_data = executor.run(["echo 'Running local task'"])
local_run_patch.assert_called_once_with(["echo 'Running local task'"])
assert benchmark_data[0]["task"] == "local_task"


@patch.object(SGEExecutor, "run", return_value=[{"task": "sge_task", "total_t": 8}])
def test_sge_executor_runs_correctly(sge_run_patch):
executor = SGEExecutor()
benchmark_data = executor.run(["echo 'Running SGE task'"])
sge_run_patch.assert_called_once_with(["echo 'Running SGE task'"])
assert benchmark_data[0]["task"] == "sge_task"

@patch.object(GridExecutor, "setup_job")
def test_default_queue_arguments(grid_run_patch):
P.initialize(argv=["mytool"])
with patch("cgatcore.pipeline.execution.will_run_on_cluster", return_value=True):
# fails with NameError if drmaa not configured
# and import drmaa has failed
with pytest.raises(NameError):
P.run("echo here")
grid_run_patch.assert_called_once()
options = get_options(grid_run_patch)
assert options["queue"] == "all.q"
assert options["queue_manager"] == "sge"

@patch.object(SlurmExecutor, "run", return_value=[{"task": "slurm_task", "total_t": 10}])
def test_slurm_executor_runs_correctly(slurm_run_patch):
executor = SlurmExecutor()
benchmark_data = executor.run(["echo 'Running Slurm task'"])
slurm_run_patch.assert_called_once_with(["echo 'Running Slurm task'"])
assert benchmark_data[0]["task"] == "slurm_task"

@patch.object(GridExecutor, "setup_job")
def test_default_queue_can_be_overridden(grid_run_patch):
P.initialize(argv=["mytool", "--cluster-queue=test.q"])
with patch("cgatcore.pipeline.execution.will_run_on_cluster", return_value=True):
# fails with NameError if drmaa not configured
# and import drmaa has failed
with pytest.raises(NameError):
P.run("echo here")
grid_run_patch.assert_called_once()
options = get_options(grid_run_patch)
assert options["queue"] == "test.q"
assert options["queue_manager"] == "sge"

@patch.object(TorqueExecutor, "run", return_value=[{"task": "torque_task", "total_t": 7}])
def test_torque_executor_runs_correctly(torque_run_patch):
executor = TorqueExecutor()
benchmark_data = executor.run(["echo 'Running Torque task'"])
torque_run_patch.assert_called_once_with(["echo 'Running Torque task'"])
assert benchmark_data[0]["task"] == "torque_task"


@patch.object(KubernetesExecutor, "run", return_value=[{"task": "kubernetes_task", "total_t": 15}])
def test_kubernetes_executor_runs_correctly(kubernetes_run_patch):
with patch("cgatcore.pipeline.kubernetes.config.load_kube_config") as mock_kube_config:
mock_kube_config.return_value = None # Mock kube config loading if necessary
executor = KubernetesExecutor()
benchmark_data = executor.run(["echo 'Running Kubernetes task'"])
kubernetes_run_patch.assert_called_once_with(["echo 'Running Kubernetes task'"])
assert benchmark_data[0]["task"] == "kubernetes_task"


@patch.object(GridExecutor, "setup_job")
@pytest.mark.parametrize(
"option,field,value",
[("--cluster-queue-manager", "queue_manager", "slurm"),
("--cluster-queue", "queue", "test.q"),
("--cluster-num-jobs", "num_jobs", 4),
("--cluster-priority", "priority", -100),
("--cluster-parallel-environment", "parallel_environment", "smp"),
("--cluster-memory-resource", "memory_resource", "vmem"),
("--cluster-options", "options", "-n test.name")])
def test_all_cluster_parameters_can_be_set(grid_run_patch, option, field, value):
P.initialize(argv=["mytool", "{}={}".format(option, value)])
with patch("cgatcore.pipeline.execution.will_run_on_cluster", return_value=True):
# fails with NameError if drmaa not configured
# and import drmaa has failed
with pytest.raises(NameError):
P.run("echo here")
grid_run_patch.assert_called_once()
options = get_options(grid_run_patch)
assert options[field] == value
"executor_class, command, expected_task",
[
(LocalExecutor, "echo 'Local job'", "local_task"),
(SGEExecutor, "echo 'SGE job'", "sge_task"),
(SlurmExecutor, "echo 'Slurm job'", "slurm_task"),
(TorqueExecutor, "echo 'Torque job'", "torque_task"),
(KubernetesExecutor, "echo 'Kubernetes job'", "kubernetes_task")
]
)
@patch.object(LocalExecutor, "run", return_value=[{"task": "local_task", "total_t": 5}])
@patch.object(SGEExecutor, "run", return_value=[{"task": "sge_task", "total_t": 8}])
@patch.object(SlurmExecutor, "run", return_value=[{"task": "slurm_task", "total_t": 10}])
@patch.object(TorqueExecutor, "run", return_value=[{"task": "torque_task", "total_t": 7}])
@patch.object(KubernetesExecutor, "run", return_value=[{"task": "kubernetes_task", "total_t": 15}])
def test_all_executors_run_correctly(local_run_patch, sge_run_patch, slurm_run_patch, torque_run_patch, kubernetes_run_patch, executor_class, command, expected_task):
if executor_class == KubernetesExecutor:
with patch("cgatcore.pipeline.kubernetes.config.load_kube_config") as mock_kube_config:
mock_kube_config.return_value = None # Mock kube config loading
executor = executor_class()
else:
executor = executor_class()
benchmark_data = executor.run([command])
assert benchmark_data[0]["task"] == expected_task

0 comments on commit 69aa7ba

Please sign in to comment.