diff --git a/megalodon/backends.py b/megalodon/backends.py index 5a4a855..43abffb 100755 --- a/megalodon/backends.py +++ b/megalodon/backends.py @@ -480,6 +480,47 @@ def get_model_info_from_fast5(read): self._parse_minimal_alphabet_info() + def _pyguppy_client_init(self): + self.client = self.pyguppy_GuppyBasecallerClient( + '{}:{}'.format(GUPPY_HOST, self.params.pyguppy.port), + self.params.pyguppy.config, **PYGUPPY_CLIENT_KWARGS) + + def _pyguppy_client_connect( + self, max_reconnect_attempts=PYGUPPY_MAX_RECONNECT_ATTEMPTS, + per_try_sleep=0.1): + if self.client is None: + self._pyguppy_client_init() + n_attempts = 0 + err_str = '' + while n_attempts < max_reconnect_attempts: + try: + LOGGER.debug('Connecting to server') + self.client.connect() + return + except (ConnectionError, ValueError, RuntimeError) as e: + err_str = str(e) + n_attempts += 1 + sleep(per_try_sleep) + LOGGER.debug( + 'Failed to connect to Guppy server after {} attempts'.format( + max_reconnect_attempts)) + self._pyguppy_client_disconnect() + raise mh.MegaError( + 'Error connecting to Guppy server. Undefined error: "{}"'.format( + err_str)) + + def _pyguppy_client_assert_connected(self): + if self.client is None: + try: + self._pyguppy_client_connect() + except mh.MegaError: + raise mh.MegaError( + 'Unable to establish connection to Guppy server.') + + def _pyguppy_client_disconnect(self): + self.client.disconnect() + self.client = None + def _load_pyguppy(self, init_sig_len=1000): def _check_guppy_version(pyguppy_version_str): try: @@ -574,11 +615,8 @@ def get_server_port(): pyguppy=self.params.pyguppy._replace(port=used_port)) def set_pyguppy_model_attributes(): - init_client = self.pyguppy_GuppyBasecallerClient( - '{}:{}'.format(GUPPY_HOST, self.params.pyguppy.port), - self.params.pyguppy.config, **PYGUPPY_CLIENT_KWARGS) + self._pyguppy_client_connect() try: - init_client.connect() init_read = [(SIGNAL_DATA( fast5_fn='init_test_read', read_id='init_test_read', raw_len=0, dacs=np.zeros(init_sig_len, dtype=np.int16), @@ -587,7 +625,7 @@ def set_pyguppy_model_attributes(): mh.CHAN_INFO_DIGI: 1}), None)] try: init_called_read, _, _ = next(self.pyguppy_basecall( - init_client, init_read)) + init_read)) except mh.MegaError: raise mh.MegaError( 'Failed to run test read with Guppy. See Guppy logs ' + @@ -606,7 +644,7 @@ def set_pyguppy_model_attributes(): 'Error connecting to Guppy server. Undefined error: ' + str(e)) finally: - init_client.disconnect() + self._pyguppy_client_disconnect() if init_called_read.model_type not in COMPAT_GUPPY_MODEL_TYPES: raise mh.MegaError(( @@ -623,6 +661,7 @@ def set_pyguppy_model_attributes(): LOGGER.info('Loading guppy basecalling backend') self.model_type = PYGUPPY_NAME self.process_devices = [None, ] * self.num_proc + self.client = None # load necessary packages and store in object attrs self.pyguppy_retries = max( @@ -671,12 +710,7 @@ def prep_model_worker(self, device=None): raise mh.MegaError('Error setting CUDA GPU device.') self.model = self.model.eval() elif self.model_type == PYGUPPY_NAME: - # open guppy client interface (None indicates using config - # from server) - self.client = self.pyguppy_GuppyBasecallerClient( - '{}:{}'.format(GUPPY_HOST, self.params.pyguppy.port), - self.params.pyguppy.config, **PYGUPPY_CLIENT_KWARGS) - self.client.connect() + self._pyguppy_client_connect() def extract_signal_info(self, fast5_fp, read_id, extract_dacs=False): """ Extract signal information from fast5 file pointer. @@ -775,72 +809,59 @@ def _softmax_mod_weights(self, raw_mod_weights): raw_mod_weights[:, lab_indices])) return np.concatenate(mod_layers, axis=1) - def pyguppy_basecall(self, client, reads_batch, failed_reads_q=None): + def pyguppy_basecall(self, reads_batch, failed_reads_q=None): def do_sleep(): # function for profiling purposes sleep(PYGUPPY_PER_TRY_TIMEOUT) def get_completed_reads(): - do_retry = True - n_reconnect_attempts = 0 comp_reads = err_str = None - while do_retry: + try: try: - comp_reads = client.get_completed_reads() - do_retry = False - except ConnectionError as e: - if n_reconnect_attempts < PYGUPPY_MAX_RECONNECT_ATTEMPTS: - LOGGER.debug('Reconnecting to server (get reads)') - client.connect() - n_reconnect_attempts += 1 - else: - err_str = ('Pyguppy pass read connection error ' + - '"{}"').format(str(e)) - do_retry = False - err_str = ('Pyguppy get completed reads connection ' + - 'error "{}"').format(str(e)) - except RuntimeError as e: - err_str = ('Pyguppy get completed reads invalid error ' + - '"{}"').format(str(e)) - do_retry = False + self._pyguppy_client_assert_connected() + except mh.MegaError as e: + err_str = str(e) + else: + comp_reads = self.client.get_completed_reads() + except ConnectionError: + try: + self._pyguppy_client_connect() + except mh.MegaError as e: + err_str = str(e) + except RuntimeError as e: + err_str = ('Pyguppy get completed reads invalid error ' + '"{}"').format(str(e)) return comp_reads, err_str saved_input_data = {} completed_reads = [] for sig_info, seq_summ_info in reads_batch: err_str = None - read_sent = False - n_reconnect_attempts = 0 pyguppy_read = get_pyguppy_read( sig_info.read_id, sig_info.dacs, sig_info.channel_info) - while not read_sent: + try: try: - read_sent = client.pass_read(pyguppy_read) - except ValueError as e: - err_str = ('Pyguppy pass read malformed error ' + - '"{}"').format(str(e)) - read_sent = True - except ConnectionError as e: - # attempt to reconnect to server when connection error - # occurs - if n_reconnect_attempts < PYGUPPY_MAX_RECONNECT_ATTEMPTS: - LOGGER.debug('Reconnecting to server (pass read)') - client.connect() - n_reconnect_attempts += 1 - else: - err_str = ('Pyguppy pass read connection error ' + - '"{}"').format(str(e)) - read_sent = True - except RuntimeError as e: - err_str = ('Pyguppy pass read undefined error ' + - '"{}"').format(str(e)) - read_sent = True - # get completed reads while sending reads so server doesn't - # back up indefinitely - iter_comp_reads = get_completed_reads()[0] - if iter_comp_reads is not None: - completed_reads.extend(iter_comp_reads) - + self._pyguppy_client_assert_connected() + except mh.MegaError as e: + err_str = str(e) + else: + read_sent = self.client.pass_read(pyguppy_read) + if not read_sent: + err_str = 'Guppy server unable to recieve read' + except ValueError as e: + err_str = ('Pyguppy pass read malformed error ' + + '"{}"').format(str(e)) + except ConnectionError as e: + err_str = ('Pyguppy pass read connection error "{}"').format( + str(e)) + try: + self._pyguppy_client_connect() + except mh.MegaError as e: + err_str = ('Pyguppy pass read connection error (unable to ' + 're-connect) "{}"').format(str(e)) + except RuntimeError as e: + err_str = ('Pyguppy pass read undefined error ' + + '"{}"').format(str(e)) if err_str is None: saved_input_data[sig_info.read_id] = (sig_info, seq_summ_info) else: @@ -854,13 +875,19 @@ def get_completed_reads(): LOGGER.debug('{} BasecallingFailed "{}"'.format( sig_info.read_id, err_str)) + # get completed reads while sending reads so server doesn't + # back up indefinitely + comp_reads = get_completed_reads()[0] + if comp_reads is not None: + completed_reads.extend(comp_reads) + # yield reads that have been called already for called_read in completed_reads: read_id = called_read["metadata"]["read_id"] try: sig_info, seq_summ_info = saved_input_data[read_id] except KeyError: - # read submitted in last batch now finished + # read submitted in previous batch now finished LOGGER.debug('{} timeout read finished'.format(read_id)) continue LOGGER.debug('{} BasecallingCompleted'.format(read_id)) @@ -880,7 +907,7 @@ def get_completed_reads(): try: sig_info, seq_summ_info = saved_input_data[read_id] except KeyError: - # read submitted in last batch now finished + # read submitted in previous batch now finished LOGGER.debug('{} timeout read finished'.format(read_id)) continue LOGGER.debug('{} BasecallingCompleted'.format(read_id)) @@ -995,7 +1022,7 @@ def _run_pyguppy_backend( 'initialization.') for called_read, sig_info, seq_summ_info in self.pyguppy_basecall( - self.client, reads_batch, failed_reads_q): + reads_batch, failed_reads_q): try: yield self._postprocess_pyguppy_called_read( called_read, sig_info, seq_summ_info, return_post_w_mods, diff --git a/megalodon/megalodon.py b/megalodon/megalodon.py index 4572b27..95c2da3 100755 --- a/megalodon/megalodon.py +++ b/megalodon/megalodon.py @@ -598,8 +598,8 @@ def _get_fail_queue( def wait_for_completion( files_p, extract_sig_ps, signal_q, proc_reads_ps, map_read_ts, getter_qpcs, aux_failed_q, input_info): - def kill_all_proc(): - for p in [files_p, ] + proc_reads_ps: + def kill_all_proc(msgs=None): + for p in [files_p, ] + extract_sig_ps + proc_reads_ps: if p.is_alive(): p.terminate() p.join() @@ -607,17 +607,18 @@ def kill_all_proc(): if q.proc.is_alive(): q.proc.terminate() q.proc.join() + sleep(0.01) + if msgs is not None: + for msg in msgs: + LOGGER.error(msg) + sys.exit(1) try: # wait for file enumeration process to finish first while files_p.is_alive(): try: aux_err = aux_failed_q.get(block=False) - kill_all_proc() - sleep(0.01) - for msg in aux_err: - LOGGER.error(msg) - sys.exit(1) + kill_all_proc(aux_err) except queue.Empty: # TODO check for failed workers and create mechanism to restart sleep(1) @@ -627,15 +628,17 @@ def kill_all_proc(): # wait for signal extraction to finish next while any(p.is_alive() for p in extract_sig_ps): + # if no worker processes are alive run will stall + if all(not p.is_alive() for p in proc_reads_ps): + # ensure this is really a stalled run and not lagging to + # close other processes + sleep(1) + if any(p.is_alive() for p in extract_sig_ps): + kill_all_proc() try: aux_err = aux_failed_q.get(block=False) kill_all_proc() - sleep(0.01) - for msg in aux_err: - LOGGER.error(msg) - sys.exit(1) except queue.Empty: - # TODO check for failed workers and create mechanism to restart sleep(1) LOGGER.debug('JoiningMain: SignalExtractors') for extract_sig_p in extract_sig_ps: @@ -651,24 +654,16 @@ def kill_all_proc(): try: aux_err = aux_failed_q.get(block=False) # if an auxiliary process fails exit megalodon - kill_all_proc() - sleep(0.01) - for msg in aux_err: - LOGGER.error(msg) - sys.exit(1) + kill_all_proc(aux_err) except queue.Empty: # check that a getter queue has not failed with a segfault for g_name, getter_qpc in getter_qpcs.items(): if not getter_qpc.proc.is_alive() and \ not signal_q.queue.empty(): - kill_all_proc() - sleep(0.01) - LOGGER.error(( + kill_all_proc([( '{} Getter queue has unexpectedly died likely ' + 'via a segfault error. Please log this ' + - 'issue.').format(g_name)) - sys.exit(1) - # TODO check for failed workers and create mechanism to restart + 'issue.').format(g_name)]) sleep(1) LOGGER.debug('JoiningMain: Workers')