Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(sketch) extend Lucene's FirstPassGroupingCollector and SearchGroup for Solr's group.skip.second.step use #232

Open
wants to merge 7 commits into
base: SOLR-11831
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
* tracking the top doc and {@link FieldComparator} slot.
* @lucene.internal */
public class CollectedSearchGroup<T> extends SearchGroup<T> {
int topDoc;
int comparatorSlot;
public int topDoc;
public int comparatorSlot;
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,34 +132,35 @@ public Collection<SearchGroup<T>> getTopGroups(int groupOffset) throws IOExcepti
final Collection<SearchGroup<T>> result = new ArrayList<>();
int upto = 0;
final int sortFieldCount = comparators.length;
assert sortFieldCount > 0; // this must always be true because fields Sort must contain at least a field
for(CollectedSearchGroup<T> group : orderedGroups) {
if (upto++ < groupOffset) {
continue;
}
// System.out.println(" group=" + (group.groupValue == null ? "null" : group.groupValue.toString()));
SearchGroup<T> searchGroup = new SearchGroup<>();
searchGroup.groupValue = group.groupValue;
// We pass this around so that we can get the corresponding solr id when serializing the search group to send to the federator
searchGroup.topDocLuceneId = group.topDoc;
searchGroup.sortValues = new Object[sortFieldCount];
for(int sortFieldIDX=0;sortFieldIDX<sortFieldCount;sortFieldIDX++) {
searchGroup.sortValues[sortFieldIDX] = comparators[sortFieldIDX].value(group.comparatorSlot);
}
searchGroup.topDocScore = Float.NaN;
// if there is the score comparator we want to return the score
for (FieldComparator comparator: comparators){
if (comparator instanceof FieldComparator.RelevanceComparator){
searchGroup.topDocScore = (Float)comparator.value(group.comparatorSlot);
}
}

SearchGroup<T> searchGroup = newSearchGroupFromCollectedSearchGroup(group, comparators, sortFieldCount);
result.add(searchGroup);
}
//System.out.println(" return " + result.size() + " groups");
return result;
}

protected SearchGroup<T> newSearchGroup() {
return new SearchGroup<>();
}

protected SearchGroup<T> newSearchGroupFromCollectedSearchGroup(
CollectedSearchGroup<T> group,
FieldComparator<?>[] comparators,
int sortFieldCount) {
SearchGroup<T> searchGroup = newSearchGroup();
searchGroup.groupValue = group.groupValue;
searchGroup.sortValues = new Object[sortFieldCount];
for(int sortFieldIDX=0;sortFieldIDX<sortFieldCount;sortFieldIDX++) {
searchGroup.sortValues[sortFieldIDX] = comparators[sortFieldIDX].value(group.comparatorSlot);
}
return searchGroup;
}

