diff --git a/src/prefiltering/CacheFriendlyOperations.cpp b/src/prefiltering/CacheFriendlyOperations.cpp index 102362bbb..139c8ab53 100644 --- a/src/prefiltering/CacheFriendlyOperations.cpp +++ b/src/prefiltering/CacheFriendlyOperations.cpp @@ -36,7 +36,7 @@ CacheFriendlyOperations::~CacheFriendlyOperations(){ template size_t CacheFriendlyOperations::findDuplicates(IndexEntryLocal **input, CounterResult *output, - size_t outputSize, unsigned short indexFrom, unsigned short indexTo, bool computeTotalScore) { + size_t outputSize, unsigned short indexFrom, unsigned short indexTo, bool computeTotalScore) { do { setupBinPointer(); CounterResult *lastPosition = (binDataFrame + BINCOUNT * binSize) - 1; @@ -58,12 +58,16 @@ size_t CacheFriendlyOperations::mergeElementsByScore(CounterResult *inp } template -size_t CacheFriendlyOperations::mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N) { +size_t CacheFriendlyOperations::mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N, const bool keepScoredHits) { do { setupBinPointer(); hashElements(inputOutputArray, N); } while(checkForOverflowAndResizeArray(false) == true); // overflowed occurred - return mergeDiagonalDuplicates(inputOutputArray); + if(keepScoredHits){ + return mergeDiagonalKeepScoredHitsDuplicates(inputOutputArray); + }else{ + return mergeDiagonalDuplicates(inputOutputArray); + } } template @@ -93,6 +97,7 @@ size_t CacheFriendlyOperations::mergeDiagonalDuplicates(CounterResult * --n; } // combine diagonals + // we keep only the last diagonal element for (size_t n = 0; n < currBinSize; n++) { const CounterResult &element = binStartPos[n]; const unsigned int hashBinElement = element.id >> (MASK_0_5_BIT); @@ -109,6 +114,40 @@ size_t CacheFriendlyOperations::mergeDiagonalDuplicates(CounterResult * return doubleElementCount; } + +template +size_t CacheFriendlyOperations::mergeDiagonalKeepScoredHitsDuplicates(CounterResult *output) { + size_t doubleElementCount = 0; + const CounterResult *bin_ref_pointer = binDataFrame; + // duplicateBitArray is already zero'd from findDuplicates + + for (size_t bin = 0; bin < BINCOUNT; bin++) { + const CounterResult *binStartPos = (bin_ref_pointer + bin * binSize); + const size_t currBinSize = (bins[bin] - binStartPos); + // write diagonals + 1 in reverse order in the byte array + for (size_t n = 0; n < currBinSize; n++) { + const unsigned int element = binStartPos[n].id >> (MASK_0_5_BIT); + duplicateBitArray[element] = static_cast(binStartPos[n].diagonal) + 1; + } + // combine diagonals + // we keep only the last diagonal element + size_t n = currBinSize - 1; + while (n != static_cast(-1)) { + const CounterResult &element = binStartPos[n]; + const unsigned int hashBinElement = element.id >> (MASK_0_5_BIT); + output[doubleElementCount].id = element.id; + output[doubleElementCount].count = element.count; + output[doubleElementCount].diagonal = element.diagonal; +// std::cout << output[doubleElementCount].id << " " << (int)output[doubleElementCount].count << " " << (int)static_cast(output[doubleElementCount].diagonal) << std::endl; + // memory overflow can not happen since input array = output array + doubleElementCount += (output[doubleElementCount].count != 0 || duplicateBitArray[hashBinElement] != static_cast(binStartPos[n].diagonal)) ? 1 : 0; + duplicateBitArray[hashBinElement] = static_cast(element.diagonal); + --n; + } + } + return doubleElementCount; +} + template size_t CacheFriendlyOperations::mergeScoreDuplicates(CounterResult *output) { size_t doubleElementCount = 0; @@ -211,12 +250,12 @@ size_t CacheFriendlyOperations::findDuplicates(CounterResult *output, s output[doubleElementCount].id = element; output[doubleElementCount].count = 0; output[doubleElementCount].diagonal = tmpElementBuffer[n].diagonal; - // const unsigned char diagonal = static_cast(tmpElementBuffer[n].diagonal); + // const unsigned char diagonal = static_cast(tmpElementBuffer[n].diagonal); // memory overflow can not happen since input array = output array - // if(duplicateBitArray[hashBinElement] != tmpElementBuffer[n].diagonal){ - // std::cout << "seq="<< output[doubleElementCount].id << "\tDiag=" << (int) output[doubleElementCount].diagonal - // << " dup.Array=" << (int)duplicateBitArray[hashBinElement] << " tmp.Arr="<< (int)tmpElementBuffer[n].diagonal << std::endl; - // } + // if(duplicateBitArray[hashBinElement] != tmpElementBuffer[n].diagonal){ + // std::cout << "seq="<< output[doubleElementCount].id << "\tDiag=" << (int) output[doubleElementCount].diagonal + // << " dup.Array=" << (int)duplicateBitArray[hashBinElement] << " tmp.Arr="<< (int)tmpElementBuffer[n].diagonal << std::endl; + // } doubleElementCount += (duplicateBitArray[hashBinElement] != static_cast(tmpElementBuffer[n].diagonal)) ? 1 : 0; duplicateBitArray[hashBinElement] = static_cast(tmpElementBuffer[n].diagonal); } diff --git a/src/prefiltering/CacheFriendlyOperations.h b/src/prefiltering/CacheFriendlyOperations.h index 1206e1f08..efbfa75af 100644 --- a/src/prefiltering/CacheFriendlyOperations.h +++ b/src/prefiltering/CacheFriendlyOperations.h @@ -81,7 +81,7 @@ class CacheFriendlyOperations { size_t mergeElementsByScore(CounterResult *inputOutputArray, const size_t N); // merge elements in CounterResult by diagonal, combines elements with same ids that occur after each other - size_t mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N); + size_t mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N, const bool keepScoredHits = false); size_t keepMaxScoreElementOnly(CounterResult *inputOutputArray, const size_t N); @@ -124,6 +124,8 @@ class CacheFriendlyOperations { size_t mergeDiagonalDuplicates(CounterResult *output); + size_t mergeDiagonalKeepScoredHitsDuplicates(CounterResult *output); + size_t keepMaxElement(CounterResult *output); }; diff --git a/src/prefiltering/QueryMatcher.cpp b/src/prefiltering/QueryMatcher.cpp index 395faba15..4666bcfe8 100644 --- a/src/prefiltering/QueryMatcher.cpp +++ b/src/prefiltering/QueryMatcher.cpp @@ -97,7 +97,9 @@ std::pair QueryMatcher::matchQuery(Sequence *querySeq, unsigned } else { memset(compositionBias, 0, sizeof(float) * querySeq->L); } - + if(diagonalScoring == true){ + ungappedAlignment->createProfile(querySeq, compositionBias); + } size_t resultSize = match(querySeq, compositionBias); if (hook != NULL) { resultSize = hook->afterDiagonalMatchingHook(*this, resultSize); @@ -105,7 +107,7 @@ std::pair QueryMatcher::matchQuery(Sequence *querySeq, unsigned std::pair queryResult; if (diagonalScoring) { // write diagonal scores in count value - ungappedAlignment->processQuery(querySeq, compositionBias, foundDiagonals, resultSize); + ungappedAlignment->align(foundDiagonals, resultSize); memset(scoreSizes, 0, SCORE_RANGE * sizeof(unsigned int)); CounterResult * resultReadPos = foundDiagonals; CounterResult * resultWritePos = foundDiagonals + resultSize; @@ -267,24 +269,17 @@ size_t QueryMatcher::match(Sequence *seq, float *compositionBias) { //std::cout << seq->getDbKey() << std::endl; //idx.printKmer(index[kmerPos], kmerSize, kmerSubMat->num2aa); //std::cout << "\t" << current_i << "\t"<< index[kmerPos] << std::endl; - //for (size_t i = 0; i < seqListSize; i++) { - // char diag = entries[i].position_j - current_i; - // std::cout << "(" << entries[i].seqId << " " << (int) diag << ")\t"; - //} +// for (size_t i = 0; i < seqListSize; i++) { +// if(23865 == entries[i].seqId ){ +// char diag = entries[i].position_j - current_i; +// std::cout << "(" << entries[i].seqId << " " << (int) diag << ")\t"; +// } +// } //std::cout << std::endl; // detected overflow while matching if ((sequenceHits + seqListSize) >= lastSequenceHit) { stats->diagonalOverflow = true; - // realloc foundDiagonals if only 10% of memory left - if((foundDiagonalsSize - overflowHitCount) < 0.1 * foundDiagonalsSize){ - foundDiagonalsSize *= 1.5; - foundDiagonals = (CounterResult*) realloc(foundDiagonals, foundDiagonalsSize * sizeof(CounterResult)); - if(foundDiagonals == NULL){ - Debug(Debug::ERROR) << "Out of memory in QueryMatcher::match\n"; - EXIT(EXIT_FAILURE); - } - } // last pointer indexPointer[current_i + 1] = sequenceHits; //std::cout << "Overflow in i=" << indexStart << std::endl; @@ -292,10 +287,19 @@ size_t QueryMatcher::match(Sequence *seq, float *compositionBias) { foundDiagonals + overflowHitCount, foundDiagonalsSize - overflowHitCount, indexStart, current_i, (diagonalScoring == false)); - + // this happens only if we have two overflows in a row if (overflowHitCount != 0) { - // merge lists, hitCount is max. dbSize so there can be no overflow in mergeElements - overflowHitCount = mergeElements(foundDiagonals, hitCount + overflowHitCount); + if(diagonalScoring == true){ + overflowHitCount = mergeElements(foundDiagonals, hitCount + overflowHitCount, true); + // align the new diaognals + ungappedAlignment->align(foundDiagonals, overflowHitCount); + // We keep only the maximal diagonal scoring hit, so the max number of hits is DBsize + overflowHitCount = keepMaxScoreElementOnly(foundDiagonals, overflowHitCount); + } else { + // in case of scoring we just sum up in mergeElements, so the max number of hits is DBsize + // merge lists, hitCount is max. dbSize so there can be no overflow in mergeElements + overflowHitCount = mergeElements(foundDiagonals, hitCount + overflowHitCount); + } } else { overflowHitCount = hitCount; } @@ -463,11 +467,11 @@ size_t QueryMatcher::findDuplicates(IndexEntryLocal **hitsByIndex, return localResultSize; } -size_t QueryMatcher::mergeElements(CounterResult *foundDiagonals, size_t hitCounter) { +size_t QueryMatcher::mergeElements(CounterResult *foundDiagonals, size_t hitCounter, bool keepScoredHits) { size_t overflowHitCount = 0; #define MERGE_CASE(x) \ case x: overflowHitCount = diagonalScoring ? \ - cachedOperation##x->mergeElementsByDiagonal(foundDiagonals,hitCounter) : \ + cachedOperation##x->mergeElementsByDiagonal(foundDiagonals,hitCounter, keepScoredHits) : \ cachedOperation##x->mergeElementsByScore(foundDiagonals,hitCounter); \ break; diff --git a/src/prefiltering/QueryMatcher.h b/src/prefiltering/QueryMatcher.h index 773141332..08b9ebd5b 100644 --- a/src/prefiltering/QueryMatcher.h +++ b/src/prefiltering/QueryMatcher.h @@ -258,8 +258,7 @@ class QueryMatcher { size_t findDuplicates(IndexEntryLocal **hitsByIndex, CounterResult *output, size_t outputSize, unsigned short indexFrom, unsigned short indexTo, bool computeTotalScore); - - size_t mergeElements(CounterResult *foundDiagonals, size_t hitCounter); + size_t mergeElements(CounterResult *foundDiagonals, size_t hitCounter, bool keepHitsWithCounts = false); size_t keepMaxScoreElementOnly(CounterResult *foundDiagonals, size_t resultSize); diff --git a/src/prefiltering/UngappedAlignment.cpp b/src/prefiltering/UngappedAlignment.cpp index 01e4365c9..54cef6f50 100644 --- a/src/prefiltering/UngappedAlignment.cpp +++ b/src/prefiltering/UngappedAlignment.cpp @@ -22,13 +22,8 @@ UngappedAlignment::~UngappedAlignment() { delete [] score_arr; } -void UngappedAlignment::processQuery(Sequence *seq, - float *biasCorrection, - CounterResult *results, - size_t resultSize) { - createProfile(seq, biasCorrection, subMatrix->subMatrix); - queryLen = seq->L; - computeScores(queryProfile, seq->L, results, resultSize); +void UngappedAlignment::align(CounterResult *results, size_t resultSize) { + computeScores(queryProfile, queryLen, results, resultSize); } @@ -290,7 +285,7 @@ void UngappedAlignment::scoreDiagonalAndUpdateHits(const char * queryProfile, // update score for(size_t hitIdx = 0; hitIdx < hitSize; hitIdx++){ hits[seqs[hitIdx].id]->count = static_cast(std::min(static_cast(255), - score_arr[hitIdx])); + score_arr[hitIdx])); if(seqs[hitIdx].seqLen == 1){ std::pair dbSeq = sequenceLookup->getSequence(hits[hitIdx]->id); if(dbSeq.second >= 32768){ @@ -344,6 +339,10 @@ void UngappedAlignment::computeScores(const char *queryProfile, // continue; // } const unsigned short currDiag = results[i].diagonal; + // skip results that already have a diagonal score + if(results[i].count != 0){ + continue; + } diagonalMatches[currDiag * DIAGONALBINSIZE + diagonalCounter[currDiag]] = &results[i]; diagonalCounter[currDiag]++; if(diagonalCounter[currDiag] == DIAGONALBINSIZE) { @@ -384,9 +383,8 @@ void UngappedAlignment::extractScores(unsigned int *score_arr, simd_int score) { void UngappedAlignment::createProfile(Sequence *seq, - float * biasCorrection, - short **subMat) { - + float * biasCorrection) { + queryLen = seq->L; if(Parameters::isEqualDbtype(seq->getSequenceType(), Parameters::DBTYPE_HMM_PROFILE)) { memset(queryProfile, 0, (Sequence::PROFILE_AA_SIZE + 1) * seq->L); }else{ @@ -409,7 +407,7 @@ void UngappedAlignment::createProfile(Sequence *seq, for (int pos = 0; pos < seq->L; pos++) { unsigned int aaIdx = seq->numSequence[pos]; for (int i = 0; i < subMatrix->alphabetSize; i++) { - queryProfile[pos * (Sequence::PROFILE_AA_SIZE + 1) + i] = (subMat[aaIdx][i] + aaCorrectionScore[pos]); + queryProfile[pos * (Sequence::PROFILE_AA_SIZE + 1) + i] = (subMatrix->subMatrix[aaIdx][i] + aaCorrectionScore[pos]); } } } diff --git a/src/prefiltering/UngappedAlignment.h b/src/prefiltering/UngappedAlignment.h index d35067495..79e5e9de7 100644 --- a/src/prefiltering/UngappedAlignment.h +++ b/src/prefiltering/UngappedAlignment.h @@ -18,10 +18,12 @@ class UngappedAlignment { ~UngappedAlignment(); + void createProfile(Sequence *seq, float *biasCorrection); + // This function computes the diagonal score for each CounterResult object // it assigns the diagonal score to the CounterResult object - void processQuery(Sequence *seq, float *compositionBias, CounterResult *results, - size_t resultSize); + void align(CounterResult *results, + size_t resultSize); int scoreSingelSequenceByCounterResult(CounterResult &result); @@ -90,8 +92,6 @@ class UngappedAlignment { void extractScores(unsigned int *score_arr, simd_int score); - void createProfile(Sequence *seq, float *biasCorrection, short **subMat); - int computeSingelSequenceScores(const char *queryProfile, const unsigned int queryLen, std::pair &dbSeq, int diagonal, unsigned int minDistToDiagonal);