Skip to content

Commit

Permalink
Expand connection retry logic (#69)
Browse files Browse the repository at this point in the history
* generalize ssh connect with retry

* move exec script to file

* update changelog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add missing `RuntimeError`; add default err message

* set exit_stats=0 on mock func

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add codecov.yml to ignore exec.py

* use class attrs for max attempts and wait time

* actually use `self.retry_connect` flag

* rename method

* add connection retry attempts test

* functional test for failing workflow

* fix functional test fail in *ssh* electron

* using async tmpfile

* add failed task handling test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* comments to clarify new test

* more comments and clarification

* remove class attrs, add init args

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add comment

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
araghukas and pre-commit-ci[bot] authored Oct 31, 2023
1 parent 10f867e commit 6a8d8d7
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 70 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Changed

- Expand connection retry logic to cover more cases
- Move exec script into a separate file

## [0.23.0] - 2023-10-20

### Changed
Expand Down
3 changes: 3 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ignore:
# This script is read into a string and formatted. Never executed directly.
- covalent_ssh_plugin/exec.py
46 changes: 46 additions & 0 deletions covalent_ssh_plugin/exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Load task `fn` from pickle. Run it. Save the result.
"""
import os
import sys
from pathlib import Path

result = None
exception = None

# NOTE: Paths must be substituted-in here by the executor.
remote_result_file = Path("{remote_result_file}").resolve()
remote_function_file = Path("{remote_function_file}").resolve()
current_remote_workdir = Path("{current_remote_workdir}").resolve()

try:
# Make sure cloudpickle is available.
import cloudpickle as pickle
except Exception as e:
import pickle

with open(remote_result_file, "wb") as f_out:
pickle.dump((None, e), f_out)
sys.exit(1) # Error.

current_dir = os.getcwd()

# Read the function object and arguments from pickle file.
with open(remote_function_file, "rb") as f_in:
fn, args, kwargs = pickle.load(f_in)

try:
# Execute the task `fn` inside the remote workdir.
current_remote_workdir.mkdir(parents=True, exist_ok=True)
os.chdir(current_remote_workdir)

result = fn(*args, **kwargs)

except Exception as e:
exception = e
finally:
os.chdir(current_dir)

# Save the result to pickle file.
with open(remote_result_file, "wb") as f_out:
pickle.dump((result, exception), f_out)
151 changes: 81 additions & 70 deletions covalent_ssh_plugin/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class SSHExecutor(RemoteExecutor):
then the execution is run on the local machine.
remote_workdir: The working directory on the remote server used for storing files produced from workflows.
create_unique_workdir: Whether to create unique sub-directories for each node / task / electron.
poll_freq: Number of seconds to wait for before retrying the result poll
do_cleanup: Whether to delete all the intermediate files or not
poll_freq: Number of seconds to wait for before retrying the result poll.
do_cleanup: Delete all the intermediate files or not if True.
max_connection_attempts: Maximum number of attempts to establish SSH connection.
retry_wait_time: Time to wait (in seconds) before reattempting connection.
"""

def __init__(
Expand All @@ -85,6 +87,8 @@ def __init__(
poll_freq: int = 15,
do_cleanup: bool = True,
retry_connect: bool = True,
max_connection_attempts: int = 5,
retry_wait_time: int = 5,
) -> None:

remote_cache = (
Expand Down Expand Up @@ -113,6 +117,8 @@ def __init__(

self.do_cleanup = do_cleanup
self.retry_connect = retry_connect
self.max_connection_attempts = max_connection_attempts
self.retry_wait_time = retry_wait_time

ssh_key_file = ssh_key_file or get_config("executors.ssh.ssh_key_file")
self.ssh_key_file = str(Path(ssh_key_file).expanduser().resolve())
Expand All @@ -124,7 +130,7 @@ def _write_function_files(
args: list,
kwargs: dict,
current_remote_workdir: str = ".",
) -> None:
) -> Tuple[str, str, str, str, str]:
"""
Helper function to pickle the function to be executed to file, and write the
python script which calls the function.
Expand All @@ -143,53 +149,26 @@ def _write_function_files(
with open(function_file, "wb") as f_out:
pickle.dump((fn, args, kwargs), f_out)
remote_function_file = os.path.join(self.remote_cache, f"function_{operation_id}.pkl")
remote_result_file = os.path.join(self.remote_cache, f"result_{operation_id}.pkl")

# Write the code that the remote server will use to execute the function.

message = f"Function file names:\nLocal function file: {function_file}\n"
message += f"Remote function file: {remote_function_file}"
app_log.debug(message)

remote_result_file = os.path.join(self.remote_cache, f"result_{operation_id}.pkl")
exec_script = "\n".join(
[
"import os",
"import sys",
"from pathlib import Path",
"",
"result = None",
"exception = None",
"",
"try:",
" import cloudpickle as pickle",
"except Exception as e:",
" import pickle",
f" with open('{remote_result_file}','wb') as f_out:",
" pickle.dump((None, e), f_out)",
" exit()",
"",
f"with open('{remote_function_file}', 'rb') as f_in:",
" fn, args, kwargs = pickle.load(f_in)",
" current_dir = os.getcwd()",
" try:",
f" Path('{current_remote_workdir}').mkdir(parents=True, exist_ok=True)",
f" os.chdir('{current_remote_workdir}')",
" result = fn(*args, **kwargs)",
" except Exception as e:",
" exception = e",
" finally:",
" os.chdir(current_dir)",
"",
"",
f"with open('{remote_result_file}','wb') as f_out:",
" pickle.dump((result, exception), f_out)",
"",
]
)
exec_blank = Path(__file__).parent / "exec.py"
script_file = os.path.join(self.cache_dir, f"exec_{operation_id}.py")
remote_script_file = os.path.join(self.remote_cache, f"exec_{operation_id}.py")
with open(script_file, "w") as f_out:
f_out.write(exec_script)

with open(exec_blank, "r", encoding="utf-8") as f_blank:
exec_script = f_blank.read().format(
remote_result_file=remote_result_file,
remote_function_file=remote_function_file,
current_remote_workdir=current_remote_workdir,
)
with open(script_file, "w", encoding="utf-8") as f_out:
f_out.write(exec_script)

return (
function_file,
Expand Down Expand Up @@ -228,9 +207,9 @@ def _on_ssh_fail(
app_log.error(message)
raise RuntimeError(message)

async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]:
async def _client_connect(self) -> Tuple[bool, Optional[asyncssh.SSHClientConnection]]:
"""
Helper function for connecting to the remote host through the paramiko module.
Attempts connection to the remote host.
Args:
None
Expand All @@ -241,35 +220,65 @@ async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]:

ssh_success = False
conn = None
if os.path.exists(self.ssh_key_file):
retries = 6 if self.retry_connect else 1
for _ in range(retries):
try:
conn = await asyncssh.connect(
self.hostname,
username=self.username,
client_keys=[self.ssh_key_file],
known_hosts=None,
)

ssh_success = True
except (socket.gaierror, ValueError, TimeoutError, ConnectionRefusedError) as e:
app_log.error(e)
if not os.path.exists(self.ssh_key_file):
message = f"SSH key file {self.ssh_key_file} does not exist."
app_log.error(message)
raise RuntimeError(message)

if conn is not None:
break
try:
conn = await self._attempt_client_connect()
ssh_success = conn is not None
except (socket.gaierror, ValueError, TimeoutError) as e:
app_log.error(e)

await asyncio.sleep(5)
return ssh_success, conn

if conn is None and not self.run_local_on_ssh_fail:
raise RuntimeError("Could not connect to remote host.")
async def _attempt_client_connect(self) -> Optional[asyncssh.SSHClientConnection]:
"""
Helper function that catches specific errors and retries connecting to the remote host.
else:
message = f"no SSH key file found at {self.ssh_key_file}. Cannot connect to host."
app_log.error(message)
raise RuntimeError(message)
Args:
max_attempts: Gives up after this many attempts.
return ssh_success, conn
Returns:
An `SSHClientConnection` object if successful, None otherwise.
"""

# Retry connecting if any of these errors happen:
_retry_errs = (
ConnectionRefusedError,
OSError, # e.g. Network unreachable
)

address = f"{self.username}@{self.hostname}"
attempt_max = self.max_connection_attempts

attempt = 0
while attempt < attempt_max:

try:
# Exit here if the connection is successful.
return await asyncssh.connect(
self.hostname,
username=self.username,
client_keys=[self.ssh_key_file],
known_hosts=None,
)
except _retry_errs as err:

if not self.retry_connect:
app_log.error(f"{err} ({address} | retry disabled).")
raise err

app_log.warning(f"{err} ({address} | retry {attempt+1}/{attempt_max})")
await asyncio.sleep(self.retry_wait_time)

finally:
attempt += 1

# Failed to connect to client.
return None

async def cleanup(
self,
Expand All @@ -290,7 +299,7 @@ async def cleanup(
script_file: Path to the script file to be deleted locally
result_file: Path to the result file to be deleted locally
remote_function_file: Path to the function file to be deleted on remote
remote_script_file: Path to the sccript file to be deleted on remote
remote_script_file: Path to the script file to be deleted on remote
remote_result_file: Path to the result file to be deleted on remote
Returns:
Expand Down Expand Up @@ -540,9 +549,11 @@ async def run(
app_log.debug("Running function file in remote machine...")
result = await self.submit_task(conn, remote_script_file)

if result_err := result.stderr.strip():
app_log.warning(result_err)
return self._on_ssh_fail(function, args, kwargs, result_err)
if result.exit_status != 0:
message = result.stderr.strip()
message = message or f"Task exited with nonzero exit status {result.exit_status}."
app_log.warning(message)
return self._on_ssh_fail(function, args, kwargs, message)

if not await self._poll_task(conn, remote_result_file):
message = (
Expand Down
20 changes: 20 additions & 0 deletions tests/functional_tests/basic_workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def basic_workflow(a, b):
print(result)

assert status == str(ct.status.COMPLETED)


@pytest.mark.functional_tests
def test_basic_workflow_failure():
@ct.electron(executor="ssh")
def join_words(a, b):
raise Exception(f"{', '.join([a, b])} -- but something went wrong!")

@ct.lattice
def basic_workflow_that_will_fail(a, b):
return join_words(a, b)

# Dispatch the workflow
dispatch_id = ct.dispatch(basic_workflow_that_will_fail)("Hello", "World")
result = ct.get_result(dispatch_id=dispatch_id, wait=True)
status = str(result.status)

print(result)

assert status == str(ct.status.FAILED)
Loading

0 comments on commit 6a8d8d7

Please sign in to comment.