diff --git a/redis/event.py b/redis/event.py index c44f97f4e5..9a214b0753 100644 --- a/redis/event.py +++ b/redis/event.py @@ -247,8 +247,10 @@ def listen(self, event: AfterPooledConnectionsInstantiationEvent): if event.client_type == ClientType.SYNC: event.credential_provider.on_next(self._re_auth) + event.credential_provider.on_error(self._raise_on_error) else: event.credential_provider.on_next(self._re_auth_async) + event.credential_provider.on_error(self._raise_on_error_async) def _re_auth(self, token): for pool in self._event.connection_pools: @@ -258,6 +260,12 @@ async def _re_auth_async(self, token): for pool in self._event.connection_pools: await pool.re_auth_callback(token) + def _raise_on_error(self, error: Exception): + raise error + + async def _raise_on_error_async(self, error: Exception): + raise error + class RegisterReAuthForSingleConnection(EventListenerInterface): """ @@ -273,8 +281,10 @@ def listen(self, event: AfterSingleConnectionInstantiationEvent): if event.client_type == ClientType.SYNC: event.connection.credential_provider.on_next(self._re_auth) + event.connection.credential_provider.on_error(self._raise_on_error) else: event.connection.credential_provider.on_next(self._re_auth_async) + event.connection.credential_provider.on_error(self._raise_on_error_async) def _re_auth(self, token): with self._event.connection_lock: @@ -286,6 +296,12 @@ async def _re_auth_async(self, token): await self._event.connection.send_command('AUTH', token.try_get('oid'), token.get_value()) await self._event.connection.read_response() + def _raise_on_error(self, error: Exception): + raise error + + async def _raise_on_error_async(self, error: Exception): + raise error + class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): def __init__(self): @@ -295,11 +311,15 @@ def listen(self, event: AfterAsyncClusterInstantiationEvent): if isinstance(event.credential_provider, StreamingCredentialProvider): self._event = event event.credential_provider.on_next(self._re_auth) + event.credential_provider.on_error(self._raise_on_error) async def _re_auth(self, token: TokenInterface): for key in self._event.nodes: await self._event.nodes[key].re_auth_callback(token) + async def _raise_on_error(self, error: Exception): + raise error + class RegisterReAuthForPubSub(EventListenerInterface): def __init__(self): @@ -320,8 +340,10 @@ def listen(self, event: AfterPubSubConnectionInstantiationEvent): if self._client_type == ClientType.SYNC: self._connection.credential_provider.on_next(self._re_auth) + self._connection.credential_provider.on_error(self._raise_on_error) else: self._connection.credential_provider.on_next(self._re_auth_async) + self._connection.credential_provider.on_error(self._raise_on_error_async) def _re_auth(self, token: TokenInterface): with self._connection_lock: @@ -336,3 +358,9 @@ async def _re_auth_async(self, token: TokenInterface): await self._connection.read_response() await self._connection_pool.re_auth_callback(token) + + def _raise_on_error(self, error: Exception): + raise error + + async def _raise_on_error_async(self, error: Exception): + raise error diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 83f339ac71..81e73d685c 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -14,6 +14,7 @@ from redis import AuthenticationError, DataError, ResponseError, RedisError from redis.asyncio import Redis, Connection, ConnectionPool from redis.asyncio.retry import Retry +from redis.auth.err import RequestTokenErr from redis.exceptions import ConnectionError from redis.backoff import NoBackoff from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -534,6 +535,42 @@ async def re_auth_callback(token): call('AUTH', auth_token.try_get('oid'), auth_token.get_value()) ]) + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + async def test_fails_on_token_renewal(self, credential_provider): + credential_provider._token_mgr._idp.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr + ] + mock_connection = Mock(spec=Connection) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_another_connection = Mock(spec=Connection) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection, mock_another_connection] + + await Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + with pytest.raises(RequestTokenErr): + await credential_provider.get_credentials() + @pytest.mark.asyncio @pytest.mark.onlynoncluster diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 94767f18ae..00e691129a 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -17,6 +17,7 @@ from redis import AuthenticationError, DataError, ResponseError, Redis, asyncio from redis.asyncio import Redis as AsyncRedis, Connection from redis.asyncio import ConnectionPool as AsyncConnectionPool +from redis.auth.err import RequestTokenErr from redis.auth.idp import IdentityProviderInterface from redis.exceptions import ConnectionError, RedisError from redis.backoff import NoBackoff @@ -512,6 +513,43 @@ def re_auth_callback(token): call('AUTH', auth_token.try_get('oid'), auth_token.get_value()) ]) + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + def test_fails_on_token_renewal(self, credential_provider): + credential_provider._token_mgr._idp.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr + ] + mock_connection = Mock(spec=ConnectionInterface) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_another_connection = Mock(spec=ConnectionInterface) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection, mock_another_connection] + mock_pool._lock = threading.Lock() + + Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + with pytest.raises(RequestTokenErr): + credential_provider.get_credentials() + @pytest.mark.onlynoncluster @pytest.mark.cp_integration