Skip to content

Commit

Permalink
Fix issue where subpipelines may get stuck due to insufficient task s…
Browse files Browse the repository at this point in the history
…chedulers by raising an error when the total number of subpipelines is greater than the maximum allowable task schedulers.

PiperOrigin-RevId: 660011775
  • Loading branch information
kmonte authored and tfx-copybara committed Aug 6, 2024
1 parent 5e90c67 commit d5253cb
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 27 deletions.
5 changes: 4 additions & 1 deletion tfx/orchestration/experimental/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""For environment specific extensions."""

import abc
import sys
from typing import Optional, Sequence

from tfx.orchestration.experimental.core import orchestration_options
Expand Down Expand Up @@ -166,7 +167,9 @@ def get_orchestration_options(
self, pipeline: pipeline_pb2.Pipeline
) -> orchestration_options.OrchestrationOptions:
del pipeline
return orchestration_options.OrchestrationOptions()
return orchestration_options.OrchestrationOptions(
max_running_task_schedulers=sys.maxsize
)

def label_and_tag_pipeline_run(
self, mlmd_handle, pipeline_id, pipeline_run_id, labels, tags
Expand Down
5 changes: 5 additions & 0 deletions tfx/orchestration/experimental/core/orchestration_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Orchestration options."""

import sys
import attr


Expand All @@ -27,6 +28,10 @@ class OrchestrationOptions:
failures.
deadline_secs: Only applicable to sync pipelines. If non-zero, a pipeline
run is aborted if the execution duration exceeds deadline_secs seconds.
max_running_task_schedulers: The total number of task schedulers that may be
running at a time. Note this is a GLOBAL limit across all concurrent runs,
subpipeline runs, etc for a given orchestrator.
"""
fail_fast: bool = False
deadline_secs: int = 0
max_running_task_schedulers: int = sys.maxsize
33 changes: 33 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Pipeline state management functionality."""

