Skip to content

Commit

Permalink
Merge branch 'sec_align'
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus1487 committed Feb 1, 2021
2 parents d27e4a3 + a45cfce commit ccd9c5b
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 98 deletions.
9 changes: 9 additions & 0 deletions megalodon/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def hidden_help(help_msg):
help='Reference FASTA or minimap2 index file used for mapping ' +
'called reads.')

map_grp.add_argument(
'--allow-supplementary-alignments', action='store_true',
help=hidden_help('Allow alignments aside from the primary alignment ' +
'to be processed. Note that this may result in ' +
'multiple modified base calls from the same read ' +
'at the same read and/or reference position.'))
map_grp.add_argument(
'--forward-strand-alignments-only', action='store_true',
help=hidden_help('Only allow forward strand alignments.'))
map_grp.add_argument(
'--cram-reference',
help=hidden_help('FASTA reference file. If --reference is a ' +
Expand Down
128 changes: 79 additions & 49 deletions megalodon/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@
'chrm', 'strand', 'start', 'end', 'q_trim_start', 'q_trim_end'))
MAP_RES = namedtuple('MAP_RES', (
'read_id', 'q_seq', 'ref_seq', 'ctg', 'strand', 'r_st', 'r_en',
'q_st', 'q_en', 'cigar', 'map_sig_start', 'map_sig_end', 'sig_len'))
MAP_RES.__new__.__defaults__ = (None, None, None)
'q_st', 'q_en', 'cigar', 'map_sig_start', 'map_sig_end', 'sig_len',
'map_num'))
MAP_RES.__new__.__defaults__ = (None, None, None, 0)
MAP_SUMM = namedtuple('MAP_SUMM', (
'read_id', 'pct_identity', 'num_align', 'num_match',
'num_del', 'num_ins', 'read_pct_coverage', 'chrom', 'strand',
'start', 'end', 'query_start', 'query_end',
'map_sig_start', 'map_sig_end', 'sig_len'))
'map_sig_start', 'map_sig_end', 'sig_len', 'map_num'))
# Defaults for backwards compatibility when reading
MAP_SUMM.__new__.__defaults__ = (
None, None, None, None, None, None, None, None)
None, None, None, None, None, None, None, None, 0)
MAP_SUMM_TMPLT = (
'{0.read_id}\t{0.pct_identity:.2f}\t{0.num_align}\t{0.num_match}\t' +
'{0.num_del}\t{0.num_ins}\t{0.read_pct_coverage:.2f}\t{0.chrom}\t' +
'{0.strand}\t{0.start}\t{0.end}\t{0.query_start}\t{0.query_end}\t' +
'{0.map_sig_start}\t{0.map_sig_end}\t{0.sig_len}\n')
'{0.map_sig_start}\t{0.map_sig_end}\t{0.sig_len}\t{0.map_num}\n')
MAP_SUMM_TYPES = dict(zip(
MAP_SUMM._fields,
(str, float, int, int, int, int, float, str, str, int, int,
int, int, int, int, int)))
int, int, int, int, int, int)))