@Override
public void setScorer(Scorable scorer) throws IOException {
for (LeafFieldComparator comparator : leafComparators) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,9 @@ public class SearchGroup<T> {
* been passed to {@link FirstPassGroupingCollector#getTopGroups} */
public Object[] sortValues;

/** The top doc of this group: we track the Lucene id,
* the Solr id and the score of the document */
public Object topDocSolrId;
public float topDocScore;

/** The topDocLuceneId will be null at the federator level because it is unique only at the shard level.
* It is used by the shard to get the corresponding solr id when serializing the search group to send to the federator
*/
public int topDocLuceneId;

@Override
public String toString() {
return "SearchGroup{" +
"groupValue=" + groupValue +
", sortValues=" + Arrays.toString(sortValues) +
", topDocSolrId=" + topDocSolrId +
", topDocScore=" + topDocScore +
", topDocLuceneId=" + topDocLuceneId +
'}';
return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
}

@Override
Expand All @@ -92,7 +76,12 @@ public int hashCode() {
return groupValue != null ? groupValue.hashCode() : 0;
}

private static class ShardIter<T> {
/**
* Iterator for all the groups on a shard
*
* @lucene.experimental
*/
protected static class ShardIter<T> {
public final Iterator<SearchGroup<T>> iter;
public final int shardIndex;

Expand All @@ -117,8 +106,16 @@ public String toString() {
}
}

// Holds all shards currently on the same group
private static class MergedGroup<T> {
protected MergedGroup<T> newMergedGroup() {
return new MergedGroup<>(this.groupValue);
}

/**
* Holds all shards currently on the same group
*
* @lucene.experimental
*/
protected static class MergedGroup<T> {

// groupValue may be null!
public final T groupValue;
Expand All @@ -129,15 +126,29 @@ private static class MergedGroup<T> {
public boolean processed;
public boolean inQueue;

/** The top doc of this group:
* the Solr id and the score of the document */
public float topDocScore;
public Object topDocSolrId;

public MergedGroup(T groupValue) {
this.groupValue = groupValue;
}

private SearchGroup<T> toSearchGroup() {
final SearchGroup<T> searchGroup = newSearchGroup();
fillSearchGroup(searchGroup);
return searchGroup;
}

protected SearchGroup<T> newSearchGroup() {
return new SearchGroup<T>();
}

protected void fillSearchGroup(SearchGroup<T> searchGroup) {
searchGroup.groupValue = this.groupValue;
searchGroup.sortValues = this.topValues;
}

protected void update(SearchGroup<T> group) {
this.topValues = group.sortValues;
}

// Only for assert
private boolean neverEquals(Object _other) {
if (_other instanceof MergedGroup) {
Expand Down Expand Up @@ -245,9 +256,7 @@ private void updateNextGroup(int topN, ShardIter<T> shard) {
if (isNew) {
// Start a new group:
//System.out.println(" new");
mergedGroup = new MergedGroup<>(group.groupValue);
mergedGroup.topDocSolrId = group.topDocSolrId;
mergedGroup.topDocScore = group.topDocScore;
mergedGroup = group.newMergedGroup();
mergedGroup.minShardIndex = shard.shardIndex;
assert group.sortValues != null;
mergedGroup.topValues = group.sortValues;
Expand Down Expand Up @@ -285,9 +294,7 @@ private void updateNextGroup(int topN, ShardIter<T> shard) {
if (mergedGroup.inQueue) {
queue.remove(mergedGroup);
}
mergedGroup.topDocScore = group.topDocScore;
mergedGroup.topDocSolrId = group.topDocSolrId;
mergedGroup.topValues = group.sortValues;
mergedGroup.update(group);
mergedGroup.minShardIndex = shard.shardIndex;
queue.add(mergedGroup);
mergedGroup.inQueue = true;
Expand Down Expand Up @@ -330,11 +337,7 @@ public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards,
group.processed = true;
//System.out.println(" pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
if (count++ >= offset) {
final SearchGroup<T> newGroup = new SearchGroup<>();
newGroup.groupValue = group.groupValue;
newGroup.sortValues = group.topValues;
newGroup.topDocSolrId = group.topDocSolrId;
newGroup.topDocScore = group.topDocScore;
final SearchGroup<T> newGroup = group.toSearchGroup();
newTopGroups.add(newGroup);
if (newTopGroups.size() == topN) {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,7 @@ private void doProcessGroupedDistributedSearchFirstPhase(ResponseBuilder rb, Que
.setGroupSort(groupingSpec.getGroupSortSpec().getSort())
.setTopNGroups(cmd.getOffset() + cmd.getLen())
.setIncludeGroupCount(groupingSpec.isIncludeGroupCount())
.setSkipSecondGroupingStep(groupingSpec.isSkipSecondGroupingStep())
.build()
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search.grouping;

import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.grouping.CollectedSearchGroup;
import org.apache.lucene.search.grouping.FirstPassGroupingCollector;
import org.apache.lucene.search.grouping.GroupSelector;
import org.apache.lucene.search.grouping.SearchGroup;

/**
* A {@link FirstPassGroupingCollector} that gathers extra information so that (in certain scenarios)
* the second pass grouping can be skipped.
*/
public class SkipSecondPassFirstPassGroupingCollector<T> extends FirstPassGroupingCollector<T> {

/**
* A {@link SearchGroup} that contains extra information so that (in certain scenarios)
* the second pass grouping can be skipped.
*/
public static class SolrSearchGroup<T> extends org.apache.lucene.search.grouping.SearchGroup<T> {

public int topDocLuceneId;
public float topDocScore;
public Object topDocSolrId;

@Override
protected MergedGroup<T> newMergedGroup() {
SolrMergedGroup<T> mergedGroup = new SolrMergedGroup<>(this.groupValue);
mergedGroup.topDocScore = this.topDocScore;
mergedGroup.topDocSolrId = this.topDocSolrId;
return mergedGroup;
}

private static class SolrMergedGroup<T> extends org.apache.lucene.search.grouping.SearchGroup.MergedGroup<T> {

public float topDocScore;
public Object topDocSolrId;

public SolrMergedGroup(T groupValue) {
super(groupValue);
}

@Override
protected SearchGroup<T> newSearchGroup() {
return new SolrSearchGroup<T>();
}

@Override
protected void fillSearchGroup(SearchGroup<T> searchGroup) {
super.fillSearchGroup(searchGroup);
((SolrSearchGroup<T>)searchGroup).topDocScore = this.topDocScore;
((SolrSearchGroup<T>)searchGroup).topDocSolrId = this.topDocSolrId;
}

@Override
public void update(SearchGroup<T> searchGroup) {
super.update(searchGroup);
this.topDocScore = ((SolrSearchGroup<T>)searchGroup).topDocScore;
this.topDocSolrId = ((SolrSearchGroup<T>)searchGroup).topDocSolrId;
}

}

}

public SkipSecondPassFirstPassGroupingCollector(GroupSelector<T> groupSelector, Sort groupSort, int topNGroups) {
super(groupSelector, groupSort, topNGroups);
}

@Override
protected SearchGroup<T> newSearchGroup() {
return new SolrSearchGroup<>();
}

@Override
protected SearchGroup<T> newSearchGroupFromCollectedSearchGroup(
CollectedSearchGroup<T> group,
FieldComparator<?>[] comparators,
int sortFieldCount) {

final SearchGroup<T> searchGroup = super.newSearchGroupFromCollectedSearchGroup(group, comparators, sortFieldCount);
final SolrSearchGroup<T> solrSearchGroup = (SolrSearchGroup<T>)searchGroup;

solrSearchGroup.topDocLuceneId = group.topDoc;

solrSearchGroup.topDocScore = Float.NaN;
// if there is the score comparator we want to return the score
for (FieldComparator comparator: comparators) {
if (comparator instanceof FieldComparator.RelevanceComparator) {
solrSearchGroup.topDocScore = (Float)comparator.value(group.comparatorSlot);
}
}

return solrSearchGroup;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.NumberType;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.grouping.SkipSecondPassFirstPassGroupingCollector;

/**
* this is a transition class: for numeric types we use function-based distributed grouping,
Expand All @@ -50,10 +51,17 @@ static Collection<SearchGroup<BytesRef>> fromMutable(SchemaField field, Collecti
FieldType fieldType = field.getType();
List<SearchGroup<BytesRef>> result = new ArrayList<>(values.size());
for (SearchGroup<MutableValue> original : values) {
SearchGroup<BytesRef> converted = new SearchGroup<BytesRef>();
final SearchGroup<BytesRef> converted;
if (original instanceof SkipSecondPassFirstPassGroupingCollector.SolrSearchGroup) {
SkipSecondPassFirstPassGroupingCollector.SolrSearchGroup<MutableValue> solrOriginal = (SkipSecondPassFirstPassGroupingCollector.SolrSearchGroup<MutableValue>)original;
SkipSecondPassFirstPassGroupingCollector.SolrSearchGroup<BytesRef> solrConverted = new SkipSecondPassFirstPassGroupingCollector.SolrSearchGroup<BytesRef>();
solrConverted.topDocLuceneId = solrOriginal.topDocLuceneId;
solrConverted.topDocScore = solrOriginal.topDocScore;
converted = solrConverted;
} else {
converted = new SearchGroup<BytesRef>();
}
converted.sortValues = original.sortValues;
converted.topDocLuceneId = original.topDocLuceneId;
converted.topDocScore = original.topDocScore;
if (original.groupValue.exists) {
BytesRefBuilder binary = new BytesRefBuilder();
fieldType.readableToIndexed(original.groupValue.toString(), binary);
Expand Down
Loading