diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bf42488ec..530b89a3e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -260,6 +260,7 @@ jobs: - TEST=torch_frontend - TEST=torch_internal_frontend - TEST=torch_util + - TEST=threading steps: - uses: actions/checkout@v4 diff --git a/returnn/util/debug.py b/returnn/util/debug.py index 992168c32..24a047906 100644 --- a/returnn/util/debug.py +++ b/returnn/util/debug.py @@ -182,6 +182,19 @@ def excepthook(exc_type, exc_obj, exc_tb): sys.excepthook = excepthook + def threading_excepthook(args, /): + """ + Thread-specific excepthook to ensure the main thread is killed on unhandled exceptions in sub threads. + """ + log_out = log.v1 or sys.stdout + print( + f"Unhandled exception in thread {threading.current_thread()}, going to interrupt main thread:", file=log_out + ) + better_exchook(args.exc_type, args.exc_value, args.exc_traceback, autodebugshell=False, file=log_out) + thread.interrupt_main() + + threading.excepthook = threading_excepthook + from returnn.util.basic import to_bool if os.environ.get("DEBUG_WARN_WITH_TRACEBACK") and to_bool(os.environ.get("DEBUG_WARN_WITH_TRACEBACK")): diff --git a/tests/test_threading.py b/tests/test_threading.py new file mode 100644 index 000000000..6188e9fe2 --- /dev/null +++ b/tests/test_threading.py @@ -0,0 +1,88 @@ +""" +test threading-specific functionality +""" + +from __future__ import annotations + +import _setup_test_env # noqa +from subprocess import Popen, PIPE, STDOUT, CalledProcessError +import os +import sys +import unittest +import tempfile +import textwrap +from returnn.util import better_exchook + + +__my_dir__ = os.path.dirname(os.path.abspath(__file__)) +__base_dir__ = os.path.dirname(__my_dir__) +__main_entry__ = __base_dir__ + "/rnn.py" +py = sys.executable + + +def run(args, input=None): + """run subproc""" + args = list(args) + print("run:", args) + # RETURNN by default outputs on stderr, so just merge both together + p = Popen(args, stdout=PIPE, stderr=STDOUT, stdin=PIPE) + out, _ = p.communicate(input=input) + print("Return code is %i" % p.returncode) + print("std out/err:\n---\n%s\n---\n" % out.decode("utf8")) + if p.returncode != 0: + raise CalledProcessError(cmd=args, returncode=p.returncode, output=out) + return out.decode("utf8") + + +def test_thread_exc_hook(): + config = textwrap.dedent( + """ + #!rnn.py + + backend = "torch" # just require any backend + log_verbosity = 5 + task = "nop" + + import threading + + t = threading.Thread(target=lambda: 1/0) + t.start() + t.join() + """ + ) + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(config) + f.flush() + + try: + output = run([py, __main_entry__, f.name]) + except CalledProcessError as exc: + out = exc.output.decode("utf8") + assert "ZeroDivisionError" in out + assert "KeyboardInterrupt" in out # this is the result from the main thread being killed + else: + assert False, f"Expected RETURNN to crash, but got {output}." + + +if __name__ == "__main__": + better_exchook.install() + if len(sys.argv) <= 1: + for k, v in sorted(globals().items()): + if k.startswith("test_"): + print("-" * 40) + print("Executing: %s" % k) + try: + v() + except unittest.SkipTest as exc: + print("SkipTest:", exc) + print("-" * 40) + print("Finished all tests.") + else: + assert len(sys.argv) >= 2 + for arg in sys.argv[1:]: + print("Executing: %s" % arg) + if arg in globals(): + globals()[arg]() # assume function and execute + else: + eval(arg) # assume Python code and execute + # better_exchook.dump_all_thread_tracebacks()