Skip to content

Commit

Permalink
Update stlite-server
Browse files Browse the repository at this point in the history
  • Loading branch information
whitphx committed May 25, 2024
1 parent c41a1e1 commit 429ebcc
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import logging
import mimetypes
import os
from typing import Final

from streamlit.components.v1.components import ComponentRegistry
from streamlit.components.types.base_component_registry import BaseComponentRegistry

from .handler import Request, RequestHandler, Response

_LOGGER = logging.getLogger(__name__)
_LOGGER: Final = logging.getLogger(__name__)


# Mimic streamlit.web.server.component_request_handler.ComponentRequestHandler
class ComponentRequestHandler(RequestHandler):
def __init__(self, registry: ComponentRegistry) -> None:
def __init__(self, registry: BaseComponentRegistry) -> None:
self._registry = registry

def get(self, request: Request, path: str) -> Response: # type: ignore[override]
Expand All @@ -27,11 +28,7 @@ def get(self, request: Request, path: str) -> Response: # type: ignore[override
abspath = os.path.realpath(os.path.join(component_root, filename))

# Do NOT expose anything outside of the component root.
if os.path.commonprefix([component_root, abspath]) != component_root or (
not os.path.normpath(abspath).startswith(
component_root
) # this is a recommendation from CodeQL, probably a bit redundant
):
if os.path.commonpath([component_root, abspath]) != component_root:
return Response(status_code=403, headers={}, body="forbidden")

try:
Expand All @@ -54,7 +51,7 @@ def get(self, request: Request, path: str) -> Response: # type: ignore[override
)

@staticmethod
def get_content_type(abspath) -> str:
def get_content_type(abspath: str) -> str:
"""Returns the ``Content-Type`` header to be used for this request.
From tornado.web.StaticFileHandler.
"""
Expand Down
3 changes: 1 addition & 2 deletions packages/kernel/py/stlite-server/stlite_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable, Final, cast

import pyodide
from streamlit.components.v1.components import ComponentRegistry
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime import Runtime, RuntimeConfig, SessionClient
Expand Down Expand Up @@ -90,7 +89,7 @@ async def start(self) -> None:
),
(
make_url_path_regex(base, "component/(.*)"),
ComponentRequestHandler(registry=ComponentRegistry.instance()),
ComponentRequestHandler(registry=self._runtime.component_registry),
),
]
self._routes = [(re.compile(pattern), handler) for (pattern, handler) in routes]
Expand Down
4 changes: 4 additions & 0 deletions packages/kernel/py/stlite-server/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from streamlit.components.v1.components import declare_component
from streamlit.hello import Hello
from streamlit.runtime.runtime import Runtime
from streamlit.runtime.scriptrunner.script_run_context import add_script_run_ctx

from stlite_server.server import Server
from tests.testutil import create_mock_script_run_ctx


@pytest.fixture
Expand Down Expand Up @@ -59,6 +61,8 @@ def start():
raise exception
server = data_from_thread["server"]

add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())

yield server

server.stop()
Expand Down
92 changes: 92 additions & 0 deletions packages/kernel/py/stlite-server/tests/testutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Stlite: copied from https://github.com/streamlit/streamlit/blob/1.35.0/lib/tests/testutil.py # noqa: E501

"""Utility functions to use in our tests."""

from streamlit import config
from streamlit.runtime.fragment import MemoryFragmentStorage
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import ScriptRunContext
from streamlit.runtime.state import SafeSessionState, SessionState


def should_skip_pydantic_tests() -> bool:
try:
import pydantic

return not pydantic.__version__.startswith("1.")
except ImportError:
return True


def create_mock_script_run_ctx() -> ScriptRunContext:
"""Create a ScriptRunContext for use in tests."""
return ScriptRunContext(
session_id="mock_session_id",
_enqueue=lambda msg: None,
query_string="mock_query_string",
session_state=SafeSessionState(SessionState(), lambda: None),
uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
main_script_path="",
page_script_hash="mock_page_script_hash",
user_info={"email": "[email protected]"},
fragment_storage=MemoryFragmentStorage(),
)


def build_mock_config_is_manually_set(overrides_dict):
orig_is_manually_set = config.is_manually_set

def mock_config_is_manually_set(name):
if name in overrides_dict:
return overrides_dict[name]
return orig_is_manually_set(name)

return mock_config_is_manually_set


def normalize_md(txt: str) -> str:
"""Replace newlines *inside paragraphs* with spaces.
Consecutive lines of text are considered part of the same paragraph
in Markdown. So this function joins those into a single line to make the
test robust to changes in text wrapping.
NOTE: This function doesn't attempt to be 100% grammatically correct
Markdown! It's just supposed to be "correct enough" for tests to pass. For
example, when we guard "\n\n" from being converted, we really should be
guarding for RegEx("\n\n+") instead. But that doesn't matter for our tests.
"""
# Two newlines in a row should NOT be replaced with a space.
txt = txt.replace("\n\n", "OMG_NEWLINE")

# Lists should NOT be replaced with a space.
txt = txt.replace("\n*", "OMG_STAR")
txt = txt.replace("\n-", "OMG_HYPHEN")

# Links broken over two lines should not get an extra space.
txt = txt.replace("]\n(", "OMG_LINK")

# Convert all remaining newlines into spaces.
txt = txt.replace("\n", " ")

# Restore everything else.
txt = txt.replace("OMG_NEWLINE", "\n\n")
txt = txt.replace("OMG_STAR", "\n*")
txt = txt.replace("OMG_HYPHEN", "\n-")
txt = txt.replace("OMG_LINK", "](")

return txt.strip()

0 comments on commit 429ebcc

Please sign in to comment.