Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for queries with invalid characters #502

Merged
merged 11 commits into from
Oct 8, 2024
3 changes: 1 addition & 2 deletions metagraph/src/common/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ namespace utils {
size_t segment_length) {
std::vector<bool> mask(array.size(), false);
size_t last_occurrence
= std::find(array.data(), array.data() + array.size(), label)
- array.data();
= std::find(array.begin(), array.end(), label) - array.begin();

for (size_t i = last_occurrence; i < array.size(); ++i) {
if (array[i] == label)
Expand Down
6 changes: 3 additions & 3 deletions metagraph/src/graph/representation/canonical_dbg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ::map_to_nodes_sequentially(std::string_view sequence,
path.reserve(sequence.size() - get_k() + 1);

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->map_to_nodes_with_rc<>(sequence, [&](node_index node, bool orientation) {
sshash->map_to_nodes_with_rc<true>(sequence, [&](node_index node, bool orientation) {
adamant-pwn marked this conversation as resolved.
Show resolved Hide resolved
callback(node && orientation ? reverse_complement(node) : node);
}, terminate);
return;
Expand Down Expand Up @@ -180,7 +180,7 @@ void CanonicalDBG::call_outgoing_kmers(node_index node,
}

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->call_outgoing_kmers_with_rc<>(node, [&](node_index next, char c, bool orientation) {
sshash->call_outgoing_kmers_with_rc<true>(node, [&](node_index next, char c, bool orientation) {
callback(orientation ? reverse_complement(next) : next, c);
});
return;
Expand Down Expand Up @@ -273,7 +273,7 @@ void CanonicalDBG::call_incoming_kmers(node_index node,
}

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->call_incoming_kmers_with_rc<>(node, [&](node_index prev, char c, bool orientation) {
sshash->call_incoming_kmers_with_rc<true>(node, [&](node_index prev, char c, bool orientation) {
callback(orientation ? reverse_complement(prev) : prev, c);
});
return;
Expand Down
54 changes: 38 additions & 16 deletions metagraph/src/graph/representation/hash/dbg_sshash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/seq_tools/reverse_complement.hpp"
#include "common/threads/threading.hpp"
#include "common/logger.hpp"
#include "common/algorithms.hpp"
#include "kmer/kmer_extractor.hpp"


Expand Down Expand Up @@ -99,32 +100,53 @@ void DBGSSHash::add_sequence(std::string_view sequence,
throw std::logic_error("adding sequences not supported");
}

template <bool with_rc>
void DBGSSHash::map_to_nodes_with_rc(std::string_view sequence,
const std::function<void(node_index, bool)>& callback,
const std::function<bool()>& terminate) const {
if (terminate() || sequence.size() < k_)
template <bool with_rc, class Dict>
void map_to_nodes_with_rc_impl(size_t k,
const Dict &dict,
std::string_view sequence,
const std::function<void(sshash::lookup_result)>& callback,
const std::function<bool()>& terminate) {
size_t n = sequence.size();
if (terminate() || n < k)
return;

if (!num_nodes()) {
for (size_t i = 0; i < sequence.size() - k_ + 1 && !terminate(); ++i) {
callback(npos, false);
if (!dict.size()) {
for (size_t i = 0; i + k <= sequence.size() && !terminate(); ++i) {
callback(sshash::lookup_result());
}
return;
}

using kmer_t = get_kmer_t<Dict>;

std::vector<bool> invalid_char(n);
for (size_t i = 0; i < n; ++i) {
invalid_char[i] = !kmer_t::is_valid(sequence[i]);
}

auto invalid_kmer = utils::drag_and_mark_segments(invalid_char, true, k);

kmer_t uint_kmer = sshash::util::string_to_uint_kmer<kmer_t>(sequence.data(), k - 1);
uint_kmer.pad_char();
for (size_t i = k - 1; i < n && !terminate(); ++i) {
uint_kmer.drop_char();
uint_kmer.kth_char_or(k - 1, kmer_t::char_to_uint(sequence[i]));
callback(invalid_kmer[i] ? sshash::lookup_result()
: dict.lookup_advanced_uint(uint_kmer, with_rc));
}
}

template <bool with_rc>
void DBGSSHash::map_to_nodes_with_rc(std::string_view sequence,
const std::function<void(node_index, bool)>& callback,
const std::function<bool()>& terminate) const {
std::visit([&](const auto &dict) {
using kmer_t = get_kmer_t<decltype(dict)>;
kmer_t uint_kmer = sshash::util::string_to_uint_kmer<kmer_t>(sequence.data(), k_ - 1);
uint_kmer.pad_char();
for (size_t i = k_ - 1; i < sequence.size() && !terminate(); ++i) {
uint_kmer.drop_char();
uint_kmer.kth_char_or(k_ - 1, kmer_t::char_to_uint(sequence[i]));
auto res = dict.lookup_advanced_uint(uint_kmer, with_rc);
map_to_nodes_with_rc_impl<with_rc>(k_, dict, sequence, [&](sshash::lookup_result res) {
callback(sshash_to_graph_index(res.kmer_id), res.kmer_orientation);
}
}, terminate);
}, dict_);
}

template
void DBGSSHash::map_to_nodes_with_rc<true>(std::string_view,
const std::function<void(node_index, bool)>&,
Expand Down
1 change: 1 addition & 0 deletions metagraph/tests/annotation/test_aligner_labeled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LabeledAlignerTest : public ::testing::Test {};

typedef ::testing::Types<std::pair<DBGHashFast, annot::ColumnCompressed<>>,
std::pair<DBGSuccinct, annot::ColumnCompressed<>>,
std::pair<DBGSSHash, annot::ColumnCompressed<>>,
std::pair<DBGHashFast, annot::RowFlatAnnotator>,
std::pair<DBGSuccinct, annot::RowFlatAnnotator>,
std::pair<DBGSuccinct, annot::RowDiffColumnAnnotator>> FewGraphAnnotationPairTypes;
Expand Down
7 changes: 3 additions & 4 deletions metagraph/tests/annotation/test_annotated_dbg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
#include "gtest/gtest.h"

#include "../test_helpers.hpp"
#include "../graph/all/test_dbg_helpers.hpp"

#include "common/threads/threading.hpp"
#include "common/vectors/bit_vector_dyn.hpp"
#include "common/vectors/vector_algorithm.hpp"
#include "annotation/representation/column_compressed/annotate_column_compressed.hpp"
#include "graph/representation/bitmap/dbg_bitmap.hpp"
#include "graph/representation/hash/dbg_hash_string.hpp"
#include "graph/representation/hash/dbg_hash_ordered.hpp"
#include "graph/representation/hash/dbg_hash_fast.hpp"

#define protected public
#define private public
Expand Down Expand Up @@ -987,6 +984,7 @@ typedef ::testing::Types<std::pair<DBGBitmap, annot::ColumnCompressed<>>,
std::pair<DBGHashOrdered, annot::ColumnCompressed<>>,
std::pair<DBGHashFast, annot::ColumnCompressed<>>,
std::pair<DBGSuccinct, annot::ColumnCompressed<>>,
std::pair<DBGSSHash, annot::ColumnCompressed<>>,
std::pair<DBGBitmap, annot::RowFlatAnnotator>,
std::pair<DBGHashString, annot::RowFlatAnnotator>,
std::pair<DBGHashOrdered, annot::RowFlatAnnotator>,
Expand Down Expand Up @@ -1016,6 +1014,7 @@ class AnnotatedDBGNoNTest : public ::testing::Test {};
typedef ::testing::Types<std::pair<DBGBitmap, annot::ColumnCompressed<>>,
std::pair<DBGHashOrdered, annot::ColumnCompressed<>>,
std::pair<DBGHashFast, annot::ColumnCompressed<>>,
std::pair<DBGSSHash, annot::ColumnCompressed<>>,
std::pair<DBGBitmap, annot::RowFlatAnnotator>,
std::pair<DBGHashOrdered, annot::RowFlatAnnotator>,
std::pair<DBGHashFast, annot::RowFlatAnnotator>,
Expand Down
1 change: 1 addition & 0 deletions metagraph/tests/annotation/test_annotated_dbg_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGBitmap, ColumnCompres
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGHashOrdered, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGHashFast, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGHashString, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGSSHash, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);

template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGSuccinct, RowFlatAnnotator>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGBitmap, RowFlatAnnotator>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
Expand Down
5 changes: 3 additions & 2 deletions metagraph/tests/graph/all/test_dbg_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ void writeFastaFile(const std::vector<std::string>& sequences, const std::string

fastaFile.close();
}

template <>
std::shared_ptr<DeBruijnGraph>
build_graph<DBGSSHash>(uint64_t k,
Expand All @@ -154,8 +155,8 @@ build_graph<DBGSSHash>(uint64_t k,
if (sequences.empty())
return std::make_shared<DBGSSHash>(k, mode);

// use DBGHashString to get contigs for SSHash
auto string_graph = build_graph<DBGHashString>(k, sequences, mode);
// use DBGHashFast to get contigs for SSHash
auto string_graph = build_graph<DBGHashFast>(k, sequences, mode);

std::vector<std::string> contigs;
size_t num_kmers = 0;
Expand Down
Loading