Skip to content

Commit

Permalink
Merge branch 'DOR-424_adapter_trimming_in_basecaller' into 'master'
Browse files Browse the repository at this point in the history
DOR-424 Adapter trimming in basecaller

Closes DOR-424

See merge request machine-learning/dorado!749
  • Loading branch information
kdolan1973 committed Dec 4, 2023
2 parents 6518087 + f30cece commit 30e639c
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 22 deletions.
70 changes: 60 additions & 10 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "nn/CRFModelConfig.h"
#include "nn/ModBaseRunner.h"
#include "nn/Runners.h"
#include "read_pipeline/AdapterDetectorNode.h"
#include "read_pipeline/AlignerNode.h"
#include "read_pipeline/BarcodeClassifierNode.h"
#include "read_pipeline/HtsReader.h"
Expand Down Expand Up @@ -77,6 +78,8 @@ void setup(std::vector<std::string> args,
const std::vector<std::string>& barcode_kits,
bool barcode_both_ends,
bool barcode_no_trim,
bool adapter_no_trim,
bool primer_no_trim,
const std::string& barcode_sample_sheet,
const std::optional<std::string>& custom_kit,
const std::optional<std::string>& custom_seqs,
Expand Down Expand Up @@ -126,9 +129,10 @@ void setup(std::vector<std::string> args,
auto read_groups = DataLoader::load_read_groups(data_path, model_name, modbase_model_names,
recursive_file_loading);

bool adapter_trimming_enabled = (!adapter_no_trim || !primer_no_trim);
const auto thread_allocations = utils::default_thread_allocations(
int(num_devices), !remora_runners.empty() ? int(num_remora_threads) : 0, enable_aligner,
!barcode_kits.empty());
!barcode_kits.empty(), adapter_trimming_enabled);

std::unique_ptr<const utils::SampleSheet> sample_sheet;
BarcodingInfo::FilterSet allowed_barcodes;
Expand Down Expand Up @@ -167,6 +171,11 @@ void setup(std::vector<std::string> args,
barcode_both_ends, barcode_no_trim, std::move(allowed_barcodes),
std::move(custom_kit), std::move(custom_seqs));
}
if (adapter_trimming_enabled) {
current_sink_node = pipeline_desc.add_node<AdapterDetectorNode>(
{current_sink_node}, thread_allocations.adapter_threads, !adapter_no_trim,
!primer_no_trim);
}
current_sink_node = pipeline_desc.add_node<ReadFilterNode>(
{current_sink_node}, min_qscore, default_parameters.min_sequence_length,
std::unordered_set<std::string>{}, thread_allocations.read_filter_threads);
Expand Down Expand Up @@ -386,9 +395,19 @@ int basecaller(int argc, char* argv[]) {
.default_value(false)
.implicit_value(true);
parser.visible.add_argument("--no-trim")
.help("Skip barcode trimming. If option is not chosen, trimming is enabled.")
.help("Skip trimming of barcodes, adapters, and primers. If option is not chosen, "
"trimming of all three is enabled.")
.default_value(false)
.implicit_value(true);
parser.visible.add_argument("--trim")
.help("Specify what to trim. Options are 'none', 'all', 'adapters', and 'primers'. "
"Default behavior is to trim all detected adapters, primers, or barcodes. "
"Choose 'adapters' to just trim adapters. The 'primers' choice will trim "
"adapters and "
"primers, but not barcodes. The 'none' choice is equivelent to using --no-trim. "
"Note that "
"this only applies to DNA. RNA adapters are always trimmed.")
.default_value(std::string(""));
parser.visible.add_argument("--sample-sheet")
.help("Path to the sample sheet to use.")
.default_value(std::string(""));
Expand All @@ -399,9 +418,9 @@ int basecaller(int argc, char* argv[]) {
.help("Path to file with custom barcode sequences.")
.default_value(std::nullopt);
parser.visible.add_argument("--estimate-poly-a")
.help("Estimate poly-A/T tail lengths (beta feature). Primarily meant "
"for cDNA and "
"dRNA use cases.")
.help("Estimate poly-A/T tail lengths (beta feature). Primarily meant for cDNA and "
"dRNA use cases. Note that if this flag is set, then adapter/primer detection "
"will be disabled.")
.default_value(false)
.implicit_value(true);

Expand Down Expand Up @@ -482,6 +501,38 @@ int basecaller(int argc, char* argv[]) {
output_mode = HtsWriter::OutputMode::UBAM;
}

bool no_trim_barcodes = false, no_trim_primers = false, no_trim_adapters = false;
auto trim_options = parser.visible.get<std::string>("--trim");
if (parser.visible.get<bool>("--no-trim")) {
if (!trim_options.empty()) {
spdlog::error("Only one of --no-trim and --trim can be used.");
std::exit(EXIT_FAILURE);
}
no_trim_barcodes = no_trim_primers = no_trim_adapters = true;
}
if (trim_options == "none") {
no_trim_barcodes = no_trim_primers = no_trim_adapters = true;
} else if (trim_options == "primers") {
no_trim_barcodes = true;
} else if (trim_options == "adapters") {
no_trim_barcodes = no_trim_primers = true;
} else if (!trim_options.empty() && trim_options != "all") {
spdlog::error("Unsupported --trim value '{}'.", trim_options);
std::exit(EXIT_FAILURE);
}
if (parser.visible.get<bool>("--estimate-poly-a")) {
if (trim_options == "primers" || trim_options == "adapters" || trim_options == "all") {
spdlog::error(
"--trim cannot be used with options 'primers', 'adapters', or 'all', "
"if you are also using --estimate-poly-a.");
std::exit(EXIT_FAILURE);
}
no_trim_primers = no_trim_adapters = true;
spdlog::info(
"Estimation of poly-a has been requested, so adapter/primer trimming has been "
"disabled.");
}

if (parser.visible.is_used("--kit-name") && parser.visible.is_used("--barcode-arrangement")) {
spdlog::error(
"--kit-name and --barcode-arrangement cannot be used together. Please provide only "
Expand Down Expand Up @@ -548,11 +599,10 @@ int basecaller(int argc, char* argv[]) {
parser.hidden.get<std::string>("--dump_stats_filter"),
parser.visible.get<std::string>("--resume-from"),
parser.visible.get<std::vector<std::string>>("--kit-name"),
parser.visible.get<bool>("--barcode-both-ends"),
parser.visible.get<bool>("--no-trim"),
parser.visible.get<std::string>("--sample-sheet"), std::move(custom_kit),
std::move(custom_seqs), resume_parser, parser.visible.get<bool>("--estimate-poly-a"),
model_selection);
parser.visible.get<bool>("--barcode-both-ends"), no_trim_barcodes, no_trim_adapters,
no_trim_primers, parser.visible.get<std::string>("--sample-sheet"),
std::move(custom_kit), std::move(custom_seqs), resume_parser,
parser.visible.get<bool>("--estimate-poly-a"), model_selection);
} catch (const std::exception& e) {
spdlog::error("{}", e.what());
return 1;
Expand Down
4 changes: 3 additions & 1 deletion dorado/demux/AdapterDetector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ const std::vector<Primer> primers = {
{"PCS110_forward",
"TCGCCTACCGTGACAAGAAAGTTGTCGGTGTCTTTGTGACTTGCCTGTCGCTCTATCTTCAGAGGAGAGTCCGCCGCCCGCAAGTTT"},
{"PCS110_reverse", "ATCGCCTACCGTGACAAGAAAGTTGTCGGTGTCTTTGTGTTTCTGTTGGTGCTGATATTGCTTT"},
{"RAD", "GCTTGGGTGTTTAACCGTTTTCGCATTTATCGTGAAACGCTTTCGCGTTTTTCGTGCGCCGCTTCA"}};
// Not included because it is too similar to RBK barcode flank
// {"RAD", "GCTTGGGTGTTTAACCGTTTTCGCATTTATCGTGAAACGCTTTCGCGTTTTTCGTGCGCCGCTTCA"}
};

} // namespace

Expand Down
1 change: 0 additions & 1 deletion dorado/read_pipeline/AdapterDetectorNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ void AdapterDetectorNode::process_read(SimplexRead& read) {
auto primer_res = m_detector.find_primers(read.read_common.seq);
primer_trim_interval = Trimmer::determine_trim_interval(primer_res, seqlen);
}
read.read_common.pre_trim_seq_length = read.read_common.seq.length();
if (m_trim_adapters || m_trim_primers) {
std::pair<int, int> trim_interval = adapter_trim_interval;
trim_interval.first = std::max(trim_interval.first, primer_trim_interval.first);
Expand Down
1 change: 0 additions & 1 deletion dorado/read_pipeline/BarcodeClassifierNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ void BarcodeClassifierNode::barcode(SimplexRead& read) {
barcoding_info->allowed_barcodes);
read.read_common.barcode = generate_barcode_string(bc_res);
read.read_common.barcoding_result = std::make_shared<BarcodeScoreResult>(std::move(bc_res));
read.read_common.pre_trim_seq_length = read.read_common.seq.length();
if (barcoding_info->trim) {
read.read_common.barcode_trim_interval = Trimmer::determine_trim_interval(
*read.read_common.barcoding_result, int(read.read_common.seq.length()));
Expand Down
1 change: 1 addition & 0 deletions dorado/read_pipeline/BasecallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ void BasecallerNode::working_reads_manager() {
utils::stitch_chunks(read_common_data, working_read->called_chunks);
read_common_data.model_name = m_model_name;
read_common_data.mean_qscore_start_pos = m_mean_qscore_start_pos;
read_common_data.pre_trim_seq_length = read_common_data.seq.length();

if (m_rna) {
std::reverse(read_common_data.seq.begin(), read_common_data.seq.end());
Expand Down
17 changes: 9 additions & 8 deletions dorado/utils/parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace dorado::utils {
ThreadAllocations default_thread_allocations(int num_devices,
int num_remora_threads,
bool enable_aligner,
bool enable_barcoder) {
bool enable_barcoder,
bool adapter_trimming) {
const int max_threads = std::thread::hardware_concurrency();
ThreadAllocations allocs;
allocs.writer_threads = num_devices * 2;
Expand All @@ -24,13 +25,13 @@ ThreadAllocations default_thread_allocations(int num_devices,
allocs.splitter_node_threads);
int remaining_threads = max_threads - total_threads_used;
remaining_threads = std::max(num_devices * 10, remaining_threads);
// Divide up work equally between the aligner and barcoder nodes if both are enabled,
// otherwise both get all the remaining threads.
if (enable_aligner || enable_barcoder) {
allocs.aligner_threads =
remaining_threads * enable_aligner / (enable_aligner + enable_barcoder);
allocs.barcoder_threads =
remaining_threads * enable_barcoder / (enable_aligner + enable_barcoder);
// Divide up work equally between the aligner, barcoder, and adapter-trimming nodes, or whatever
// subset of them are enabled.
if (enable_aligner || enable_barcoder || adapter_trimming) {
int number_enabled = int(enable_aligner) + int(enable_barcoder) + int(adapter_trimming);
allocs.aligner_threads = remaining_threads * int(enable_aligner) / number_enabled;
allocs.barcoder_threads = remaining_threads * int(enable_barcoder) / number_enabled;
allocs.adapter_threads = remaining_threads * int(adapter_trimming) / number_enabled;
}
return allocs;
};
Expand Down
4 changes: 3 additions & 1 deletion dorado/utils/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ struct ThreadAllocations {
int loader_threads{0};
int aligner_threads{0};
int barcoder_threads{0};
int adapter_threads{0};
};

ThreadAllocations default_thread_allocations(int num_devices,
int num_remora_threads,
bool enable_aligner,
bool enable_barcoder);
bool enable_barcoder,
bool adapter_trimming);

} // namespace dorado::utils

0 comments on commit 30e639c

Please sign in to comment.