diff --git a/shub_workflow/graph/__init__.py b/shub_workflow/graph/__init__.py index 69dd202..1db52b5 100644 --- a/shub_workflow/graph/__init__.py +++ b/shub_workflow/graph/__init__.py @@ -107,13 +107,16 @@ def _setup_starting_jobs(self) -> None: wait_for: List[TaskId] = self.get_jobdict(taskid).get("wait_for", []) self._add_pending_job(taskid, wait_for=tuple(wait_for)) - for taskid in list(self.__pending_jobs.keys()): - if taskid in self.__completed_jobs: - jobid, outcome = self.__completed_jobs[taskid] - self.__pending_jobs.pop(taskid) - self._check_completed_job(taskid, jobid, outcome) - elif taskid in self.__running_jobs: - self.__pending_jobs.pop(taskid) + initial_pending_jobs: Set[TaskId] = set() + while initial_pending_jobs != set(self.__pending_jobs.keys()): + initial_pending_jobs = set(self.__pending_jobs.keys()) + for taskid in list(self.__pending_jobs.keys()): + if taskid in self.__completed_jobs: + jobid, outcome = self.__completed_jobs[taskid] + self.__pending_jobs.pop(taskid) + self._check_completed_job(taskid, jobid, outcome) + elif taskid in self.__running_jobs: + self.__pending_jobs.pop(taskid) def _fill_available_resources(self): """ @@ -360,6 +363,7 @@ def handle_retry(self, job: TaskId, outcome: str) -> bool: retries = jobconf.get("retries", 0) if retries > 0: self._add_pending_job(job, is_retry=True) + self.__completed_jobs.pop(job, None) jobconf["retries"] -= 1 logger.warning( "Will retry job %s (outcome: %s, number of retries left: %s)", diff --git a/tests/test_graph_manager.py b/tests/test_graph_manager.py index da0e09d..e0414e6 100644 --- a/tests/test_graph_manager.py +++ b/tests/test_graph_manager.py @@ -2419,3 +2419,63 @@ def test_resume_finished_with_retry(self, mocked_get_jobs): units=None, project_id=None, ) + + def test_resume_running_second_level(self, mocked_get_jobs): + """Test that second level running tasks are acquired and not rescheduled""" + mocked_get_jobs.side_effect = [ + [{"tags": ["NAME=test", "FLOW_ID=34ab"]}], # call to determine if there is resuming + [ + { + "tags": ["PARENT_NAME=test", "FLOW_ID=34ab", f"TASK_ID=jobB.{i}"], + "key": f"999/2/{i+1}", + } + for i in (1, 3) + ] + + [ + {"tags": ["PARENT_NAME=test", "FLOW_ID=34ab", "TASK_ID=jobC"], "key": "999/3/1"} + ], # call to get running jobs + [ + { + "tags": ["PARENT_NAME=test", "FLOW_ID=34ab", f"TASK_ID=jobA.{i}"], + "key": f"999/1/{i+1}", + "close_reason": "finished", + } + for i in (0, 1, 2, 3) + ] + + [ + { + "tags": ["PARENT_NAME=test", "FLOW_ID=34ab", f"TASK_ID=jobB.{i}"], + "key": f"999/2/{i+1}", + "close_reason": "finished", + } + for i in (0, 2) + ], # call to get finished jobs + ] + with script_args(["--flow-id=34ab", "--root-jobs"]): + manager = TestManager3() + manager._on_start() + self.assertTrue(manager.is_resumed) + + manager.is_finished = lambda x: None + + manager.schedule_script = Mock() + # first loop. jobB partially runnning, jobC running + result = next(manager._run_loops()) + self.assertTrue(result) + self.assertEqual(manager.schedule_script.call_count, 0) + + # second loop. Still no change + manager.is_finished = lambda x: None + result = next(manager._run_loops()) + self.assertTrue(result) + self.assertEqual(manager.schedule_script.call_count, 0) + + # Third loop. jobC completes + manager.is_finished = lambda x: Outcome("finished") if x == "999/3/1" else None + manager.schedule_script.side_effect = [ + "999/4/1", + ] + result = next(manager._run_loops()) + self.assertTrue(result) + manager.schedule_script.assert_called_with(["commandD"], tags=["TASK_ID=jobD"], units=None, project_id=None) + self.assertEqual(manager.schedule_script.call_count, 1)