diff --git a/asyncssh/connection.py b/asyncssh/connection.py index dafeb91..bd30c82 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -179,6 +179,7 @@ _ProtocolFactory = Union[_ClientFactory, _ServerFactory] _Conn = TypeVar('_Conn', 'SSHClientConnection', 'SSHServerConnection') +_ConnSelf = TypeVar('_ConnSelf', bound='SSHConnection') class _TunnelProtocol(Protocol): """Base protocol for connections to tunnel SSH over""" @@ -944,7 +945,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._disable_trivial_auth = False - async def __aenter__(self) -> 'SSHConnection': + async def __aenter__(self: _ConnSelf) -> _ConnSelf: """Allow SSHConnection to be used as an async context manager""" return self