diff --git a/reflex/state.py b/reflex/state.py index 56c62c150e..66b1e3cabf 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -11,6 +11,7 @@ import json import pickle import sys +import typing import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -90,7 +91,13 @@ ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer -from reflex.utils.types import _isinstance, get_origin, override +from reflex.utils.types import ( + _isinstance, + get_origin, + is_union, + override, + value_inside_optional, +) from reflex.vars import VarData if TYPE_CHECKING: @@ -1713,6 +1720,35 @@ async def _process_event( # Get the function to process the event. fn = functools.partial(handler.fn, state) + try: + type_hints = typing.get_type_hints(handler.fn) + except Exception: + type_hints = {} + + for arg, value in list(payload.items()): + hinted_args = type_hints.get(arg, Any) + if hinted_args is Any: + continue + if is_union(hinted_args): + if value is None: + continue + hinted_args = value_inside_optional(hinted_args) + if ( + isinstance(value, dict) + and inspect.isclass(hinted_args) + and ( + dataclasses.is_dataclass(hinted_args) + or issubclass(hinted_args, Base) + ) + ): + payload[arg] = hinted_args(**value) + if isinstance(value, list) and (hinted_args is set or hinted_args is Set): + payload[arg] = set(value) + if isinstance(value, list) and ( + hinted_args is tuple or hinted_args is Tuple + ): + payload[arg] = tuple(value) + # Wrap the function in a try/except block. try: # Handle async functions.