MOD_POS_TAG = 'Mm'
MOD_PROB_TAG = 'Ml'
Expand Down Expand Up @@ -112,7 +113,7 @@ def get_map_pos_from_res(map_res):
class MapInfo:
def __init__(
self, aligner, map_fmt, ref_fn, out_dir, do_output_mappings,
samtools_exec, do_sort_mappings, cram_ref_fn):
samtools_exec, do_sort_mappings, cram_ref_fn, allow_supps=False):
if aligner is None:
self.ref_names_and_lens = None
else:
Expand All @@ -130,6 +131,7 @@ def __init__(
self.do_output_mappings = do_output_mappings
self.samtools_exec = samtools_exec
self.do_sort_mappings = do_sort_mappings
self.allow_supps = allow_supps

def open_alignment_out_file(self):
map_fn = '{}.{}'.format(
Expand Down Expand Up @@ -183,24 +185,35 @@ def test_samtools(self):
self.do_sort_mappings = False


def align_read(q_seq, aligner, map_thr_buf, read_id=None):
try:
# enumerate all alignments to avoid memory leak from mappy
r_algn = list(aligner.map(str(q_seq), buf=map_thr_buf))[0]
except IndexError:
# alignment not produced
def align_read(
q_seq, aligner, map_thr_buf, read_id=None, allow_supps=False,
return_tuple=True):
def parse_alignment(r_algn, map_num=0):
ref_seq = aligner.seq(r_algn.ctg, r_algn.r_st, r_algn.r_en)
if r_algn.strand == -1:
ref_seq = mh.revcomp(ref_seq)
r_map_res = MAP_RES(
read_id=read_id, q_seq=q_seq, ref_seq=ref_seq, ctg=r_algn.ctg,
strand=r_algn.strand, r_st=r_algn.r_st, r_en=r_algn.r_en,
q_st=r_algn.q_st, q_en=r_algn.q_en, cigar=r_algn.cigar,
map_num=map_num)
if return_tuple:
return tuple(r_map_res)
return r_map_res

# enumerate all alignments to avoid memory leak from mappy
r_algns = list(aligner.map(str(q_seq), buf=map_thr_buf))
if len(r_algns) == 0:
# no alignments produced
return None

ref_seq = aligner.seq(r_algn.ctg, r_algn.r_st, r_algn.r_en)
if r_algn.strand == -1:
ref_seq = mh.revcomp(ref_seq)
return MAP_RES(
read_id=read_id, q_seq=q_seq, ref_seq=ref_seq, ctg=r_algn.ctg,
strand=r_algn.strand, r_st=r_algn.r_st, r_en=r_algn.r_en,
q_st=r_algn.q_st, q_en=r_algn.q_en, cigar=r_algn.cigar)
if allow_supps:
return [parse_alignment(r_algn, map_num)
for map_num, r_algn in enumerate(r_algns)]
return [parse_alignment(r_algns[0]), ]


def _map_read_worker(aligner, map_conn):
def _map_read_worker(aligner, map_conn, allow_supps):
LOGGER.debug('MappingWorkerStarting')
# get mappy aligner thread buffer
map_thr_buf = mappy.ThreadBuffer()
Expand All @@ -212,11 +225,8 @@ def _map_read_worker(aligner, map_conn):
except EOFError:
LOGGER.debug('MappingWorkerClosing')
break
map_res = align_read(q_seq, aligner, map_thr_buf, read_id)
if map_res is not None:
# only convert to tuple if result is valid
map_res = tuple(map_res)
map_conn.send(map_res)
map_conn.send(align_read(
q_seq, aligner, map_thr_buf, read_id, allow_supps))


def parse_cigar(r_cigar, strand, ref_len):
Expand Down Expand Up @@ -253,26 +263,11 @@ def parse_cigar(r_cigar, strand, ref_len):
return r_to_q_poss


def map_read(
caller_conn, called_read, sig_info, mo_q=None, signal_reversed=False,
rl_cumsum=None):
""" Map read (query) sequence
Returns:
Tuple containing
1) reference sequence (endcoded as int labels)
2) mapping from reference to read positions (after trimming)
3) reference mapping position (including read trimming positions)
4) cigar as produced by mappy
"""
# send seq to _map_read_worker and receive mapped seq and pos
q_seq = called_read.seq[::-1] if signal_reversed else called_read.seq
caller_conn.send((q_seq, sig_info.read_id))
map_res = caller_conn.recv()
if map_res is None:
raise mh.MegaError('No alignment')
def process_mapping(
map_res, called_read, sig_info, mo_q, signal_reversed, rl_cumsum):
map_res = MAP_RES(*map_res)
# add signal coordinates to mapping output if run-length cumsum provided
# add signal coordinates to mapping output if run-length cumsum
# provided
if rl_cumsum is not None:
# convert query start and end to signal-anchored locations
# Note that for signal_reversed reads, the start will be larger than
Expand Down Expand Up @@ -307,7 +302,32 @@ def map_read(
raise mh.MegaError('Invalid cigar string encountered.')
map_pos = get_map_pos_from_res(map_res)

return map_res.ref_seq, r_to_q_poss, map_pos, map_res.cigar
return (map_res.ref_seq, r_to_q_poss, map_pos, map_res.cigar,
map_res.map_num)


def map_read(
caller_conn, called_read, sig_info, mo_q=None, signal_reversed=False,
rl_cumsum=None):
""" Map read (query) sequence
Returns:
Tuple containing
1) reference sequence (endcoded as int labels)
2) mapping from reference to read positions (after trimming)
3) reference mapping position (including read trimming positions)
4) cigar as produced by mappy
"""
# send seq to _map_read_worker and receive mapped seq and pos
q_seq = called_read.seq[::-1] if signal_reversed else called_read.seq
caller_conn.send((q_seq, sig_info.read_id))
map_ress = caller_conn.recv()
if map_ress is None:
raise mh.MegaError('No alignment')

return [process_mapping(
map_res, called_read, sig_info, mo_q, signal_reversed, rl_cumsum)
for map_res in map_ress]


def compute_pct_identity(cigar):
Expand All @@ -334,6 +354,15 @@ def read_passes_filters(filt_params, read_len, q_st, q_en, cigar):
return True


def get_map_flag(strand, map_num):
flag = 0
if strand == -1:
flag += 16
if map_num > 0:
flag += 2048
return flag


def _get_map_queue(mo_q, mo_conn, map_info, ref_out_info, aux_failed_q):
def write_alignment(map_res):
# convert tuple back to namedtuple
Expand All @@ -354,7 +383,7 @@ def write_alignment(map_res):
a = prepare_mapping(
map_res.read_id,
q_seq if map_res.strand == 1 else mh.revcomp(q_seq),
flag=0 if map_res.strand == 1 else 16,
flag=get_map_flag(map_res.strand, map_res.map_num),
ref_id=map_fp.get_tid(map_res.ctg), ref_st=map_res.r_st,
cigartuples=[(op, op_l) for op_l, op in map_res.cigar],
tags=[('NM', nalign - nmatch)])
Expand All @@ -369,12 +398,13 @@ def write_alignment(map_res):
strand=mh.int_strand_to_str(map_res.strand), start=map_res.r_st,
end=map_res.r_st + nalign - nins, query_start=map_res.q_st,
query_end=map_res.q_en, map_sig_start=map_res.map_sig_start,
map_sig_end=map_res.map_sig_end, sig_len=map_res.sig_len)
map_sig_end=map_res.map_sig_end, sig_len=map_res.sig_len,
map_num=map_res.map_num)
summ_fp.write(MAP_SUMM_TMPLT.format(r_map_summ))

if ref_out_info.do_output.pr_refs and read_passes_filters(
ref_out_info.filt_params, len(map_res.q_seq), map_res.q_st,
map_res.q_en, map_res.cigar):
map_res.q_en, map_res.cigar) and map_res.map_num == 0:
pr_ref_fp.write('>{}\n{}\n'.format(
map_res.read_id, map_res.ref_seq))

