Skip to content

Commit

Permalink
Merge branch 'pygup_connect' into 'master'
Browse files Browse the repository at this point in the history
Pygup connect

See merge request algorithm/megalodon!60
  • Loading branch information
marcus1487 committed Jan 21, 2021
2 parents 7efbbfb + 3ee21d2 commit de94d25
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 88 deletions.
157 changes: 92 additions & 65 deletions megalodon/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,47 @@ def get_model_info_from_fast5(read):

self._parse_minimal_alphabet_info()

def _pyguppy_client_init(self):
self.client = self.pyguppy_GuppyBasecallerClient(
'{}:{}'.format(GUPPY_HOST, self.params.pyguppy.port),
self.params.pyguppy.config, **PYGUPPY_CLIENT_KWARGS)

def _pyguppy_client_connect(
self, max_reconnect_attempts=PYGUPPY_MAX_RECONNECT_ATTEMPTS,
per_try_sleep=0.1):
if self.client is None:
self._pyguppy_client_init()
n_attempts = 0
err_str = ''
while n_attempts < max_reconnect_attempts:
try:
LOGGER.debug('Connecting to server')
self.client.connect()
return
except (ConnectionError, ValueError, RuntimeError) as e:
err_str = str(e)
n_attempts += 1
sleep(per_try_sleep)
LOGGER.debug(
'Failed to connect to Guppy server after {} attempts'.format(
max_reconnect_attempts))
self._pyguppy_client_disconnect()
raise mh.MegaError(
'Error connecting to Guppy server. Undefined error: "{}"'.format(
err_str))

def _pyguppy_client_assert_connected(self):
if self.client is None:
try:
self._pyguppy_client_connect()
except mh.MegaError:
raise mh.MegaError(
'Unable to establish connection to Guppy server.')

def _pyguppy_client_disconnect(self):
self.client.disconnect()
self.client = None

def _load_pyguppy(self, init_sig_len=1000):
def _check_guppy_version(pyguppy_version_str):
try:
Expand Down Expand Up @@ -574,11 +615,8 @@ def get_server_port():
pyguppy=self.params.pyguppy._replace(port=used_port))

def set_pyguppy_model_attributes():
init_client = self.pyguppy_GuppyBasecallerClient(
'{}:{}'.format(GUPPY_HOST, self.params.pyguppy.port),
self.params.pyguppy.config, **PYGUPPY_CLIENT_KWARGS)
self._pyguppy_client_connect()
try:
init_client.connect()
init_read = [(SIGNAL_DATA(
fast5_fn='init_test_read', read_id='init_test_read',
raw_len=0, dacs=np.zeros(init_sig_len, dtype=np.int16),
Expand All @@ -587,7 +625,7 @@ def set_pyguppy_model_attributes():
mh.CHAN_INFO_DIGI: 1}), None)]
try:
init_called_read, _, _ = next(self.pyguppy_basecall(
init_client, init_read))
init_read))
except mh.MegaError:
raise mh.MegaError(
'Failed to run test read with Guppy. See Guppy logs ' +
Expand All @@ -606,7 +644,7 @@ def set_pyguppy_model_attributes():
'Error connecting to Guppy server. Undefined error: ' +
str(e))
finally:
init_client.disconnect()
self._pyguppy_client_disconnect()

