diff --git a/CHANGES.md b/CHANGES.md new file mode 100644 index 0000000..1b93139 --- /dev/null +++ b/CHANGES.md @@ -0,0 +1,18 @@ +# Versioning + +`minject` uses semantic versioning. To learn more about semantic versioning, see the [semantic versioning specification](https://semver.org/#semantic-versioning-200). + +# Changelog + +## v1.1.0-beta.1 + +Add support for async Python. This version introduces the following methods and decorators: + +- `Registry.__aenter__` +- `Registry.__aexit__` +- `Registry.aget` +- `@async_context` + +## v1.0.0 + +- Initial Release diff --git a/minject/__init__.py b/minject/__init__.py index fbd7319..738c842 100644 --- a/minject/__init__.py +++ b/minject/__init__.py @@ -44,7 +44,7 @@ def __init__(self, api, path): ... something = registry['something_i_need'] """ -__version__ = "1.0.0" +__version__ = "1.1.0-beta.1" from . import inject from .inject_attrs import inject_define as define, inject_field as field diff --git a/minject/asyncio_extensions.py b/minject/asyncio_extensions.py new file mode 100644 index 0000000..b694548 --- /dev/null +++ b/minject/asyncio_extensions.py @@ -0,0 +1,47 @@ +""" +This module provides fallback implementations of asyncio features that +are not available in Python 3.7. +""" + +try: + # Python 3.7 mypy raises attr-defined error for to_thread, so + # we must ignore it here. + from asyncio import to_thread # type: ignore[attr-defined] +# This is copy pasted from here: https://github.com/python/cpython/blob/03775472cc69e150ced22dc30334a7a202fc0380/Lib/asyncio/threads.py#L1-L25 +except ImportError: + """High-level support for working with threads in asyncio""" + + import contextvars + import functools + from asyncio import events + + # Minject Specific Edit: I commented out the following line + # and moved it out of the try - except block + # __all__ = "to_thread", + + # Minject Specific Edit: I removed the '/' from the function signature, + # as this is not supported in python 3.7 (added in python 3.8). + # The '/' forces that "func" be passed positionally (I.E. first argument + # to to_thread). Users of this extension must be careful to pass the argument + # to "func" positionally, or there could be different behavior + # when using minject in python 3.7 and python 3.9+. I added a "type: ignore" + # comment to silence mypy errors related to defining a function with a + # different signature. + # original asyncio source: "async def to_thread(func, /, *args, **kwargs):" + async def to_thread(func, *args, **kwargs): # type: ignore + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + """ + loop = events.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + +__all__ = "to_thread" diff --git a/minject/inject.py b/minject/inject.py index 49ac4fa..aea9d75 100644 --- a/minject/inject.py +++ b/minject/inject.py @@ -2,11 +2,25 @@ import itertools import os -from typing import Any, Callable, Dict, Optional, Sequence, Type, TypeVar, Union, cast, overload +from typing import ( + Any, + Callable, + Dict, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import TypeGuard, assert_type -from typing_extensions import TypeGuard +from minject.asyncio_extensions import to_thread +from minject.types import _AsyncContext -from .metadata import RegistryMetadata, _gen_meta, _get_meta +from .metadata import _INJECT_METADATA_ATTR, RegistryMetadata, _gen_meta, _get_meta from .model import ( Deferred, DeferredAny, @@ -14,10 +28,11 @@ Resolver, resolve_value, ) -from .types import _MinimalMappingProtocol +from .types import _AsyncContext, _MinimalMappingProtocol T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) +T_async_context = TypeVar("T_async_context", bound=_AsyncContext) R = TypeVar("R") @@ -70,6 +85,19 @@ def wrap(cls: Type[T]) -> Type[T]: return wrap +def async_context(cls: Type[T_async_context]) -> Type[T_async_context]: + """ + Declare that a class is as an async context manager + that can be initialized by the registry through aget(). This + is to distinguish the class from an async context manager that + should not be initialized by the registry (an example of + this being asyncio.Lock). + """ + meta = _gen_meta(cls) + meta.is_async_context = True + return cls + + def define( base_class: Type[T], _close: Optional[Callable[[T], None]] = None, @@ -78,7 +106,9 @@ def define( """Create a new registry key based on a class and optional bindings.""" meta = _get_meta(base_class) if meta: - meta = RegistryMetadata(base_class, bindings=dict(meta.bindings)) + meta = RegistryMetadata( + base_class, is_async_context=meta.is_async_context, bindings=dict(meta.bindings) + ) meta.update_bindings(**bindings) else: meta = RegistryMetadata(base_class, bindings=bindings) @@ -91,6 +121,29 @@ def _is_type(key: "RegistryKey[T]") -> TypeGuard[Type[T]]: return isinstance(key, type) +def _is_key_async(key: "RegistryKey[T]") -> bool: + """ + Check whether a registry key is an "async", or in other words + marked for async initialization within the registry with @async_context. + If a key is "async", it can be initialized through Registry.aget. + """ + # At present, we only consider objects with RegistryMetadata.is_async_context + # set to True to be "async", or able to be initialized through Registry.aget. + # In the future, we likely will support initializing both async and non-async + # objects through aget, but we are deferring implementing this until + # we have a bit more experience using the async Registry API. + if isinstance(key, str): + return False + elif isinstance(key, RegistryMetadata): + return key.is_async_context + else: + assert_type(key, Type[T]) + inject_metadata = _get_meta(key) + if inject_metadata is None: + return False + return inject_metadata.is_async_context + + class _RegistryReference(Deferred[T_co]): """Reference to an object in the registry to be loaded later. (you should not instantiate this class directly, instead use the @@ -103,6 +156,11 @@ def __init__(self, key: "RegistryKey[T_co]") -> None: def resolve(self, registry_impl: Resolver) -> T_co: return registry_impl.resolve(self._key) + async def aresolve(self, registry_impl: Resolver) -> T_co: + if _is_key_async(self._key): + return await registry_impl._aresolve(self._key) + return await to_thread(registry_impl.resolve, self._key) + @property def type_of_object_referenced_in_key(self) -> "Type[T_co]": if type(self.key) == RegistryMetadata: @@ -188,6 +246,9 @@ def resolve(self, registry_impl: Resolver) -> T_co: kwargs[key] = resolve_value(registry_impl, arg) return self.func()(*args, **kwargs) + async def aresolve(self, registry_impl: Resolver) -> T_co: + raise NotImplementedError("Have not implemented async registry function") + def func(self) -> Callable[..., T_co]: return self._func @@ -273,6 +334,9 @@ def resolve(self, registry_impl: Resolver) -> T_co: else: return cast(T_co, self._default) + async def aresolve(self, registry_impl: Resolver) -> T_co: + return await to_thread(self.resolve, registry_impl) + @property def key(self) -> Optional[str]: return self._key @@ -316,6 +380,9 @@ def resolve(self, registry_impl: Resolver) -> T_co: return self._default return cast(T_co, sub) + async def aresolve(self, registry_impl: Resolver) -> T_co: + return await to_thread(self.resolve, registry_impl) + def nested_config( keys: Union[Sequence[str], str], default: Union[T, _RaiseKeyError] = RAISE_KEY_ERROR @@ -375,5 +442,8 @@ class _RegistrySelf(Deferred[Resolver]): def resolve(self, registry_impl: Resolver) -> Resolver: return registry_impl + async def aresolve(self, registry_impl: Resolver) -> Resolver: + return await to_thread(self.resolve, registry_impl) + self_tag = _RegistrySelf() diff --git a/minject/metadata.py b/minject/metadata.py index 0f84eb5..52e7a69 100644 --- a/minject/metadata.py +++ b/minject/metadata.py @@ -14,7 +14,7 @@ TypeVar, ) -from .model import DeferredAny, RegistryKey, resolve_value +from .model import DeferredAny, RegistryKey, aresolve_value, resolve_value from .types import Kwargs if TYPE_CHECKING: @@ -106,12 +106,14 @@ def __init__( cls: Type[T_co], close: Optional[Callable[[T_co], None]] = None, bindings: Optional[Kwargs] = None, + is_async_context: bool = False, ): self._cls = cls self._bindings = bindings or {} self._close = close self._interfaces = [cls for cls in inspect.getmro(cls) if cls is not object] + self.is_async_context = is_async_context @property def interfaces(self) -> Sequence[Type]: @@ -159,6 +161,16 @@ def _init_object(self, obj: T_co, registry_impl: "Registry") -> None: # type: i self._cls.__init__(obj, **init_kwargs) + async def _ainit_object(self, obj: T_co, registry_impl: "Registry") -> None: # type: ignore[misc] + """ + asynchronous version of _init_object. Calls _aresolve instead + of _resolve. + """ + init_kwargs = {} + for name_, value in self._bindings.items(): + init_kwargs[name_] = await aresolve_value(registry_impl, value) + self._cls.__init__(obj, **init_kwargs) + def _close_object(self, obj: T_co) -> None: # type: ignore[misc] if self._close: self._close(obj) diff --git a/minject/model.py b/minject/model.py index c2969a8..16bcca3 100644 --- a/minject/model.py +++ b/minject/model.py @@ -34,6 +34,21 @@ class Resolver(abc.ABC): def resolve(self, key: "RegistryKey[T]") -> T: ... + async def _aresolve(self, key: "RegistryKey[T]") -> T: + """ + Resolve a key into an instance of that key. The key must be marked + with the @async_context decorator. + """ + raise NotImplementedError("Please implement _aresolve.") + + async def _push_async_context(self, key: Any) -> Any: + """ + Push an async context onto the context stack maintained by the Resolver. + This is necessary to enter/close the context of an object + marked with @async_context. + """ + raise NotImplementedError + @property @abc.abstractmethod def config(self) -> RegistryConfigWrapper: @@ -52,6 +67,15 @@ class Deferred(abc.ABC, Generic[T_co]): def resolve(self, registry_impl: Resolver) -> T_co: ... + @abc.abstractmethod + async def aresolve(self, registry_impl: Resolver) -> T_co: + """ + Resolve a deferred object into an instance of the object. The object, + may or may not be asynchronous. If the object is asynchronous (marked with @async_context), + resolve the object asynchronously. Otherwise, resolve synchronously. + """ + ... + Resolvable = Union[Deferred[T_co], T_co] # Union of Deferred and Any is just Any, but want to call out that a Deffered is quite common @@ -69,3 +93,13 @@ def resolve_value(registry_impl: Resolver, value: Resolvable[T]) -> T: return value.resolve(registry_impl) else: return value + + +async def aresolve_value(registry_impl: Resolver, value: Resolvable[T]) -> T: + """ + Async version of resolve_value, which calls aresolve on Deferred instances. + """ + if isinstance(value, Deferred): + return await value.aresolve(registry_impl) + else: + return value diff --git a/minject/registry.py b/minject/registry.py index 0802fc7..bc800a5 100644 --- a/minject/registry.py +++ b/minject/registry.py @@ -1,14 +1,19 @@ """The Registry itself is a runtime collection of initialized classes.""" import functools import logging +from contextlib import AsyncExitStack +from textwrap import dedent from threading import RLock -from typing import Callable, Dict, Generic, Iterable, List, Optional, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, TypeVar, Union, cast from typing_extensions import Concatenate, ParamSpec +from minject.asyncio_extensions import to_thread +from minject.inject import _is_key_async, _RegistryReference, reference + from .config import RegistryConfigWrapper, RegistryInitConfig from .metadata import RegistryMetadata, _get_meta, _get_meta_from_key -from .model import RegistryKey, Resolvable, Resolver, resolve_value +from .model import RegistryKey, Resolvable, Resolver, aresolve_value, resolve_value LOG = logging.getLogger(__name__) @@ -76,6 +81,9 @@ def __init__(self, config: Optional[RegistryInitConfig] = None): self._lock = RLock() + self._async_context_stack: AsyncExitStack = AsyncExitStack() + self._async_entered = False + if config is not None: self._config._from_dict(config) @@ -86,6 +94,26 @@ def config(self) -> RegistryConfigWrapper: def resolve(self, key: "RegistryKey[T]") -> T: return self[key] + async def _aresolve(self, key: "RegistryKey[T]") -> T: + result = await self.aget(key) + if result is None: + raise KeyError(key, "could not be resolved") + return result + + async def _push_async_context(self, key: Any) -> Any: + result = await self._async_context_stack.enter_async_context(key) + if result is not key: + raise ValueError( + dedent( + """ + Classes decorated with @async_context must + return the same value from __aenter__ as they do + from their constructor. Hint: __aenter__ should + return self. + """ + ).strip() + ) + @_synchronized def close(self) -> None: """Close all objects contained in the registry.""" @@ -177,6 +205,38 @@ def _register_by_metadata(self, meta: RegistryMetadata[T]) -> RegistryWrapper[T] return wrapper + async def _aregister_by_metadata(self, meta: RegistryMetadata[T]) -> RegistryWrapper[T]: + """ + async version of _register_by_metadata. Calls _ainit_object instead of _init_object. + """ + LOG.debug("registering %s", meta) + + # allocate the object (but don't initialize yet) + obj = meta._new_object() + + # add to the registry (done before init in case of circular reference) + wrapper = self._set_by_metadata(meta, obj, _global=False) + + success = False + try: + # initialize the object + await meta._ainit_object(obj, self) + # add to our list of all objects (this MUST happen after init so + # any references come earlier in sequence and are destroyed first) + self._objects.append(wrapper) + + # after creating an object, enter the objects context + # if it is marked with the @async_context decorator. + if meta.is_async_context: + await self._push_async_context(obj) + + success = True + finally: + if not success: + self._remove_by_metadata(meta, wrapper, _global=False) + + return wrapper + @_synchronized def _get_by_metadata( self, meta: RegistryMetadata[T], default: Optional[Union[T, _AutoOrNone]] = AUTO_OR_NONE @@ -198,6 +258,17 @@ def _get_by_metadata( else: return None + async def _aget_by_metadata(self, meta: RegistryMetadata[T]) -> Optional[RegistryWrapper[T]]: + """ + async version of _get_by_metadata. The default argument has been removed + from the signature, as there is no use case for it at the time of writing. + Please make a feature request if you need this functionality. + """ + if meta in self._by_meta: + return self._by_meta[meta] + + return await self._aregister_by_metadata(meta) + @_synchronized def __len__(self) -> int: return len(self._objects) @@ -238,6 +309,11 @@ def get( Returns: The requested object or default if not found. """ + if _is_key_async(key): + raise AssertionError( + "cannot use synchronous get on async object (object marked with @async_context)" + ) + if key == object: return None # NEVER auto-init plain object @@ -245,21 +321,90 @@ def get( return _unwrap(self._by_name.get(key, RegistryWrapper(cast(T, default)))) meta = _get_meta_from_key(key) + maybe_class = self._get_if_already_in_registry(key, meta) + if maybe_class is not None: + return maybe_class - if isinstance(key, type): - # if a type has metadata attached to it as an attribute, - # the registry must use that metadata to construct the object - # or query for a constructed object. This is because the user - # has intentionally added metadata to the class, and thus - # we should not use metadata inherited from interfaces. - if _get_meta(key, include_bases=False) is not None: - return _unwrap(self._get_by_metadata(meta, default)) + return _unwrap(self._get_by_metadata(meta, default)) - obj_list = self._by_iface.get(key) - if obj_list: - return _unwrap(obj_list[0]) + async def aget(self, key: "RegistryKey[T]") -> Optional[T]: + """ + Resolve objects marked with the @async_context decorator. + """ + if not _is_key_async(key): + raise AssertionError("key must be async to use aget") - return _unwrap(self._get_by_metadata(meta, default)) + if not self._async_entered: + raise AssertionError("cannot use aget outside of async context") + + meta = _get_meta_from_key(key) + maybe_initialized_obj = self._get_if_already_in_registry(key, meta) + if maybe_initialized_obj is not None: + return maybe_initialized_obj + + initialized_obj = await self._aget_by_metadata(meta) + return _unwrap(initialized_obj) + + def _get_if_already_in_registry( + self, key: "RegistryKey[T]", meta: "RegistryMetadata[T]" + ) -> Optional[T]: + # retrieve the class metdata, and the metadata of class + # without inherited metadata + meta = _get_meta_from_key(key) + + # if the class has already been registered, return it + if meta in self._by_meta: + return _unwrap(self._by_meta[meta]) + + # following checks only apply if key is a class + if not isinstance(key, type): + return None + + # If the class (key) has no metadata, but an object exists in the + # registry that is a concrete subtype of the class, return that + # object. If the class has metadata, we must use the metadata to + # construct the class, and we should not check the registry for + # concrete subtypes. A user must specify metadata for a class itself + # in order to force the registry to use that metadata to construct the + # class, inherited metadata alone does not prevent the registry + # from returning a concrete subtype of the class. + meta_no_bases = _get_meta(key, include_bases=False) + obj_list = self._by_iface.get(key) + if meta_no_bases is None and obj_list: + return _unwrap(obj_list[0]) + + # nothing has been registered for this metadata yet + return None + + async def __aenter__(self) -> "Registry": + """ + Mark a registry instance as ready for resolving async objects. + """ + if self._async_entered: + raise AssertionError( + "Attempting to enter registry context while already in context. This should not happen." + ) + self._async_entered = True + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """ + Closes the registry. Closes all contexts on the registry's context stack + and then closes the registry itself with regisry.close(). + """ + if not self._async_entered: + raise AssertionError( + "Attempting to exit registry context while not in context. This should not happen." + ) + self._async_entered = False + # close all objects in the registry + try: + await self._async_context_stack.aclose() + finally: + # as we currently only support async -> sync transitions, + # a sync class cannot depend on an async class, and thus we may + # safely close all sync classes after closing all async classes. + await to_thread(self.close) def __getitem__(self, key: "RegistryKey[T]") -> T: """Get an object from the registry by a key. diff --git a/minject/types.py b/minject/types.py index b6e9ab1..ac568a9 100644 --- a/minject/types.py +++ b/minject/types.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, TypeVar +from types import TracebackType +from typing import Any, Dict, Type, TypeVar -from typing_extensions import Protocol, runtime_checkable +from typing_extensions import Protocol, Self, runtime_checkable Arg = Any Kwargs = Dict[str, Arg] @@ -22,3 +23,19 @@ def __getitem__(self, key: K_contra) -> V_co: def __contains__(self, key: K_contra) -> bool: ... + + +class _AsyncContext(Protocol): + """ + Protocol for an object that can be marked with the @async_context + decorator. This is any async context manager that return Self from + it's __aenter__ method. + """ + + async def __aenter__(self: Self) -> Self: + ... + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + ... diff --git a/pyproject.toml b/pyproject.toml index be8c948..fcd235c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "pytest-cov", "pytest-mock", "pytest-xdist", + 'pytest-asyncio', "typing", ] @@ -43,6 +44,7 @@ dependencies = [ 'pytest', 'pytest-mock', 'pytest-randomly', + 'pytest-asyncio', 'pytest-rerunfailures', 'pytest-xdist[psutil]', ] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..4088045 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode=auto diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..c01e9b9 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,347 @@ +from types import TracebackType +from typing import Dict, Type + +import pytest + +from minject.inject import async_context, bind, config, define, nested_config, reference +from minject.registry import AUTO_OR_NONE, Registry + +TEXT = "we love tests" + + +@pytest.fixture +def registry() -> Registry: + config_dict: Dict[str, str] = {"a": {"b": "c"}, 1: 2} + r = Registry(config=config_dict) + return r + + +class MyDependencyNotSpecifiedAsync: + def __init__(self) -> None: + self.in_context = False + + async def __aenter__(self) -> "MyDependencyNotSpecifiedAsync": + self.in_context = True + return self + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + del exc_type, exc_value, traceback + self.in_context = False + + +@async_context +class MyDependencyAsync: + def __init__(self) -> None: + self.in_context = False + + async def __aenter__(self) -> "MyDependencyAsync": + self.in_context = True + return self + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + del exc_type, exc_value, traceback + self.in_context = False + + +@async_context +class MyAsyncAPIContextCounter: + def __init__(self) -> None: + self.in_context = False + self.entered_context_counter = 0 + self.exited_context_counter = 0 + + async def __aenter__(self) -> "MyAsyncAPIContextCounter": + self.in_context = True + self.entered_context_counter += 1 + return self + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + del exc_type, exc_value, traceback + self.exited_context_counter += 1 + self.in_context = False + + +@bind(dep_async=reference(MyDependencyAsync)) +class MySyncClassWithAsyncDependency: + def __init__(self, dep_async: MyDependencyAsync) -> None: + self.dep_async = dep_async + + +@bind(_close=lambda self: self.close()) +class MySyncClassWithCloseMethod: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +@async_context +@bind(sync_close_dep=reference(MySyncClassWithCloseMethod)) +class MyAsyncClassWithSyncCloseDependency: + def __init__( + self, sync_close_dep: "MySyncClassWithCloseMethod", will_throw: bool = False + ) -> None: + self.sync_close_dep = sync_close_dep + self.entered = False + self.will_throw = will_throw + + async def __aenter__(self) -> "MyAsyncClassWithSyncCloseDependency": + self.entered = True + return self + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + del exc_type, exc_value, traceback + self.entered = False + if self.will_throw: + raise ValueError("This is a test error") + + +@async_context +@bind(text=TEXT) +@bind(dep_async=reference(MyDependencyAsync)) +@bind(dep_not_specified=reference(MyDependencyNotSpecifiedAsync)) +@bind(dep_context_counter=reference(MyAsyncAPIContextCounter)) +class MyAsyncApi: + def __init__( + self, + text: str, + dep_async: MyDependencyAsync, + dep_not_specified: MyDependencyNotSpecifiedAsync, + dep_context_counter: MyAsyncAPIContextCounter, + ) -> None: + self.text = text + self.in_context = False + self.dep_async = dep_async + self.dep_not_specified = dep_not_specified + self.dep_context_counter = dep_context_counter + + async def __aenter__(self) -> "MyAsyncApi": + self.in_context = True + return self + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + del exc_type, exc_value, traceback + self.in_context = False + + +MY_ASYNC_API_DEFINE = define( + MyAsyncApi, + text=TEXT, + dep_async=reference(MyDependencyAsync), + dep_not_specified=reference(MyDependencyNotSpecifiedAsync), + dep_context_counter=reference(MyAsyncAPIContextCounter), +) + + +@async_context +@bind(nested=nested_config("a.b")) +@bind(flat=config(1)) +class MyAsyncApiWithConfig: + def __init__(self, nested: str, flat: int) -> None: + self.nested = nested + self.flat = flat + + async def __aenter__(self) -> "MyAsyncApiWithConfig": + return self + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: + pass + + +@async_context +class BadContextManager: + def __init__(self) -> None: + pass + + async def __aenter__(self) -> 1: + return 1 + + async def __aexit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ): + pass + + +async def test_async_registry_simple(registry: Registry) -> None: + async with registry as r: + my_api = await r.aget(MyDependencyAsync) + assert my_api.in_context == True + + assert my_api.in_context == False + + +async def test_async_registry_recursive(registry: Registry) -> None: + async with registry as r: + my_api = await r.aget(MyAsyncApi) + assert my_api.text == TEXT + assert my_api.in_context == True + assert my_api.dep_async.in_context == True + assert my_api.dep_not_specified.in_context == False + assert my_api.dep_context_counter.in_context == True + assert my_api.dep_context_counter.entered_context_counter == 1 + + assert my_api.in_context == False + assert my_api.dep_async.in_context == False + assert my_api.dep_not_specified.in_context == False + assert my_api.dep_context_counter.in_context == False + assert my_api.dep_context_counter.exited_context_counter == 1 + + +async def test_multiple_instantiation_child(registry: Registry) -> None: + my_api: MyAsyncApi + async with registry as r: + my_api = await r.aget(MyAsyncApi) + my_api_2 = await r.aget(MyAsyncApi) + my_api_3 = await r.aget(MyAsyncApi) + assert my_api is my_api_2 is my_api_3 + + assert my_api.dep_context_counter.entered_context_counter == 1 + + assert my_api.dep_context_counter.exited_context_counter == 1 + + +async def test_multiple_instantiation_top_level(registry: Registry) -> None: + my_counter: MyAsyncAPIContextCounter + async with registry as r: + my_counter = await r.aget(MyAsyncAPIContextCounter) + my_api_2 = await r.aget(MyAsyncAPIContextCounter) + my_api_3 = await r.aget(MyAsyncAPIContextCounter) + assert my_counter is my_api_2 is my_api_3 + assert my_counter.entered_context_counter == 1 + assert my_counter.exited_context_counter == 1 + + +async def test_multiple_instantiation_mixed(registry: Registry) -> None: + my_counter: MyAsyncAPIContextCounter + async with registry as r: + my_counter = await r.aget(MyAsyncAPIContextCounter) + assert my_counter.entered_context_counter == 1 + await r.aget(MyAsyncApi) + assert my_counter.entered_context_counter == 1 + assert my_counter.exited_context_counter == 1 + + +async def test_async_context_outside_context_manager(registry: Registry) -> None: + with pytest.raises(AssertionError): + # attempting to instantiate a class + # marked with @async_context without + # being in an async context should + # raise an error + _ = await registry.aget(MyAsyncApi) + + +async def test_try_instantiate_async_class_with_sync_api(registry: Registry) -> None: + with pytest.raises(AssertionError): + # attempting to instantiate a class + # marked with @async_context using sync API + # should raise an error + _ = registry[MyDependencyAsync] + + with pytest.raises(AssertionError): + # still throws an error even when registry context + # has been entered + async with registry as r: + _ = r[MyAsyncApi] + + +async def test_context_manager_aenter_must_return_self(registry: Registry) -> None: + """ + Async context manager must return self from aenter method, + throw value error otherwise. + """ + async with registry as r: + with pytest.raises(ValueError): + _ = await r.aget(BadContextManager) + + +async def test_config_in_async(registry: Registry) -> None: + async with registry as r: + r = await r.aget(MyAsyncApiWithConfig) + assert r.nested == "c" + assert r.flat == 2 + + +async def test_entering_already_entered_registry_throws(registry: Registry) -> None: + async with registry as r: + with pytest.raises(AssertionError): + async with r: + pass + + +async def test_define(registry: Registry) -> None: + async with registry as r: + my_api = await r.aget(MY_ASYNC_API_DEFINE) + assert my_api.text == TEXT + assert my_api.in_context == True + assert my_api.dep_async.in_context == True + assert my_api.dep_not_specified.in_context == False + assert my_api.dep_context_counter.in_context == True + assert my_api.dep_context_counter.entered_context_counter == 1 + + assert my_api.in_context == False + assert my_api.dep_async.in_context == False + assert my_api.dep_not_specified.in_context == False + assert my_api.dep_context_counter.in_context == False + assert my_api.dep_context_counter.exited_context_counter == 1 + + +def test_get_item_sync_class_async_dependency_throws(registry: Registry) -> None: + with pytest.raises(AssertionError): + _ = registry[MySyncClassWithAsyncDependency] + + +def test_get_sync_class_async_dependency_throws(registry: Registry) -> None: + with pytest.raises(AssertionError): + _ = registry.get(MySyncClassWithAsyncDependency, AUTO_OR_NONE) + + +async def test_exit_logic_success(registry: Registry) -> None: + async with registry as r: + my_cls = await r.aget(MyAsyncClassWithSyncCloseDependency) + assert my_cls.entered == True + assert my_cls.sync_close_dep.closed == False + + assert my_cls.entered == False + assert my_cls.sync_close_dep.closed == True + + +async def test_exit_logic_failure(registry: Registry) -> None: + with pytest.raises(ValueError): + async with registry as r: + bindings = define( + MyAsyncClassWithSyncCloseDependency, + sync_close_dep=reference(MySyncClassWithCloseMethod), + will_throw=True, + ) + my_cls = await r.aget(bindings) + assert my_cls.entered == True + assert my_cls.sync_close_dep.closed == False + + assert my_cls.entered == False + assert my_cls.sync_close_dep.closed == True + + +async def test_async_contains(registry: Registry) -> None: + async with registry as r: + assert (MyAsyncApi in r) is False + assert (MyDependencyAsync in r) is False + assert (MyDependencyNotSpecifiedAsync in r) is False + + _ = await r.aget(MyAsyncApi) + + assert (MyAsyncApi in r) is True + assert (MyDependencyAsync in r) is True + assert (MyDependencyNotSpecifiedAsync in r) is True