Skip to content

Commit

Permalink
Added error propagation to main thread
Browse files Browse the repository at this point in the history
  • Loading branch information
vladvildanov committed Dec 20, 2024
1 parent a9c200c commit 4527bf0
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
28 changes: 28 additions & 0 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,10 @@ def listen(self, event: AfterPooledConnectionsInstantiationEvent):

if event.client_type == ClientType.SYNC:
event.credential_provider.on_next(self._re_auth)
event.credential_provider.on_error(self._raise_on_error)
else:
event.credential_provider.on_next(self._re_auth_async)
event.credential_provider.on_error(self._raise_on_error_async)

def _re_auth(self, token):
for pool in self._event.connection_pools:
Expand All @@ -258,6 +260,12 @@ async def _re_auth_async(self, token):
for pool in self._event.connection_pools:
await pool.re_auth_callback(token)

def _raise_on_error(self, error: Exception):
raise error

async def _raise_on_error_async(self, error: Exception):
raise error


class RegisterReAuthForSingleConnection(EventListenerInterface):
"""
Expand All @@ -273,8 +281,10 @@ def listen(self, event: AfterSingleConnectionInstantiationEvent):

if event.client_type == ClientType.SYNC:
event.connection.credential_provider.on_next(self._re_auth)
event.connection.credential_provider.on_error(self._raise_on_error)
else:
event.connection.credential_provider.on_next(self._re_auth_async)
event.connection.credential_provider.on_error(self._raise_on_error_async)

def _re_auth(self, token):
with self._event.connection_lock:
Expand All @@ -286,6 +296,12 @@ async def _re_auth_async(self, token):
await self._event.connection.send_command('AUTH', token.try_get('oid'), token.get_value())
await self._event.connection.read_response()

def _raise_on_error(self, error: Exception):
raise error

async def _raise_on_error_async(self, error: Exception):
raise error


class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
def __init__(self):
Expand All @@ -295,11 +311,15 @@ def listen(self, event: AfterAsyncClusterInstantiationEvent):
if isinstance(event.credential_provider, StreamingCredentialProvider):
self._event = event
event.credential_provider.on_next(self._re_auth)
event.credential_provider.on_error(self._raise_on_error)

async def _re_auth(self, token: TokenInterface):
for key in self._event.nodes:
await self._event.nodes[key].re_auth_callback(token)

async def _raise_on_error(self, error: Exception):
raise error


class RegisterReAuthForPubSub(EventListenerInterface):
def __init__(self):
Expand All @@ -320,8 +340,10 @@ def listen(self, event: AfterPubSubConnectionInstantiationEvent):

if self._client_type == ClientType.SYNC:
self._connection.credential_provider.on_next(self._re_auth)
self._connection.credential_provider.on_error(self._raise_on_error)
else:
self._connection.credential_provider.on_next(self._re_auth_async)
self._connection.credential_provider.on_error(self._raise_on_error_async)

def _re_auth(self, token: TokenInterface):
with self._connection_lock:
Expand All @@ -336,3 +358,9 @@ async def _re_auth_async(self, token: TokenInterface):
await self._connection.read_response()

await self._connection_pool.re_auth_callback(token)

def _raise_on_error(self, error: Exception):
raise error

async def _raise_on_error_async(self, error: Exception):
raise error
37 changes: 37 additions & 0 deletions tests/test_asyncio/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from redis import AuthenticationError, DataError, ResponseError, RedisError
from redis.asyncio import Redis, Connection, ConnectionPool
from redis.asyncio.retry import Retry
from redis.auth.err import RequestTokenErr
from redis.exceptions import ConnectionError
from redis.backoff import NoBackoff
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
Expand Down Expand Up @@ -534,6 +535,42 @@ async def re_auth_callback(token):
call('AUTH', auth_token.try_get('oid'), auth_token.get_value())
])

@pytest.mark.parametrize(
"credential_provider",
[
{
"cred_provider_class": EntraIdCredentialsProvider,
"cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005},
"mock_idp": True,
}
],
indirect=True,
)
async def test_fails_on_token_renewal(self, credential_provider):
credential_provider._token_mgr._idp.request_token.side_effect = [
RequestTokenErr,
RequestTokenErr,
RequestTokenErr,
RequestTokenErr
]
mock_connection = Mock(spec=Connection)
mock_connection.retry = Retry(NoBackoff(), 0)
mock_another_connection = Mock(spec=Connection)
mock_pool = Mock(spec=ConnectionPool)
mock_pool.connection_kwargs = {
"credential_provider": credential_provider,
}
mock_pool.get_connection.return_value = mock_connection
mock_pool._available_connections = [mock_connection, mock_another_connection]

await Redis(
connection_pool=mock_pool,
credential_provider=credential_provider,
)

with pytest.raises(RequestTokenErr):
await credential_provider.get_credentials()


@pytest.mark.asyncio
@pytest.mark.onlynoncluster
Expand Down
38 changes: 38 additions & 0 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from redis import AuthenticationError, DataError, ResponseError, Redis, asyncio
from redis.asyncio import Redis as AsyncRedis, Connection
from redis.asyncio import ConnectionPool as AsyncConnectionPool
from redis.auth.err import RequestTokenErr
from redis.auth.idp import IdentityProviderInterface
from redis.exceptions import ConnectionError, RedisError
from redis.backoff import NoBackoff
Expand Down Expand Up @@ -512,6 +513,43 @@ def re_auth_callback(token):
call('AUTH', auth_token.try_get('oid'), auth_token.get_value())
])

@pytest.mark.parametrize(
"credential_provider",
[
{
"cred_provider_class": EntraIdCredentialsProvider,
"cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005},
"mock_idp": True,
}
],
indirect=True,
)
def test_fails_on_token_renewal(self, credential_provider):
credential_provider._token_mgr._idp.request_token.side_effect = [
RequestTokenErr,
RequestTokenErr,
RequestTokenErr,
RequestTokenErr
]
mock_connection = Mock(spec=ConnectionInterface)
mock_connection.retry = Retry(NoBackoff(), 0)
mock_another_connection = Mock(spec=ConnectionInterface)
mock_pool = Mock(spec=ConnectionPool)
mock_pool.connection_kwargs = {
"credential_provider": credential_provider,
}
mock_pool.get_connection.return_value = mock_connection
mock_pool._available_connections = [mock_connection, mock_another_connection]
mock_pool._lock = threading.Lock()

Redis(
connection_pool=mock_pool,
credential_provider=credential_provider,
)

with pytest.raises(RequestTokenErr):
credential_provider.get_credentials()


@pytest.mark.onlynoncluster
@pytest.mark.cp_integration
Expand Down

0 comments on commit 4527bf0

Please sign in to comment.