if init_called_read.model_type not in COMPAT_GUPPY_MODEL_TYPES:
raise mh.MegaError((
Expand All @@ -623,6 +661,7 @@ def set_pyguppy_model_attributes():
LOGGER.info('Loading guppy basecalling backend')
self.model_type = PYGUPPY_NAME
self.process_devices = [None, ] * self.num_proc
self.client = None

# load necessary packages and store in object attrs
self.pyguppy_retries = max(
Expand Down Expand Up @@ -671,12 +710,7 @@ def prep_model_worker(self, device=None):
raise mh.MegaError('Error setting CUDA GPU device.')
self.model = self.model.eval()
elif self.model_type == PYGUPPY_NAME:
# open guppy client interface (None indicates using config
# from server)
self.client = self.pyguppy_GuppyBasecallerClient(
'{}:{}'.format(GUPPY_HOST, self.params.pyguppy.port),
self.params.pyguppy.config, **PYGUPPY_CLIENT_KWARGS)
self.client.connect()
self._pyguppy_client_connect()

def extract_signal_info(self, fast5_fp, read_id, extract_dacs=False):
""" Extract signal information from fast5 file pointer.
Expand Down Expand Up @@ -775,72 +809,59 @@ def _softmax_mod_weights(self, raw_mod_weights):
raw_mod_weights[:, lab_indices]))
return np.concatenate(mod_layers, axis=1)

def pyguppy_basecall(self, client, reads_batch, failed_reads_q=None):
def pyguppy_basecall(self, reads_batch, failed_reads_q=None):
def do_sleep():
# function for profiling purposes
sleep(PYGUPPY_PER_TRY_TIMEOUT)

def get_completed_reads():
do_retry = True
n_reconnect_attempts = 0
comp_reads = err_str = None
while do_retry:
try:
try:
comp_reads = client.get_completed_reads()
do_retry = False
except ConnectionError as e:
if n_reconnect_attempts < PYGUPPY_MAX_RECONNECT_ATTEMPTS:
LOGGER.debug('Reconnecting to server (get reads)')
client.connect()
n_reconnect_attempts += 1
else:
err_str = ('Pyguppy pass read connection error ' +
'"{}"').format(str(e))
do_retry = False
err_str = ('Pyguppy get completed reads connection ' +
'error "{}"').format(str(e))
except RuntimeError as e:
err_str = ('Pyguppy get completed reads invalid error ' +
'"{}"').format(str(e))
do_retry = False
self._pyguppy_client_assert_connected()
except mh.MegaError as e:
err_str = str(e)
else:
comp_reads = self.client.get_completed_reads()
except ConnectionError:
try:
self._pyguppy_client_connect()
except mh.MegaError as e:
err_str = str(e)
except RuntimeError as e:
err_str = ('Pyguppy get completed reads invalid error '
'"{}"').format(str(e))
return comp_reads, err_str

saved_input_data = {}
completed_reads = []
for sig_info, seq_summ_info in reads_batch:
err_str = None
read_sent = False
n_reconnect_attempts = 0
pyguppy_read = get_pyguppy_read(
sig_info.read_id, sig_info.dacs, sig_info.channel_info)
while not read_sent:
try:
try:
read_sent = client.pass_read(pyguppy_read)
except ValueError as e:
err_str = ('Pyguppy pass read malformed error ' +
'"{}"').format(str(e))
read_sent = True
except ConnectionError as e:
# attempt to reconnect to server when connection error
# occurs
if n_reconnect_attempts < PYGUPPY_MAX_RECONNECT_ATTEMPTS:
LOGGER.debug('Reconnecting to server (pass read)')
client.connect()
n_reconnect_attempts += 1
else:
err_str = ('Pyguppy pass read connection error ' +
'"{}"').format(str(e))
read_sent = True
except RuntimeError as e:
err_str = ('Pyguppy pass read undefined error ' +
'"{}"').format(str(e))
read_sent = True
# get completed reads while sending reads so server doesn't
# back up indefinitely
iter_comp_reads = get_completed_reads()[0]
if iter_comp_reads is not None:
completed_reads.extend(iter_comp_reads)

self._pyguppy_client_assert_connected()
except mh.MegaError as e:
err_str = str(e)
else:
read_sent = self.client.pass_read(pyguppy_read)
if not read_sent:
err_str = 'Guppy server unable to recieve read'
except ValueError as e:
err_str = ('Pyguppy pass read malformed error ' +
'"{}"').format(str(e))
except ConnectionError as e:
err_str = ('Pyguppy pass read connection error "{}"').format(
str(e))
try:
self._pyguppy_client_connect()
except mh.MegaError as e:
err_str = ('Pyguppy pass read connection error (unable to '
're-connect) "{}"').format(str(e))
except RuntimeError as e:
err_str = ('Pyguppy pass read undefined error ' +
'"{}"').format(str(e))
if err_str is None:
saved_input_data[sig_info.read_id] = (sig_info, seq_summ_info)
else:
Expand All @@ -854,13 +875,19 @@ def get_completed_reads():
LOGGER.debug('{} BasecallingFailed "{}"'.format(
sig_info.read_id, err_str))

# get completed reads while sending reads so server doesn't
# back up indefinitely
comp_reads = get_completed_reads()[0]
if comp_reads is not None:
completed_reads.extend(comp_reads)

# yield reads that have been called already
for called_read in completed_reads:
read_id = called_read["metadata"]["read_id"]
try:
sig_info, seq_summ_info = saved_input_data[read_id]
except KeyError:
# read submitted in last batch now finished
# read submitted in previous batch now finished
LOGGER.debug('{} timeout read finished'.format(read_id))
continue
LOGGER.debug('{} BasecallingCompleted'.format(read_id))
Expand All @@ -880,7 +907,7 @@ def get_completed_reads():
try:
sig_info, seq_summ_info = saved_input_data[read_id]
except KeyError:
# read submitted in last batch now finished
# read submitted in previous batch now finished
LOGGER.debug('{} timeout read finished'.format(read_id))
continue
LOGGER.debug('{} BasecallingCompleted'.format(read_id))
Expand Down Expand Up @@ -995,7 +1022,7 @@ def _run_pyguppy_backend(
'initialization.')

for called_read, sig_info, seq_summ_info in self.pyguppy_basecall(
self.client, reads_batch, failed_reads_q):
reads_batch, failed_reads_q):
try:
yield self._postprocess_pyguppy_called_read(
called_read, sig_info, seq_summ_info, return_post_w_mods,
Expand Down
41 changes: 18 additions & 23 deletions megalodon/megalodon.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,26 +598,27 @@ def _get_fail_queue(
def wait_for_completion(
files_p, extract_sig_ps, signal_q, proc_reads_ps, map_read_ts,
getter_qpcs, aux_failed_q, input_info):
def kill_all_proc():
for p in [files_p, ] + proc_reads_ps:
def kill_all_proc(msgs=None):
for p in [files_p, ] + extract_sig_ps + proc_reads_ps:
if p.is_alive():
p.terminate()
p.join()
for q in list(getter_qpcs.values()):
if q.proc.is_alive():
q.proc.terminate()
q.proc.join()
sleep(0.01)
if msgs is not None:
for msg in msgs:
LOGGER.error(msg)
sys.exit(1)

try:
# wait for file enumeration process to finish first
while files_p.is_alive():
try:
aux_err = aux_failed_q.get(block=False)
kill_all_proc()
sleep(0.01)
for msg in aux_err:
LOGGER.error(msg)
sys.exit(1)
kill_all_proc(aux_err)
except queue.Empty:
# TODO check for failed workers and create mechanism to restart
sleep(1)
Expand All @@ -627,15 +628,17 @@ def kill_all_proc():

# wait for signal extraction to finish next
while any(p.is_alive() for p in extract_sig_ps):
# if no worker processes are alive run will stall
if all(not p.is_alive() for p in proc_reads_ps):
# ensure this is really a stalled run and not lagging to
# close other processes
sleep(1)
if any(p.is_alive() for p in extract_sig_ps):
kill_all_proc()
try:
aux_err = aux_failed_q.get(block=False)
kill_all_proc()
sleep(0.01)
for msg in aux_err:
LOGGER.error(msg)
sys.exit(1)
except queue.Empty:
# TODO check for failed workers and create mechanism to restart
sleep(1)
LOGGER.debug('JoiningMain: SignalExtractors')
for extract_sig_p in extract_sig_ps:
Expand All @@ -651,24 +654,16 @@ def kill_all_proc():
try:
aux_err = aux_failed_q.get(block=False)
# if an auxiliary process fails exit megalodon
kill_all_proc()
sleep(0.01)
for msg in aux_err:
LOGGER.error(msg)
sys.exit(1)
kill_all_proc(aux_err)
except queue.Empty:
# check that a getter queue has not failed with a segfault
for g_name, getter_qpc in getter_qpcs.items():
if not getter_qpc.proc.is_alive() and \
not signal_q.queue.empty():
kill_all_proc()
sleep(0.01)
LOGGER.error((
kill_all_proc([(
'{} Getter queue has unexpectedly died likely ' +
'via a segfault error. Please log this ' +
'issue.').format(g_name))
sys.exit(1)
# TODO check for failed workers and create mechanism to restart
'issue.').format(g_name)])
sleep(1)

LOGGER.debug('JoiningMain: Workers')
Expand Down

0 comments on commit de94d25

Please sign in to comment.