Skip to content

Commit

Permalink
Fixed document source and score field mismatch in sorted hybrid queri…
Browse files Browse the repository at this point in the history
…es (#1043)

* Fixed mismatch between document source and score fields when sorting is enabled in hybrid query

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Jan 3, 2025
1 parent f844a78 commit 030e3f4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 19 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import java.util.Objects;
import java.util.Locale;
import java.util.ArrayList;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -44,8 +47,10 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
@Nullable
private FieldDoc after;
private FieldComparator<?> firstComparator;
// bottom would be set to null per shard.
private FieldValueHitQueue.Entry bottom;
// the array stores bottom elements of the min heap of sorted hits for each sub query
@Getter(AccessLevel.PACKAGE)
@VisibleForTesting
private FieldValueHitQueue.Entry fieldValueLeafTrackers[];
@Getter
private int totalHits;
protected int docBase;
Expand All @@ -65,6 +70,7 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
@Getter
protected float maxScore = 0.0f;
protected int[] collectedHits;
private boolean needsInitialization = true;

// searchSortPartOfIndexSort is used to evaluate whether to perform index sort or not.
private Boolean searchSortPartOfIndexSort = null;
Expand Down Expand Up @@ -203,7 +209,7 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
comparators[subQueryNumber].copy(slot, doc);
add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score);
if (queueFull[subQueryNumber]) {
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
} else {
queueFull[subQueryNumber] = true;
Expand All @@ -216,9 +222,9 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException {
// This hit is competitive - replace bottom element in queue & adjustTop
if (numHits > 0) {
comparators[subQueryNumber].copy(bottom.slot, doc);
updateBottom(doc, compoundScores[subQueryNumber]);
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].copy(fieldValueLeafTrackers[subQueryNumber].slot, doc);
updateBottom(doc, compoundScores[subQueryNumber], subQueryNumber);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
}

Expand All @@ -245,14 +251,16 @@ protected boolean thresholdCheck(int doc, int subQueryNumber) throws IOException
The method initializes once per search request.
*/
protected void initializePriorityQueuesWithComparators(LeafReaderContext context, int numberOfSubQueries) throws IOException {
if (compoundScores == null) {
if (needsInitialization) {
compoundScores = new FieldValueHitQueue[numberOfSubQueries];
comparators = new LeafFieldComparator[numberOfSubQueries];
queueFull = new boolean[numberOfSubQueries];
collectedHits = new int[numberOfSubQueries];
for (int i = 0; i < numberOfSubQueries; i++) {
initializeLeafFieldComparators(context, i);
}
fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries];
needsInitialization = false;
}
if (initializeLeafComparatorsPerSegmentOnce) {
for (int i = 0; i < numberOfSubQueries; i++) {
Expand Down Expand Up @@ -369,7 +377,7 @@ private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<Fiel
private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryNumber, float score) {
FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc);
bottomEntry.score = score;
bottom = compoundScore.add(bottomEntry);
fieldValueLeafTrackers[subQueryNumber] = compoundScore.add(bottomEntry);
// The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case
// slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits
// on the current page) and slot = hitsCollected - 1.
Expand All @@ -381,9 +389,9 @@ private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry>
queueFull[subQueryNumber] = isQueueFull;
}

private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore) {
bottom.doc = docBase + doc;
bottom = compoundScore.updateTop();
private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryIndex) {
fieldValueLeafTrackers[subQueryIndex].doc = docBase + doc;
fieldValueLeafTrackers[subQueryIndex] = compoundScore.updateTop();
}

private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -24,6 +24,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
Expand All @@ -35,14 +36,13 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.PagingFieldCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase {
static final String TEXT_FIELD_NAME = "field";
Expand Down Expand Up @@ -127,8 +127,13 @@ public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce
DocIdSetIterator iterator = hybridQueryScorer.iterator();

int doc = iterator.nextDoc();
assertNull(hybridTopFieldDocSortCollector.getFieldValueLeafTrackers());
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
leafCollector.collect(doc);
FieldValueHitQueue.Entry[] fieldValueLeafTrackers = hybridTopFieldDocSortCollector.getFieldValueLeafTrackers();
assertNotNull(fieldValueLeafTrackers);
assertEquals(1, fieldValueLeafTrackers.length);
assertEquals(doc, fieldValueLeafTrackers[0].doc);
doc = iterator.nextDoc();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
Expand Down Expand Up @@ -46,7 +46,7 @@
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {

Expand Down

0 comments on commit 030e3f4

Please sign in to comment.