diff --git a/.cspell.json b/.cspell.json index 5c81bd0d5f9..804f8242bf8 100644 --- a/.cspell.json +++ b/.cspell.json @@ -4,6 +4,7 @@ "abstractmethods", "arange", "astype", + "asyncssh", "autotune", "autotuning", "bincount", @@ -12,11 +13,13 @@ "commandline", "conda", "configspace", + "coro", "dataframe", "devcontainer", "discretization", "discretize", "drivername", + "dstpath", "dtype", "duckdb", "emukit", @@ -31,6 +34,7 @@ "iterrows", "jsonschema", "jupyterlab", + "keepalive", "kwargs", "libmamba", "linalg", @@ -47,10 +51,12 @@ "obvs", "perc", "pinv", + "poweroff", "pylint", "pyplot", "pytest", "Quickstart", + "refcnt", "rexec", "rootfs", "runhistory", @@ -63,11 +69,13 @@ "skopt", "smac", "sqlalchemy", + "srcpaths", "subcmd", "subschema", "subschemas", "tolist", "tunables", + "xdist", "xlabel", "ylabel" ] diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 21694125e8e..d3448860c34 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -17,6 +17,7 @@ "postCreateCommand": "/opt/conda/bin/conda env update --solver=libmamba -v -n mlos -f ${containerWorkspaceFolder}/conda-envs/mlos.yml", // Various mounting, run, post-create, and user settings "containerEnv": { + "LOCAL_USER_NAME": "${localEnv:USER}${localEnv:USERNAME}", "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}", "CONTAINER_WORKSPACE_FOLDER": "${containerWorkspaceFolder}", "SSH_AUTH_SOCK": "${localEnv:SSH_AUTH_SOCK}", diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index db29f6ef56d..a1dcaf1478d 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -24,6 +24,9 @@ services: env_file: ../.env + extra_hosts: + - "host.docker.internal:host-gateway" + docs-www: build: context: ./tmp diff --git a/.devcontainer/scripts/run-devcontainer.sh b/.devcontainer/scripts/run-devcontainer.sh index a1ef8df5934..2d8be436516 100755 --- a/.devcontainer/scripts/run-devcontainer.sh +++ b/.devcontainer/scripts/run-devcontainer.sh @@ -28,6 +28,7 @@ else docker_gid=$(stat -c%g /var/run/docker.sock) fi +set -x mkdir -p "/tmp/$container_name/dc/shellhistory" docker run -it --rm \ --name "$container_name" \ diff --git a/.github/workflows/devcontainer.yml b/.github/workflows/devcontainer.yml index 82607a47459..384f5c5ac7c 100644 --- a/.github/workflows/devcontainer.yml +++ b/.github/workflows/devcontainer.yml @@ -90,6 +90,7 @@ jobs: --env LOCAL_WORKSPACE_FOLDER=$(pwd) \ --env PYTEST_EXTRA_OPTIONS=$PYTEST_EXTRA_OPTIONS \ --workdir /workspaces/MLOS \ + --add-host host.docker.internal:host-gateway \ --name mlos-devcontainer mlos-devcontainer sleep 1800 - name: Fixup vscode uid/gid in the running container timeout-minutes: 3 diff --git a/.vscode/settings.json b/.vscode/settings.json index eac28f1e77f..fff265d85f5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -134,6 +134,7 @@ }, "python.testing.pytestArgs": [ "-n1", // don't run tests in parallel inside vscode - makes attaching the debugger more cumbersome + "--dist=no", "--log-level=DEBUG", "." ], diff --git a/conda-envs/mlos-3.10.yml b/conda-envs/mlos-3.10.yml index 4b314c2e352..447d396b464 100644 --- a/conda-envs/mlos-3.10.yml +++ b/conda-envs/mlos-3.10.yml @@ -32,6 +32,7 @@ dependencies: - types-colorama - types-jsonschema - types-pygments + - types-pytest-lazy-fixture - types-requests - types-setuptools - "--editable ../mlos_core[full-tests]" diff --git a/conda-envs/mlos-3.11.yml b/conda-envs/mlos-3.11.yml index 333bc75448a..496c31260c6 100644 --- a/conda-envs/mlos-3.11.yml +++ b/conda-envs/mlos-3.11.yml @@ -32,6 +32,7 @@ dependencies: - types-colorama - types-jsonschema - types-pygments + - types-pytest-lazy-fixture - types-requests - types-setuptools - "--editable ../mlos_core[full-tests]" diff --git a/conda-envs/mlos-3.8.yml b/conda-envs/mlos-3.8.yml index 89f8ef0897b..94c457746fd 100644 --- a/conda-envs/mlos-3.8.yml +++ b/conda-envs/mlos-3.8.yml @@ -32,6 +32,7 @@ dependencies: - types-colorama - types-jsonschema - types-pygments + - types-pytest-lazy-fixture - types-requests - types-setuptools - "--editable ../mlos_core[full-tests]" diff --git a/conda-envs/mlos-3.9.yml b/conda-envs/mlos-3.9.yml index 46cc07d4a13..aa36117214c 100644 --- a/conda-envs/mlos-3.9.yml +++ b/conda-envs/mlos-3.9.yml @@ -32,6 +32,7 @@ dependencies: - types-colorama - types-jsonschema - types-pygments + - types-pytest-lazy-fixture - types-requests - types-setuptools - "--editable ../mlos_core[full-tests]" diff --git a/conda-envs/mlos-windows.yml b/conda-envs/mlos-windows.yml index b3f284bf566..ce115a571f4 100644 --- a/conda-envs/mlos-windows.yml +++ b/conda-envs/mlos-windows.yml @@ -36,6 +36,7 @@ dependencies: - types-colorama - types-jsonschema - types-pygments + - types-pytest-lazy-fixture - types-requests - types-setuptools - "--editable ../mlos_core[full-tests]" diff --git a/conda-envs/mlos.yml b/conda-envs/mlos.yml index e5098695a36..8bfcef06ce2 100644 --- a/conda-envs/mlos.yml +++ b/conda-envs/mlos.yml @@ -30,6 +30,7 @@ dependencies: - types-colorama - types-jsonschema - types-pygments + - types-pytest-lazy-fixture - types-requests - types-setuptools - "--editable ../mlos_core[full-tests]" diff --git a/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/common-defs-subschemas.json b/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/common-defs-subschemas.json new file mode 100644 index 00000000000..e8919db4e3f --- /dev/null +++ b/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/common-defs-subschemas.json @@ -0,0 +1,47 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/common-defs-subschemas.json", + "title": "mlos_bench SSH Service common defs config", + "description": "mlos_bench SSH Service common defs config", + "$defs": { + "ssh_service_config": { + "description": "SSH Service config.", + "type": "object", + "properties": { + "ssh_request_timeout": { + "description": "Request timeout in seconds, or null to disable timeout.", + "type": ["null", "number"], + "minimum": 1 + }, + "ssh_keepalive_interval": { + "description": "Whether to send keep alive packets to the remote machine(s).", + "type": ["null", "number"], + "minimum": 1 + }, + "ssh_port": { + "description": "Default port to use when connecting to the remote machine(s).", + "type": "integer", + "minimum": 1, + "maximum": 65535, + "examples": [22] + }, + "ssh_username": { + "description": "Default username to use when connecting to the remote machine(s). If null or unspecified, will default to the value in ~/.ssh/config or the current user if not provided in another config.", + "type": ["null", "string"], + "examples": ["root"] + }, + "ssh_priv_key_path": { + "$comment": "TODO: add support for multiple private keys or dynamically fetched private keys.", + "description": "Optional path to the private key to use when connecting to the remote machine(s). Keys are automatically searched for if not specified. Note: these should be passwordless or else an ssh-agent should be available via the SSH_AUTH_SOCK environment variable.", + "type": "string", + "examples": ["~/.ssh/id_rsa", "~/.ssh/id_ecdsa"] + }, + "ssh_known_hosts_path": { + "description": "Path to known_hosts file. Set to null to disable host key checking.", + "type": ["null", "string"], + "examples": [null, "~/.ssh/known_hosts"] + } + } + } + } +} diff --git a/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-fileshare-service-subschema.json b/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-fileshare-service-subschema.json new file mode 100644 index 00000000000..df1a85a6daf --- /dev/null +++ b/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-fileshare-service-subschema.json @@ -0,0 +1,21 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-fileshare-service-subschema.json", + "title": "mlos_bench SSH Fileshare Service config", + "description": "config for an mlos_bench SSH Fileshare Service", + "type": "object", + "properties": { + "class": { + "enum": [ + "mlos_bench.services.remote.ssh.SshFileShareService", + "mlos_bench.services.remote.ssh.ssh_fileshare.SshFileShareService" + ] + }, + "config": { + "$ref": "./common-defs-subschemas.json#/$defs/ssh_service_config", + "minProperties": 1, + "unevaluatedProperties": false + } + }, + "required": ["class"] +} diff --git a/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-host-service-subschema.json b/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-host-service-subschema.json new file mode 100644 index 00000000000..78803c0c5e5 --- /dev/null +++ b/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-host-service-subschema.json @@ -0,0 +1,37 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/services/remote/ssh/ssh-host-service-subschema.json", + "title": "mlos_bench SSH Host Service config", + "description": "config for an mlos_bench SSH Host Service", + "type": "object", + "properties": { + "class": { + "enum": [ + "mlos_bench.services.remote.ssh.SshHostService", + "mlos_bench.services.remote.ssh.ssh_host_service.SshHostService" + ] + }, + "config": { + "allOf": [ + { + "$ref": "./common-defs-subschemas.json#/$defs/ssh_service_config" + }, + { + "type": "object", + "properties": { + "ssh_shell": { + "type": "string", + "examples": [ + "/bin/bash", + "/bin/sh" + ] + } + } + } + ], + "minProperties": 1, + "unevaluatedProperties": false + } + }, + "required": ["class"] +} diff --git a/mlos_bench/mlos_bench/config/schemas/services/service-schema.json b/mlos_bench/mlos_bench/config/schemas/services/service-schema.json index ce280670b1c..ecf55478e35 100644 --- a/mlos_bench/mlos_bench/config/schemas/services/service-schema.json +++ b/mlos_bench/mlos_bench/config/schemas/services/service-schema.json @@ -34,6 +34,12 @@ { "$ref": "./remote/mock/mock-auth-service-subschema.json" }, + { + "$ref": "./remote/ssh/ssh-host-service-subschema.json" + }, + { + "$ref": "./remote/ssh/ssh-fileshare-service-subschema.json" + }, { "$ref": "./remote/azure/azure-auth-service-subschema.json" }, diff --git a/mlos_bench/mlos_bench/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/services/remote/ssh/__init__.py index 87eac00d1ac..2ab1705a74d 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/__init__.py @@ -4,4 +4,10 @@ # """SSH remote service.""" -# TODO +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService +from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService + +__all__ = [ + "SshHostService", + "SshFileShareService", +] diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py new file mode 100644 index 00000000000..f478e528d27 --- /dev/null +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -0,0 +1,107 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +A collection functions for interacting with SSH servers as file shares. +""" + +from enum import Enum +from typing import Tuple, Union + +import logging + +from asyncssh import scp, SFTPError, SSHClientConnection + +from mlos_bench.services.base_fileshare import FileShareService +from mlos_bench.services.remote.ssh.ssh_service import SshService +from mlos_bench.util import merge_parameters + +_LOG = logging.getLogger(__name__) + + +class CopyMode(Enum): + """ + Copy mode enum. + """ + + DOWNLOAD = 1 + UPLOAD = 2 + + +class SshFileShareService(FileShareService, SshService): + """A collection of functions for interacting with SSH servers as file shares.""" + + async def _start_file_copy(self, params: dict, mode: CopyMode, + local_path: str, remote_path: str, + recursive: bool = True) -> None: + # pylint: disable=too-many-arguments + """ + Starts a file copy operation + + Parameters + ---------- + params : dict + Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). + mode : CopyMode + Whether to download or upload the file. + local_path : str + Local path to the file/dir. + remote_path : str + Remote path to the file/dir. + recursive : bool, optional + _description_, by default True + + Raises + ------ + OSError + If the local OS returns an error. + SFTPError + If the remote OS returns an error. + """ + connection, _ = await self._get_client_connection(params) + srcpaths: Union[str, Tuple[SSHClientConnection, str]] + dstpath: Union[str, Tuple[SSHClientConnection, str]] + if mode == CopyMode.DOWNLOAD: + srcpaths = (connection, remote_path) + dstpath = local_path + elif mode == CopyMode.UPLOAD: + srcpaths = local_path + dstpath = (connection, remote_path) + else: + raise ValueError(f"Unknown copy mode: {mode}") + return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) + + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + params = merge_parameters( + dest=self.config.copy(), + source=params, + required_keys=[ + "ssh_hostname", + ] + ) + super().download(params, remote_path, local_path, recursive) + file_copy_future = self._run_coroutine( + self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) + try: + file_copy_future.result() + except (OSError, SFTPError) as ex: + _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) + raise ex + + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + params = merge_parameters( + dest=self.config.copy(), + source=params, + required_keys=[ + "ssh_hostname", + ] + ) + super().upload(params, local_path, remote_path, recursive) + file_copy_future = self._run_coroutine( + self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) + try: + file_copy_future.result() + except (OSError, SFTPError) as ex: + _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex) + raise ex diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py new file mode 100644 index 00000000000..f67e71b7225 --- /dev/null +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -0,0 +1,274 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +A collection Service functions for managing hosts via SSH. +""" + +from concurrent.futures import Future +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import logging + +from asyncssh import SSHCompletedProcess, ConnectionLost, DisconnectError, ProcessError + +from mlos_bench.environments.status import Status +from mlos_bench.services.base_service import Service +from mlos_bench.services.remote.ssh.ssh_service import SshService +from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec +from mlos_bench.services.types.os_ops_type import SupportsOSOps +from mlos_bench.util import merge_parameters + +_LOG = logging.getLogger(__name__) + + +class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): + """ + Helper methods to manage machines via SSH. + """ + + # pylint: disable=too-many-instance-attributes + + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): + """ + Create a new instance of an SSH Service. + + Parameters + ---------- + config : dict + Free-format dictionary that contains the benchmark environment + configuration. + global_config : dict + Free-format dictionary of global parameters. + parent : Service + Parent service that can provide mixin functions. + methods : Union[Dict[str, Callable], List[Callable], None] + New methods to register with the service. + """ + super().__init__( + config, global_config, parent, + self.merge_methods(methods, [ + self.shutdown, + self.reboot, + self.wait_os_operation, + self.remote_exec, + self.get_remote_exec_results, + ])) + self._shell = self.config.get("ssh_shell", "/bin/bash") + + async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: + """ + Runs a command asynchronously on a host via SSH. + + Parameters + ---------- + params : dict + Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). + cmd : str + Command(s) to run via shell. + + Returns + ------- + SSHCompletedProcess + Returns the result of the command. + """ + if isinstance(script, str): + # Script should be an iterable of lines, not an iterable string. + script = [script] + connection, _ = await self._get_client_connection(params) + # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. + # Handle transferring environment variables by making a script to set them. + env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] + script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] + # Note: connection.run() uses "exec" with a shell by default. + return await connection.run('\n'.join(script_lines), + check=False, + timeout=self._request_timeout, + env=env_params) + + def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: + """ + Start running a command on remote host OS. + + Parameters + ---------- + script : Iterable[str] + A list of lines to execute as a script on a remote VM. + config : dict + Flat dictionary of (key, value) pairs of parameters. + They usually come from `const_args` and `tunable_params` + properties of the Environment. + env_params : dict + Parameters to pass as *shell* environment variables into the script. + This is usually a subset of `config` with some possible conversions. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + config = merge_parameters( + dest=self.config.copy(), + source=config, + required_keys=[ + "ssh_hostname", + ] + ) + config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) + return (Status.PENDING, config) + + def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: + """ + Get the results of the asynchronously running command. + + Parameters + ---------- + config : dict + Flat dictionary of (key, value) pairs of tunable parameters. + Must have the "asyncRemoteExecResultsFuture" key to get the results. + If the key is not present, return Status.PENDING. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} + """ + future = config.get("asyncRemoteExecResultsFuture") + if not future: + raise ValueError("Missing 'asyncRemoteExecResultsFuture'.") + assert isinstance(future, Future) + result = None + try: + result = future.result(timeout=self._request_timeout) + assert isinstance(result, SSHCompletedProcess) + stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout + stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr + return ( + Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, + { + "stdout": stdout, + "stderr": stderr, + "ssh_completed_process_result": result, + }, + ) + except (ConnectionLost, DisconnectError, ProcessError, TimeoutError) as ex: + _LOG.error("Failed to get remote exec results: %s", ex) + return (Status.FAILED, {"result": result}) + + def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]: + """_summary_ + + Parameters + ---------- + cmd_opts_list : List[str] + List of commands to try to execute. + params : dict + The params used to connect to the host. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + config = merge_parameters( + dest=self.config.copy(), + source=params, + required_keys=[ + "ssh_hostname", + ] + ) + cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) + script = rf""" + if [[ $EUID -ne 0 ]]; then + sudo=$(command -v sudo) + sudo=${{sudo:+$sudo -n}} + fi + + for cmd in {cmd_opts}; do + $sudo $cmd && exit 0 + done + + echo 'ERROR: Failed to shutdown/reboot the system.' + exit 1 + """ + return self.remote_exec(script, config, env_params={}) + + def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: + """ + Initiates a (graceful) shutdown of the Host/VM OS. + + Parameters + ---------- + params: dict + Flat dictionary of (key, value) pairs of tunable parameters. + force : bool + If True, force stop the Host/VM. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + cmd_opts_list = [ + 'shutdown -h now', + 'poweroff', + 'halt -p', + 'systemctl poweroff', + ] + return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) + + def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: + """ + Initiates a (graceful) shutdown of the Host/VM OS. + + Parameters + ---------- + params: dict + Flat dictionary of (key, value) pairs of tunable parameters. + force : bool + If True, force restart the Host/VM. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + cmd_opts_list = [ + 'shutdown -r now', + 'reboot', + 'halt --reboot', + 'systemctl reboot', + 'kill -9 1' if force else 'kill 1', + ] + return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) + + def wait_os_operation(self, params: dict) -> Tuple[Status, dict]: + """ + Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. + Return TIMED_OUT when timing out. + + Parameters + ---------- + params: dict + Flat dictionary of (key, value) pairs of tunable parameters. + Must have the "asyncRemoteExecResultsFuture" key to get the results. + If the key is not present, return Status.PENDING. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} + Result is info on the operation runtime if SUCCEEDED, otherwise {}. + """ + return self.get_remote_exec_results(params) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py new file mode 100644 index 00000000000..21db889bb59 --- /dev/null +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -0,0 +1,368 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +A collection functions for interacting with SSH servers as file shares. +""" + +from abc import ABCMeta +from asyncio import Event as CoroEvent, Lock as CoroLock +from concurrent.futures import Future +from types import TracebackType +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union +from threading import current_thread + +import logging +import os +import sys + +import asyncssh + +from asyncssh.connection import SSHClientConnection + +from mlos_bench.services.base_service import Service +from mlos_bench.event_loop_context import EventLoopContext + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + + +_LOG = logging.getLogger(__name__) + + +class SshClient(asyncssh.SSHClient): + """ + Wrapper around SSHClient to help provide connection caching and reconnect logic. + + Used by the SshService to try and maintain a single connection to hosts, + handle reconnects if possible, and use that to run commands rather than + reconnect for each command. + """ + + _CONNECTION_PENDING = 'INIT' + _CONNECTION_LOST = 'LOST' + + def __init__(self, *args: tuple, **kwargs: dict): + self._connection_id: str = SshClient._CONNECTION_PENDING + self._connection: Optional[SSHClientConnection] = None + self._conn_event: CoroEvent = CoroEvent() + super().__init__(*args, **kwargs) + + def __repr__(self) -> str: + return self._connection_id + + @staticmethod + def id_from_connection(connection: SSHClientConnection) -> str: + """Gets a unique id repr for the connection.""" + return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access + + @staticmethod + def id_from_params(connect_params: dict) -> str: + """Gets a unique id repr for the connection.""" + return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}" + + def connection_made(self, conn: SSHClientConnection) -> None: + """ + Override hook provided by asyncssh.SSHClient. + + Changes the connection_id from _CONNECTION_PENDING to a unique id repr. + """ + self._conn_event.clear() + _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ + # pylint: disable=protected-access + self._connection_id = SshClient.id_from_connection(conn) + self._connection = conn + self._conn_event.set() + return super().connection_made(conn) + + def connection_lost(self, exc: Optional[Exception]) -> None: + self._conn_event.clear() + _LOG.debug("%s: %s", current_thread().name, "connection_lost") + if exc is None: + _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) + else: + _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) + self._connection_id = SshClient._CONNECTION_LOST + self._connection = None + self._conn_event.set() + return super().connection_lost(exc) + + async def connection(self) -> Optional[SSHClientConnection]: + """ + Waits for and returns the SSHClientConnection to be established or lost. + """ + _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) + await self._conn_event.wait() + _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) + return self._connection + + +class SshClientCache: + """ + Manages a cache of SshClient connections. + Note: Only one per event loop thread supported. + See additional details in SshService comments. + """ + + def __init__(self) -> None: + self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {} + self._cache_lock = CoroLock() + self._refcnt: int = 0 + + def __str__(self) -> str: + return str(self._cache) + + def __len__(self) -> int: + return len(self._cache) + + def enter(self) -> None: + """ + Manages the cache lifecycle with reference counting. + To be used in the __enter__ method of a caller's context manager. + """ + self._refcnt += 1 + + def exit(self) -> None: + """ + Manages the cache lifecycle with reference counting. + To be used in the __exit__ method of a caller's context manager. + """ + self._refcnt -= 1 + if self._refcnt <= 0: + self.cleanup() + + async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: + """ + Gets a (possibly cached) client connection. + + Parameters + ---------- + connect_params: dict + Parameters to pass to asyncssh.create_connection. + + Returns + ------- + Tuple[SSHClientConnection, SshClient] + A tuple of (SSHClientConnection, SshClient). + """ + _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) + async with self._cache_lock: + connection_id = SshClient.id_from_params(connect_params) + client: Union[None, SshClient, asyncssh.SSHClient] + _, client = self._cache.get(connection_id, (None, None)) + if client: + _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) + connection = await client.connection() + if not connection: + _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) + self._cache.pop(connection_id) + # Try to reconnect next. + else: + _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) + if connection_id not in self._cache: + _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) + connection, client = await asyncssh.create_connection(SshClient, **connect_params) + assert isinstance(client, SshClient) + self._cache[connection_id] = (connection, client) + _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) + return self._cache[connection_id] + + def cleanup(self) -> None: + """ + Closes all cached connections. + """ + for (connection, _) in self._cache.values(): + connection.close() + self._cache = {} + + +class SshService(Service, metaclass=ABCMeta): + """ + Base class for SSH services. + """ + + # AsyncSSH requires an asyncio event loop to be running to work. + # However, running that event loop blocks the main thread. + # To avoid having to change our entire API to use async/await, all the way + # up the stack, we run the event loop that runs any async code in a + # background thread and submit async code to it using + # asyncio.run_coroutine_threadsafe, interacting with Futures after that. + # This is a bit of a hack, but it works for now. + # + # The event loop is created on demand and shared across all SshService + # instances, hence we need to lock it when doing the creation/cleanup, + # or later, during context enter and exit. + # + # We ran tests to ensure that multiple requests can still be executing + # concurrently inside that event loop so there should be no practical + # performance loss for our initial cases even with just single background + # thread running the event loop. + # + # Note: the tests were run to confirm that this works with two threads. + # Using a larger thread pool requires a bit more work since asyncssh + # requires that run() requests are submitted to the same event loop handler + # that the connection was made on. + # In that case, each background thread should get its own SshClientCache. + + # Maintain one just one event loop thread for all SshService instances. + # But only keep it running while they are within a context. + _EVENT_LOOP_CONTEXT = EventLoopContext() + _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache() + + _REQUEST_TIMEOUT: Optional[float] = None # seconds + + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): + super().__init__(config, global_config, parent, methods) + + # Make sure that the value we allow overriding on a per-connection + # basis are present in the config so merge_parameters can do its thing. + self.config.setdefault('ssh_port', None) + assert isinstance(self.config['ssh_port'], (int, type(None))) + self.config.setdefault('ssh_username', None) + assert isinstance(self.config['ssh_username'], (str, type(None))) + self.config.setdefault('ssh_priv_key_path', None) + assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) + + # None can be used to disable the request timeout. + self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) + self._request_timeout = float(self._request_timeout) if self._request_timeout is not None else None + + # Prep an initial connect_params. + self._connect_params: dict = { + # In general scripted commands shouldn't need a pty and having one + # available can confuse some commands, though we may need to make + # this configurable in the future. + 'request_pty': False, + # By default disable known_hosts checking (since most VMs expected to be dynamically created). + 'known_hosts': None, + } + + if 'ssh_known_hosts_file' in self.config: + self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) + if isinstance(self._connect_params['known_hosts'], str): + known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) + if not os.path.exists(known_hosts_file): + raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") + self._connect_params['known_hosts'] = known_hosts_file + if self._connect_params['known_hosts'] is None: + _LOG.info("%s known_hosts checking is disabled per config.", self) + + if 'ssh_keepalive_interval' in self.config: + keepalive_internal = self.config.get('ssh_keepalive_interval') + self._connect_params['keepalive_interval'] = None if keepalive_internal is None else int(keepalive_internal) + + def _enter_context(self) -> "SshService": + # Start the background thread if it's not already running. + assert not self._in_context + SshService._EVENT_LOOP_CONTEXT.enter() + SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter() + super()._enter_context() + return self + + def _exit_context(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: + # Stop the background thread if it's not needed anymore and potentially + # cleanup the cache as well. + assert self._in_context + SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit() + SshService._EVENT_LOOP_CONTEXT.exit() + return super()._exit_context(ex_type, ex_val, ex_tb) + + @classmethod + def clear_client_cache(cls) -> None: + """ + Clears the cache of client connections. + Note: This may cause in flight operations to fail. + """ + cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() + + CoroReturnType = TypeVar('CoroReturnType') + if sys.version_info >= (3, 9): + FutureReturnType: TypeAlias = Future[CoroReturnType] + else: + FutureReturnType: TypeAlias = Future + + def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: + """ + Runs the given coroutine in the background event loop thread. + + Parameters + ---------- + coro : Coroutine[Any, Any, CoroReturnType] + The coroutine to run. + + Returns + ------- + Future[CoroReturnType] + A future that will be completed when the coroutine completes. + """ + assert self._in_context + return self._EVENT_LOOP_CONTEXT.run_coroutine(coro) + + def _get_connect_params(self, params: dict) -> dict: + """ + Produces a dict of connection parameters for asyncssh.create_connection. + + Parameters + ---------- + params : dict + Additional connection parameters specific to this host. + + Returns + ------- + dict + A dict of connection parameters for asyncssh.create_connection. + """ + # Setup default connect_params dict for all SshClients we might need to create. + + # Note: None is an acceptable value for several of these, in which case + # reasonable defaults or values from ~/.ssh/config will take effect. + + # Start with the base config params. + connect_params = self._connect_params.copy() + + connect_params['host'] = params['ssh_hostname'] # required + + if params.get('ssh_port'): + connect_params['port'] = int(params.pop('ssh_port')) + elif self.config['ssh_port']: + connect_params['port'] = int(self.config['ssh_port']) + + if 'ssh_username' in params: + connect_params['username'] = str(params.pop('ssh_username')) + elif self.config['ssh_username']: + connect_params['username'] = str(self.config['ssh_username']) + + priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) + if priv_key_file: + priv_key_file = os.path.expanduser(priv_key_file) + if not os.path.exists(priv_key_file): + raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") + connect_params['client_keys'] = [priv_key_file] + + return connect_params + + async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]: + """ + Gets a (possibly cached) SshClient (connection) for the given connection params. + + Parameters + ---------- + params : dict + Optional override connection parameters. + + Returns + ------- + Tuple[SSHClientConnection, SshClient] + The connection and client objects. + """ + assert self._in_context + return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params)) diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index 4bc71d810b0..eef6b63e13d 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -7,11 +7,36 @@ Used to make mypy happy about multiple conftest.py modules. """ +from logging import debug, warning +from subprocess import run from typing import Optional +import filecmp +import os +import socket +import shutil + +import pytest + from mlos_bench.util import get_class_from_name +# A decorator for tests that require docker. +# Use with @requires_docker above a test_...() function. +DOCKER = shutil.which('docker') +if DOCKER: + cmd = run("docker builder inspect default", shell=True, check=False, capture_output=True) + stdout = cmd.stdout.decode() + if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): + debug("Docker is available but missing support for targeting linux platform.") + DOCKER = None +requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') + +# A decorator for tests that require ssh. +# Use with @requires_ssh above a test_...() function. +SSH = shutil.which('ssh') +requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') + # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. SEED = 42 @@ -39,3 +64,70 @@ def check_class_name(obj: object, expected_class_name: str) -> bool: """ full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ return full_class_name == try_resolve_class_name(expected_class_name) + + +def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: + """ + Test to see if a socket is open. + + Parameters + ---------- + host : str + port : int + timeout: float + + Returns + ------- + bool + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(timeout) # seconds + result = sock.connect_ex((host, port)) + return result == 0 + + +def resolve_host_name(host: str) -> Optional[str]: + """ + Resolves the host name to an IP address. + + Parameters + ---------- + host : str + + Returns + ------- + str + """ + try: + return socket.gethostbyname(host) + except socket.gaierror: + return None + + +def are_dir_trees_equal(dir1: str, dir2: str) -> bool: + """ + Compare two directories recursively. Files in each directory are + assumed to be equal if their names and contents are equal. + + @param dir1: First directory path + @param dir2: Second directory path + + @return: True if the directory trees are the same and + there were no errors while accessing the directories or files, + False otherwise. + """ + # See Also: https://stackoverflow.com/a/6681395 + dirs_cmp = filecmp.dircmp(dir1, dir2) + if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: + warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") + return False + (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) + if len(mismatch) > 0 or len(errors) > 0: + warning(f"Found differences in files:\n{mismatch}\n{errors}") + return False + for common_dir in dirs_cmp.common_dirs: + new_dir1 = os.path.join(dir1, common_dir) + new_dir2 = os.path.join(dir2, common_dir) + if not are_dir_trees_equal(new_dir1, new_dir2): + return False + return True diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index 48377833e61..75b2ae0cbee 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -8,7 +8,6 @@ from typing import Callable, List, Optional -from glob import iglob import os import sys diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index 0533046a2cc..2e836f3d8a0 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -54,6 +54,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" + # pylint: disable=too-complex config = config_loader_service.load_config(config_path, ConfigSchema.CLI) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_fileshare_service-bad-port-type.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_fileshare_service-bad-port-type.jsonc new file mode 100644 index 00000000000..2b0485bde53 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_fileshare_service-bad-port-type.jsonc @@ -0,0 +1,7 @@ +{ + "class": "mlos_bench.services.remote.ssh.SshFileShareService", + "config": { + "ssh_username": "someuser", + "ssh_port": null // bad port type + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_host_service-bad-known_hosts-type.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_host_service-bad-known_hosts-type.jsonc new file mode 100644 index 00000000000..0ee7df52239 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_host_service-bad-known_hosts-type.jsonc @@ -0,0 +1,11 @@ +{ + "class": "mlos_bench.services.remote.ssh.SshHostService", + "config": { + "ssh_username": "someuser", + "ssh_known_hosts_path": [ + "array", + "not", + "supported" + ] + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_host_service-bad_request_timeout.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_host_service-bad_request_timeout.jsonc new file mode 100644 index 00000000000..c32e3337794 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/invalid/ssh_host_service-bad_request_timeout.jsonc @@ -0,0 +1,7 @@ +{ + "class": "mlos_bench.services.remote.ssh.SshHostService", + "config": { + "ssh_username": "someuser", + "ssh_request_timeout": 0.0 + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/unhandled/ssh_fileshare_service-extras.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/unhandled/ssh_fileshare_service-extras.jsonc new file mode 100644 index 00000000000..287ce0bb52b --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/unhandled/ssh_fileshare_service-extras.jsonc @@ -0,0 +1,7 @@ +{ + "class": "mlos_bench.services.remote.ssh.SshFileShareService", + "config": { + "ssh_username": "someuser", + "extra": "invalid" + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/unhandled/ssh_host_service-extras.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/unhandled/ssh_host_service-extras.jsonc new file mode 100644 index 00000000000..77419c19b09 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/bad/unhandled/ssh_host_service-extras.jsonc @@ -0,0 +1,7 @@ +{ + "class": "mlos_bench.services.remote.ssh.SshHostService", + "config": { + "ssh_username": "someuser", + "extra": "invalid" + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/full/ssh_fileshare_service-full.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/full/ssh_fileshare_service-full.jsonc new file mode 100644 index 00000000000..34376397a00 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/full/ssh_fileshare_service-full.jsonc @@ -0,0 +1,13 @@ +{ + // "$schema": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/services/service-schema.json", + "description": "SSH Host Service configuration.", + + "class": "mlos_bench.services.remote.ssh.ssh_fileshare.SshFileShareService", + "config": { + "ssh_username": "someuser", + "ssh_port": 22, + "ssh_priv_key_path": "~/.ssh/id_rsa", + "ssh_known_hosts_path": "~/.ssh/known_hosts", + "ssh_request_timeout": 90 + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/full/ssh_host_service-full.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/full/ssh_host_service-full.jsonc new file mode 100644 index 00000000000..909d76e6a4b --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/full/ssh_host_service-full.jsonc @@ -0,0 +1,15 @@ +{ + // "$schema": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/services/service-schema.json", + "description": "SSH Host Service configuration.", + + "class": "mlos_bench.services.remote.ssh.ssh_host_service.SshHostService", + "config": { + "ssh_username": "someuser", + "ssh_port": 2222, + "ssh_priv_key_path": "~/.ssh/id_rsa", + "ssh_known_hosts_path": null, + "ssh_request_timeout": 90, + "ssh_keepalive_interval": 10, + "ssh_shell": "/bin/bash" + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/partial/ssh_fileshare_service-partial.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/partial/ssh_fileshare_service-partial.jsonc new file mode 100644 index 00000000000..6d7b61fc2b4 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/partial/ssh_fileshare_service-partial.jsonc @@ -0,0 +1,3 @@ +{ + "class": "mlos_bench.services.remote.ssh.SshFileShareService" +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/partial/ssh_host_service-partial.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/partial/ssh_host_service-partial.jsonc new file mode 100644 index 00000000000..8bd71c54d4f --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test-cases/good/partial/ssh_host_service-partial.jsonc @@ -0,0 +1,8 @@ +{ + "description": "SSH Host Service configuration.", + + "class": "mlos_bench.services.remote.ssh.SshHostService", + "config": { + "ssh_username": "someuser" + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 67fdcbdfbd6..1dbe256cbd3 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -17,6 +17,7 @@ from mlos_bench.services.base_service import Service from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.temp_dir_context import TempDirContextService +from mlos_bench.services.remote.ssh.ssh_service import SshService from mlos_bench.tests import try_resolve_class_name from mlos_bench.tests.config.schemas import (get_schema_test_cases, @@ -37,6 +38,7 @@ NON_CONFIG_SERVICE_CLASSES = { ConfigPersistenceService, # configured thru the launcher cli args TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. + SshService, # ABCMeta abstract base class } expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index b9b2f117254..5a15cf623e6 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -6,7 +6,9 @@ Common fixtures for mock TunableGroups and Environment objects. """ -from typing import Any, Dict +from typing import Any, Dict, List + +import os import json5 as json import pytest @@ -139,3 +141,40 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: }, tunables=tunable_groups ) + + +# Fixtures to configure the pytest-docker plugin. + + +@pytest.fixture(scope="session") +def docker_compose_file(pytestconfig: pytest.Config) -> List[str]: + """ + Returns the path to the docker-compose file. + + Parameters + ---------- + pytestconfig : pytest.Config + + Returns + ------- + str + Path to the docker-compose file. + """ + _ = pytestconfig # unused + return [ + os.path.join(os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml"), + # Add additional configs as necessary here. + ] + + +@pytest.fixture(scope="session") +def docker_compose_project_name() -> str: + """ + Returns the name of the docker-compose project. + + Returns + ------- + str + Name of the docker-compose project. + """ + return f"mlos_bench-test-{os.getpid()}" diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index da989819d5f..c534efd7189 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -22,6 +22,7 @@ class EventLoopContextCaller: """ Simple class to test the EventLoopContext. + See Also: SshService """ EVENT_LOOP_CONTEXT = EventLoopContext() diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/.gitignore b/mlos_bench/mlos_bench/tests/services/remote/ssh/.gitignore new file mode 100644 index 00000000000..690bae05069 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/.gitignore @@ -0,0 +1 @@ +id_rsa diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/Dockerfile b/mlos_bench/mlos_bench/tests/services/remote/ssh/Dockerfile new file mode 100644 index 00000000000..e8becec9f46 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/Dockerfile @@ -0,0 +1,16 @@ +# Basic Dockerfile for testing the SSH service. +FROM debian:latest + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + openssh-server openssh-client \ + sudo +ARG PORT=2254 +EXPOSE ${PORT} +RUN echo "Port ${PORT}" > /etc/ssh/sshd_config.d/local.conf \ + && echo "PermitRootLogin prohibit-password" >> /etc/ssh/sshd_config.d/local.conf \ + && ssh-keygen -t rsa -N '' -f /root/.ssh/id_rsa \ + && cat /root/.ssh/id_rsa.pub > /root/.ssh/authorized_keys + +ENV TIMEOUT=180 +CMD ["/bin/bash", "-eux", "-c", "trap : TERM INT; service ssh start && sleep ${TIMEOUT:-180} & wait"] diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md b/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md new file mode 100644 index 00000000000..3c74e77be37 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md @@ -0,0 +1,28 @@ +# SshServices Testing + +The "unit" tests for the `SshService` classes are more functional than other unit tests in that we don't merely mock them out, but actually setup small SSH servers with `docker compose` and interact with them using the `SshHostService` and `SshFileShareService`. + +To do this, we make use of the `pytest-docker` plugin to bring up the services defined in the [`docker-compose.yml`](./docker-compose.yml) file in this directory. + +There are two services defined in that config: + +1. `ssh-server` +2. `alt-server` + +We rely on `docker compose` to map their internal container service ports to random ports on the host. +Hence, when connecting, we need to look up these ports on demand using something akin to `docker compose port`. +Because of complexities of networking in different development environments (especially for Docker on WSL2 for Windows), we may also have to connect to a different host address than `localhost` (e.g., `host.docker.internal`, which is dynamically requested as a part of of the [devcontainer](../../../../../../.devcontainer/docker-compose.yml) setup). + +Both containers run the same image, which is dynamically built, and defined in the [`Dockerfile`](./Dockerfile). +This will dynamically generate a passphrase-less ssh key (`id_rsa`) stored inside the image that can be `docker cp`-ed out and then used to authenticate `ssh` clients into that instance. + +These are brought up as session fixtures under a unique (PID based) compose project name for each `pytest` invocation, but only when docker is detected on the host (via the `@docker_required` decorator we define in [`mlos_bench/tests/__init__.py`](../../../__init__.py)), else those tests are skipped. + +> For manual testing, to bring up/down the test infrastructure the [`up.sh`](./up.sh) and [`down.sh`](./down.sh) scripts can be used, which assigns a known project name. + +In the case of `pytest`, since the `SshService` base class implements a shared connection cache that we wish to test, and testing "rebooting" of servers (containers) is also necessary, tests are run serially across a single worker by using the `pytest-xdist` plugin's `--dist loadgroup` feature and the `@pytest.mark.xdist_group("ssh_test_server")` decorator. +In some cases we explicitly call the python garbage collector via `gc.collect()` to make sure that the shared cache cleanup handler is operating as expected. + +## See Also + +Notes in the [`SshService`](../../../../services/remote/ssh/ssh_service.py) implementation. diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py new file mode 100644 index 00000000000..22ffc65628f --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -0,0 +1,77 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Common data classes for the SSH service tests. +""" + +from dataclasses import dataclass +from subprocess import run +from typing import Optional + +from pytest_docker.plugin import Services as DockerServices + +from mlos_bench.tests import check_socket + + +# The SSH test server port and name. +# See Also: docker-compose.yml +SSH_TEST_SERVER_PORT = 2254 +SSH_TEST_SERVER_NAME = 'ssh-server' +ALT_TEST_SERVER_NAME = 'alt-server' + + +@dataclass +class SshTestServerInfo: + """ + A data class for SshTestServerInfo. + """ + + compose_project_name: str + service_name: str + hostname: str + username: str + id_rsa_path: str + _port: Optional[int] = None + + def get_port(self, uncached: bool = False) -> int: + """ + Gets the port that the SSH test server is listening on. + + Note: this value can change when the service restarts so we can't rely on the DockerServices. + """ + if self._port is None or uncached: + port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", + shell=True, check=True, capture_output=True) + self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) + return self._port + + def to_ssh_service_config(self, uncached: bool = False) -> dict: + """Convert to a config dict for SshService.""" + return { + "ssh_hostname": self.hostname, + "ssh_port": self.get_port(uncached), + "ssh_username": self.username, + "ssh_priv_key_path": self.id_rsa_path, + } + + def to_connect_params(self, uncached: bool = False) -> dict: + """ + Convert to a connect_params dict for SshClient. + See Also: mlos_bench.services.remote.ssh.ssh_service.SshService._get_connect_params() + """ + return { + "host": self.hostname, + "port": self.get_port(uncached), + "username": self.username, + } + + +def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: + """Wait until a docker service is ready.""" + docker_services.wait_until_responsive( + check=lambda: check_socket(hostname, port), + timeout=30.0, + pause=0.5, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py new file mode 100644 index 00000000000..26349d50342 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py @@ -0,0 +1,123 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Fixtures for the SSH service tests. +""" + +from typing import Generator +from subprocess import run + +import os +import sys +import tempfile + +import pytest +from pytest_docker.plugin import Services as DockerServices + +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService +from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService + +from mlos_bench.tests import resolve_host_name +from mlos_bench.tests.services.remote.ssh import (SshTestServerInfo, + ALT_TEST_SERVER_NAME, + SSH_TEST_SERVER_NAME, + wait_docker_service_socket) + +# pylint: disable=redefined-outer-name + +HOST_DOCKER_NAME = 'host.docker.internal' + + +@pytest.fixture(scope="session") +def ssh_test_server_hostname() -> str: + """Returns the local hostname to use to connect to the test ssh server.""" + if sys.platform == 'win32': + # Docker (Desktop) for Windows (WSL2) uses a special networking magic + # to refer to the host machine when exposing ports. + return 'localhost' + # On Linux, if we're running in a docker container, we can use the + # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. + if resolve_host_name(HOST_DOCKER_NAME): + return HOST_DOCKER_NAME + # Otherwise, assume we're executing directly inside conda on the host. + return 'localhost' + + +@pytest.fixture(scope="session") +def ssh_test_server(ssh_test_server_hostname: str, + docker_compose_project_name: str, + docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: + """ + Fixture for getting the ssh test server services setup via docker-compose + using pytest-docker. + + Yields the (hostname, port, username, id_rsa_path) of the test server. + + Once the session is over, the docker containers are torn down, and the + temporary file holding the dynamically generated private key of the test + server is deleted. + """ + # Get a copy of the ssh id_rsa key from the test ssh server. + with tempfile.NamedTemporaryFile() as id_rsa_file: + ssh_test_server_info = SshTestServerInfo( + compose_project_name=docker_compose_project_name, + service_name=SSH_TEST_SERVER_NAME, + hostname=ssh_test_server_hostname, + username='root', + id_rsa_path=id_rsa_file.name) + wait_docker_service_socket(docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()) + id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" + docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" + cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True) + if cmd.returncode != 0: + raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container: {str(cmd.stderr)}") + os.chmod(id_rsa_file.name, 0o600) + yield ssh_test_server_info + # NamedTempFile deleted on context exit + + +@pytest.fixture(scope="session") +def alt_test_server(ssh_test_server: SshTestServerInfo, + docker_services: DockerServices) -> SshTestServerInfo: + """ + Fixture for getting the second ssh test server info from the docker-compose.yml. + See additional notes in the ssh_test_server fixture above. + """ + # Note: The alt-server uses the same image as the ssh-server container, so + # the id_rsa key and username should all match. + # Only the host port it is allocate is different. + alt_test_server_info = SshTestServerInfo( + compose_project_name=ssh_test_server.compose_project_name, + service_name=ALT_TEST_SERVER_NAME, + hostname=ssh_test_server.hostname, + username=ssh_test_server.username, + id_rsa_path=ssh_test_server.id_rsa_path) + wait_docker_service_socket(docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()) + return alt_test_server_info + + +@pytest.fixture +def ssh_host_service(ssh_test_server: SshTestServerInfo) -> SshHostService: + """Generic SshHostService fixture.""" + return SshHostService( + config={ + "ssh_username": ssh_test_server.username, + "ssh_priv_key_path": ssh_test_server.id_rsa_path, + }, + global_config={}, + parent=None, + ) + + +@pytest.fixture +def ssh_fileshare_service() -> SshFileShareService: + """Generic SshFileShareService fixture.""" + return SshFileShareService( + config={ + # Left blank to make sure we test per connection overrides. + }, + global_config={}, + parent=None, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/docker-compose.yml b/mlos_bench/mlos_bench/tests/services/remote/ssh/docker-compose.yml new file mode 100644 index 00000000000..d4f84fb5266 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/docker-compose.yml @@ -0,0 +1,50 @@ +name: mlos_bench-test-ssh-server +services: + ssh-server: + hostname: ssh-server + attach: false + build: + context: . + dockerfile: Dockerfile + args: + - PORT=${PORT:-2254} + tags: + - mlos_bench-test-ssh-server:latest + image: mlos_bench-test-ssh-server:latest + ports: + # To allow multiple instances of this to coexist, instead of explicitly + # mapping the port, let it get assigned randomly on the host. + # It + #- ${PORT:-2254}:${PORT:-2254} + - ${PORT:-2254} + extra_hosts: + - host.docker.internal:host-gateway + environment: + # Let the environment variable TIMEOUT override the default. + - TIMEOUT=${TIMEOUT:-180} + # Also start a second server for testing multiple instances. + alt-server: + depends_on: + - ssh-server + hostname: alt-server + attach: false + restart: always + image: mlos_bench-test-ssh-server:latest + ports: + - ${PORT:-2254} + extra_hosts: + - host.docker.internal:host-gateway + environment: + # Let the environment variable TIMEOUT override the default. + - TIMEOUT=${TIMEOUT:-180} + # Check that we can connect to the server from the client. + ssh-client: + depends_on: + - ssh-server + restart: no + image: mlos_bench-test-ssh-server:latest + extra_hosts: + - host.docker.internal:host-gateway + # Implicitly uses root and the key in that image. + #command: ["ssh", "-o", "StrictHostKeyChecking=accept-new", "-p", "${PORT:-2254}", "host.docker.internal", "hostname"] + command: ["ssh", "-o", "StrictHostKeyChecking=accept-new", "-p", "${PORT:-2254}", "ssh-server", "hostname"] diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/down.sh b/mlos_bench/mlos_bench/tests/services/remote/ssh/down.sh new file mode 100755 index 00000000000..a8a8b85cfdb --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/down.sh @@ -0,0 +1,19 @@ +#!/bin/bash +## +## Copyright (c) Microsoft Corporation. +## Licensed under the MIT License. +## + +# A script to stop the SSH server in a container and remove the SSH keys. +# For pytest, the fixture in conftest.py will handle this for us using the +# pytest-docker plugin, but for manual testing, this script can be used. + +set -eu + +scriptdir=$(dirname "$(readlink -f "$0")") +cd "$scriptdir" + +PROJECT_NAME="mlos_bench-test-manual" + +docker compose -p "$PROJECT_NAME" down +rm -f ./id_rsa diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py new file mode 100644 index 00000000000..f6a836601f4 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -0,0 +1,204 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.services.remote.ssh.ssh_services +""" + +from contextlib import contextmanager +from os.path import basename +from pathlib import Path +from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name +from typing import Any, Dict, Generator, List + +import os +import tempfile + +import pytest + +from asyncssh import SFTPError + +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService +from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService +from mlos_bench.util import path_join + +from mlos_bench.tests import are_dir_trees_equal, requires_docker +from mlos_bench.tests.services.remote.ssh import SshTestServerInfo + + +@contextmanager +def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]: + """ + Provides a context manager for a temporary file that can be closed and + still unlinked. + + Since Windows doesn't allow us to reopen the file while it's still open we + need to handle deletion ourselves separately. + + Parameters + ---------- + kwargs: dict + Args to pass to NamedTemporaryFile constructor. + + Returns + ------- + context manager for a temporary file + """ + fname = None + try: + with tempfile.NamedTemporaryFile(delete=False, **kwargs) as temp_file: + fname = temp_file.name + yield temp_file + finally: + if fname: + os.unlink(fname) + + +@pytest.mark.xdist_group("ssh_test_server") +@requires_docker +def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: + """Test the SshFileShareService single file download/upload.""" + with ssh_fileshare_service: + config = ssh_test_server.to_ssh_service_config() + + remote_file_path = "/tmp/test_ssh_fileshare_single_file" + lines = [ + "foo", + "bar", + ] + lines = [line + "\n" for line in lines] + + # 1. Write a local file and upload it. + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + temp_file.writelines(lines) + temp_file.flush() + temp_file.close() + + ssh_fileshare_service.upload( + params=config, + local_path=temp_file.name, + remote_path=remote_file_path, + ) + + # 2. Download the remote file and compare the contents. + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + temp_file.close() + ssh_fileshare_service.download( + params=config, + remote_path=remote_file_path, + local_path=temp_file.name, + ) + # Download will replace the inode at that name, so we need to reopen the file. + with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + read_lines = temp_file_h.readlines() + assert read_lines == lines + + +@pytest.mark.xdist_group("ssh_test_server") +@requires_docker +def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: + """Test the SshFileShareService recursive download/upload.""" + with ssh_fileshare_service: + config = ssh_test_server.to_ssh_service_config() + + remote_file_path = "/tmp/test_ssh_fileshare_recursive_dir" + files_lines: Dict[str, List[str]] = { + "file-a.txt": [ + "a", + "1", + ], + "file-b.txt": [ + "b", + "2", + ], + "subdir/foo.txt": [ + "foo", + "bar", + ], + } + files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} + + with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: + # Setup the directory structure. + for (file_path, lines) in files_lines.items(): + path = Path(tempdir1, file_path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, mode='w+t', encoding='utf-8') as temp_file: + temp_file.writelines(lines) + temp_file.flush() + assert os.path.getsize(path) > 0 + + # Copy that structure over to the remote server. + ssh_fileshare_service.upload( + params=config, + local_path=f"{tempdir1}", + remote_path=f"{remote_file_path}", + recursive=True, + ) + + # Copy the remote structure back to the local machine. + ssh_fileshare_service.download( + params=config, + remote_path=f"{remote_file_path}", + local_path=f"{tempdir2}", + recursive=True, + ) + + # Compare both. + # Note: remote dir name is appended to target. + assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path))) + + +@pytest.mark.xdist_group("ssh_test_server") +@requires_docker +def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: + """Test the SshFileShareService single file download that doesn't exist.""" + with ssh_fileshare_service: + config = ssh_test_server.to_ssh_service_config() + + canary_str = "canary" + + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + temp_file.writelines([canary_str]) + temp_file.flush() + temp_file.close() + + with pytest.raises(SFTPError): + ssh_fileshare_service.download( + params=config, + remote_path="/tmp/file-dne.txt", + local_path=temp_file.name, + ) + with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + read_lines = temp_file_h.readlines() + assert read_lines == [canary_str] + + +@pytest.mark.xdist_group("ssh_test_server") +@requires_docker +def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + ssh_fileshare_service: SshFileShareService) -> None: + """Test the SshFileShareService single file upload that doesn't exist.""" + with ssh_host_service, ssh_fileshare_service: + config = ssh_test_server.to_ssh_service_config() + + path = '/tmp/upload-file-src-dne.txt' + with pytest.raises(OSError): + ssh_fileshare_service.upload( + params=config, + remote_path=path, + local_path=path, + ) + (status, results) = ssh_host_service.remote_exec( + script=[f"[[ ! -e {path} ]]; echo $?"], + config=config, + env_params={}, + ) + (status, results) = ssh_host_service.get_remote_exec_results(results) + assert status.is_succeeded() + assert str(results["stdout"]).strip() == "0" diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py new file mode 100644 index 00000000000..c1e8e853ca0 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -0,0 +1,206 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.services.remote.ssh.ssh_host_service +""" + +import time + +import pytest +from pytest_docker.plugin import Services as DockerServices + +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService +from mlos_bench.services.remote.ssh.ssh_service import SshClient + +from mlos_bench.tests import requires_docker +from mlos_bench.tests.services.remote.ssh import (SshTestServerInfo, + ALT_TEST_SERVER_NAME, + SSH_TEST_SERVER_NAME, + wait_docker_service_socket) + + +@requires_docker +@pytest.mark.xdist_group("ssh_test_server") +def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService) -> None: + """ + Test the SshHostService remote_exec. + + This checks state of the service across multiple invocations and states to + check for internal cache handling logic as well. + """ + # pylint: disable=protected-access + with ssh_host_service: + config = ssh_test_server.to_ssh_service_config() + + # Note: this may interact with state from other tests, so we group them + # into the same xdist_group, run with --dist loadgroup. + connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) + assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None + connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id) + assert connection_client is None + + (status, results_info) = ssh_host_service.remote_exec( + script=["hostname"], + config=config, + env_params={}, + ) + assert status.is_pending() + assert "asyncRemoteExecResultsFuture" in results_info + status, results = ssh_host_service.get_remote_exec_results(results_info) + assert status.is_succeeded() + assert results["stdout"].strip() == SSH_TEST_SERVER_NAME + + # Check that the client caching is behaving as expected. + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + assert connection is not None + assert connection._username == ssh_test_server.username + assert connection._host == ssh_test_server.hostname + assert connection._port == ssh_test_server.get_port() + local_port = connection._local_port + assert local_port + assert client is not None + assert client._conn_event.is_set() + + # Connect to a different server. + (status, results_info) = ssh_host_service.remote_exec( + script=["hostname"], + config=alt_test_server.to_ssh_service_config(), + env_params={ + "UNUSED": "unused", # unused, making sure it doesn't carry over with cached connections + }, + ) + assert status.is_pending() + assert "asyncRemoteExecResultsFuture" in results_info + status, results = ssh_host_service.get_remote_exec_results(results_info) + assert status.is_succeeded() + assert results["stdout"].strip() == ALT_TEST_SERVER_NAME + + # Test reusing the existing connection. + (status, results_info) = ssh_host_service.remote_exec( + script=["echo BAR=$BAR && echo UNUSED=$UNUSED && false"], + config=config, + # Also test interacting with environment_variables. + env_params={ + "BAR": "bar", + }, + ) + status, results = ssh_host_service.get_remote_exec_results(results_info) + assert status.is_failed() # should retain exit code from "false" + stdout = str(results["stdout"]) + assert stdout.splitlines() == [ + "BAR=bar", + "UNUSED=", + ] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + assert connection._local_port == local_port + + # Close the connection (gracefully) + connection.close() + + # Try and reconnect and see if it detects the closed connection and starts over. + (status, results_info) = ssh_host_service.remote_exec( + script=[ + # Test multi-string scripts. + "echo FOO=$FOO\n", + # Test multi-line strings. + "echo BAR=$BAR\necho BAZ=$BAZ", + ], + config=config, + # Also test interacting with environment_variables. + env_params={ + 'FOO': 'foo', + }, + ) + status, results = ssh_host_service.get_remote_exec_results(results_info) + assert status.is_succeeded() + stdout = str(results["stdout"]) + lines = stdout.splitlines() + assert lines == [ + "FOO=foo", + "BAR=", + "BAZ=", + ] + # Make sure it looks like we reconnected. + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + assert connection._local_port != local_port + + # Make sure the cache is cleaned up on context exit. + assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 + + +@requires_docker +@pytest.mark.parametrize("graceful", [True, False]) +@pytest.mark.xdist_group("ssh_test_server") +def test_ssh_service_reboot(docker_services: DockerServices, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + graceful: bool) -> None: + """ + Test the SshHostService reboot. + """ + # Note: rebooting changes the port number unfortunately, but makes it + # easier to check for success. + # Also, it may cause issues with other parallel unit tests, so we run it as + # a part of the same unit test for now. + with ssh_host_service: + alt_test_server_ssh_service_config = alt_test_server.to_ssh_service_config() + (status, results_info) = ssh_host_service.remote_exec( + script=[ + 'echo "sleeping..."', + 'sleep 30', + 'echo "should not reach this point"' + ], + config=alt_test_server_ssh_service_config, + env_params={}, + ) + assert status.is_pending() + # Wait a moment for that to start in the background thread. + time.sleep(0.5) + + # Now try to restart the server (gracefully). + # TODO: Test graceful vs. forceful. + if graceful: + (status, reboot_results_info) = ssh_host_service.reboot(params=alt_test_server_ssh_service_config) + assert status.is_pending() + + (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) + # NOTE: reboot/shutdown ops mostly return FAILED, even though the reboot succeeds. + print(f"reboot status: {status} {reboot_results_info}") + else: + (status, kill_results_info) = ssh_host_service.remote_exec( + script=["kill -9 1; kill -9 -1"], + config=alt_test_server_ssh_service_config, + env_params={}, + ) + (status, kill_results_info) = ssh_host_service.get_remote_exec_results(kill_results_info) + print(f"kill status: {status} {kill_results_info}") + + # TODO: Check for decent error handling on disconnects. + status, results = ssh_host_service.get_remote_exec_results(results_info) + assert status.is_failed() + stdout = str(results["stdout"]) + assert "sleeping" in stdout + assert "should not reach this point" not in stdout + + # Give docker some time to restart the service after the "reboot". + # Note: this relies on having `restart: always` in the docker-compose.yml file. + time.sleep(1) + + # try to reconnect and see if the port changed + alt_test_server_ssh_service_config_new = alt_test_server.to_ssh_service_config(uncached=True) + assert alt_test_server_ssh_service_config_new["ssh_port"] != alt_test_server_ssh_service_config["ssh_port"] + + wait_docker_service_socket(docker_services, alt_test_server.hostname, alt_test_server_ssh_service_config_new["ssh_port"]) + + (status, results_info) = ssh_host_service.remote_exec( + script=["hostname"], + config=alt_test_server_ssh_service_config_new, + env_params={}, + ) + status, results = ssh_host_service.get_remote_exec_results(results_info) + assert status.is_succeeded() + assert results["stdout"].strip() == ALT_TEST_SERVER_NAME diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py new file mode 100644 index 00000000000..14fc015be3f --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -0,0 +1,105 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.services.remote.ssh.SshService base class. +""" + +import asyncio +import time + +from subprocess import run +from threading import Thread + +import pytest +from pytest_lazyfixture import lazy_fixture + +from mlos_bench.services.remote.ssh.ssh_service import SshService +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService +from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService + +from mlos_bench.tests import requires_docker, requires_ssh, check_socket, resolve_host_name +from mlos_bench.tests.services.remote.ssh import SshTestServerInfo, ALT_TEST_SERVER_NAME, SSH_TEST_SERVER_NAME + + +@requires_docker +@requires_ssh +@pytest.mark.xdist_group("ssh_test_server") +@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [ + (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), + (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), +]) +def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, + server_name: str) -> None: + """Check for the pytest-docker ssh test infra.""" + assert ssh_test_server_info.service_name == server_name + + ip_addr = resolve_host_name(ssh_test_server_info.hostname) + assert ip_addr is not None + + local_port = ssh_test_server_info.get_port() + assert check_socket(ip_addr, local_port) + ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \ + + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \ + + f"-p {local_port} {ssh_test_server_info.hostname} hostname" + cmd = run(ssh_cmd.split(), + capture_output=True, + text=True, + check=True) + assert cmd.stdout.strip() == server_name + + +@pytest.mark.xdist_group("ssh_test_server") +def test_ssh_service_context_handler() -> None: + """ + Test the SSH service context manager handling. + See Also: test_event_loop_context + """ + # pylint: disable=protected-access + + # Should start with no event loop thread. + assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread is None + + # The background thread should only be created upon context entry. + ssh_host_service = SshHostService(config={}, global_config={}, parent=None) + assert ssh_host_service + assert not ssh_host_service._in_context + assert ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread is None + + # After we enter the SshService instance context, we should have a background thread. + with ssh_host_service: + assert ssh_host_service._in_context + assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable] + # Give the thread a chance to start. + # Mostly important on the underpowered Windows CI machines. + time.sleep(0.25) + assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None + + ssh_fileshare_service = SshFileShareService(config={}, global_config={}, parent=None) + assert ssh_fileshare_service + assert not ssh_fileshare_service._in_context + + with ssh_fileshare_service: + assert ssh_fileshare_service._in_context + assert ssh_host_service._in_context + assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \ + is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \ + is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread + assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + + assert not ssh_fileshare_service._in_context + # And that instance should be unusable after we are outside the context. + with pytest.raises(AssertionError): + ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1)) + + # The background thread should remain running since we have another context still open. + assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable] + assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None + + +if __name__ == '__main__': + # For debugging in Windows which has issues with pytest detection in vscode. + pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"]) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/up.sh b/mlos_bench/mlos_bench/tests/services/remote/ssh/up.sh new file mode 100755 index 00000000000..42bc984e6e5 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/up.sh @@ -0,0 +1,30 @@ +#!/bin/bash +## +## Copyright (c) Microsoft Corporation. +## Licensed under the MIT License. +## + +# A script to start the SSH server in a container and copy the SSH keys from it. +# For pytest, the fixture in conftest.py will handle this for us using the +# pytest-docker plugin, but for manual testing, this script can be used. + +set -eu +set -x + +scriptdir=$(dirname "$(readlink -f "$0")") +cd "$scriptdir" + +PROJECT_NAME="mlos_bench-test-manual" + +#docker compose -p "$PROJECT_NAME" build +export TIMEOUT=infinity +docker compose -p "$PROJECT_NAME" up --build --remove-orphans +docker compose -p "$PROJECT_NAME" exec ssh-server service ssh start +docker compose -p "$PROJECT_NAME" cp ssh-server:/root/.ssh/id_rsa ./id_rsa +chmod 0600 ./id_rsa +set +x + +echo "OK: private key available at '$scriptdir/id_rsa'. Connect to the ssh-server container at the following port:" +docker compose -p "$PROJECT_NAME" port ssh-server ${PORT:-2254} | cut -d: -f2 +echo "INFO: And this port for the alt-server container:" +docker compose -p "$PROJECT_NAME" port alt-server ${PORT:-2254} | cut -d: -f2 diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index 445cfa727d0..fdc8db0b1e3 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -28,6 +28,7 @@ extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass # Additional tools for extra functionality. 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'], + 'ssh': ['asyncssh'], 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'], 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'], 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'], @@ -47,6 +48,8 @@ 'pytest-xdist', 'pytest-cov', 'pytest-local-badge', + 'pytest-lazy-fixture', + 'pytest-docker', ] # pylint: disable=duplicate-code diff --git a/setup.cfg b/setup.cfg index e7399ba4916..e2a3f753c05 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ addopts = -vv -svl --ff --nf -n auto + --dist loadgroup # --log-level=DEBUG # Moved these to Makefile (coverage is expensive and we only need it in the pipelines generally). #--cov=mlos_core --cov-report=xml @@ -48,6 +49,7 @@ filterwarnings = ignore:.*(Please leave at default or explicitly set .size=None).*:DeprecationWarning:smac:0 ignore:.*(Trying to register a configuration that was not previously suggested).*:UserWarning:.*llamatune.*:0 ignore:.*(DISPLAY environment variable is set).*:UserWarning:.*conftest.*:0 + ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*test_ssh_service.*:0 # # mypy static type checker configs @@ -86,6 +88,9 @@ ignore_missing_imports = True [mypy-pytest] ignore_missing_imports = True +[mypy-pytest_docker.*] +ignore_missing_imports = True + # https://github.com/scikit-learn/scikit-learn/issues/16705 [mypy-sklearn.*] ignore_missing_imports = True