Expand Down
105 changes: 68 additions & 37 deletions megalodon/megalodon.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,36 +150,10 @@ def interpolate_sig_pos(r_to_q_poss, mapped_rl_cumsum):
return ref_to_block


def process_read(
getter_qpcs, caller_conn, bc_res, ref_out_info, vars_info, mods_info,
bc_info):
""" Workhorse per-read megalodon function (connects all the parts)
"""
(sig_info, seq_summ_info, called_read, rl_cumsum, can_post, post_w_mods,
mods_scores) = bc_res
if bc_info.do_output.any:
# convert seq_summ_info to tuple since namedtuples can't be
# pickled for passing through a queue.
if bc_info.rev_sig:
# sequence is stored internally in sequencing direction. Send
# to basecall output in reference direction.
getter_qpcs[mh.BC_NAME].queue.put((
sig_info.read_id, called_read.seq[::-1],
called_read.qual[::-1], mods_scores, tuple(seq_summ_info)))
else:
getter_qpcs[mh.BC_NAME].queue.put((
sig_info.read_id, called_read.seq, called_read.qual,
mods_scores, tuple(seq_summ_info)))

# if no mapping connection return after basecalls are passed out
if caller_conn is None:
return

# map read and record mapping from reference to query positions
map_q = getter_qpcs[mh.MAP_NAME].queue \
if mh.MAP_NAME in getter_qpcs else None
r_ref_seq, r_to_q_poss, r_ref_pos, r_cigar = mapping.map_read(
caller_conn, called_read, sig_info, map_q, bc_info.rev_sig, rl_cumsum)
def process_mapping(
getter_qpcs, ref_out_info, vars_info, mods_info, bc_info, sig_info,
called_read, rl_cumsum, can_post, post_w_mods, r_ref_seq, r_to_q_poss,
r_ref_pos, r_cigar, map_num):
np_ref_seq = mh.seq_to_int(r_ref_seq, error_on_invalid=False)

failed_reads_q = getter_qpcs[_FAILED_READ_GETTER_NAME].queue
Expand All @@ -192,7 +166,8 @@ def process_read(
pass_sig_map_filts, sig_info.fast5_fn, sig_info.dacs,
sig_info.scale_params, r_ref_seq, sig_info.stride,
sig_info.read_id, r_to_q_poss, rl_cumsum, r_ref_pos, ref_out_info)
if ref_out_info.do_output.can_sig_maps and pass_sig_map_filts:
if ref_out_info.do_output.can_sig_maps and pass_sig_map_filts and \
map_num == 0:
try:
getter_qpcs[mh.SIG_MAP_NAME].queue.put(
signal_mapping.get_remapping(*sig_map_res[1:]))
Expand Down Expand Up @@ -238,15 +213,55 @@ def process_read(
func=mods.call_read_mods,
args=(r_ref_pos, r_ref_seq, ref_to_block, mapped_post_w_mods,
mods_info, mod_sig_map_q, sig_map_res, bc_info.rev_sig,
sig_info.read_id, failed_reads_q, sig_info.fast5_fn),
sig_info.read_id, failed_reads_q, sig_info.fast5_fn,
map_num),
r_vals=(sig_info.read_id, r_ref_pos.chrm, r_ref_pos.strand,
r_ref_pos.start, r_ref_seq, len(called_read.seq),
r_ref_pos.q_trim_start, r_ref_pos.q_trim_end, r_cigar),
r_ref_pos.q_trim_start, r_ref_pos.q_trim_end, r_cigar,
map_num),
out_q=getter_qpcs[mh.PR_MOD_NAME].queue,
fast5_fn=sig_info.fast5_fn + ':::' + sig_info.read_id,
failed_reads_q=failed_reads_q)


