diff --git a/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 b/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 index 210246992b..222524d2d6 100644 --- a/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 @@ -8,20 +8,6 @@ {% endfor %} export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => { -{% if component.name == "CodeBlock" and "language" in component.props %} - if (language) { - (async () => { - try { - const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${language}`); - SyntaxHighlighter.registerLanguage(language, module.default); - } catch (error) { - console.error(`Error importing language module for ${language}:`, error); - } - })(); - - - } -{% endif %} {% for hook in component.hooks %} {{ hook }} {% endfor %} diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index 53761284a7..9d5052df17 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -8,13 +8,14 @@ from reflex.components.component import Component, ComponentNamespace from reflex.components.core.cond import color_mode_cond from reflex.components.lucide.icon import Icon +from reflex.components.markdown.markdown import _LANGUAGE, MarkdownComponentMap from reflex.components.radix.themes.components.button import Button from reflex.components.radix.themes.layout.box import Box from reflex.constants.colors import Color from reflex.event import set_clipboard from reflex.style import Style from reflex.utils import console, format -from reflex.utils.imports import ImportDict, ImportVar +from reflex.utils.imports import ImportVar from reflex.vars.base import LiteralVar, Var, VarData LiteralCodeLanguage = Literal[ @@ -378,7 +379,7 @@ class Theme: setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme)) -class CodeBlock(Component): +class CodeBlock(Component, MarkdownComponentMap): """A code block.""" library = "react-syntax-highlighter@15.6.1" @@ -417,39 +418,6 @@ class CodeBlock(Component): # A custom copy button to override the default one. copy_button: Optional[Union[bool, Component]] = None - def add_imports(self) -> ImportDict: - """Add imports for the CodeBlock component. - - Returns: - The import dict. - """ - imports_: ImportDict = {} - - if ( - self.language is not None - and (language_without_quotes := str(self.language).replace('"', "")) - in LiteralCodeLanguage.__args__ # type: ignore - ): - imports_[ - f"react-syntax-highlighter/dist/cjs/languages/prism/{language_without_quotes}" - ] = [ - ImportVar( - tag=format.to_camel_case(language_without_quotes), - is_default=True, - install=False, - ) - ] - - return imports_ - - def _get_custom_code(self) -> Optional[str]: - if ( - self.language is not None - and (language_without_quotes := str(self.language).replace('"', "")) - in LiteralCodeLanguage.__args__ # type: ignore - ): - return f"{self.alias}.registerLanguage('{language_without_quotes}', {format.to_camel_case(language_without_quotes)})" - @classmethod def create( cls, @@ -534,8 +502,8 @@ def _render(self): theme = self.theme - out.add_props(style=theme).remove_props("theme", "code").add_props( - children=self.code + out.add_props(style=theme).remove_props("theme", "code", "language").add_props( + children=self.code, language=_LANGUAGE ) return out @@ -543,6 +511,46 @@ def _render(self): def _exclude_props(self) -> list[str]: return ["can_copy", "copy_button"] + @classmethod + def _get_language_registration_hook(cls) -> str: + """Get the hook to register the language. + + Returns: + The hook to register the language. + """ + return f""" + if ({str(_LANGUAGE)}) {{ + (async () => {{ + try {{ + const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{{str(_LANGUAGE)}}}`); + SyntaxHighlighter.registerLanguage({str(_LANGUAGE)}, module.default); + }} catch (error) {{ + console.error(`Error importing language module for ${{{str(_LANGUAGE)}}}:`, error); + }} + }})(); + }} +""" + + @classmethod + def get_component_map_custom_code(cls) -> str: + """Get the custom code for the component. + + Returns: + The custom code for the component. + """ + return cls._get_language_registration_hook() + + def add_hooks(self) -> list[str | Var]: + """Add hooks for the component. + + Returns: + The hooks for the component. + """ + return [ + f"const {str(_LANGUAGE)} = {str(self.language)}", + self._get_language_registration_hook(), + ] + class CodeblockNamespace(ComponentNamespace): """Namespace for the CodeBlock component.""" diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi index 765ae47f3d..da89195ce3 100644 --- a/reflex/components/datadisplay/code.pyi +++ b/reflex/components/datadisplay/code.pyi @@ -7,10 +7,10 @@ import dataclasses from typing import Any, ClassVar, Dict, Literal, Optional, Union, overload from reflex.components.component import Component, ComponentNamespace +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.constants.colors import Color from reflex.event import BASE_STATE, EventType from reflex.style import Style -from reflex.utils.imports import ImportDict from reflex.vars.base import Var LiteralCodeLanguage = Literal[ @@ -349,8 +349,7 @@ for theme_name in dir(Theme): continue setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme)) -class CodeBlock(Component): - def add_imports(self) -> ImportDict: ... +class CodeBlock(Component, MarkdownComponentMap): @overload @classmethod def create( # type: ignore @@ -984,6 +983,9 @@ class CodeBlock(Component): ... def add_style(self): ... + @classmethod + def get_component_map_custom_code(cls) -> str: ... + def add_hooks(self) -> list[str | Var]: ... class CodeblockNamespace(ComponentNamespace): themes = Theme diff --git a/reflex/components/datadisplay/shiki_code_block.py b/reflex/components/datadisplay/shiki_code_block.py index 4a3e05d0e3..2b4e1f5063 100644 --- a/reflex/components/datadisplay/shiki_code_block.py +++ b/reflex/components/datadisplay/shiki_code_block.py @@ -12,6 +12,7 @@ from reflex.components.core.cond import color_mode_cond from reflex.components.el.elements.forms import Button from reflex.components.lucide.icon import Icon +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.props import NoExtrasAllowedProps from reflex.components.radix.themes.layout.box import Box from reflex.event import run_script, set_clipboard @@ -528,7 +529,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) -class ShikiCodeBlock(Component): +class ShikiCodeBlock(Component, MarkdownComponentMap): """A Code block.""" library = "/components/shiki/code" diff --git a/reflex/components/datadisplay/shiki_code_block.pyi b/reflex/components/datadisplay/shiki_code_block.pyi index 2b8b0d3851..92546ee4fd 100644 --- a/reflex/components/datadisplay/shiki_code_block.pyi +++ b/reflex/components/datadisplay/shiki_code_block.pyi @@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.base import Base from reflex.components.component import Component, ComponentNamespace +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.props import NoExtrasAllowedProps from reflex.event import BASE_STATE, EventType from reflex.style import Style @@ -350,7 +351,7 @@ class ShikiJsTransformer(ShikiBaseTransformers): fns: list[FunctionStringVar] style: Optional[Style] -class ShikiCodeBlock(Component): +class ShikiCodeBlock(Component, MarkdownComponentMap): @overload @classmethod def create( # type: ignore diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index b790bf7a12..376cb8bd67 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -2,25 +2,18 @@ from __future__ import annotations +import dataclasses import textwrap from functools import lru_cache from hashlib import md5 -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Sequence, Union from reflex.components.component import Component, CustomComponent -from reflex.components.radix.themes.layout.list import ( - ListItem, - OrderedList, - UnorderedList, -) -from reflex.components.radix.themes.typography.heading import Heading -from reflex.components.radix.themes.typography.link import Link -from reflex.components.radix.themes.typography.text import Text from reflex.components.tags.tag import Tag from reflex.utils import types from reflex.utils.imports import ImportDict, ImportVar from reflex.vars.base import LiteralVar, Var -from reflex.vars.function import ARRAY_ISARRAY +from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg from reflex.vars.number import ternary_operation # Special vars used in the component map. @@ -28,6 +21,7 @@ _PROPS = Var(_js_expr="...props") _PROPS_IN_TAG = Var(_js_expr="{...props}") _MOCK_ARG = Var(_js_expr="", _var_type=str) +_LANGUAGE = Var(_js_expr="_language", _var_type=str) # Special remark plugins. _REMARK_MATH = Var(_js_expr="remarkMath") @@ -53,7 +47,15 @@ def get_base_component_map() -> dict[str, Callable]: The base component map. """ from reflex.components.datadisplay.code import CodeBlock + from reflex.components.radix.themes.layout.list import ( + ListItem, + OrderedList, + UnorderedList, + ) from reflex.components.radix.themes.typography.code import Code + from reflex.components.radix.themes.typography.heading import Heading + from reflex.components.radix.themes.typography.link import Link + from reflex.components.radix.themes.typography.text import Text return { "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"), @@ -74,6 +76,67 @@ def get_base_component_map() -> dict[str, Callable]: } +@dataclasses.dataclass() +class MarkdownComponentMap: + """Mixin class for handling custom component maps in Markdown components.""" + + _explicit_return: bool = dataclasses.field(default=False) + + @classmethod + def get_component_map_custom_code(cls) -> str: + """Get the custom code for the component map. + + Returns: + The custom code for the component map. + """ + return "" + + @classmethod + def create_map_fn_var( + cls, + fn_body: Var | None = None, + fn_args: Sequence[str] | None = None, + explicit_return: bool | None = None, + ) -> Var: + """Create a function Var for the component map. + + Args: + fn_body: The formatted component as a string. + fn_args: The function arguments. + explicit_return: Whether to use explicit return syntax. + + Returns: + The function Var for the component map. + """ + fn_args = fn_args or cls.get_fn_args() + fn_body = fn_body if fn_body is not None else cls.get_fn_body() + explicit_return = explicit_return or cls._explicit_return + + return ArgsFunctionOperation.create( + args_names=(DestructuredArg(fields=tuple(fn_args)),), + return_expr=fn_body, + explicit_return=explicit_return, + ) + + @classmethod + def get_fn_args(cls) -> Sequence[str]: + """Get the function arguments for the component map. + + Returns: + The function arguments as a list of strings. + """ + return ["node", _CHILDREN._js_expr, _PROPS._js_expr] + + @classmethod + def get_fn_body(cls) -> Var: + """Get the function body for the component map. + + Returns: + The function body as a string. + """ + return Var(_js_expr="undefined", _var_type=None) + + class Markdown(Component): """A markdown component.""" @@ -153,9 +216,6 @@ def add_imports(self) -> ImportDict | list[ImportDict]: Returns: The imports for the markdown component. """ - from reflex.components.datadisplay.code import CodeBlock, Theme - from reflex.components.radix.themes.typography.code import Code - return [ { "": "katex/dist/katex.min.css", @@ -179,10 +239,71 @@ def add_imports(self) -> ImportDict | list[ImportDict]: component(_MOCK_ARG)._get_all_imports() # type: ignore for component in self.component_map.values() ], - CodeBlock.create(theme=Theme.light)._get_imports(), - Code.create()._get_imports(), ] + def _get_tag_map_fn_var(self, tag: str) -> Var: + return self._get_map_fn_var_from_children(self.get_component(tag), tag) + + def format_component_map(self) -> dict[str, Var]: + """Format the component map for rendering. + + Returns: + The formatted component map. + """ + components = { + tag: self._get_tag_map_fn_var(tag) + for tag in self.component_map + if tag not in ("code", "codeblock") + } + + # Separate out inline code and code blocks. + components["code"] = self._get_inline_code_fn_var() + + return components + + def _get_inline_code_fn_var(self) -> Var: + """Get the function variable for inline code. + + This function creates a Var that represents a function to handle + both inline code and code blocks in markdown. + + Returns: + The Var for inline code. + """ + # Get any custom code from the codeblock and code components. + custom_code_list = self._get_map_fn_custom_code_from_children( + self.get_component("codeblock") + ) + custom_code_list.extend( + self._get_map_fn_custom_code_from_children(self.get_component("code")) + ) + + codeblock_custom_code = "\n".join(custom_code_list) + + # Format the code to handle inline and block code. + formatted_code = f""" +const match = (className || '').match(/language-(?.*)/); +const {str(_LANGUAGE)} = match ? match[1] : ''; +{codeblock_custom_code}; + return inline ? ( + {self.format_component("code")} + ) : ( + {self.format_component("codeblock", language=_LANGUAGE)} + ); + """.replace("\n", " ") + + return MarkdownComponentMap.create_map_fn_var( + fn_args=( + "node", + "inline", + "className", + _CHILDREN._js_expr, + _PROPS._js_expr, + ), + fn_body=Var(_js_expr=formatted_code), + explicit_return=True, + ) + def get_component(self, tag: str, **props) -> Component: """Get the component for a tag and props. @@ -239,43 +360,53 @@ def format_component(self, tag: str, **props) -> str: """ return str(self.get_component(tag, **props)).replace("\n", "") - def format_component_map(self) -> dict[str, Var]: - """Format the component map for rendering. + def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var: + """Create a function Var for the component map for the specified tag. + + Args: + component: The component to check for custom code. + tag: The tag of the component. Returns: - The formatted component map. + The function Var for the component map. """ - components = { - tag: Var( - _js_expr=f"(({{node, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => ({self.format_component(tag)}))" - ) - for tag in self.component_map - } - - # Separate out inline code and code blocks. - components["code"] = Var( - _js_expr=f"""(({{node, inline, className, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => {{ - const match = (className || '').match(/language-(?.*)/); - const language = match ? match[1] : ''; - if (language) {{ - (async () => {{ - try {{ - const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{language}}`); - SyntaxHighlighter.registerLanguage(language, module.default); - }} catch (error) {{ - console.error(`Error importing language module for ${{language}}:`, error); - }} - }})(); - }} - return inline ? ( - {self.format_component("code")} - ) : ( - {self.format_component("codeblock", language=Var(_js_expr="language", _var_type=str))} - ); - }})""".replace("\n", " ") + formatted_component = Var( + _js_expr=f"({self.format_component(tag)})", _var_type=str ) + if isinstance(component, MarkdownComponentMap): + return component.create_map_fn_var(fn_body=formatted_component) - return components + # fallback to the default fn Var creation if the component is not a MarkdownComponentMap. + return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component) + + def _get_map_fn_custom_code_from_children(self, component) -> list[str]: + """Recursively get markdown custom code from children components. + + Args: + component: The component to check for custom code. + + Returns: + A list of markdown custom code strings. + """ + custom_code_list = [] + if isinstance(component, MarkdownComponentMap): + custom_code_list.append(component.get_component_map_custom_code()) + + # If the component is a custom component(rx.memo), obtain the underlining + # component and get the custom code from the children. + if isinstance(component, CustomComponent): + custom_code_list.extend( + self._get_map_fn_custom_code_from_children( + component.component_fn(*component.get_prop_vars()) + ) + ) + elif isinstance(component, Component): + for child in component.children: + custom_code_list.extend( + self._get_map_fn_custom_code_from_children(child) + ) + + return custom_code_list @staticmethod def _component_map_hash(component_map) -> str: @@ -288,12 +419,12 @@ def _get_component_map_name(self) -> str: return f"ComponentMap_{self.component_map_hash}" def _get_custom_code(self) -> str | None: - hooks = set() + hooks = {} for _component in self.component_map.values(): comp = _component(_MOCK_ARG) hooks.update(comp._get_all_hooks_internal()) hooks.update(comp._get_all_hooks()) - formatted_hooks = "\n".join(hooks) + formatted_hooks = "\n".join(hooks.keys()) return f""" function {self._get_component_map_name()} () {{ {formatted_hooks} diff --git a/reflex/components/markdown/markdown.pyi b/reflex/components/markdown/markdown.pyi index 9878e6181d..1c329fb8cb 100644 --- a/reflex/components/markdown/markdown.pyi +++ b/reflex/components/markdown/markdown.pyi @@ -3,8 +3,9 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ +import dataclasses from functools import lru_cache -from typing import Any, Callable, Dict, Optional, Union, overload +from typing import Any, Callable, Dict, Optional, Sequence, Union, overload from reflex.components.component import Component from reflex.event import BASE_STATE, EventType @@ -16,6 +17,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str) _PROPS = Var(_js_expr="...props") _PROPS_IN_TAG = Var(_js_expr="{...props}") _MOCK_ARG = Var(_js_expr="", _var_type=str) +_LANGUAGE = Var(_js_expr="_language", _var_type=str) _REMARK_MATH = Var(_js_expr="remarkMath") _REMARK_GFM = Var(_js_expr="remarkGfm") _REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages") @@ -27,6 +29,21 @@ NO_PROPS_TAGS = ("ul", "ol", "li") @lru_cache def get_base_component_map() -> dict[str, Callable]: ... +@dataclasses.dataclass() +class MarkdownComponentMap: + @classmethod + def get_component_map_custom_code(cls) -> str: ... + @classmethod + def create_map_fn_var( + cls, + fn_body: Var | None = None, + fn_args: Sequence[str] | None = None, + explicit_return: bool | None = None, + ) -> Var: ... + @classmethod + def get_fn_args(cls) -> Sequence[str]: ... + @classmethod + def get_fn_body(cls) -> Var: ... class Markdown(Component): @overload @@ -82,6 +99,6 @@ class Markdown(Component): ... def add_imports(self) -> ImportDict | list[ImportDict]: ... + def format_component_map(self) -> dict[str, Var]: ... def get_component(self, tag: str, **props) -> Component: ... def format_component(self, tag: str, **props) -> str: ... - def format_component_map(self) -> dict[str, Var]: ... diff --git a/reflex/components/radix/themes/layout/list.py b/reflex/components/radix/themes/layout/list.py index d83fd168b1..96fa169a0f 100644 --- a/reflex/components/radix/themes/layout/list.py +++ b/reflex/components/radix/themes/layout/list.py @@ -8,6 +8,7 @@ from reflex.components.core.foreach import Foreach from reflex.components.el.elements.typography import Li, Ol, Ul from reflex.components.lucide.icon import Icon +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.radix.themes.typography.text import Text from reflex.vars.base import Var @@ -36,7 +37,7 @@ ] -class BaseList(Component): +class BaseList(Component, MarkdownComponentMap): """Base class for ordered and unordered lists.""" tag = "ul" @@ -154,7 +155,7 @@ def create( ) -class ListItem(Li): +class ListItem(Li, MarkdownComponentMap): """Display an item of an ordered or unordered list.""" @classmethod diff --git a/reflex/components/radix/themes/layout/list.pyi b/reflex/components/radix/themes/layout/list.pyi index 9c983c2ffd..b42f689b9e 100644 --- a/reflex/components/radix/themes/layout/list.pyi +++ b/reflex/components/radix/themes/layout/list.pyi @@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, Literal, Optional, Union, overload from reflex.components.component import Component, ComponentNamespace from reflex.components.el.elements.typography import Li, Ol, Ul +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.event import BASE_STATE, EventType from reflex.style import Style from reflex.vars.base import Var @@ -29,7 +30,7 @@ LiteralListStyleTypeOrdered = Literal[ "katakana", ] -class BaseList(Component): +class BaseList(Component, MarkdownComponentMap): @overload @classmethod def create( # type: ignore @@ -393,7 +394,7 @@ class OrderedList(BaseList, Ol): """ ... -class ListItem(Li): +class ListItem(Li, MarkdownComponentMap): @overload @classmethod def create( # type: ignore diff --git a/reflex/components/radix/themes/typography/code.py b/reflex/components/radix/themes/typography/code.py index ca19859d32..ab610b5053 100644 --- a/reflex/components/radix/themes/typography/code.py +++ b/reflex/components/radix/themes/typography/code.py @@ -7,13 +7,14 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.vars.base import Var from ..base import LiteralAccentColor, LiteralVariant, RadixThemesComponent from .base import LiteralTextSize, LiteralTextWeight -class Code(elements.Code, RadixThemesComponent): +class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap): """A block level extended quotation.""" tag = "Code" diff --git a/reflex/components/radix/themes/typography/code.pyi b/reflex/components/radix/themes/typography/code.pyi index a211b97c47..0276eb9822 100644 --- a/reflex/components/radix/themes/typography/code.pyi +++ b/reflex/components/radix/themes/typography/code.pyi @@ -7,13 +7,14 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.components.core.breakpoints import Breakpoints from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.event import BASE_STATE, EventType from reflex.style import Style from reflex.vars.base import Var from ..base import RadixThemesComponent -class Code(elements.Code, RadixThemesComponent): +class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap): @overload @classmethod def create( # type: ignore diff --git a/reflex/components/radix/themes/typography/heading.py b/reflex/components/radix/themes/typography/heading.py index 03e1097176..ce1eaa68f5 100644 --- a/reflex/components/radix/themes/typography/heading.py +++ b/reflex/components/radix/themes/typography/heading.py @@ -7,13 +7,14 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight -class Heading(elements.H1, RadixThemesComponent): +class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap): """A foundational text primitive based on the element.""" tag = "Heading" diff --git a/reflex/components/radix/themes/typography/heading.pyi b/reflex/components/radix/themes/typography/heading.pyi index a58b4ebf42..b5cb5c9d3b 100644 --- a/reflex/components/radix/themes/typography/heading.pyi +++ b/reflex/components/radix/themes/typography/heading.pyi @@ -7,13 +7,14 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.components.core.breakpoints import Breakpoints from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.event import BASE_STATE, EventType from reflex.style import Style from reflex.vars.base import Var from ..base import RadixThemesComponent -class Heading(elements.H1, RadixThemesComponent): +class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap): @overload @classmethod def create( # type: ignore diff --git a/reflex/components/radix/themes/typography/link.py b/reflex/components/radix/themes/typography/link.py index 6e3d2f9834..1cc6735362 100644 --- a/reflex/components/radix/themes/typography/link.py +++ b/reflex/components/radix/themes/typography/link.py @@ -12,6 +12,7 @@ from reflex.components.core.colors import color from reflex.components.core.cond import cond from reflex.components.el.elements.inline import A +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.next.link import NextLink from reflex.utils.imports import ImportDict from reflex.vars.base import Var @@ -24,7 +25,7 @@ next_link = NextLink.create() -class Link(RadixThemesComponent, A, MemoizationLeaf): +class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap): """A semantic element for navigation between pages.""" tag = "Link" diff --git a/reflex/components/radix/themes/typography/link.pyi b/reflex/components/radix/themes/typography/link.pyi index 8e3cfb959c..db963c6dfe 100644 --- a/reflex/components/radix/themes/typography/link.pyi +++ b/reflex/components/radix/themes/typography/link.pyi @@ -8,6 +8,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.components.component import MemoizationLeaf from reflex.components.core.breakpoints import Breakpoints from reflex.components.el.elements.inline import A +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.next.link import NextLink from reflex.event import BASE_STATE, EventType from reflex.style import Style @@ -19,7 +20,7 @@ from ..base import RadixThemesComponent LiteralLinkUnderline = Literal["auto", "hover", "always", "none"] next_link = NextLink.create() -class Link(RadixThemesComponent, A, MemoizationLeaf): +class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap): def add_imports(self) -> ImportDict: ... @overload @classmethod diff --git a/reflex/components/radix/themes/typography/text.py b/reflex/components/radix/themes/typography/text.py index e3576360a2..1663ddedfa 100644 --- a/reflex/components/radix/themes/typography/text.py +++ b/reflex/components/radix/themes/typography/text.py @@ -10,6 +10,7 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent @@ -37,7 +38,7 @@ ] -class Text(elements.Span, RadixThemesComponent): +class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap): """A foundational text primitive based on the element.""" tag = "Text" diff --git a/reflex/components/radix/themes/typography/text.pyi b/reflex/components/radix/themes/typography/text.pyi index a5d023d37e..824348b428 100644 --- a/reflex/components/radix/themes/typography/text.pyi +++ b/reflex/components/radix/themes/typography/text.pyi @@ -8,6 +8,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Breakpoints from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.event import BASE_STATE, EventType from reflex.style import Style from reflex.vars.base import Var @@ -35,7 +36,7 @@ LiteralType = Literal[ "sup", ] -class Text(elements.Span, RadixThemesComponent): +class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap): @overload @classmethod def create( # type: ignore diff --git a/reflex/event.py b/reflex/event.py index e51d1cc073..a64d4d6c18 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -45,6 +45,7 @@ from reflex.vars.base import LiteralVar, Var from reflex.vars.function import ( ArgsFunctionOperation, + FunctionArgs, FunctionStringVar, FunctionVar, VarOperationCall, @@ -1643,7 +1644,7 @@ def create( _js_expr="", _var_type=EventChain, _var_data=_var_data, - _args_names=arg_def, + _args=FunctionArgs(arg_def), _return_expr=invocation.call( LiteralVar.create([LiteralVar.create(event) for event in value.events]), arg_def_expr, diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 49ef996149..98f3b23358 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -4,8 +4,9 @@ import dataclasses import sys -from typing import Any, Callable, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union +from reflex.utils import format from reflex.utils.types import GenericType from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock @@ -126,6 +127,36 @@ def create( ) +@dataclasses.dataclass(frozen=True) +class DestructuredArg: + """Class for destructured arguments.""" + + fields: Tuple[str, ...] = tuple() + rest: Optional[str] = None + + def to_javascript(self) -> str: + """Convert the destructured argument to JavaScript. + + Returns: + The destructured argument in JavaScript. + """ + return format.wrap( + ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""), + "{", + "}", + ) + + +@dataclasses.dataclass( + frozen=True, +) +class FunctionArgs: + """Class for function arguments.""" + + args: Tuple[Union[str, DestructuredArg], ...] = tuple() + rest: Optional[str] = None + + @dataclasses.dataclass( eq=False, frozen=True, @@ -134,8 +165,9 @@ def create( class ArgsFunctionOperation(CachedVarOperation, FunctionVar): """Base class for immutable function defined via arguments and return expression.""" - _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _return_expr: Union[Var, Any] = dataclasses.field(default=None) + _explicit_return: bool = dataclasses.field(default=False) @cached_property_no_lock def _cached_var_name(self) -> str: @@ -144,13 +176,31 @@ def _cached_var_name(self) -> str: Returns: The name of the var. """ - return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" + arg_names_str = ", ".join( + [ + arg if isinstance(arg, str) else arg.to_javascript() + for arg in self._args.args + ] + ) + (f", ...{self._args.rest}" if self._args.rest else "") + + return_expr_str = str(LiteralVar.create(self._return_expr)) + + # Wrap return expression in curly braces if explicit return syntax is used. + return_expr_str_wrapped = ( + format.wrap(return_expr_str, "{", "}") + if self._explicit_return + else return_expr_str + ) + + return f"(({arg_names_str}) => {return_expr_str_wrapped})" @classmethod def create( cls, - args_names: Tuple[str, ...], + args_names: Sequence[Union[str, DestructuredArg]], return_expr: Var | Any, + rest: str | None = None, + explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, ) -> ArgsFunctionOperation: @@ -159,6 +209,8 @@ def create( Args: args_names: The names of the arguments. return_expr: The return expression of the function. + rest: The name of the rest argument. + explicit_return: Whether to use explicit return syntax. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -168,8 +220,9 @@ def create( _js_expr="", _var_type=_var_type, _var_data=_var_data, - _args_names=args_names, + _args=FunctionArgs(args=tuple(args_names), rest=rest), _return_expr=return_expr, + _explicit_return=explicit_return, ) diff --git a/tests/units/components/base/test_script.py b/tests/units/components/base/test_script.py index b909b6c617..e9c40188ba 100644 --- a/tests/units/components/base/test_script.py +++ b/tests/units/components/base/test_script.py @@ -62,14 +62,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }})))))}}' + f'onReady={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }}))))}}' in render_dict["props"] ) assert ( - f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }})))))}}' + f'onLoad={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }}))))}}' in render_dict["props"] ) assert ( - f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }})))))}}' + f'onError={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }}))))}}' in render_dict["props"] ) diff --git a/tests/units/components/datadisplay/test_code.py b/tests/units/components/datadisplay/test_code.py index 809c68fe56..6b7168756b 100644 --- a/tests/units/components/datadisplay/test_code.py +++ b/tests/units/components/datadisplay/test_code.py @@ -11,22 +11,3 @@ def test_code_light_dark_theme(theme, expected): code_block = CodeBlock.create(theme=theme) assert code_block.theme._js_expr == expected # type: ignore - - -def generate_custom_code(language, expected_case): - return f"SyntaxHighlighter.registerLanguage('{language}', {expected_case})" - - -@pytest.mark.parametrize( - "language, expected_case", - [ - ("python", "python"), - ("firestore-security-rules", "firestoreSecurityRules"), - ("typescript", "typescript"), - ], -) -def test_get_custom_code(language, expected_case): - code_block = CodeBlock.create(language=language) - assert code_block._get_custom_code() == generate_custom_code( - language, expected_case - ) diff --git a/tests/units/components/markdown/__init__.py b/tests/units/components/markdown/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py new file mode 100644 index 0000000000..866f32ae14 --- /dev/null +++ b/tests/units/components/markdown/test_markdown.py @@ -0,0 +1,190 @@ +from typing import Type + +import pytest + +from reflex.components.component import Component, memo +from reflex.components.datadisplay.code import CodeBlock +from reflex.components.datadisplay.shiki_code_block import ShikiHighLevelCodeBlock +from reflex.components.markdown.markdown import Markdown, MarkdownComponentMap +from reflex.components.radix.themes.layout.box import Box +from reflex.components.radix.themes.typography.heading import Heading +from reflex.vars.base import Var + + +class CustomMarkdownComponent(Component, MarkdownComponentMap): + """A custom markdown component.""" + + tag = "CustomMarkdownComponent" + library = "custom" + + @classmethod + def get_fn_args(cls) -> tuple[str, ...]: + """Return the function arguments. + + Returns: + The function arguments. + """ + return ("custom_node", "custom_children", "custom_props") + + @classmethod + def get_fn_body(cls) -> Var: + """Return the function body. + + Returns: + The function body. + """ + return Var(_js_expr="{return custom_node + custom_children + custom_props}") + + +def syntax_highlighter_memoized_component(codeblock: Type[Component]): + @memo + def code_block(code: str, language: str): + return Box.create( + codeblock.create( + code, + language=language, + class_name="code-block", + can_copy=True, + ), + class_name="relative mb-4", + ) + + def code_block_markdown(*children, **props): + return code_block( + code=children[0], language=props.pop("language", "plain"), **props + ) + + return code_block_markdown + + +@pytest.mark.parametrize( + "fn_body, fn_args, explicit_return, expected", + [ + ( + None, + None, + False, + Var(_js_expr="(({node, children, ...props}) => undefined)"), + ), + ("return node", ("node",), True, Var(_js_expr="(({node}) => {return node})")), + ( + "return node + children", + ("node", "children"), + True, + Var(_js_expr="(({node, children}) => {return node + children})"), + ), + ( + "return node + props", + ("node", "...props"), + True, + Var(_js_expr="(({node, ...props}) => {return node + props})"), + ), + ( + "return node + children + props", + ("node", "children", "...props"), + True, + Var( + _js_expr="(({node, children, ...props}) => {return node + children + props})" + ), + ), + ], +) +def test_create_map_fn_var(fn_body, fn_args, explicit_return, expected): + result = MarkdownComponentMap.create_map_fn_var( + fn_body=Var(_js_expr=fn_body, _var_type=str) if fn_body else None, + fn_args=fn_args, + explicit_return=explicit_return, + ) + assert result._js_expr == expected._js_expr + + +@pytest.mark.parametrize( + ("cls", "fn_body", "fn_args", "explicit_return", "expected"), + [ + ( + MarkdownComponentMap, + None, + None, + False, + Var(_js_expr="(({node, children, ...props}) => undefined)"), + ), + ( + MarkdownComponentMap, + "return node", + ("node",), + True, + Var(_js_expr="(({node}) => {return node})"), + ), + ( + CustomMarkdownComponent, + None, + None, + True, + Var( + _js_expr="(({custom_node, custom_children, custom_props}) => {return custom_node + custom_children + custom_props})" + ), + ), + ( + CustomMarkdownComponent, + "return custom_node", + ("custom_node",), + True, + Var(_js_expr="(({custom_node}) => {return custom_node})"), + ), + ], +) +def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expected): + result = cls.create_map_fn_var( + fn_body=Var(_js_expr=fn_body, _var_type=int) if fn_body else None, + fn_args=fn_args, + explicit_return=explicit_return, + ) + assert result._js_expr == expected._js_expr + + +@pytest.mark.parametrize( + "key,component_map, expected", + [ + ( + "code", + {}, + """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""", + ), + ( + "code", + { + "codeblock": lambda value, **props: ShikiHighLevelCodeBlock.create( + value, **props + ) + }, + """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""", + ), + ( + "h1", + { + "h1": lambda value: CustomMarkdownComponent.create( + Heading.create(value, as_="h1", size="6", margin_y="0.5em") + ) + }, + """(({custom_node, custom_children, custom_props}) => ({children}))""", + ), + ( + "code", + {"codeblock": syntax_highlighter_memoized_component(CodeBlock)}, + """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""", + ), + ( + "code", + { + "codeblock": syntax_highlighter_memoized_component( + ShikiHighLevelCodeBlock + ) + }, + """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""", + ), + ], +) +def test_markdown_format_component(key, component_map, expected): + markdown = Markdown.create("# header", component_map=component_map) + result = markdown.format_component_map() + assert str(result[key]) == expected diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index e4744b9fb9..a2485d10eb 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -844,9 +844,9 @@ def get_event_triggers(self) -> Dict[str, Any]: comp = C1.create(on_foo=C1State.mock_handler) assert comp.render()["props"][0] == ( - "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents(" + "onFoo={((__e, _alpha, _bravo, _charlie) => (addEvents(" f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], ' - "[__e, _alpha, _bravo, _charlie], ({ })))))}" + "[__e, _alpha, _bravo, _charlie], ({ }))))}" ) diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 5e26da5d8e..f17b3c4e40 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -222,16 +222,16 @@ def test_event_console_log(): assert spec.handler.fn.__qualname__ == "_call_function" assert spec.args[0][0].equals(Var(_js_expr="function")) assert spec.args[0][1].equals( - Var('(() => ((console["log"]("message"))))', _var_type=Callable) + Var('(() => (console["log"]("message")))', _var_type=Callable) ) assert ( format.format_event(spec) - == 'Event("_call_function", {function:(() => ((console["log"]("message"))))})' + == 'Event("_call_function", {function:(() => (console["log"]("message")))})' ) spec = event.console_log(Var(_js_expr="message")) assert ( format.format_event(spec) - == 'Event("_call_function", {function:(() => ((console["log"](message))))})' + == 'Event("_call_function", {function:(() => (console["log"](message)))})' ) @@ -242,16 +242,16 @@ def test_event_window_alert(): assert spec.handler.fn.__qualname__ == "_call_function" assert spec.args[0][0].equals(Var(_js_expr="function")) assert spec.args[0][1].equals( - Var('(() => ((window["alert"]("message"))))', _var_type=Callable) + Var('(() => (window["alert"]("message")))', _var_type=Callable) ) assert ( format.format_event(spec) - == 'Event("_call_function", {function:(() => ((window["alert"]("message"))))})' + == 'Event("_call_function", {function:(() => (window["alert"]("message")))})' ) spec = event.window_alert(Var(_js_expr="message")) assert ( format.format_event(spec) - == 'Event("_call_function", {function:(() => ((window["alert"](message))))})' + == 'Event("_call_function", {function:(() => (window["alert"](message)))})' ) diff --git a/tests/units/test_var.py b/tests/units/test_var.py index e9fa40faba..5944739213 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -22,7 +22,11 @@ var_operation, var_operation_return, ) -from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar +from reflex.vars.function import ( + ArgsFunctionOperation, + DestructuredArg, + FunctionStringVar, +) from reflex.vars.number import LiteralBooleanVar, LiteralNumberVar, NumberVar from reflex.vars.object import LiteralObjectVar, ObjectVar from reflex.vars.sequence import ( @@ -921,13 +925,13 @@ def test_function_var(): ) assert ( str(manual_addition_func.call(1, 2)) - == '(((a, b) => (({ ["args"] : [a, b], ["result"] : a + b })))(1, 2))' + == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' ) increment_func = addition_func(1) assert ( str(increment_func.call(2)) - == "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))" + == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))" ) create_hello_statement = ArgsFunctionOperation.create( @@ -937,8 +941,24 @@ def test_function_var(): last_name = LiteralStringVar.create("Universe") assert ( str(create_hello_statement.call(f"{first_name} {last_name}")) - == '(((name) => (("Hello, "+name+"!")))("Steven Universe"))' + == '(((name) => ("Hello, "+name+"!"))("Steven Universe"))' + ) + + # Test with destructured arguments + destructured_func = ArgsFunctionOperation.create( + (DestructuredArg(fields=("a", "b")),), + Var(_js_expr="a + b"), + ) + assert ( + str(destructured_func.call({"a": 1, "b": 2})) + == '((({a, b}) => a + b)(({ ["a"] : 1, ["b"] : 2 })))' + ) + + # Test with explicit return + explicit_return_func = ArgsFunctionOperation.create( + ("a", "b"), Var(_js_expr="return a + b"), explicit_return=True ) + assert str(explicit_return_func.call(1, 2)) == "(((a, b) => {return a + b})(1, 2))" def test_var_operation(): diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index f8b6055413..cd1d0179db 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -374,7 +374,7 @@ def test_format_match( events=[EventSpec(handler=EventHandler(fn=mock_event))], args_spec=lambda: [], ), - '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ })))))', + '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ }))))', ), ( EventChain( @@ -395,7 +395,7 @@ def test_format_match( ], args_spec=lambda e: [e.target.value], ), - '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ })))))', + '((_e) => (addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ }))))', ), ( EventChain( @@ -403,7 +403,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"stopPropagation": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true })))))', + '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true }))))', ), ( EventChain( @@ -415,7 +415,7 @@ def test_format_match( ], args_spec=lambda: [], ), - '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ })))))', + '((...args) => (addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ }))))', ), ( EventChain( @@ -423,7 +423,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"preventDefault": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true })))))', + '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true }))))', ), ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'), (Var(_js_expr="var", _var_type=int).guess_type(), "var"),