-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add asyncio event loop context manager (#555)
Extracting out from another refactor of SSH support PR #510
- Loading branch information
Showing
2 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# | ||
""" | ||
EventLoopContext class definition. | ||
""" | ||
|
||
from asyncio import AbstractEventLoop | ||
from concurrent.futures import Future | ||
from typing import Any, Coroutine, Optional, TypeVar | ||
from threading import Lock as ThreadLock, Thread | ||
|
||
import asyncio | ||
import sys | ||
|
||
if sys.version_info >= (3, 10): | ||
from typing import TypeAlias | ||
else: | ||
from typing_extensions import TypeAlias | ||
|
||
|
||
class EventLoopContext: | ||
""" | ||
EventLoopContext encapsulates a background thread for asyncio event | ||
loop processing as an aid for context managers. | ||
There is generally only expected to be one of these, either as a base | ||
class instance if it's specific to that functionality or for the full | ||
mlos_bench process to support parallel trial runners, for instance. | ||
It's enter() and exit() routines are expected to be called from the | ||
caller's context manager routines (e.g., __enter__ and __exit__). | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._event_loop: Optional[AbstractEventLoop] = None | ||
self._event_loop_thread: Optional[Thread] = None | ||
self._event_loop_thread_lock = ThreadLock() | ||
self._event_loop_thread_refcnt: int = 0 | ||
|
||
def _run_event_loop(self) -> None: | ||
""" | ||
Runs the asyncio event loop in a background thread. | ||
""" | ||
assert self._event_loop is not None | ||
asyncio.set_event_loop(self._event_loop) | ||
self._event_loop.run_forever() | ||
|
||
def enter(self) -> None: | ||
""" | ||
Manages starting the background thread for event loop processing. | ||
""" | ||
# Start the background thread if it's not already running. | ||
with self._event_loop_thread_lock: | ||
if not self._event_loop_thread: | ||
assert self._event_loop_thread_refcnt == 0 | ||
assert self._event_loop is None | ||
self._event_loop = asyncio.new_event_loop() | ||
self._event_loop_thread = Thread(target=self._run_event_loop, daemon=True) | ||
self._event_loop_thread.start() | ||
self._event_loop_thread_refcnt += 1 | ||
|
||
def exit(self) -> None: | ||
""" | ||
Manages cleaning up the background thread for event loop processing. | ||
""" | ||
with self._event_loop_thread_lock: | ||
self._event_loop_thread_refcnt -= 1 | ||
assert self._event_loop_thread_refcnt >= 0 | ||
if self._event_loop_thread_refcnt == 0: | ||
assert self._event_loop is not None | ||
self._event_loop.call_soon_threadsafe(self._event_loop.stop) | ||
assert self._event_loop_thread is not None | ||
self._event_loop_thread.join(timeout=1) | ||
if self._event_loop_thread.is_alive(): | ||
raise RuntimeError("Failed to stop event loop thread.") | ||
self._event_loop = None | ||
self._event_loop_thread = None | ||
|
||
CoroReturnType = TypeVar('CoroReturnType') | ||
if sys.version_info >= (3, 9): | ||
FutureReturnType: TypeAlias = Future[CoroReturnType] | ||
else: | ||
FutureReturnType: TypeAlias = Future | ||
|
||
def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: | ||
""" | ||
Runs the given coroutine in the background event loop thread and | ||
returns a Future that can be used to wait for the result. | ||
Parameters | ||
---------- | ||
coro : Coroutine[Any, Any, CoroReturnType] | ||
The coroutine to run. | ||
Returns | ||
------- | ||
Future[CoroReturnType] | ||
A future that will be completed when the coroutine completes. | ||
""" | ||
assert self._event_loop_thread_refcnt > 0 | ||
assert self._event_loop is not None | ||
return asyncio.run_coroutine_threadsafe(coro, self._event_loop) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# | ||
""" | ||
Tests for mlos_bench.event_loop_context background thread logic. | ||
""" | ||
|
||
import asyncio | ||
import time | ||
|
||
from threading import Thread | ||
from types import TracebackType | ||
from typing import Optional, Type | ||
from typing_extensions import Literal | ||
|
||
import pytest | ||
|
||
from mlos_bench.event_loop_context import EventLoopContext | ||
|
||
|
||
class EventLoopContextCaller: | ||
""" | ||
Simple class to test the EventLoopContext. | ||
""" | ||
|
||
EVENT_LOOP_CONTEXT = EventLoopContext() | ||
|
||
def __init__(self, instance_id: int) -> None: | ||
self._id = instance_id | ||
self._in_context = False | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(id={self._id})" | ||
|
||
def __enter__(self) -> None: | ||
assert not self._in_context | ||
self.EVENT_LOOP_CONTEXT.enter() | ||
self._in_context = True | ||
|
||
def __exit__(self, ex_type: Optional[Type[BaseException]], | ||
ex_val: Optional[BaseException], | ||
ex_tb: Optional[TracebackType]) -> Literal[False]: | ||
assert self._in_context | ||
self.EVENT_LOOP_CONTEXT.exit() | ||
self._in_context = False | ||
return False | ||
|
||
|
||
def test_event_loop_context() -> None: | ||
"""Test event loop context background thread setup/cleanup handling.""" | ||
# pylint: disable=protected-access | ||
|
||
# Should start with no event loop thread. | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None | ||
|
||
# The background thread should only be created upon context entry. | ||
event_loop_caller_instance_1 = EventLoopContextCaller(1) | ||
assert event_loop_caller_instance_1 | ||
assert not event_loop_caller_instance_1._in_context | ||
assert event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread is None | ||
|
||
# After we enter the instance context, we should have a background thread. | ||
with event_loop_caller_instance_1: | ||
assert event_loop_caller_instance_1._in_context | ||
assert isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable] | ||
# Give the thread a chance to start. | ||
# Mostly important on the underpowered Windows CI machines. | ||
time.sleep(0.25) | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread.is_alive() | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 1 | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is not None | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() | ||
|
||
event_loop_caller_instance_2 = EventLoopContextCaller(instance_id=2) | ||
assert event_loop_caller_instance_2 | ||
assert not event_loop_caller_instance_2._in_context | ||
|
||
with event_loop_caller_instance_2: | ||
assert event_loop_caller_instance_2._in_context | ||
assert event_loop_caller_instance_1._in_context | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 | ||
# We should only get one thread for all instances. | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \ | ||
is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \ | ||
is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \ | ||
is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \ | ||
is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop | ||
|
||
assert not event_loop_caller_instance_2._in_context | ||
|
||
# The background thread should remain running since we have another context still open. | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 1 | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is not None | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread.is_alive() | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is not None | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() | ||
|
||
start = time.time() | ||
future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) | ||
assert 0.0 <= time.time() - start < 0.1 | ||
assert future.result(timeout=0.2) == 'foo' | ||
assert 0.1 <= time.time() - start <= 0.2 | ||
|
||
# Once we exit the last context, the background thread should be stopped | ||
# and unusable for running co-routines. | ||
|
||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0 | ||
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is None | ||
|
||
with pytest.raises(AssertionError): | ||
event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1)) | ||
|
||
|
||
if __name__ == '__main__': | ||
# For debugging in Windows which has issues with pytest detection in vscode. | ||
pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"]) |