Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hmusta committed Mar 21, 2024
1 parent d57beb7 commit 39f6713
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 92 deletions.
3 changes: 2 additions & 1 deletion metagraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ include_directories(
external-libraries/zlib
external-libraries/sdust
external-libraries/simde-no-tests
external-libraries/sshash/include
external-libraries/sshash/external/pthash/external/essentials/include
${PROJECT_SOURCE_DIR}/src
)

Expand Down Expand Up @@ -320,7 +322,6 @@ IF(APPLE)
ENDIF()
add_subdirectory(external-libraries/spdlog)
add_subdirectory(external-libraries/sshash SYSTEM)
target_include_directories(sshash_static PUBLIC SYSTEM external-libraries/sshash/include)
target_compile_options(test_alphabet PRIVATE -Wno-strict-aliasing)
add_subdirectory(external-libraries/DYNAMIC)
add_subdirectory(external-libraries/zlib)
Expand Down
4 changes: 2 additions & 2 deletions metagraph/src/cli/build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ int build_graph(Config *config) {

} else if (config->graph_type == Config::GraphType::SSHASH){

graph.reset(new DBGSSHash(files.at(0), config->k));
graph.reset(new DBGSSHash(files.at(0), config->k, config->graph_mode));
if(files.size() > 1){
logger->error("Only one file for SSHash");
exit(1);
}

}else {
//slower method
switch (config->graph_type) {
Expand Down
22 changes: 22 additions & 0 deletions metagraph/src/graph/representation/canonical_dbg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common/seq_tools/reverse_complement.hpp"
#include "common/logger.hpp"
#include "graph/representation/succinct/dbg_succinct.hpp"
#include "graph/representation/hash/dbg_sshash.hpp"


namespace mtg {
Expand Down Expand Up @@ -62,6 +63,13 @@ ::map_to_nodes_sequentially(std::string_view sequence,
std::vector<node_index> path;
path.reserve(sequence.size() - get_k() + 1);

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->map_to_nodes_with_rc(sequence, [&](node_index node, bool orientation) {
callback(node && orientation ? reverse_complement(node) : node);
}, terminate);
return;
}

// map until the first mismatch
bool stop = false;
graph_->map_to_nodes_sequentially(sequence,
Expand Down Expand Up @@ -171,6 +179,13 @@ void CanonicalDBG::call_outgoing_kmers(node_index node,
return;
}

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->call_outgoing_kmers_with_rc(node, [&](node_index next, char c, bool orientation) {
callback(orientation ? reverse_complement(next) : next, c);
});
return;
}

// includes `$` for DBGSuccinct
const auto &alphabet = graph_->alphabet();

Expand Down Expand Up @@ -257,6 +272,13 @@ void CanonicalDBG::call_incoming_kmers(node_index node,
return;
}

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->call_incoming_kmers_with_rc(node, [&](node_index prev, char c, bool orientation) {
callback(orientation ? reverse_complement(prev) : prev, c);
});
return;
}

// includes `$` for DBGSuccinct
const auto &alphabet = graph_->alphabet();

Expand Down
142 changes: 108 additions & 34 deletions metagraph/src/graph/representation/hash/dbg_sshash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ DBGSSHash::DBGSSHash(size_t k):k_(k) {
dict_ = std::make_unique<sshash::dictionary>();
}

DBGSSHash::DBGSSHash(std::string const& input_filename, size_t k):k_(k){
DBGSSHash::DBGSSHash(std::string const& input_filename, size_t k, Mode mode):k_(k), mode_(mode) {
sshash::build_configuration build_config;
build_config.k = k;//
// quick fix for value of m... k/2 but odd
build_config.m = (k_+1)/2;
build_config.num_threads = get_num_threads();
if(build_config.m % 2 == 0) build_config.m++;
dict_ = std::make_unique<sshash::dictionary>();
dict_->build(input_filename, build_config);
}
std::string DBGSSHash::file_extension() const { return kExtension; }
size_t DBGSSHash::get_k() const { return k_; }
DeBruijnGraph::Mode DBGSSHash::get_mode() const { return BASIC; }
DeBruijnGraph::Mode DBGSSHash::get_mode() const { return mode_; }

void DBGSSHash::add_sequence(std::string_view sequence,
const std::function<void(node_index)> &on_insertion) {
Expand All @@ -39,15 +40,32 @@ void DBGSSHash::map_to_nodes(std::string_view sequence,
void DBGSSHash ::map_to_nodes_sequentially(std::string_view sequence,
const std::function<void(node_index)> &callback,
const std::function<bool()> &terminate) const {
if (terminate() || sequence.size() < k_)
return;

auto uint_kmer = sshash::util::string_to_uint_kmer(sequence.data(), k_ - 1) << 2;
for (size_t i = k_ - 1; i < sequence.size() && !terminate(); ++i) {
uint_kmer = (uint_kmer >> 2) + (sshash::util::char_to_uint(sequence[i]) << (2 * (k_ - 1)));
callback(dict_->lookup_uint(uint_kmer, false) + 1);
}
}

void DBGSSHash ::map_to_nodes_with_rc(std::string_view sequence,
const std::function<void(node_index, bool)> &callback,
const std::function<bool()> &terminate) const {
sshash::streaming_query_regular_parsing streamer(dict_.get());
streamer.start();
for (size_t i = 0; i + k_ <= sequence.size() && !terminate(); ++i) {
callback(kmer_to_node(sequence.substr(i, k_)));
const char *kmer = sequence.data() + i;
auto res = streamer.lookup_advanced(kmer);
callback(res.kmer_id + 1, res.kmer_orientation);
}
}

DBGSSHash::node_index DBGSSHash::traverse(node_index node, char next_char) const {
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_forward_neighbours(&kmer[0]);
uint64_t ssh_idx = -1;
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_forward_neighbours(&kmer[0], false);
uint64_t ssh_idx = -1;
switch (next_char) {
case 'A':
ssh_idx = nb.forward_A.kmer_id;
Expand All @@ -69,8 +87,8 @@ DBGSSHash::node_index DBGSSHash::traverse(node_index node, char next_char) const

DBGSSHash::node_index DBGSSHash::traverse_back(node_index node, char prev_char) const {
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_backward_neighbours(&kmer[0]);
uint64_t ssh_idx = -1;
sshash::neighbourhood nb = dict_->kmer_backward_neighbours(&kmer[0], false);
uint64_t ssh_idx = -1;
switch (prev_char) {
case 'A':
ssh_idx = nb.backward_A.kmer_id;
Expand All @@ -94,7 +112,6 @@ void DBGSSHash ::adjacent_outgoing_nodes(node_index node,
const std::function<void(node_index)> &callback) const {
assert(node > 0 && node <= num_nodes());
call_outgoing_kmers(node, [&](auto child, char) { callback(child); });

}

void DBGSSHash ::adjacent_incoming_nodes(node_index node,
Expand All @@ -107,37 +124,87 @@ void DBGSSHash ::call_outgoing_kmers(node_index node,
const OutgoingEdgeCallback &callback) const {
assert(node > 0 && node <= num_nodes());

auto prefix = get_node_sequence(node).substr(1);
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_forward_neighbours(kmer.c_str(), false);
if (nb.forward_A.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_A.kmer_id + 1, 'A');

if (nb.forward_C.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_C.kmer_id + 1, 'C');

for (char c : alphabet_) {
auto next = kmer_to_node(prefix + c);
if (next != npos)
callback(next, c);
}
if (nb.forward_G.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_G.kmer_id + 1, 'G');

if (nb.forward_T.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_T.kmer_id + 1, 'T');
}


void DBGSSHash ::call_incoming_kmers(node_index node,
const IncomingEdgeCallback &callback) const {
assert(node > 0 && node <= num_nodes());

std::string suffix = get_node_sequence(node);
suffix.pop_back();
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_backward_neighbours(kmer.c_str(), false);
if (nb.backward_A.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_A.kmer_id + 1, 'A');

for (char c : alphabet_) {
auto prev = kmer_to_node(c + suffix);
if (prev != npos)
callback(prev, c);
}
if (nb.backward_C.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_C.kmer_id + 1, 'C');

if (nb.backward_G.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_G.kmer_id + 1, 'G');

if (nb.backward_T.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_T.kmer_id + 1, 'T');
}

void DBGSSHash ::call_outgoing_kmers_with_rc(node_index node,
const std::function<void(node_index, char, bool)> &callback) const {
assert(node > 0 && node <= num_nodes());

std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_forward_neighbours(kmer.c_str(), true);
if (nb.forward_A.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_A.kmer_id + 1, 'A', nb.forward_A.kmer_orientation);

if (nb.forward_C.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_C.kmer_id + 1, 'C', nb.forward_C.kmer_orientation);

if (nb.forward_G.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_G.kmer_id + 1, 'G', nb.forward_G.kmer_orientation);

if (nb.forward_T.kmer_id != sshash::constants::invalid_uint64)
callback(nb.forward_T.kmer_id + 1, 'T', nb.forward_T.kmer_orientation);
}


void DBGSSHash ::call_incoming_kmers_with_rc(node_index node,
const std::function<void(node_index, char, bool)> &callback) const {
assert(node > 0 && node <= num_nodes());

std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_backward_neighbours(kmer.c_str(), true);
if (nb.backward_A.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_A.kmer_id + 1, 'A', nb.backward_A.kmer_orientation);

if (nb.backward_C.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_C.kmer_id + 1, 'C', nb.backward_C.kmer_orientation);

if (nb.backward_G.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_G.kmer_id + 1, 'G', nb.backward_G.kmer_orientation);

if (nb.backward_T.kmer_id != sshash::constants::invalid_uint64)
callback(nb.backward_T.kmer_id + 1, 'T', nb.backward_T.kmer_orientation);
}

size_t DBGSSHash::outdegree(node_index node) const {
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_forward_neighbours(&kmer[0]);
size_t out_deg = bool(nb.forward_A.kmer_id + 1) // change to loop?
+ bool(nb.forward_C.kmer_id + 1)
+ bool(nb.forward_G.kmer_id + 1)
+ bool(nb.forward_T.kmer_id + 1);
sshash::neighbourhood nb = dict_->kmer_forward_neighbours(&kmer[0], false);
size_t out_deg = (nb.forward_A.kmer_id != sshash::constants::invalid_uint64) // change to loop?
+ (nb.forward_C.kmer_id != sshash::constants::invalid_uint64)
+ (nb.forward_G.kmer_id != sshash::constants::invalid_uint64)
+ (nb.forward_T.kmer_id != sshash::constants::invalid_uint64);
return out_deg;
}

Expand All @@ -151,11 +218,11 @@ bool DBGSSHash::has_multiple_outgoing(node_index node) const {

size_t DBGSSHash::indegree(node_index node) const {
std::string kmer = DBGSSHash::get_node_sequence(node);
sshash::neighbourhood nb = dict_->kmer_backward_neighbours(&kmer[0]);
size_t in_deg = bool(nb.backward_A.kmer_id + 1) // change to loop?
+ bool(nb.backward_C.kmer_id + 1)
+ bool(nb.backward_G.kmer_id + 1)
+ bool(nb.backward_T.kmer_id + 1);
sshash::neighbourhood nb = dict_->kmer_backward_neighbours(kmer.c_str(), false);
size_t in_deg = (nb.backward_A.kmer_id != sshash::constants::invalid_uint64) // change to loop?
+ (nb.backward_C.kmer_id != sshash::constants::invalid_uint64)
+ (nb.backward_G.kmer_id != sshash::constants::invalid_uint64)
+ (nb.backward_T.kmer_id != sshash::constants::invalid_uint64);
return in_deg;
}

Expand All @@ -175,8 +242,15 @@ void DBGSSHash::call_kmers(
}

DBGSSHash::node_index DBGSSHash::kmer_to_node(std::string_view kmer) const {
uint64_t ssh_idx = dict_->lookup(kmer.begin(), false);
return ssh_idx + 1;
return num_nodes() ? dict_->lookup(kmer.begin(), false) + 1 : npos;
}

std::pair<DBGSSHash::node_index, bool> DBGSSHash::kmer_to_node_with_rc(std::string_view kmer) const {
if (!num_nodes())
return std::make_pair(npos, false);

auto res = dict_->lookup_advanced(kmer.begin(), true);
return std::make_pair(res.kmer_id + 1, res.kmer_orientation);
}

std::string DBGSSHash::get_node_sequence(node_index node) const {
Expand Down
19 changes: 17 additions & 2 deletions metagraph/src/graph/representation/hash/dbg_sshash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace graph {
class DBGSSHash : public DeBruijnGraph {
public:
explicit DBGSSHash(size_t k);
DBGSSHash(std::string const& input_filename, size_t k);
DBGSSHash(std::string const& input_filename, size_t k, Mode mode = BASIC);

~DBGSSHash();

Expand All @@ -35,6 +35,11 @@ class DBGSSHash : public DeBruijnGraph {
const std::function<void(node_index)> &callback,
const std::function<bool()> &terminate = []() { return false; }) const override;

void map_to_nodes_with_rc(
std::string_view sequence,
const std::function<void(node_index, bool)> &callback,
const std::function<bool()> &terminate = []() { return false; }) const;

void adjacent_outgoing_nodes(node_index node,
const std::function<void(node_index)> &callback) const override;

Expand Down Expand Up @@ -74,22 +79,32 @@ class DBGSSHash : public DeBruijnGraph {
bool has_single_incoming(node_index) const override;

node_index kmer_to_node(std::string_view kmer) const override;

std::pair<node_index, bool> kmer_to_node_with_rc(std::string_view kmer) const;

void call_outgoing_kmers(node_index node,
const OutgoingEdgeCallback &callback) const override;

void call_outgoing_kmers_with_rc(node_index node,
const std::function<void(node_index, char, bool)> &callback) const;

void call_incoming_kmers(node_index node,
const IncomingEdgeCallback &callback) const override;

void call_incoming_kmers_with_rc(node_index node,
const std::function<void(node_index, char, bool)> &callback) const;


bool operator==(const DeBruijnGraph &other) const override;

const std::string &alphabet() const override;

const sshash::dictionary& data() const { return *dict_; }

private:
static const std::string alphabet_;
std::unique_ptr<sshash::dictionary> dict_;
size_t k_;
Mode mode_;
};

} // namespace graph
Expand Down
Loading

0 comments on commit 39f6713

Please sign in to comment.