From 4a4fd146980d167aa6ed37ec4dd345272f9d470f Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 3 Dec 2024 22:46:31 -0800 Subject: [PATCH] Refactor kmer map --- src/common/adt/kmer_vector.hpp | 4 + src/common/alignment/kmer_map.hpp | 204 +++++++++++++++++-------- src/common/alignment/kmer_mapper.hpp | 86 ++++------- src/common/paired_info/paired_info.hpp | 2 +- src/common/sequence/rtseq.hpp | 15 +- 5 files changed, 186 insertions(+), 125 deletions(-) diff --git a/src/common/adt/kmer_vector.hpp b/src/common/adt/kmer_vector.hpp index 9d4c1760c..31c98bd0c 100644 --- a/src/common/adt/kmer_vector.hpp +++ b/src/common/adt/kmer_vector.hpp @@ -172,6 +172,10 @@ class KMerVector { return vector_[idx]; } + unsigned K() const { + return K_; + } + private: unsigned K_; size_t size_; diff --git a/src/common/alignment/kmer_map.hpp b/src/common/alignment/kmer_map.hpp index bfa7d9f9e..af1bb744f 100644 --- a/src/common/alignment/kmer_map.hpp +++ b/src/common/alignment/kmer_map.hpp @@ -10,135 +10,213 @@ #include "sequence/rtseq.hpp" +#include "adt/kmer_vector.hpp" + +#include #include #include - -#define XXH_INLINE_ALL -#include "xxh/xxhash.h" +#include +#include namespace debruijn_graph { class KMerMap { - struct str_hash { - std::size_t operator()(const char* key, std::size_t key_size) const { - return XXH3_64bits(key, key_size); + typedef RtSeq Kmer; + typedef typename Kmer::DataType RawSeqData; + + static constexpr uint64_t kThombstone = UINT64_MAX; + + struct RawKMerHash { + using is_transparent = void; + + size_t operator()(const Kmer &k) const noexcept { + return k.GetHash(); + } + + size_t operator()(const RawSeqData *k) const noexcept { + return Kmer::GetHash(k, kmers_.el_size()); + } + + size_t operator()(size_t idx) const noexcept { + return Kmer::GetHash(kmers_[idx], kmers_.el_size()); } + + RawKMerHash(const adt::KMerVector &kmers) noexcept + : kmers_(kmers) {} + + const adt::KMerVector &kmers_; }; - typedef RtSeq Kmer; - typedef RtSeq Seq; - typedef typename Seq::DataType RawSeqData; - typedef typename tsl::htrie_map HTMap; + struct RawKMerEq { + using is_transparent = void; + + bool operator()(size_t lhs, size_t rhs) const noexcept { + return lhs == rhs; + } + + bool operator()(size_t lhs, const Kmer &rhs) const noexcept { + return rhs.Eq(kmers_[lhs]); + } + + bool operator()(size_t lhs, const RawSeqData *rhs) const noexcept { + return Kmer::Eq(kmers_[lhs], rhs, kmers_.el_size()); + } + + RawKMerEq(const adt::KMerVector &kmers) noexcept + : kmers_(kmers) {} + + const adt::KMerVector &kmers_; + }; + + + using Mapping = phmap::parallel_flat_hash_map; class iterator : public boost::iterator_facade, + const std::pair, std::forward_iterator_tag, - const std::pair> { + const std::pair> { public: - iterator(unsigned k, HTMap::const_iterator iter) - : k_(k), iter_(iter) {} + iterator(const adt::KMerVector &kmers, + Mapping::const_iterator iter, Mapping::const_iterator end) + : kmers_(kmers), iter_(iter), end_(end) { skip(); } private: friend class boost::iterator_core_access; - void increment() { - ++iter_; + void skip() { + // Skip over singletons (values) + while (iter_ != end_ && iter_->second == kThombstone) + ++iter_; } - + void increment() { ++iter_; skip(); } bool equal(const iterator &other) const { return iter_ == other.iter_; } - const std::pair dereference() const { - iter_.key(key_out_); - Kmer k(k_, (const RawSeqData*)key_out_.data()); - Seq s(k_, (const RawSeqData*)iter_.value()); - return std::make_pair(k, s); + const std::pair dereference() const { + VERIFY(iter_->second != kThombstone); + const adt::KMerVector &ref = kmers_; + return std::pair(Kmer(ref.K(), ref[iter_->first]), + Kmer(ref.K(), ref[iter_->second])); } - unsigned k_; - HTMap::const_iterator iter_; - mutable std::string key_out_; + // So we can easily copy the stuff + std::reference_wrapper> kmers_; + Mapping::const_iterator iter_; + Mapping::const_iterator end_; }; public: KMerMap(unsigned k) - : k_(k) { - rawcnt_ = (unsigned)Seq::GetDataSize(k_); + : kmers_(k), + mapping_(0, RawKMerHash(kmers_), RawKMerEq(kmers_)) { } ~KMerMap() { clear(); } - void erase(const Kmer &key) { - auto res = mapping_.find_ks((const char*)key.data(), rawcnt_ * sizeof(RawSeqData)); - if (res == mapping_.end()) - return; - - delete[] res.value(); - mapping_.erase(res); + template + bool erase(const Key &key) { + return mapping_.erase(key); } - void set(const Kmer &key, const Seq &value) { - RawSeqData *rawvalue = nullptr; - auto res = mapping_.find_ks((const char*)key.data(), rawcnt_ * sizeof(RawSeqData)); - if (res == mapping_.end()) { - rawvalue = new RawSeqData[rawcnt_]; - mapping_.insert_ks((const char*)key.data(), rawcnt_ * sizeof(RawSeqData), rawvalue); + template + bool set(const Key1 &key, const Key2 &value) { + // Ok, this is a little bit tricky. First of all, we need to see, if we + // know the indices for both key and value. We start from value, so we + // can save on lookups. + + bool inserted = false; + size_t vhash = mapping_.hash(value); + auto vit = mapping_.find(value, vhash); + size_t vidx = kThombstone; + if (vit == mapping_.end()) { + // We have not seen the value yet, put it into the vector + kmers_.push_back(value); + vidx = kmers_.size() - 1; + // Save the mapping, hash ensures that hash(value) == hash(vidx) + // since hash(vidx) == hash(kmers_[vidx]) + auto [it, emplaced] = mapping_.emplace_with_hash(vhash, vidx, kThombstone); + VERIFY(emplaced); inserted |= emplaced; } else { - rawvalue = res.value(); + vidx = vit->first; } - memcpy(rawvalue, value.data(), rawcnt_ * sizeof(RawSeqData)); - } - bool count(const Kmer &key) const { - return mapping_.count_ks((const char*)key.data(), rawcnt_ * sizeof(RawSeqData)); + // Check, if we know the key + size_t khash = mapping_.hash(key); + auto kit = mapping_.find(key, khash); + size_t kidx = kThombstone; + if (kit == mapping_.end()) { + // Key is not known, put into the vector + // We have not seen the value yet, put it into the vector + kmers_.push_back(key); + kidx = kmers_.size() - 1; + auto [it, emplaced] = mapping_.emplace_with_hash(khash, kidx, vidx); + VERIFY(emplaced); inserted |= emplaced; + } else { + // Key is known, just update the value index + // kidx = kit->first; + kit->second = vidx; + } + + return inserted; } - const RawSeqData *find(const Kmer &key) const { - auto res = mapping_.find_ks((const char*)key.data(), rawcnt_ * sizeof(RawSeqData)); + template + bool count(const Key &key) const { + auto res = mapping_.find(key); if (res == mapping_.end()) - return nullptr; + return false; - return res.value(); + return res->second != kThombstone; } - const RawSeqData *find(const RawSeqData *key) const { - auto res = mapping_.find_ks((const char*)key, rawcnt_ * sizeof(RawSeqData)); + template + bool idx(const Key &key) const { + auto it = mapping_.find(key); + VERIFY(it != mapping_.end()); + return it->first; + } + + template + const RawSeqData *find(const Key &key) const { + auto res = mapping_.find(key); if (res == mapping_.end()) return nullptr; - return res.value(); + return res->second != kThombstone ? kmers_[res->second] : nullptr; } void clear() { // Delete all the values - for (auto it = mapping_.begin(); it != mapping_.end(); ++it) { - VERIFY(it.value() != nullptr); - delete[] it.value(); - it.value() = nullptr; - } + kmers_.clear(); // Delete the mapping and all the keys mapping_.clear(); } size_t size() const { - return mapping_.size(); + size_t sz = 0; + for (const auto &entry : mapping_) + sz += entry.second != kThombstone; + return sz; } iterator begin() const { - return iterator(k_, mapping_.begin()); + return iterator(kmers_, mapping_.begin(), mapping_.end()); } iterator end() const { - return iterator(k_, mapping_.end()); + return iterator(kmers_, mapping_.end(), mapping_.end()); + } + + const auto &kmers() const { + return kmers_; } private: - unsigned k_; - unsigned rawcnt_; - HTMap mapping_; + adt::KMerVector kmers_; + Mapping mapping_; }; } diff --git a/src/common/alignment/kmer_mapper.hpp b/src/common/alignment/kmer_mapper.hpp index 06a442cce..006f7b2b7 100644 --- a/src/common/alignment/kmer_mapper.hpp +++ b/src/common/alignment/kmer_mapper.hpp @@ -14,8 +14,8 @@ #include "assembly_graph/core/action_handlers.hpp" #include "sequence/sequence.hpp" #include "sequence/sequence_tools.hpp" +#include "utils/verify.hpp" -#include #include namespace debruijn_graph { @@ -31,19 +31,18 @@ class KmerMapper : public omnigraph::GraphActionHandler { KMerMap mapping_; bool normalized_; - bool CheckAllDifferent(const Sequence &old_s, const Sequence &new_s) const { - std::set kmers; - Kmer kmer = old_s.start(k_) >> 0; - for (size_t i = k_ - 1; i < old_s.size(); ++i) { - kmer <<= old_s[i]; - kmers.insert(kmer); - } - kmer = new_s.start(k_) >> 0; - for (size_t i = k_ - 1; i < new_s.size(); ++i) { - kmer <<= new_s[i]; - kmers.insert(kmer); + template + const RawSeqData* GetNonTrivialRoot(const Kmer &kmer) const { + const RawSeqData *answer = nullptr; + const RawSeqData *rawval = mapping_.find(kmer); + + size_t step = 0; + while (rawval != nullptr) { + answer = rawval; + rawval = mapping_.find(rawval); + step += 1; } - return kmers.size() == old_s.size() - k_ + 1 + new_s.size() - k_ + 1; + return step > 1 ? answer : nullptr; } public: @@ -54,38 +53,28 @@ class KmerMapper : public omnigraph::GraphActionHandler { normalized_(false) { } - virtual ~KmerMapper() {} + virtual ~KmerMapper() = default; - auto begin() const -> decltype(mapping_.begin()) { - return mapping_.begin(); - } - - auto end() const -> decltype(mapping_.end()) { - return mapping_.end(); - } + auto begin() const { return mapping_.begin(); } + auto end() const { return mapping_.end(); } void Normalize() { if (normalized_) return; - adt::KMerVector all(k_, size()); - for (auto it = begin(); it != end(); ++it) - all.push_back(it->first); - - std::vector roots(all.size(), nullptr); - -# pragma omp parallel for + const adt::KMerVector &all = mapping_.kmers(); + #pragma omp parallel for for (size_t i = 0; i < all.size(); ++i) { - Seq val(k_, all[i]); - roots[i] = GetRoot(val); - } + const RawSeqData *kmer = all[i]; + const RawSeqData *root = GetNonTrivialRoot(kmer); + if (!root) + continue; -# pragma omp parallel for - for (size_t i = 0; i < all.size(); ++i) { - if (roots[i] != nullptr) { - Seq kmer(k_, all[i]); - mapping_.set(kmer, Seq(k_, roots[i])); - } + // This is potentially racy, however we do not insert new values, we + // only change values of the existing ones in a pre-determined way + // (root), the final result is ok. + bool inserted = mapping_.set(kmer, root); + VERIFY_MSG(!inserted, "should never insert new kmers here"); } normalized_ = true; @@ -95,15 +84,6 @@ class KmerMapper : public omnigraph::GraphActionHandler { return k_; } -// void Revert(const Kmer &kmer) { -// Kmer old_value = Substitute(kmer); -// if (old_value != kmer) { -// mapping_.erase(kmer); -// mapping_.set(old_value, kmer); -// normalized_ = false; -// } -// } - void RemapKmers(const Sequence &old_s, const Sequence &new_s) { VERIFY(this->IsAttached()); size_t old_length = old_s.size() - k_ + 1; @@ -150,20 +130,6 @@ class KmerMapper : public omnigraph::GraphActionHandler { RemapKmers(this->g().EdgeNucls(edge1), this->g().EdgeNucls(edge2)); } - const RawSeqData* GetRoot(const Kmer &kmer) const { - const RawSeqData *answer = nullptr; - const RawSeqData *rawval = mapping_.find(kmer); - - while (rawval != nullptr) { - Seq val(k_, rawval); - - answer = rawval; - rawval = mapping_.find(val); - } - return answer; - } - - Kmer Substitute(const Kmer &kmer) const { VERIFY(this->IsAttached()); const auto *rawval = mapping_.find(kmer); diff --git a/src/common/paired_info/paired_info.hpp b/src/common/paired_info/paired_info.hpp index ad9455034..4ba6648ad 100644 --- a/src/common/paired_info/paired_info.hpp +++ b/src/common/paired_info/paired_info.hpp @@ -743,7 +743,7 @@ class PairedIndices { VERIFY(size == data_.size()); - for (int i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { data_[i].BinRead(str); } } diff --git a/src/common/sequence/rtseq.hpp b/src/common/sequence/rtseq.hpp index 0324831d0..6a0ad3eda 100644 --- a/src/common/sequence/rtseq.hpp +++ b/src/common/sequence/rtseq.hpp @@ -334,7 +334,7 @@ class RuntimeSeq { RuntimeSeq start(size_t K) const { return RuntimeSeq(K, data_.data()); } - + /** * Reads sequence from the file (in the same format as BinWrite writes it) * and returns false if error occured, true otherwise. @@ -695,6 +695,19 @@ class RuntimeSeq { return GetHash(data_.data(), GetDataSize(size_), seed); } + static bool Eq(const DataType *lhs, const DataType *rhs, size_t sz) { + for (size_t i = 0; i < sz; ++i) + if (lhs[i] != rhs[i]) + return false; + + return true; + } + + bool Eq(const DataType *data) const { + size_t data_size = GetDataSize(size_); + return Eq(data_.data(), data, data_size); + } + struct hash { size_t operator()(const RuntimeSeq &seq, uint64_t seed = 0) const { return seq.GetHash(seed);