Skip to content

Commit

Permalink
SOLR-11831: Skip second grouping step if group.limit is 1 (aka Las Ve…
Browse files Browse the repository at this point in the history
…gas patch)

Summary:
In cases where we do grouping and ask for  {{group.limit=1}} only it is possible to skip the second grouping step. In our test datasets it improved speed by around 40%.

Essentially, in the first grouping step each shard returns the top K groups based on the highest scoring document in each group. The top K groups from each shard are merged in the federator and in the second step we ask all the shards to return the top documents from each of the top ranking groups.

If we only want to return the highest scoring document per group we can return the top document id in the first step, merge results in the federator to retain the top K groups and then skip the second grouping step entirely.

QueryComponent: interim 'make it compile (somehow)' change (#228)

add SearchGroupsContainer (#230)

factor out SearchGroupsResultTransformer.serializeOneSearchGroup method (Christine)

Refactor transformToNative adding deserializeOneSearchGroup

 increase GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND use (see also 6bdf87)

Remove error logging in allowSkipSecondGroupingStep

Check that withinGroupSort is a prefix of groupSort

SkipSecondStepSearchGroupShardResponseProcessor.addSearchGroupToShards now leaves ShardDoc.fields null

factor out TopGroupsShardResponseProcessor.fillResultIds method

Add regression test group.main=true and group.format=simple

Improve GroupingSpecification validation adding validate() method

SkipSecondStepSearchResultResultTransformer.serializeOneSearchGroup tweaks
  • Loading branch information
Malvina Josephidou authored and diegoceccarelli committed Aug 2, 2019
1 parent 52b5ec8 commit 4c7e5f1
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,18 @@ public Collection<SearchGroup<T>> getTopGroups(int groupOffset) throws IOExcepti
// System.out.println(" group=" + (group.groupValue == null ? "null" : group.groupValue.toString()));
SearchGroup<T> searchGroup = new SearchGroup<>();
searchGroup.groupValue = group.groupValue;
// We pass this around so that we can get the corresponding solr id when serializing the search group to send to the federator
searchGroup.topDocLuceneId = group.topDoc;
searchGroup.sortValues = new Object[sortFieldCount];
for(int sortFieldIDX=0;sortFieldIDX<sortFieldCount;sortFieldIDX++) {
searchGroup.sortValues[sortFieldIDX] = comparators[sortFieldIDX].value(group.comparatorSlot);
}
// TODO: It should be possible to extend this to handle more than one FieldComparator and other types
if (sortFieldCount > 0 && comparators[0] instanceof FieldComparator.RelevanceComparator ){
searchGroup.topDocScore = (Float)comparators[0].value(group.comparatorSlot);
} else {
searchGroup.topDocScore = -1;
}
result.add(searchGroup);
}
//System.out.println(" return " + result.size() + " groups");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,25 @@ public class SearchGroup<T> {
* been passed to {@link FirstPassGroupingCollector#getTopGroups} */
public Object[] sortValues;

/** The top doc of this group: we track the Lucene id,
* the Solr id and the score of the document */
public Object topDocSolrId;
public float topDocScore;

/** The topDocLuceneId will be null at the federator level because it is unique only at the shard level.
* It is used by the shard to get the corresponding solr id when serializing the search group to send to the federator
*/
public int topDocLuceneId;

@Override
public String toString() {
return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
return "SearchGroup{" +
"groupValue=" + groupValue +
", sortValues=" + Arrays.toString(sortValues) +
", topDocSolrId=" + topDocSolrId +
", topDocScore=" + topDocScore +
", topDocLuceneId=" + topDocLuceneId +
'}';
}

@Override
Expand Down Expand Up @@ -113,6 +129,11 @@ private static class MergedGroup<T> {
public boolean processed;
public boolean inQueue;

/** The top doc of this group:
* the Solr id and the score of the document */
public float topDocScore;
public Object topDocSolrId;

public MergedGroup(T groupValue) {
this.groupValue = groupValue;
}
Expand Down Expand Up @@ -225,6 +246,8 @@ private void updateNextGroup(int topN, ShardIter<T> shard) {
// Start a new group:
//System.out.println(" new");
mergedGroup = new MergedGroup<>(group.groupValue);
mergedGroup.topDocSolrId = group.topDocSolrId;
mergedGroup.topDocScore = group.topDocScore;
mergedGroup.minShardIndex = shard.shardIndex;
assert group.sortValues != null;
mergedGroup.topValues = group.sortValues;
Expand Down Expand Up @@ -262,6 +285,8 @@ private void updateNextGroup(int topN, ShardIter<T> shard) {
if (mergedGroup.inQueue) {
queue.remove(mergedGroup);
}
mergedGroup.topDocScore = group.topDocScore;
mergedGroup.topDocSolrId = group.topDocSolrId;
mergedGroup.topValues = group.sortValues;
mergedGroup.minShardIndex = shard.shardIndex;
queue.add(mergedGroup);
Expand Down Expand Up @@ -308,6 +333,8 @@ public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards,
final SearchGroup<T> newGroup = new SearchGroup<>();
newGroup.groupValue = group.groupValue;
newGroup.sortValues = group.topValues;
newGroup.topDocSolrId = group.topDocSolrId;
newGroup.topDocScore = group.topDocScore;
newTopGroups.add(newGroup);
if (newTopGroups.size() == topN) {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import org.apache.solr.search.grouping.distributed.requestfactory.StoredFieldsShardRequestFactory;
import org.apache.solr.search.grouping.distributed.requestfactory.TopGroupsShardRequestFactory;
import org.apache.solr.search.grouping.distributed.responseprocessor.SearchGroupShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.responseprocessor.SkipSecondStepSearchGroupShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.responseprocessor.StoredFieldsShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.responseprocessor.TopGroupsShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.shardresultserializer.SearchGroupsResultTransformer;
Expand Down Expand Up @@ -305,17 +306,18 @@ protected void prepareGrouping(ResponseBuilder rb) throws IOException {
groupingSpec.setNeedScore((rb.getFieldFlags() & SolrIndexSearcher.GET_SCORES) != 0);
groupingSpec.setTruncateGroups(params.getBool(GroupParams.GROUP_TRUNCATE, false));

// when group.format=grouped then, validate group.offset
// for group.main=true and group.format=simple, start value is used instead of group.offset
// and start is already validate above for negative values
if (!(groupingSpec.isMain() || groupingSpec.getResponseFormat() == Grouping.Format.simple) &&
groupingSpec.getWithinGroupSortSpec().getOffset() < 0) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "'group.offset' parameter cannot be negative");
if (params.getBool(GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND, GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND_DEFAULT)) {
// skip second step is enabled
groupingSpec.setSkipSecondGroupingStep(true);
// check if reranking is enabled
if (rb.getRankQuery() != null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND+" does not support reranking parameter "+CommonParams.RQ);
}
}
groupingSpec.validate();
}



/**
* Actually run the query
*/
Expand Down Expand Up @@ -547,7 +549,9 @@ protected int groupedDistributedProcess(ResponseBuilder rb) {
} else if (rb.stage < ResponseBuilder.STAGE_EXECUTE_QUERY) {
nextStage = ResponseBuilder.STAGE_EXECUTE_QUERY;
} else if (rb.stage == ResponseBuilder.STAGE_EXECUTE_QUERY) {
shardRequestFactory = new TopGroupsShardRequestFactory();
if (!rb.getGroupingSpec().isSkipSecondGroupingStep()) {
shardRequestFactory = new TopGroupsShardRequestFactory();
}
nextStage = ResponseBuilder.STAGE_GET_FIELDS;
} else if (rb.stage < ResponseBuilder.STAGE_GET_FIELDS) {
nextStage = ResponseBuilder.STAGE_GET_FIELDS;
Expand Down Expand Up @@ -593,10 +597,18 @@ public void handleResponses(ResponseBuilder rb, ShardRequest sreq) {
}
}

protected SearchGroupShardResponseProcessor newSearchGroupShardResponseProcessor(ResponseBuilder rb) {
if (rb.getGroupingSpec().isSkipSecondGroupingStep()) {
return new SkipSecondStepSearchGroupShardResponseProcessor();
} else {
return new SearchGroupShardResponseProcessor();
}
}

protected void handleGroupedResponses(ResponseBuilder rb, ShardRequest sreq) {
ShardResponseProcessor responseProcessor = null;
if ((sreq.purpose & ShardRequest.PURPOSE_GET_TOP_GROUPS) != 0) {
responseProcessor = new SearchGroupShardResponseProcessor();
responseProcessor = newSearchGroupShardResponseProcessor(rb);
} else if ((sreq.purpose & ShardRequest.PURPOSE_GET_TOP_IDS) != 0) {
responseProcessor = new TopGroupsShardResponseProcessor();
} else if ((sreq.purpose & ShardRequest.PURPOSE_GET_FIELDS) != 0) {
Expand Down Expand Up @@ -1286,6 +1298,14 @@ private boolean doProcessSearchByIds(ResponseBuilder rb) throws IOException {
return true;
}

protected SearchGroupsResultTransformer newSearchGroupsResultTransformer(ResponseBuilder rb, SolrIndexSearcher searcher) {
if (rb.getGroupingSpec().isSkipSecondGroupingStep()) {
return new SearchGroupsResultTransformer.SkipSecondStepSearchResultResultTransformer(searcher);
} else {
return new SearchGroupsResultTransformer(searcher);
}
}

private void doProcessGroupedDistributedSearchFirstPhase(ResponseBuilder rb, QueryCommand cmd, QueryResult result) throws IOException {

GroupingSpecification groupingSpec = rb.getGroupingSpec();
Expand Down Expand Up @@ -1315,7 +1335,7 @@ private void doProcessGroupedDistributedSearchFirstPhase(ResponseBuilder rb, Que

CommandHandler commandHandler = topsGroupsActionBuilder.build();
commandHandler.execute();
SearchGroupsResultTransformer serializer = new SearchGroupsResultTransformer(searcher);
SearchGroupsResultTransformer serializer = newSearchGroupsResultTransformer(rb, searcher);

rsp.add("firstPhase", commandHandler.processResult(result, serializer));
rsp.add("totalHitCount", commandHandler.getTotalHitCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
*/
package org.apache.solr.search.grouping;

import java.util.Arrays;
import java.util.Collections;
import org.apache.lucene.search.SortField;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.GroupParams;
import org.apache.solr.search.Grouping;
import org.apache.solr.search.SortSpec;

Expand All @@ -36,6 +41,57 @@ public class GroupingSpecification {
private Grouping.Format responseFormat;
private boolean needScore;
private boolean truncateGroups;
/* This is an optimization to skip the second grouping step when groupLimit is 1. The second
* grouping step retrieves the top K documents for each group. This is not necessary when only one
* document per group is required because in the first step every shard sends back the group score given
* by its top document.
*/
private boolean skipSecondGroupingStep;

/**
* Validates the current GropingSpecification.
* It will throw a SolrException the grouping specification is not valid, otherwise
* it will return without side effects.
*/
public void validate(){
if (isSkipSecondGroupingStep()) {
validateSkipSecondGroupingStep(withinGroupSortSpec, groupSortSpec);
}

// when group.format=grouped then, validate group.offset
// for group.main=true and group.format=simple, start value is used instead of group.offset
// and start is already validate above for negative values
if (!(main || responseFormat == Grouping.Format.simple) &&
withinGroupSortSpec.getOffset() < 0) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "'group.offset' parameter cannot be negative");
}
}

private void validateSkipSecondGroupingStep(final SortSpec withinGroupSpecification, final SortSpec groupSort) {
// Only possible if we only want one doc per group
final int limit = withinGroupSpecification.getCount();
final int offset = withinGroupSpecification.getOffset();
if (limit != 1) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND + " does not support " +
GroupParams.GROUP_LIMIT + " != 1 ("+GroupParams.GROUP_LIMIT+" is "+limit+")");
}
if (offset != 0){
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND + " does not support " + GroupParams.GROUP_OFFSET + " != 0 (" +
GroupParams.GROUP_OFFSET + " is "+offset + ")");
}

final SortField[] withinGroupSortFields = withinGroupSpecification.getSort().getSort();
final SortField[] groupSortFields = groupSort.getSort().getSort();

// Within group sort must be the same as group sort because if we skip second step no sorting within group will be done.
// This checks if withinGroupSortFields is a prefix of groupSortFields
if (Collections.indexOfSubList(Arrays.asList(groupSortFields), Arrays.asList(withinGroupSortFields)) != 0) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
GroupParams.GROUP_SKIP_DISTRIBUTED_SECOND + " does not allow the given within/global sort group configuration");
}
}

public String[] getFields() {
return fields;
Expand Down Expand Up @@ -129,4 +185,8 @@ public void setWithinGroupSortSpec(SortSpec withinGroupSortSpec) {
this.withinGroupSortSpec = withinGroupSortSpec;
}

public boolean isSkipSecondGroupingStep() { return skipSecondGroupingStep; }

public void setSkipSecondGroupingStep(boolean skipSecondGroupingStep) { this.skipSecondGroupingStep = skipSecondGroupingStep; }

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ static Collection<SearchGroup<BytesRef>> fromMutable(SchemaField field, Collecti
for (SearchGroup<MutableValue> original : values) {
SearchGroup<BytesRef> converted = new SearchGroup<BytesRef>();
converted.sortValues = original.sortValues;
converted.topDocLuceneId = original.topDocLuceneId;
converted.topDocScore = original.topDocScore;
if (original.groupValue.exists) {
BytesRefBuilder binary = new BytesRefBuilder();
fieldType.readableToIndexed(original.groupValue.toString(), binary);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.solr.handler.component.ShardRequest;
import org.apache.solr.handler.component.ShardResponse;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.search.SortSpec;
import org.apache.solr.search.grouping.distributed.ShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.command.SearchGroupsFieldCommandResult;
Expand All @@ -47,6 +48,14 @@
*/
public class SearchGroupShardResponseProcessor implements ShardResponseProcessor {

protected SearchGroupsResultTransformer newSearchGroupsResultTransformer(SolrIndexSearcher solrIndexSearcher) {
return new SearchGroupsResultTransformer(solrIndexSearcher);
}

protected SearchGroupsContainer newSearchGroupsContainer(ResponseBuilder rb) {
return new SearchGroupsContainer(rb.getGroupingSpec().getFields());
}

@Override
public void process(ResponseBuilder rb, ShardRequest shardRequest) {
SortSpec groupSortSpec = rb.getGroupingSpec().getGroupSortSpec();
Expand All @@ -56,16 +65,14 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) {
assert withinGroupSort != null;

final Map<String, List<Collection<SearchGroup<BytesRef>>>> commandSearchGroups = new HashMap<>(fields.length, 1.0f);
final Map<String, Map<SearchGroup<BytesRef>, Set<String>>> tempSearchGroupToShards = new HashMap<>(fields.length, 1.0f);
for (String field : fields) {
commandSearchGroups.put(field, new ArrayList<Collection<SearchGroup<BytesRef>>>(shardRequest.responses.size()));
tempSearchGroupToShards.put(field, new HashMap<SearchGroup<BytesRef>, Set<String>>());
if (!rb.searchGroupToShards.containsKey(field)) {
rb.searchGroupToShards.put(field, new HashMap<SearchGroup<BytesRef>, Set<String>>());
}
}

SearchGroupsResultTransformer serializer = new SearchGroupsResultTransformer(rb.req.getSearcher());
SearchGroupsResultTransformer serializer = newSearchGroupsResultTransformer(rb.req.getSearcher());
int maxElapsedTime = 0;
int hitCountDuringFirstPhase = 0;

Expand All @@ -75,6 +82,8 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) {
rb.rsp.getValues().add(ShardParams.SHARDS_INFO + ".firstPhase", shardInfo);
}

SearchGroupsContainer searchGroupsContainer = newSearchGroupsContainer(rb);

for (ShardResponse srsp : shardRequest.responses) {
if (shardInfo != null) {
SimpleOrderedMap<Object> nl = new SimpleOrderedMap<>(4);
Expand Down Expand Up @@ -123,15 +132,7 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) {
}

commandSearchGroups.get(field).add(searchGroups);
for (SearchGroup<BytesRef> searchGroup : searchGroups) {
Map<SearchGroup<BytesRef>, Set<String>> map = tempSearchGroupToShards.get(field);
Set<String> shards = map.get(searchGroup);
if (shards == null) {
shards = new HashSet<>();
map.put(searchGroup, shards);
}
shards.add(srsp.getShard());
}
searchGroupsContainer.addSearchGroups(srsp, field, searchGroups);
}
hitCountDuringFirstPhase += (Integer) srsp.getSolrResponse().getResponse().get("totalHitCount");
}
Expand All @@ -143,8 +144,39 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) {
if (mergedTopGroups == null) {
continue;
}
searchGroupsContainer.addMergedSearchGroups(rb, groupField, mergedTopGroups);
searchGroupsContainer.addSearchGroupToShards(rb, groupField, mergedTopGroups);
}
}

protected static class SearchGroupsContainer {

private final Map<String, Map<SearchGroup<BytesRef>, Set<String>>> tempSearchGroupToShards;

public SearchGroupsContainer(String[] fields) {
tempSearchGroupToShards = new HashMap<>(fields.length, 1.0f);
for (String field : fields) {
tempSearchGroupToShards.put(field, new HashMap<SearchGroup<BytesRef>, Set<String>>());
}
}

public void addSearchGroups(ShardResponse srsp, String field, Collection<SearchGroup<BytesRef>> searchGroups) {
for (SearchGroup<BytesRef> searchGroup : searchGroups) {
Map<SearchGroup<BytesRef>, Set<String>> map = tempSearchGroupToShards.get(field);
Set<String> shards = map.get(searchGroup);
if (shards == null) {
shards = new HashSet<>();
map.put(searchGroup, shards);
}
shards.add(srsp.getShard());
}
}

public void addMergedSearchGroups(ResponseBuilder rb, String groupField, Collection<SearchGroup<BytesRef>> mergedTopGroups) {
rb.mergedSearchGroups.put(groupField, mergedTopGroups);
}

public void addSearchGroupToShards(ResponseBuilder rb, String groupField, Collection<SearchGroup<BytesRef>> mergedTopGroups) {
for (SearchGroup<BytesRef> mergedTopGroup : mergedTopGroups) {
rb.searchGroupToShards.get(groupField).put(mergedTopGroup, tempSearchGroupToShards.get(groupField).get(mergedTopGroup));
}
Expand Down
Loading

0 comments on commit 4c7e5f1

Please sign in to comment.