Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement redis key prefix for StateManagerRedis #4307

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,9 @@ class EnvironmentVariables:
# Where to save screenshots when tests fail.
SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None)

# Optional redis key prefix for the state manager.
REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None)


environment = EnvironmentVariables()

Expand Down
35 changes: 28 additions & 7 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from typing_extensions import Self

from reflex import event
from reflex.config import get_config
from reflex.config import EnvironmentVariables, get_config
from reflex.istate.data import RouterData
from reflex.istate.storage import ClientStorageBase
from reflex.vars.base import (
Expand Down Expand Up @@ -3074,6 +3074,26 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration


TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)


def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE:
"""Prefix the token with the redis prefix.

Args:
token: The token to prefix.

Returns:
The prefixed token.
"""
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
if not prefix:
return token
if isinstance(token, bytes):
return prefix.encode() + token
return f"{prefix}{token}"


class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""

Expand Down Expand Up @@ -3211,7 +3231,7 @@ async def get_state(
state = None

# Fetch the serialized substate from redis.
redis_state = await self.redis.get(token)
redis_state = await self.redis.get(prefix_redis_token(token))

if redis_state is not None:
# Deserialize the substate.
Expand Down Expand Up @@ -3266,7 +3286,8 @@ async def set_state(
# Check that we're holding the lock.
if (
lock_id is not None
and await self.redis.get(self._lock_key(token)) != lock_id
and await self.redis.get(prefix_redis_token(self._lock_key(token)))
!= lock_id
):
raise LockExpiredError(
f"Lock expired for token {token} while processing. Consider increasing "
Expand Down Expand Up @@ -3297,7 +3318,7 @@ async def set_state(
pickle_state = state._serialize()
if pickle_state:
await self.redis.set(
_substate_key(client_token, state),
prefix_redis_token(_substate_key(client_token, state)),
pickle_state,
ex=self.token_expiration,
)
Expand Down Expand Up @@ -3347,7 +3368,7 @@ async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
True if the lock was obtained.
"""
return await self.redis.set(
lock_key,
prefix_redis_token(lock_key),
lock_id,
px=self.lock_expiration,
nx=True, # only set if it doesn't exist
Expand Down Expand Up @@ -3382,7 +3403,7 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
while not state_is_locked:
# wait for the lock to be released
while True:
if not await self.redis.exists(lock_key):
if not await self.redis.exists(prefix_redis_token(lock_key)):
break # key was removed, try to get the lock again
message = await pubsub.get_message(
ignore_subscribe_messages=True,
Expand Down Expand Up @@ -3423,7 +3444,7 @@ async def _lock(self, token: str):
finally:
if state_is_locked:
# only delete our lock
await self.redis.delete(lock_key)
await self.redis.delete(prefix_redis_token(lock_key))

async def close(self):
"""Explicitly close the redis connection and connection_pool.
Expand Down
50 changes: 45 additions & 5 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
import threading
from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock

import pytest
Expand Down Expand Up @@ -40,6 +40,7 @@
StateProxy,
StateUpdate,
_substate_key,
prefix_redis_token,
)
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
Expand Down Expand Up @@ -1671,7 +1672,7 @@ async def test_state_manager_modify_state(
"""
async with state_manager.modify_state(substate_token) as state:
if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock")
assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked()
Expand All @@ -1681,7 +1682,9 @@ async def test_state_manager_modify_state(
state.complex[3] = complex_1
# lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
assert (
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert not state_manager._states_locks[token].locked()

Expand Down Expand Up @@ -1723,7 +1726,9 @@ async def _coro():
assert (await state_manager.get_state(substate_token)).num1 == exp_num1

if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
assert (
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert not state_manager._states_locks[token].locked()
Expand Down Expand Up @@ -1783,7 +1788,7 @@ async def test_state_manager_lock_expire(

@pytest.mark.asyncio
async def test_state_manager_lock_expire_contend(
state_manager_redis: StateManager, token: str, substate_token_redis: str
state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
):
"""Test that the state manager lock expires and queued waiters proceed.

Expand Down Expand Up @@ -1825,6 +1830,41 @@ async def _coro_waiter():
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1


@pytest.fixture(scope="function")
def redis_prefix() -> Generator[str, None, None]:
"""Fixture for redis prefix.

Yields:
A redis prefix.
"""
prefix = "test_prefix"
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix)
yield prefix
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(None)


@pytest.mark.asyncio
async def test_state_manager_redis_prefix(
state_manager_redis: StateManagerRedis,
substate_token_redis: str,
redis_prefix: str,
):
"""Test that the state manager redis prefix is applied correctly.

Args:
state_manager_redis: A state manager instance.
substate_token_redis: A token + substate name for looking up in state manager.
redis_prefix: A redis prefix.
"""
async with state_manager_redis.modify_state(substate_token_redis) as state:
state.num1 = 42

prefixed_token = prefix_redis_token(substate_token_redis)
assert prefixed_token == f"{redis_prefix}{substate_token_redis}"

assert await state_manager_redis.redis.get(prefixed_token)


@pytest.fixture(scope="function")
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
"""Mock app fixture.
Expand Down
Loading