def process_read(
getter_qpcs, caller_conn, bc_res, ref_out_info, vars_info, mods_info,
bc_info):
""" Workhorse per-read megalodon function (connects all the parts)
"""
(sig_info, seq_summ_info, called_read, rl_cumsum, can_post, post_w_mods,
mods_scores) = bc_res
if bc_info.do_output.any:
# convert seq_summ_info to tuple since namedtuples can't be
# pickled for passing through a queue.
if bc_info.rev_sig:
# sequence is stored internally in sequencing direction. Send
# to basecall output in reference direction.
getter_qpcs[mh.BC_NAME].queue.put((
sig_info.read_id, called_read.seq[::-1],
called_read.qual[::-1], mods_scores, tuple(seq_summ_info)))
else:
getter_qpcs[mh.BC_NAME].queue.put((
sig_info.read_id, called_read.seq, called_read.qual,
mods_scores, tuple(seq_summ_info)))

# if no mapping connection return after basecalls are passed out
if caller_conn is None:
return

# map read and record mapping from reference to query positions
map_q = getter_qpcs[mh.MAP_NAME].queue \
if mh.MAP_NAME in getter_qpcs else None
for (r_ref_seq, r_to_q_poss, r_ref_pos, r_cigar,
r_map_num) in mapping.map_read(
caller_conn, called_read, sig_info, map_q, bc_info.rev_sig,
rl_cumsum):
process_mapping(
getter_qpcs, ref_out_info, vars_info, mods_info, bc_info, sig_info,
called_read, rl_cumsum, can_post, post_w_mods, r_ref_seq,
r_to_q_poss, r_ref_pos, r_cigar, r_map_num)


########################
# Process reads worker #
########################
Expand Down Expand Up @@ -780,7 +795,8 @@ def process_all_reads(
if aligner is not None:
for ti, map_conn in enumerate(map_conns):
map_read_ts.append(threading.Thread(
target=mapping._map_read_worker, args=(aligner, map_conn),
target=mapping._map_read_worker,
args=(aligner, map_conn, map_info.allow_supps),
daemon=True, name='Mapper{:03d}'.format(ti)))
map_read_ts[-1].start()

Expand All @@ -807,8 +823,22 @@ def parse_aligner_args(args):
LOGGER.error('Provided reference file does not exist or is ' +
'not a file.')
sys.exit(1)
aligner = mappy.Aligner(
str(args.reference), preset=str('map-ont'), best_n=1)
aligner_kwargs = {'preset': str('map-ont')}
if args.allow_supplementary_alignments:
LOGGER.warning(
'--allow-supplementary-alignments option is set. This '
'allows modified base and variant calls to be made from the '
'same read base at multiple reference bases and/or the same '
'reference base from multiple read bases in the same read. '
'This can lead to over-counting of modified base/variant '
'calls and/or spurious extra calls. There are use cases for '
'this behavior, but these caveats should be considered when '
'using this option.')
else:
aligner_kwargs.update({'best_n': 1})
if args.forward_strand_alignments_only:
aligner_kwargs.update({'extra_flags': 0x100000})
aligner = mappy.Aligner(str(args.reference), **aligner_kwargs)
else:
aligner = None
if args.reference is not None:
Expand All @@ -820,7 +850,8 @@ def parse_aligner_args(args):
out_dir=args.output_directory,
do_output_mappings=mh.MAP_NAME in args.outputs,
samtools_exec=args.samtools_executable,
do_sort_mappings=args.sort_mappings, cram_ref_fn=args.cram_reference)
do_sort_mappings=args.sort_mappings, cram_ref_fn=args.cram_reference,
allow_supps=args.allow_supplementary_alignments)
if map_info.do_output_mappings:
try:
map_info.test_open_alignment_out_file()
Expand Down
Loading

0 comments on commit ccd9c5b

Please sign in to comment.