From 429ebcc6f8b451d2ade00891e98d08e1f77e3692 Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Sat, 25 May 2024 15:27:18 +0900 Subject: [PATCH] Update stlite-server --- .../component_request_handler.py | 15 ++- .../py/stlite-server/stlite_server/server.py | 3 +- .../py/stlite-server/tests/test_server.py | 4 + .../kernel/py/stlite-server/tests/testutil.py | 92 +++++++++++++++++++ 4 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 packages/kernel/py/stlite-server/tests/testutil.py diff --git a/packages/kernel/py/stlite-server/stlite_server/component_request_handler.py b/packages/kernel/py/stlite-server/stlite_server/component_request_handler.py index f995303f7..7a0939a6d 100644 --- a/packages/kernel/py/stlite-server/stlite_server/component_request_handler.py +++ b/packages/kernel/py/stlite-server/stlite_server/component_request_handler.py @@ -1,17 +1,18 @@ import logging import mimetypes import os +from typing import Final -from streamlit.components.v1.components import ComponentRegistry +from streamlit.components.types.base_component_registry import BaseComponentRegistry from .handler import Request, RequestHandler, Response -_LOGGER = logging.getLogger(__name__) +_LOGGER: Final = logging.getLogger(__name__) # Mimic streamlit.web.server.component_request_handler.ComponentRequestHandler class ComponentRequestHandler(RequestHandler): - def __init__(self, registry: ComponentRegistry) -> None: + def __init__(self, registry: BaseComponentRegistry) -> None: self._registry = registry def get(self, request: Request, path: str) -> Response: # type: ignore[override] @@ -27,11 +28,7 @@ def get(self, request: Request, path: str) -> Response: # type: ignore[override abspath = os.path.realpath(os.path.join(component_root, filename)) # Do NOT expose anything outside of the component root. - if os.path.commonprefix([component_root, abspath]) != component_root or ( - not os.path.normpath(abspath).startswith( - component_root - ) # this is a recommendation from CodeQL, probably a bit redundant - ): + if os.path.commonpath([component_root, abspath]) != component_root: return Response(status_code=403, headers={}, body="forbidden") try: @@ -54,7 +51,7 @@ def get(self, request: Request, path: str) -> Response: # type: ignore[override ) @staticmethod - def get_content_type(abspath) -> str: + def get_content_type(abspath: str) -> str: """Returns the ``Content-Type`` header to be used for this request. From tornado.web.StaticFileHandler. """ diff --git a/packages/kernel/py/stlite-server/stlite_server/server.py b/packages/kernel/py/stlite-server/stlite_server/server.py index e6d89600e..3d715e740 100644 --- a/packages/kernel/py/stlite-server/stlite_server/server.py +++ b/packages/kernel/py/stlite-server/stlite_server/server.py @@ -5,7 +5,6 @@ from typing import Callable, Final, cast import pyodide -from streamlit.components.v1.components import ComponentRegistry from streamlit.proto.BackMsg_pb2 import BackMsg from streamlit.proto.ForwardMsg_pb2 import ForwardMsg from streamlit.runtime import Runtime, RuntimeConfig, SessionClient @@ -90,7 +89,7 @@ async def start(self) -> None: ), ( make_url_path_regex(base, "component/(.*)"), - ComponentRequestHandler(registry=ComponentRegistry.instance()), + ComponentRequestHandler(registry=self._runtime.component_registry), ), ] self._routes = [(re.compile(pattern), handler) for (pattern, handler) in routes] diff --git a/packages/kernel/py/stlite-server/tests/test_server.py b/packages/kernel/py/stlite-server/tests/test_server.py index 548784933..2055efe6f 100644 --- a/packages/kernel/py/stlite-server/tests/test_server.py +++ b/packages/kernel/py/stlite-server/tests/test_server.py @@ -11,8 +11,10 @@ from streamlit.components.v1.components import declare_component from streamlit.hello import Hello from streamlit.runtime.runtime import Runtime +from streamlit.runtime.scriptrunner.script_run_context import add_script_run_ctx from stlite_server.server import Server +from tests.testutil import create_mock_script_run_ctx @pytest.fixture @@ -59,6 +61,8 @@ def start(): raise exception server = data_from_thread["server"] + add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx()) + yield server server.stop() diff --git a/packages/kernel/py/stlite-server/tests/testutil.py b/packages/kernel/py/stlite-server/tests/testutil.py new file mode 100644 index 000000000..1dde7bd4f --- /dev/null +++ b/packages/kernel/py/stlite-server/tests/testutil.py @@ -0,0 +1,92 @@ +# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stlite: copied from https://github.com/streamlit/streamlit/blob/1.35.0/lib/tests/testutil.py # noqa: E501 + +"""Utility functions to use in our tests.""" + +from streamlit import config +from streamlit.runtime.fragment import MemoryFragmentStorage +from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager +from streamlit.runtime.scriptrunner import ScriptRunContext +from streamlit.runtime.state import SafeSessionState, SessionState + + +def should_skip_pydantic_tests() -> bool: + try: + import pydantic + + return not pydantic.__version__.startswith("1.") + except ImportError: + return True + + +def create_mock_script_run_ctx() -> ScriptRunContext: + """Create a ScriptRunContext for use in tests.""" + return ScriptRunContext( + session_id="mock_session_id", + _enqueue=lambda msg: None, + query_string="mock_query_string", + session_state=SafeSessionState(SessionState(), lambda: None), + uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"), + main_script_path="", + page_script_hash="mock_page_script_hash", + user_info={"email": "mock@test.com"}, + fragment_storage=MemoryFragmentStorage(), + ) + + +def build_mock_config_is_manually_set(overrides_dict): + orig_is_manually_set = config.is_manually_set + + def mock_config_is_manually_set(name): + if name in overrides_dict: + return overrides_dict[name] + return orig_is_manually_set(name) + + return mock_config_is_manually_set + + +def normalize_md(txt: str) -> str: + """Replace newlines *inside paragraphs* with spaces. + + Consecutive lines of text are considered part of the same paragraph + in Markdown. So this function joins those into a single line to make the + test robust to changes in text wrapping. + + NOTE: This function doesn't attempt to be 100% grammatically correct + Markdown! It's just supposed to be "correct enough" for tests to pass. For + example, when we guard "\n\n" from being converted, we really should be + guarding for RegEx("\n\n+") instead. But that doesn't matter for our tests. + """ + # Two newlines in a row should NOT be replaced with a space. + txt = txt.replace("\n\n", "OMG_NEWLINE") + + # Lists should NOT be replaced with a space. + txt = txt.replace("\n*", "OMG_STAR") + txt = txt.replace("\n-", "OMG_HYPHEN") + + # Links broken over two lines should not get an extra space. + txt = txt.replace("]\n(", "OMG_LINK") + + # Convert all remaining newlines into spaces. + txt = txt.replace("\n", " ") + + # Restore everything else. + txt = txt.replace("OMG_NEWLINE", "\n\n") + txt = txt.replace("OMG_STAR", "\n*") + txt = txt.replace("OMG_HYPHEN", "\n-") + txt = txt.replace("OMG_LINK", "](") + + return txt.strip()