Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interrupt main thread on Exception in sub thread #1674

Merged
merged 16 commits into from
Jan 10, 2025
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ jobs:
- TEST=torch_frontend
- TEST=torch_internal_frontend
- TEST=torch_util
- TEST=threading

steps:
- uses: actions/checkout@v4
Expand Down
13 changes: 13 additions & 0 deletions returnn/util/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ def excepthook(exc_type, exc_obj, exc_tb):

sys.excepthook = excepthook

def threading_excepthook(args, /):
"""
Thread-specific excepthook to ensure the main thread is killed on unhandled exceptions in sub threads.
"""
log_out = log.v1 or sys.stdout
print(
f"Unhandled exception in thread {threading.current_thread()}, going to interrupt main thread:", file=log_out
)
better_exchook(args.exc_type, args.exc_value, args.exc_traceback, autodebugshell=False, file=log_out)
thread.interrupt_main()

threading.excepthook = threading_excepthook

from returnn.util.basic import to_bool

if os.environ.get("DEBUG_WARN_WITH_TRACEBACK") and to_bool(os.environ.get("DEBUG_WARN_WITH_TRACEBACK")):
Expand Down
88 changes: 88 additions & 0 deletions tests/test_threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
test threading-specific functionality
"""

from __future__ import annotations

import _setup_test_env # noqa
from subprocess import Popen, PIPE, STDOUT, CalledProcessError
import os
import sys
import unittest
import tempfile
import textwrap
from returnn.util import better_exchook


__my_dir__ = os.path.dirname(os.path.abspath(__file__))
__base_dir__ = os.path.dirname(__my_dir__)
__main_entry__ = __base_dir__ + "/rnn.py"
py = sys.executable


def run(args, input=None):
"""run subproc"""
args = list(args)
print("run:", args)
# RETURNN by default outputs on stderr, so just merge both together
p = Popen(args, stdout=PIPE, stderr=STDOUT, stdin=PIPE)
out, _ = p.communicate(input=input)
print("Return code is %i" % p.returncode)
print("std out/err:\n---\n%s\n---\n" % out.decode("utf8"))
if p.returncode != 0:
raise CalledProcessError(cmd=args, returncode=p.returncode, output=out)
return out.decode("utf8")


def test_thread_exc_hook():
config = textwrap.dedent(
"""
#!rnn.py

backend = "torch" # just require any backend
NeoLegends marked this conversation as resolved.
Show resolved Hide resolved
log_verbosity = 5
task = "nop"

import threading

t = threading.Thread(target=lambda: 1/0)
t.start()
NeoLegends marked this conversation as resolved.
Show resolved Hide resolved
t.join()
"""
)
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(config)
f.flush()

try:
output = run([py, __main_entry__, f.name])
except CalledProcessError as exc:
out = exc.output.decode("utf8")
assert "ZeroDivisionError" in out
assert "KeyboardInterrupt" in out # this is the result from the main thread being killed
else:
assert False, f"Expected RETURNN to crash, but got {output}."


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
for k, v in sorted(globals().items()):
if k.startswith("test_"):
print("-" * 40)
print("Executing: %s" % k)
try:
v()
except unittest.SkipTest as exc:
print("SkipTest:", exc)
print("-" * 40)
print("Finished all tests.")
else:
assert len(sys.argv) >= 2
for arg in sys.argv[1:]:
print("Executing: %s" % arg)
if arg in globals():
globals()[arg]() # assume function and execute
else:
eval(arg) # assume Python code and execute
# better_exchook.dump_all_thread_tracebacks()
Loading