From e0de7f9192f9aec1a1c2a1161842e8d912340ec6 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sun, 5 Jan 2025 16:56:40 +0200 Subject: [PATCH] http: simplify test and fix minor issues - simplify tests, and add a test for file retrieval - remove unused loop param - call shutdown before cleanup, as docs require Signed-off-by: Benny Zlotnik --- .../src/jumpstarter_driver_http/driver.py | 17 ++-- .../jumpstarter_driver_http/driver_test.py | 84 ++++++++++++++----- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/contrib/drivers/http/src/jumpstarter_driver_http/driver.py b/contrib/drivers/http/src/jumpstarter_driver_http/driver.py index f5f54dd..060e1a5 100644 --- a/contrib/drivers/http/src/jumpstarter_driver_http/driver.py +++ b/contrib/drivers/http/src/jumpstarter_driver_http/driver.py @@ -1,4 +1,3 @@ -import asyncio import logging import os from dataclasses import dataclass, field @@ -170,7 +169,9 @@ async def start(self): return self.runner = web.AppRunner(self.app) - await self.runner.setup() + if self.runner: + await self.runner.setup() + site = web.TCPSite(self.runner, self.host, self.port) await site.start() logger.info(f"HTTP server started at http://{self.host}:{self.port}") @@ -224,17 +225,17 @@ def get_port(self) -> int: def close(self): if self.runner: try: - loop = asyncio.get_running_loop() - if loop.is_running(): - asyncio.create_task(self._async_cleanup(loop)) - except RuntimeError: - anyio.run(self.runner.cleanup) + if anyio.get_current_task(): + anyio.from_thread.run(self._async_cleanup) + except Exception as e: + logger.warning(f"HTTP server cleanup failed synchronously: {e}") self.runner = None super().close() - async def _async_cleanup(self, loop): + async def _async_cleanup(self): try: if self.runner: + await self.runner.shutdown() await self.runner.cleanup() logger.info("HTTP server cleanup completed asynchronously.") except Exception as e: diff --git a/contrib/drivers/http/src/jumpstarter_driver_http/driver_test.py b/contrib/drivers/http/src/jumpstarter_driver_http/driver_test.py index ed2b08e..2a47da7 100644 --- a/contrib/drivers/http/src/jumpstarter_driver_http/driver_test.py +++ b/contrib/drivers/http/src/jumpstarter_driver_http/driver_test.py @@ -1,34 +1,78 @@ -from pathlib import Path +import os +import uuid from tempfile import TemporaryDirectory +import aiohttp +import anyio import pytest +from anyio import create_memory_object_stream -from jumpstarter.common.utils import serve +from jumpstarter.common.resources import ClientStreamResource from .driver import HttpServer -@pytest.mark.asyncio -async def test_http_server(): - with TemporaryDirectory() as source_dir, TemporaryDirectory() as server_dir: - server = HttpServer(root_dir=server_dir) - await server.start() +@pytest.fixture +def anyio_backend(): + return "asyncio" - with serve(server) as client: - test_content = b"test content" - source_file_path = Path(source_dir) / "test.txt" - source_file_path.write_bytes(test_content) +@pytest.fixture +def temp_dir(): + with TemporaryDirectory() as tmpdir: + yield tmpdir - uploaded_filename_url = client.put_local_file(str(source_file_path)) - assert uploaded_filename_url == f"{client.get_url()}/test.txt" +@pytest.fixture +async def server(temp_dir): + server = HttpServer(root_dir=temp_dir) + await server.start() + try: + yield server + finally: + await server.stop() - files = client.list_files() - assert "test.txt" in files +@pytest.mark.anyio +async def test_http_server(server): + filename = "test.txt" + test_content = b"test content" - deleted_filename = client.delete_file("test.txt") - assert deleted_filename == "test.txt" + send_stream, receive_stream = create_memory_object_stream(max_buffer_size=1024) - files_after_deletion = client.list_files() - assert "test.txt" not in files_after_deletion + resource_uuid = uuid.uuid4() + server.resources[resource_uuid] = receive_stream - await server.stop() + resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") + + async def send_data(): + await send_stream.send(test_content) + await send_stream.aclose() + + async with anyio.create_task_group() as tg: + tg.start_soon(send_data) + + uploaded_url = await server.put_file(filename, resource_handle) + assert uploaded_url == f"{server.get_url()}/{filename}" + + files = server.list_files() + assert filename in files + + async with aiohttp.ClientSession() as session: + async with session.get(uploaded_url) as response: + assert response.status == 200 + retrieved_content = await response.read() + assert retrieved_content == test_content + + deleted_filename = await server.delete_file(filename) + assert deleted_filename == filename + + files_after_deletion = server.list_files() + assert filename not in files_after_deletion + +def test_http_server_host_config(temp_dir): + custom_host = "192.168.1.1" + server = HttpServer(root_dir=temp_dir, host=custom_host) + assert server.get_host() == custom_host + +def test_http_server_root_directory_creation(temp_dir): + new_dir = os.path.join(temp_dir, "new_http_root") + _ = HttpServer(root_dir=new_dir) + assert os.path.exists(new_dir)