diff --git a/reflex/config.py b/reflex/config.py index 049cc2e834..a57dfd7313 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -545,6 +545,9 @@ class EnvironmentVariables: # Where to save screenshots when tests fail. SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None) + # Optional redis key prefix for the state manager. + REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None) + environment = EnvironmentVariables() diff --git a/reflex/state.py b/reflex/state.py index 94ff35a884..657827b88d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -42,7 +42,7 @@ from typing_extensions import Self from reflex import event -from reflex.config import get_config +from reflex.config import EnvironmentVariables, get_config from reflex.istate.data import RouterData from reflex.istate.storage import ClientStorageBase from reflex.vars.base import ( @@ -3074,6 +3074,26 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes) + + +def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE: + """Prefix the token with the redis prefix. + + Args: + token: The token to prefix. + + Returns: + The prefixed token. + """ + prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get() + if not prefix: + return token + if isinstance(token, bytes): + return prefix.encode() + token + return f"{prefix}{token}" + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3211,7 +3231,7 @@ async def get_state( state = None # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + redis_state = await self.redis.get(prefix_redis_token(token)) if redis_state is not None: # Deserialize the substate. @@ -3266,7 +3286,8 @@ async def set_state( # Check that we're holding the lock. if ( lock_id is not None - and await self.redis.get(self._lock_key(token)) != lock_id + and await self.redis.get(prefix_redis_token(self._lock_key(token))) + != lock_id ): raise LockExpiredError( f"Lock expired for token {token} while processing. Consider increasing " @@ -3297,7 +3318,7 @@ async def set_state( pickle_state = state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + prefix_redis_token(_substate_key(client_token, state)), pickle_state, ex=self.token_expiration, ) @@ -3347,7 +3368,7 @@ async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: True if the lock was obtained. """ return await self.redis.set( - lock_key, + prefix_redis_token(lock_key), lock_id, px=self.lock_expiration, nx=True, # only set if it doesn't exist @@ -3382,7 +3403,7 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: while not state_is_locked: # wait for the lock to be released while True: - if not await self.redis.exists(lock_key): + if not await self.redis.exists(prefix_redis_token(lock_key)): break # key was removed, try to get the lock again message = await pubsub.get_message( ignore_subscribe_messages=True, @@ -3423,7 +3444,7 @@ async def _lock(self, token: str): finally: if state_is_locked: # only delete our lock - await self.redis.delete(lock_key) + await self.redis.delete(prefix_redis_token(lock_key)) async def close(self): """Explicitly close the redis connection and connection_pool. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2ce0b7bd52..7ffdcd79d2 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,7 @@ import sys import threading from textwrap import dedent -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock import pytest @@ -40,6 +40,7 @@ StateProxy, StateUpdate, _substate_key, + prefix_redis_token, ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types @@ -1671,7 +1672,7 @@ async def test_state_manager_modify_state( """ async with state_manager.modify_state(substate_token) as state: if isinstance(state_manager, StateManagerRedis): - assert await state_manager.redis.get(f"{token}_lock") + assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() @@ -1681,7 +1682,9 @@ async def test_state_manager_modify_state( state.complex[3] = complex_1 # lock should be dropped after exiting the context if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + assert ( + await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert not state_manager._states_locks[token].locked() @@ -1723,7 +1726,9 @@ async def _coro(): assert (await state_manager.get_state(substate_token)).num1 == exp_num1 if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + assert ( + await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert not state_manager._states_locks[token].locked() @@ -1783,7 +1788,7 @@ async def test_state_manager_lock_expire( @pytest.mark.asyncio async def test_state_manager_lock_expire_contend( - state_manager_redis: StateManager, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str ): """Test that the state manager lock expires and queued waiters proceed. @@ -1825,6 +1830,41 @@ async def _coro_waiter(): assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.fixture(scope="function") +def redis_prefix() -> Generator[str, None, None]: + """Fixture for redis prefix. + + Yields: + A redis prefix. + """ + prefix = "test_prefix" + reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix) + yield prefix + reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(None) + + +@pytest.mark.asyncio +async def test_state_manager_redis_prefix( + state_manager_redis: StateManagerRedis, + substate_token_redis: str, + redis_prefix: str, +): + """Test that the state manager redis prefix is applied correctly. + + Args: + state_manager_redis: A state manager instance. + substate_token_redis: A token + substate name for looking up in state manager. + redis_prefix: A redis prefix. + """ + async with state_manager_redis.modify_state(substate_token_redis) as state: + state.num1 = 42 + + prefixed_token = prefix_redis_token(substate_token_redis) + assert prefixed_token == f"{redis_prefix}{substate_token_redis}" + + assert await state_manager_redis.redis.get(prefixed_token) + + @pytest.fixture(scope="function") def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: """Mock app fixture.