Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KeyError when using cylc remove after a reload changed the graph #6516

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cylc/flow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,7 @@ def configure_workflow_state_polling_tasks(self):
"script cannot be defined for automatic" +
" workflow polling task '%s':\n%s" % (l_task, cs))
# Generate the automatic scripting.
for name, tdef in list(self.taskdefs.items()):
for name, tdef in self.taskdefs.items():
if name not in self.workflow_polling_tasks:
continue
rtc = tdef.rtconfig
Expand Down
30 changes: 12 additions & 18 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
from cylc.flow.broadcast_mgr import BroadcastMgr
from cylc.flow.cfgspec.glbl_cfg import glbl_cfg
from cylc.flow.config import WorkflowConfig
from cylc.flow.cycling.loader import get_point
from cylc.flow.data_store_mgr import DataStoreMgr
from cylc.flow.exceptions import (
CommandFailedError,
Expand All @@ -93,9 +92,7 @@
get_user,
is_remote_platform,
)
from cylc.flow.id import (
Tokens,
)
from cylc.flow.id import Tokens
from cylc.flow.log_level import (
verbosity_to_env,
verbosity_to_opts,
Expand Down Expand Up @@ -1105,6 +1102,8 @@ def remove_tasks(
# Mapping of *relative* task IDs to removed flow numbers:
removed: Dict[Tokens, FlowNums] = {}
not_removed: Set[Tokens] = set()
# All the matched tasks (will add applicable active tasks below):
matched_tasks = inactive.copy()
to_kill: List[TaskProxy] = []

for itask in active:
Expand All @@ -1113,6 +1112,7 @@ def remove_tasks(
not_removed.add(itask.tokens.task)
continue
removed[itask.tokens.task] = fnums_to_remove
matched_tasks.add((itask.tdef, itask.point))
if fnums_to_remove == itask.flow_nums:
# Need to remove the task from the pool.
# Spawn next occurrence of xtrigger sequential task (otherwise
Expand All @@ -1123,21 +1123,13 @@ def remove_tasks(
itask.removed = True
itask.flow_nums.difference_update(fnums_to_remove)

# All the matched tasks (including inactive & applicable active tasks):
matched_tasks = {
*removed.keys(),
*(Tokens(cycle=str(cycle), task=task) for task, cycle in inactive),
}

for tokens in matched_tasks:
tdef = self.config.taskdefs[tokens['task']]
for tdef, point in matched_tasks:
tokens = Tokens(cycle=str(point), task=tdef.name)

# Go through any tasks downstream of this matched task to see if
# any need to stand down as a result of this task being removed:
for child in set(itertools.chain.from_iterable(
generate_graph_children(
tdef, get_point(tokens['cycle'])
).values()
generate_graph_children(tdef, point).values()
)):
child_itask = self.pool.get_task(child.point, child.name)
if not child_itask:
Expand Down Expand Up @@ -1173,7 +1165,7 @@ def remove_tasks(
# Check if downstream task should remain spawned:
if (
# Ignoring tasks we are already dealing with:
child_itask.tokens.task in matched_tasks
(child_itask.tdef, child_itask.point) in matched_tasks
or child_itask.state.any_satisfied_prerequisite_outputs()
):
continue
Expand All @@ -1187,11 +1179,14 @@ def remove_tasks(

# Remove the matched tasks from the flows in the DB tables:
db_removed_fnums = self.workflow_db_mgr.remove_task_from_flows(
tokens['cycle'], tokens['task'], flow_nums,
str(point), tdef.name, flow_nums,
)
if db_removed_fnums:
removed.setdefault(tokens, set()).update(db_removed_fnums)

if tokens not in removed:
not_removed.add(tokens)

if to_kill:
self.kill_tasks(to_kill, warn=False)

Expand All @@ -1206,7 +1201,6 @@ def remove_tasks(
)
LOG.info(f"Removed task(s): {', '.join(sorted(tasks_str_list))}")

not_removed.update(matched_tasks.difference(removed))
if not_removed:
fnums_str = (
repr_flow_nums(flow_nums, full=True) if flow_nums else ''
Expand Down
40 changes: 20 additions & 20 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Tuple,
Type,
Union,
cast,
)

from cylc.flow import LOG
Expand Down Expand Up @@ -1335,9 +1336,9 @@ def hold_tasks(self, items: Iterable[str]) -> int:
for itask in itasks:
self.hold_active_task(itask)
# Set inactive tasks to be held:
for name, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(name, cycle, True)
self.tasks_to_hold.update(inactive_tasks)
for tdef, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(tdef.name, cycle, True)
self.tasks_to_hold.add((tdef.name, cycle))
self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold)
LOG.debug(f"Tasks to hold: {self.tasks_to_hold}")
return len(unmatched)
Expand All @@ -1353,9 +1354,9 @@ def release_held_tasks(self, items: Iterable[str]) -> int:
for itask in itasks:
self.release_held_active_task(itask)
# Unhold inactive tasks:
for name, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(name, cycle, False)
self.tasks_to_hold.difference_update(inactive_tasks)
for tdef, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(tdef.name, cycle, False)
self.tasks_to_hold.discard((tdef.name, cycle))
self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold)
LOG.debug(f"Tasks to hold: {self.tasks_to_hold}")
return len(unmatched)
Expand Down Expand Up @@ -1979,8 +1980,7 @@ def set_prereqs_and_outputs(
if not flow:
# default: assign to all active flows
flow_nums = self._get_active_flow_nums()
for name, point in inactive_tasks:
tdef = self.config.get_taskdef(name)
for tdef, point in inactive_tasks:
if prereqs:
self._set_prereqs_tdef(
point, tdef, prereqs, flow_nums, flow_wait)
Expand Down Expand Up @@ -2175,7 +2175,7 @@ def force_trigger_tasks(

"""
# Get matching tasks proxies, and matching inactive task IDs.
existing_tasks, inactive_ids, unmatched = self.filter_task_proxies(
existing_tasks, inactive, unmatched = self.filter_task_proxies(
items, inactive=True, warn_no_active=False,
)

Expand All @@ -2199,15 +2199,15 @@ def force_trigger_tasks(
if not flow:
# default: assign to all active flows
flow_nums = self._get_active_flow_nums()
for name, point in inactive_ids:
if not self.can_be_spawned(name, point):
for tdef, point in inactive:
if not self.can_be_spawned(tdef.name, point):
continue
submit_num, _, prev_fwait = (
self._get_task_history(name, point, flow_nums)
self._get_task_history(tdef.name, point, flow_nums)
)
itask = TaskProxy(
self.tokens,
self.config.get_taskdef(name),
tdef,
point,
flow_nums,
flow_wait=flow_wait,
Expand Down Expand Up @@ -2327,7 +2327,7 @@ def filter_task_proxies(
ids: Iterable[str],
warn_no_active: bool = True,
inactive: bool = False,
) -> 'Tuple[List[TaskProxy], Set[Tuple[str, PointBase]], List[str]]':
) -> 'Tuple[List[TaskProxy], Set[Tuple[TaskDef, PointBase]], List[str]]':
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note with this change is that the hashability/equality of TaskDefs is purely reference based (i.e. __hash__() and __eq__() fall back to object default). This should be fine as the TaskDefs are always coming from the scheduler.config.taskdefs dictionary, I believe

"""Return task proxies that match names, points, states in items.

Args:
Expand All @@ -2353,7 +2353,7 @@ def filter_task_proxies(
ids,
warn=warn_no_active,
)
inactive_matched: 'Set[Tuple[str, PointBase]]' = set()
inactive_matched: 'Set[Tuple[TaskDef, PointBase]]' = set()
if inactive and unmatched:
inactive_matched, unmatched = self.match_inactive_tasks(
unmatched
Expand All @@ -2364,7 +2364,7 @@ def filter_task_proxies(
def match_inactive_tasks(
self,
ids: Iterable[str],
) -> Tuple[Set[Tuple[str, 'PointBase']], List[str]]:
) -> 'Tuple[Set[Tuple[TaskDef, PointBase]], List[str]]':
"""Match task IDs against task definitions (rather than the task pool).

IDs will be matched providing the ID:
Expand All @@ -2377,7 +2377,7 @@ def match_inactive_tasks(
(matched_tasks, unmatched_tasks)

"""
matched_tasks: 'Set[Tuple[str, PointBase]]' = set()
matched_tasks: 'Set[Tuple[TaskDef, PointBase]]' = set()
unmatched_tasks: 'List[str]' = []
for id_ in ids:
try:
Expand All @@ -2404,8 +2404,8 @@ def match_inactive_tasks(
unmatched_tasks.append(id_)
continue

point_str = tokens['cycle']
name_str = tokens['task']
point_str = cast('str', tokens['cycle'])
name_str = cast('str', tokens['task'])
if name_str not in self.config.taskdefs:
if self.config.find_taskdefs(name_str):
# It's a family name; was not matched by active tasks
Expand All @@ -2427,7 +2427,7 @@ def match_inactive_tasks(
point = get_point(point_str)
taskdef = self.config.taskdefs[name_str]
if taskdef.is_valid_point(point):
matched_tasks.add((taskdef.name, point))
matched_tasks.add((taskdef, point))
else:
LOG.warning(
self.ERR_PREFIX_TASK_NOT_ON_SEQUENCE.format(
Expand Down
3 changes: 1 addition & 2 deletions cylc/flow/taskdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class TaskDef:
def __init__(self, name, rtcfg, start_point, initial_point):
if not TaskID.is_valid_name(name):
raise TaskDefError("Illegal task name: %s" % name)

self.name: str = name
self.rtconfig = rtcfg
self.start_point = start_point
self.initial_point = initial_point
Expand All @@ -192,7 +192,6 @@ def __init__(self, name, rtcfg, start_point, initial_point):
self.external_triggers = []
self.xtrig_labels = {} # {sequence: [labels]}

self.name = name
self.elapsed_times = deque(maxlen=self.MAX_LEN_ELAPSED_TIMES)
self._add_std_outputs()
self.has_abs_triggers = False
Expand Down
2 changes: 1 addition & 1 deletion cylc/flow/workflow_db_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def _put_update_task_x(
self, table_name: str, itask: 'TaskProxy', set_args: 'DbArgDict'
) -> None:
"""Put UPDATE statement for a task_* table."""
where_args = {
where_args: Dict[str, Any] = {
"cycle": str(itask.point),
"name": itask.tdef.name,
}
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_optional_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
get_completion_expression,
)
from cylc.flow.task_state import (
TASK_STATUSES_ACTIVE,
TASK_STATUS_EXPIRED,
TASK_STATUS_PREPARING,
TASK_STATUS_RUNNING,
Expand Down Expand Up @@ -484,7 +483,7 @@ async def test_removed_taskdef(
'R1': 'a'
}
}
}, id_=id_)
}, workflow_id=id_)

# restart the workflow
schd: 'Scheduler' = scheduler(id_)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def test_reload_failure(
async with start(schd):
# corrupt the config by removing the scheduling section
two_conf = {**one_conf, 'scheduling': {}}
flow(two_conf, id_=id_)
flow(two_conf, workflow_id=id_)

# reload the workflow
await commands.run_cmd(commands.reload_workflow(schd))
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cylc.flow.commands import (
force_trigger_tasks,
reload_workflow,
remove_tasks,
run_cmd,
)
Expand Down Expand Up @@ -478,3 +479,37 @@ async def test_kill_running(flow, scheduler, run, complete, reflog):
('1/c', ('1/b',)),
# The a:failed output should not cause 1/q to run
}


async def test_reload_changed_config(flow, scheduler, run, complete):
"""Test that a task is removed from the pool if its configuration changes
to make it no longer match the graph."""
wid = flow({
'scheduling': {
'graph': {
'R1': '''
a => b
a:started => s & b
''',
},
},
'runtime': {
'a': {
'simulation': {
# Ensure 1/a still in pool during reload
'fail cycle points': 'all',
},
},
},
})
schd: Scheduler = scheduler(wid, paused_start=False)
async with run(schd):
await complete(schd, '1/s')
# Change graph then reload
flow('b', workflow_id=wid)
await run_cmd(reload_workflow(schd))
assert schd.config.cfg['scheduling']['graph']['R1'] == 'b'
assert schd.pool.get_task_ids() == {'1/a', '1/b'}

await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL]))
await complete(schd, '1/b')
2 changes: 1 addition & 1 deletion tests/integration/test_stop_after_cycle_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_db_value(schd) -> Optional[str]:

# change the configured cycle point to "2"
config['scheduling']['stop after cycle point'] = '2'
id_ = flow(config, id_=id_)
id_ = flow(config, workflow_id=id_)
schd = scheduler(id_, paused_start=False)
async with run(schd):
# the cycle point should be reloaded from the workflow configuration
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/test_task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ async def test_restart_prereqs(

# Edit the workflow to add a new dependency on "z"
conf['scheduling']['graph']['R1'] = graph_2
id_ = flow(conf, id_=id_)
id_ = flow(conf, workflow_id=id_)

# Restart it
schd = scheduler(id_, run_mode='simulation', paused_start=False)
Expand Down Expand Up @@ -834,7 +834,7 @@ async def test_reload_prereqs(

# Modify flow.cylc to add a new dependency on "z"
conf['scheduling']['graph']['R1'] = graph_2
flow(conf, id_=id_)
flow(conf, workflow_id=id_)

# Reload the workflow config
await commands.run_cmd(commands.reload_workflow(schd))
Expand Down Expand Up @@ -953,7 +953,7 @@ async def test_graph_change_prereq_satisfaction(

# shutdown and change the workflow definiton
conf['scheduling']['graph']['R1'] += '\nb => c'
flow(conf, id_=id_)
flow(conf, workflow_id=id_)
schd = scheduler(id_, run_mode='simulation', paused_start=False)

async with start(schd):
Expand All @@ -966,7 +966,7 @@ async def test_graph_change_prereq_satisfaction(

# Modify flow.cylc to add a new dependency on "b"
conf['scheduling']['graph']['R1'] += '\nb => c'
flow(conf, id_=id_)
flow(conf, workflow_id=id_)

# Reload the workflow config
await commands.run_cmd(commands.reload_workflow(schd))
Expand Down Expand Up @@ -2158,7 +2158,7 @@ async def list_data_store():
].replace('@a', '@c')

# reload
flow(config, id_=id_)
flow(config, workflow_id=id_)
await commands.run_cmd(commands.reload_workflow(schd))

# check xtrigs post-reload
Expand Down
Loading
Loading