import base64
import collections
import contextlib
import copy
import dataclasses
Expand Down Expand Up @@ -515,6 +516,38 @@ def new(
Raises:
status_lib.StatusNotOkError: If a pipeline with same UID already exists.
"""
num_subpipelines = 0
to_process = collections.deque([pipeline])
while to_process:
p = to_process.popleft()
for node in p.nodes:
if node.WhichOneof('node') == 'sub_pipeline':
num_subpipelines += 1
to_process.append(node.sub_pipeline)
# If the number of active task schedulers is less than the maximum number of
# active task schedulers, subpipelines may not work.
# This is because when scheduling the subpipeline, the start node
# and end node will be scheduled immediately, potentially causing contention
# where the end node is waiting on some intermediary node to finish, but the
# intermediary node cannot be scheduled as the end node is running.
# Note that this number is an upper bound - in reality if subpipelines are
# dependent on each other the limit will be lower.
max_task_schedulers = (
env.get_env()
.get_orchestration_options(pipeline)
.max_running_task_schedulers
)
if max_task_schedulers < num_subpipelines:
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
f'The maxmimum number of task schedulers ({max_task_schedulers})'
f' is less than the number of subpipelines ({num_subpipelines}).'
' Please set the maximum number of task schedulers to at least'
f' {num_subpipelines} in'
' OrchestrationOptions.max_running_components.'
),
)
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
context = context_lib.register_context_if_not_exists(
mlmd_handle,
Expand Down
65 changes: 54 additions & 11 deletions tfx/orchestration/experimental/core/pipeline_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import os
import sys
import time
from typing import List
from unittest import mock
Expand All @@ -26,6 +27,7 @@
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import event_observer
from tfx.orchestration.experimental.core import orchestration_options
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration.experimental.core import task_gen_utils
Expand All @@ -36,6 +38,7 @@
from tfx.proto.orchestration import run_state_pb2
from tfx.utils import json_utils
from tfx.utils import status as status_lib

import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -155,9 +158,20 @@ def test_node_state_json(self):

class TestEnv(env._DefaultEnv):

def __init__(self, base_dir, max_str_len):
def __init__(self, base_dir, max_str_len, max_task_schedulers):
self.base_dir = base_dir
self.max_str_len = max_str_len
self.max_task_schedulers = max_task_schedulers

def get_orchestration_options(
self, pipeline: pipeline_pb2.Pipeline
) -> orchestration_options.OrchestrationOptions:
super_options = super().get_orchestration_options(pipeline)
return orchestration_options.OrchestrationOptions(
fail_fast=super_options.fail_fast,
deadline_secs=super_options.deadline_secs,
max_running_task_schedulers=self.max_task_schedulers,
)

def get_base_dir(self):
return self.base_dir
Expand Down Expand Up @@ -276,6 +290,33 @@ def test_new_pipeline_state_with_sub_pipelines(self):
],
)

def test_new_pipeline_state_with_sub_pipelines_fails_when_not_enough_task_schedulers(
self,
):
with TestEnv(None, 20000, 1), self._mlmd_connection as m:
pstate._active_owned_pipelines_exist = False
pipeline = _test_pipeline('pipeline1')
# Add 2 additional layers of sub pipelines. Note that there is no normal
# pipeline node in the first pipeline layer.
_add_sub_pipeline(
pipeline,
'sub_pipeline1',
sub_pipeline_nodes=['Trainer'],
sub_pipeline_run_id='sub_pipeline1_run0',
)
_add_sub_pipeline(
pipeline.nodes[0].sub_pipeline,
'sub_pipeline2',
sub_pipeline_nodes=['Trainer'],
sub_pipeline_run_id='sub_pipeline1_sub_pipeline2_run0',
)
with self.assertRaisesRegex(
status_lib.StatusNotOkError,
'The maxmimum number of task schedulers',
) as e:
pstate.PipelineState.new(m, pipeline)
self.assertEqual(e.exception.code, status_lib.Code.FAILED_PRECONDITION)

def test_load_pipeline_state(self):
with self._mlmd_connection as m:
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
Expand Down Expand Up @@ -770,7 +811,9 @@ def test_initiate_node_start_stop(self, mock_time):
def recorder(event):
events.append(event)

with TestEnv(None, 2000), event_observer.init(), self._mlmd_connection as m:
with TestEnv(
None, 2000, sys.maxsize
), event_observer.init(), self._mlmd_connection as m:
event_observer.register_observer(recorder)

pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
Expand Down Expand Up @@ -900,7 +943,7 @@ def recorder(event):
@mock.patch.object(pstate, 'time')
def test_get_node_states_dict(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1120,7 +1163,7 @@ def test_pipeline_view_get_pipeline_run_state(self, mock_time):
@mock.patch.object(pstate, 'time')
def test_pipeline_view_get_node_run_states(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1205,7 +1248,7 @@ def test_pipeline_view_get_node_run_states(self, mock_time):
@mock.patch.object(pstate, 'time')
def test_pipeline_view_get_node_run_state_history(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1252,7 +1295,7 @@ def test_node_state_for_skipped_nodes_in_partial_pipeline_run(
):
"""Tests that nodes marked to be skipped have the right node state and previous node state."""
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1371,7 +1414,7 @@ def test_load_all_with_list_options(self):
def test_get_previous_node_run_states_for_skipped_nodes(self, mock_time):
"""Tests that nodes marked to be skipped have the right previous run state."""
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1498,7 +1541,7 @@ def test_create_and_load_concurrent_pipeline_runs(self):
)

def test_get_pipeline_and_node(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand All @@ -1516,7 +1559,7 @@ def test_get_pipeline_and_node(self):
)

def test_get_pipeline_and_node_not_found(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1594,7 +1637,7 @@ def test_save_with_max_str_len(self):
state=pstate.NodeState.COMPLETE,
)
}
with TestEnv(None, 20):
with TestEnv(None, 20, sys.maxsize):
execution = metadata_store_pb2.Execution()
proxy = pstate._NodeStatesProxy(execution)
proxy.set(node_states)
Expand All @@ -1605,7 +1648,7 @@ def test_save_with_max_str_len(self):
),
json_utils.dumps(node_states_without_state_history),
)
with TestEnv(None, 2000):
with TestEnv(None, 2000, sys.maxsize):
execution = metadata_store_pb2.Execution()
proxy = pstate._NodeStatesProxy(execution)
proxy.set(node_states)
Expand Down
40 changes: 25 additions & 15 deletions tfx/orchestration/experimental/core/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ class TaskManager:
TaskManager instance can be used as a context manager:
"""

def __init__(self,
mlmd_handle: metadata.Metadata,
task_queue: tq.TaskQueue,
max_active_task_schedulers: int,
max_dequeue_wait_secs: float = _MAX_DEQUEUE_WAIT_SECS,
process_all_queued_tasks_before_exit: bool = False):
def __init__(
self,
mlmd_handle: metadata.Metadata,
task_queue: tq.TaskQueue,
max_active_task_schedulers: int,
max_dequeue_wait_secs: float = _MAX_DEQUEUE_WAIT_SECS,
process_all_queued_tasks_before_exit: bool = False,
):
"""Constructs `TaskManager`.
Args:
Expand All @@ -160,7 +162,8 @@ def __init__(self,
self._task_queue = task_queue
self._max_dequeue_wait_secs = max_dequeue_wait_secs
self._process_all_queued_tasks_before_exit = (
process_all_queued_tasks_before_exit)
process_all_queued_tasks_before_exit
)

self._tm_lock = threading.Lock()
self._stop_event = threading.Event()
Expand Down Expand Up @@ -216,8 +219,10 @@ def exception(self) -> Optional[BaseException]:
if self._main_future is None:
raise RuntimeError('Task manager context not entered.')
if not self._main_future.done():
raise RuntimeError('Task manager main thread not done; call should be '
'conditioned on `done` returning `True`.')
raise RuntimeError(
'Task manager main thread not done; call should be '
'conditioned on `done` returning `True`.'
)
return self._main_future.exception()

def _main(self) -> None:
Expand Down Expand Up @@ -271,7 +276,8 @@ def _handle_exec_node_task(self, task: task_lib.ExecNodeTask) -> None:
if node_uid in self._scheduler_by_node_uid:
raise RuntimeError(
'Cannot create multiple task schedulers for the same task; '
'task_id: {}'.format(task.task_id))
'task_id: {}'.format(task.task_id)
)
scheduler = _SchedulerWrapper(
typing.cast(
ts.TaskScheduler[task_lib.ExecNodeTask],
Expand All @@ -294,13 +300,16 @@ def _handle_cancel_node_task(self, task: task_lib.CancelNodeTask) -> None:
if scheduler is None:
logging.info(
'No task scheduled for node uid: %s. The task might have already '
'completed before it could be cancelled.', task.node_uid)
'completed before it could be cancelled.',
task.node_uid,
)
else:
scheduler.cancel(cancel_task=task)
self._task_queue.task_done(task)

def _process_exec_node_task(self, scheduler: _SchedulerWrapper,
task: task_lib.ExecNodeTask) -> None:
def _process_exec_node_task(
self, scheduler: _SchedulerWrapper, task: task_lib.ExecNodeTask
) -> None:
"""Processes an `ExecNodeTask` using the given task scheduler."""
# This is a blocking call to the scheduler which can take a long time to
# complete for some types of task schedulers. The scheduler is expected to
Expand All @@ -318,7 +327,7 @@ def _process_exec_node_task(self, scheduler: _SchedulerWrapper,
code=status_lib.Code.UNKNOWN,
message=''.join(
traceback.format_exception(*sys.exc_info(), limit=1),
)
),
)
result = ts.TaskSchedulerResult(status=status)
logging.info(
Expand Down Expand Up @@ -414,5 +423,6 @@ def _cleanup(self, final: bool = False) -> None:
'Exception %d (out of %d):',
i,
len(exceptions),
exc_info=(type(e), e, e.__traceback__))
exc_info=(type(e), e, e.__traceback__),
)
raise TasksProcessingError(exceptions)

0 comments on commit d5253cb

Please sign in to comment.