diff --git a/CHANGELOG.md b/CHANGELOG.md index e6f621d..bfaf89c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Changed + +- Expand connection retry logic to cover more cases +- Move exec script into a separate file + ## [0.23.0] - 2023-10-20 ### Changed diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..6312708 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,3 @@ +ignore: + # This script is read into a string and formatted. Never executed directly. + - covalent_ssh_plugin/exec.py diff --git a/covalent_ssh_plugin/exec.py b/covalent_ssh_plugin/exec.py new file mode 100644 index 0000000..4ec24a2 --- /dev/null +++ b/covalent_ssh_plugin/exec.py @@ -0,0 +1,46 @@ +""" +Load task `fn` from pickle. Run it. Save the result. +""" +import os +import sys +from pathlib import Path + +result = None +exception = None + +# NOTE: Paths must be substituted-in here by the executor. +remote_result_file = Path("{remote_result_file}").resolve() +remote_function_file = Path("{remote_function_file}").resolve() +current_remote_workdir = Path("{current_remote_workdir}").resolve() + +try: + # Make sure cloudpickle is available. + import cloudpickle as pickle +except Exception as e: + import pickle + + with open(remote_result_file, "wb") as f_out: + pickle.dump((None, e), f_out) + sys.exit(1) # Error. + +current_dir = os.getcwd() + +# Read the function object and arguments from pickle file. +with open(remote_function_file, "rb") as f_in: + fn, args, kwargs = pickle.load(f_in) + +try: + # Execute the task `fn` inside the remote workdir. + current_remote_workdir.mkdir(parents=True, exist_ok=True) + os.chdir(current_remote_workdir) + + result = fn(*args, **kwargs) + +except Exception as e: + exception = e +finally: + os.chdir(current_dir) + +# Save the result to pickle file. +with open(remote_result_file, "wb") as f_out: + pickle.dump((result, exception), f_out) diff --git a/covalent_ssh_plugin/ssh.py b/covalent_ssh_plugin/ssh.py index b0a69c3..c1444ed 100644 --- a/covalent_ssh_plugin/ssh.py +++ b/covalent_ssh_plugin/ssh.py @@ -66,8 +66,10 @@ class SSHExecutor(RemoteExecutor): then the execution is run on the local machine. remote_workdir: The working directory on the remote server used for storing files produced from workflows. create_unique_workdir: Whether to create unique sub-directories for each node / task / electron. - poll_freq: Number of seconds to wait for before retrying the result poll - do_cleanup: Whether to delete all the intermediate files or not + poll_freq: Number of seconds to wait for before retrying the result poll. + do_cleanup: Delete all the intermediate files or not if True. + max_connection_attempts: Maximum number of attempts to establish SSH connection. + retry_wait_time: Time to wait (in seconds) before reattempting connection. """ def __init__( @@ -85,6 +87,8 @@ def __init__( poll_freq: int = 15, do_cleanup: bool = True, retry_connect: bool = True, + max_connection_attempts: int = 5, + retry_wait_time: int = 5, ) -> None: remote_cache = ( @@ -113,6 +117,8 @@ def __init__( self.do_cleanup = do_cleanup self.retry_connect = retry_connect + self.max_connection_attempts = max_connection_attempts + self.retry_wait_time = retry_wait_time ssh_key_file = ssh_key_file or get_config("executors.ssh.ssh_key_file") self.ssh_key_file = str(Path(ssh_key_file).expanduser().resolve()) @@ -124,7 +130,7 @@ def _write_function_files( args: list, kwargs: dict, current_remote_workdir: str = ".", - ) -> None: + ) -> Tuple[str, str, str, str, str]: """ Helper function to pickle the function to be executed to file, and write the python script which calls the function. @@ -143,6 +149,7 @@ def _write_function_files( with open(function_file, "wb") as f_out: pickle.dump((fn, args, kwargs), f_out) remote_function_file = os.path.join(self.remote_cache, f"function_{operation_id}.pkl") + remote_result_file = os.path.join(self.remote_cache, f"result_{operation_id}.pkl") # Write the code that the remote server will use to execute the function. @@ -150,46 +157,18 @@ def _write_function_files( message += f"Remote function file: {remote_function_file}" app_log.debug(message) - remote_result_file = os.path.join(self.remote_cache, f"result_{operation_id}.pkl") - exec_script = "\n".join( - [ - "import os", - "import sys", - "from pathlib import Path", - "", - "result = None", - "exception = None", - "", - "try:", - " import cloudpickle as pickle", - "except Exception as e:", - " import pickle", - f" with open('{remote_result_file}','wb') as f_out:", - " pickle.dump((None, e), f_out)", - " exit()", - "", - f"with open('{remote_function_file}', 'rb') as f_in:", - " fn, args, kwargs = pickle.load(f_in)", - " current_dir = os.getcwd()", - " try:", - f" Path('{current_remote_workdir}').mkdir(parents=True, exist_ok=True)", - f" os.chdir('{current_remote_workdir}')", - " result = fn(*args, **kwargs)", - " except Exception as e:", - " exception = e", - " finally:", - " os.chdir(current_dir)", - "", - "", - f"with open('{remote_result_file}','wb') as f_out:", - " pickle.dump((result, exception), f_out)", - "", - ] - ) + exec_blank = Path(__file__).parent / "exec.py" script_file = os.path.join(self.cache_dir, f"exec_{operation_id}.py") remote_script_file = os.path.join(self.remote_cache, f"exec_{operation_id}.py") - with open(script_file, "w") as f_out: - f_out.write(exec_script) + + with open(exec_blank, "r", encoding="utf-8") as f_blank: + exec_script = f_blank.read().format( + remote_result_file=remote_result_file, + remote_function_file=remote_function_file, + current_remote_workdir=current_remote_workdir, + ) + with open(script_file, "w", encoding="utf-8") as f_out: + f_out.write(exec_script) return ( function_file, @@ -228,9 +207,9 @@ def _on_ssh_fail( app_log.error(message) raise RuntimeError(message) - async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]: + async def _client_connect(self) -> Tuple[bool, Optional[asyncssh.SSHClientConnection]]: """ - Helper function for connecting to the remote host through the paramiko module. + Attempts connection to the remote host. Args: None @@ -241,35 +220,65 @@ async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]: ssh_success = False conn = None - if os.path.exists(self.ssh_key_file): - retries = 6 if self.retry_connect else 1 - for _ in range(retries): - try: - conn = await asyncssh.connect( - self.hostname, - username=self.username, - client_keys=[self.ssh_key_file], - known_hosts=None, - ) - ssh_success = True - except (socket.gaierror, ValueError, TimeoutError, ConnectionRefusedError) as e: - app_log.error(e) + if not os.path.exists(self.ssh_key_file): + message = f"SSH key file {self.ssh_key_file} does not exist." + app_log.error(message) + raise RuntimeError(message) - if conn is not None: - break + try: + conn = await self._attempt_client_connect() + ssh_success = conn is not None + except (socket.gaierror, ValueError, TimeoutError) as e: + app_log.error(e) - await asyncio.sleep(5) + return ssh_success, conn - if conn is None and not self.run_local_on_ssh_fail: - raise RuntimeError("Could not connect to remote host.") + async def _attempt_client_connect(self) -> Optional[asyncssh.SSHClientConnection]: + """ + Helper function that catches specific errors and retries connecting to the remote host. - else: - message = f"no SSH key file found at {self.ssh_key_file}. Cannot connect to host." - app_log.error(message) - raise RuntimeError(message) + Args: + max_attempts: Gives up after this many attempts. - return ssh_success, conn + Returns: + An `SSHClientConnection` object if successful, None otherwise. + """ + + # Retry connecting if any of these errors happen: + _retry_errs = ( + ConnectionRefusedError, + OSError, # e.g. Network unreachable + ) + + address = f"{self.username}@{self.hostname}" + attempt_max = self.max_connection_attempts + + attempt = 0 + while attempt < attempt_max: + + try: + # Exit here if the connection is successful. + return await asyncssh.connect( + self.hostname, + username=self.username, + client_keys=[self.ssh_key_file], + known_hosts=None, + ) + except _retry_errs as err: + + if not self.retry_connect: + app_log.error(f"{err} ({address} | retry disabled).") + raise err + + app_log.warning(f"{err} ({address} | retry {attempt+1}/{attempt_max})") + await asyncio.sleep(self.retry_wait_time) + + finally: + attempt += 1 + + # Failed to connect to client. + return None async def cleanup( self, @@ -290,7 +299,7 @@ async def cleanup( script_file: Path to the script file to be deleted locally result_file: Path to the result file to be deleted locally remote_function_file: Path to the function file to be deleted on remote - remote_script_file: Path to the sccript file to be deleted on remote + remote_script_file: Path to the script file to be deleted on remote remote_result_file: Path to the result file to be deleted on remote Returns: @@ -540,9 +549,11 @@ async def run( app_log.debug("Running function file in remote machine...") result = await self.submit_task(conn, remote_script_file) - if result_err := result.stderr.strip(): - app_log.warning(result_err) - return self._on_ssh_fail(function, args, kwargs, result_err) + if result.exit_status != 0: + message = result.stderr.strip() + message = message or f"Task exited with nonzero exit status {result.exit_status}." + app_log.warning(message) + return self._on_ssh_fail(function, args, kwargs, message) if not await self._poll_task(conn, remote_result_file): message = ( diff --git a/tests/functional_tests/basic_workflow_test.py b/tests/functional_tests/basic_workflow_test.py index 5f46f99..5f60bf3 100644 --- a/tests/functional_tests/basic_workflow_test.py +++ b/tests/functional_tests/basic_workflow_test.py @@ -27,3 +27,23 @@ def basic_workflow(a, b): print(result) assert status == str(ct.status.COMPLETED) + + +@pytest.mark.functional_tests +def test_basic_workflow_failure(): + @ct.electron(executor="ssh") + def join_words(a, b): + raise Exception(f"{', '.join([a, b])} -- but something went wrong!") + + @ct.lattice + def basic_workflow_that_will_fail(a, b): + return join_words(a, b) + + # Dispatch the workflow + dispatch_id = ct.dispatch(basic_workflow_that_will_fail)("Hello", "World") + result = ct.get_result(dispatch_id=dispatch_id, wait=True) + status = str(result.status) + + print(result) + + assert status == str(ct.status.FAILED) diff --git a/tests/ssh_test.py b/tests/ssh_test.py index 4b8e1a3..acedaf3 100644 --- a/tests/ssh_test.py +++ b/tests/ssh_test.py @@ -110,6 +110,63 @@ def simple_task(x): ) +@pytest.mark.asyncio +async def test_nonzero_exit_status(mocker): + """Test handling nonzero exit status from 'run task' command on remote.""" + + from collections import namedtuple + + # Stand-in for the ssh connection object. + class _FakeConn: + fake_proc = namedtuple("fake_proc", ["stdout", "stderr"]) + + async def run(self, *_): + return _FakeConn.fake_proc("Python 3.8", "") + + async def wait_closed(self): + return True + + # Stand-in for the `run` command result. + class _FakeResultFailed: + stderr = "Fake error message" + exit_status = 1 # <--- This is the important part. + + # Use these for patching. + _conn = _FakeConn() + _result = _FakeResultFailed() + + # Patch anything that requires a real connection. + mocker.patch("covalent_ssh_plugin.ssh.get_config", side_effect=get_config_mock) + mocker.patch("covalent_ssh_plugin.ssh.SSHExecutor._validate_credentials", return_value=True) + mocker.patch("covalent_ssh_plugin.ssh.SSHExecutor._client_connect", return_value=(True, _conn)) + mocker.patch( + "covalent_ssh_plugin.ssh.SSHExecutor._write_function_files", return_value=[""] * 5 + ) + mocker.patch("covalent_ssh_plugin.ssh.SSHExecutor._upload_task", return_value=True) + mocker.patch("covalent_ssh_plugin.ssh.SSHExecutor.submit_task", return_value=_result) + + async with aiofiles.tempfile.NamedTemporaryFile("w") as f: + + executor = SSHExecutor( + username="user", + hostname="host", + ssh_key_file=f.name, + run_local_on_ssh_fail=False, + retry_connect=False, + ) + + executor.conda_env = None + + # Check that `exit_status != 0` triggers a runtime error. + with pytest.raises(RuntimeError): + await executor.run( + function=lambda: "Doesn't matter; dummy function.", + args=[5], + kwargs={}, + task_metadata={"dispatch_id": -1, "node_id": -1}, + ) + + @pytest.mark.asyncio async def test_client_connect(mocker): """Test that connection will fail if credentials are not supplied.""" @@ -133,6 +190,73 @@ async def test_client_connect(mocker): assert connected is True +@pytest.mark.asyncio +async def test_client_connect_retry_attempts(mocker): + """Test various outcomes of retrying client connection.""" + + mocker.patch("covalent_ssh_plugin.ssh.get_config", side_effect=get_config_mock) + + # Dummy used to patch `asyncssh.connect` calls. + async def _mock_asyncssh_connect(*args, **kwargs): + + # Set `counter = -1` to test immediate success. + if _mock_asyncssh_connect.err_counter < 0: + return "immediate_connection_object" # Success. + + # Set `succeed_after` to decide number of failures before success. + if _mock_asyncssh_connect.err_counter > _mock_asyncssh_connect.succeed_after - 1: + return "eventual_connection_object" # Success. + + # Failures. + if _mock_asyncssh_connect.err_counter % 2 == 0: + err = ConnectionRefusedError("Pretend connection was refused.") + else: + err = OSError("Pretend network unreachable.") + + _mock_asyncssh_connect.err_counter += 1 + raise err + + mocker.patch("asyncssh.connect", side_effect=_mock_asyncssh_connect) + + # Stand-in for ssh key file. + async with aiofiles.tempfile.NamedTemporaryFile("w") as f: + ssh_key_file = f.name + + executor = SSHExecutor( + username="user", + hostname="host", + ssh_key_file=ssh_key_file, + retry_connect=False, + ) + + # Set class attribute to shorten wait time for quicker testing. + executor.retry_wait_time = 1 + + # Test immediate success. + _mock_asyncssh_connect.err_counter = -1 + await executor._client_connect() + + # Test eventual success. + _mock_asyncssh_connect.err_counter = 0 + _mock_asyncssh_connect.succeed_after = 3 + executor.retry_connect = True + await executor._client_connect() + + # Test immediate failure. + _mock_asyncssh_connect.err_counter = 0 + executor.retry_connect = False + with pytest.raises(ConnectionRefusedError): + await executor._client_connect() + + # Test eventual failure. + _mock_asyncssh_connect.err_counter = 0 + _mock_asyncssh_connect.succeed_after = executor.max_connection_attempts + 1 + executor.retry_connect = True + ssh_success, conn = await executor._client_connect() + assert ssh_success is False + assert conn is None + + @pytest.mark.asyncio async def test_current_remote_workdir(mocker): async def mock_conn_run(x): @@ -154,6 +278,7 @@ async def mock_wait_closed(): async def mock_submit_task(mock_conn, file): ret = MagicMock() ret.stderr = "" + ret.exit_status = 0 return ret mocker.patch("covalent_ssh_plugin.ssh.get_config", side_effect=get_config_mock)