Skip to content

Commit

Permalink
Merge pull request #33014 from vespa-engine/toregge/detect-if-query-o…
Browse files Browse the repository at this point in the history
…perators-need-ranking

Detect if query operators need ranking.
  • Loading branch information
geirst authored Dec 9, 2024
2 parents a0bbb32 + cf0f99c commit ff2bf65
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 2 deletions.
47 changes: 47 additions & 0 deletions searchcore/src/tests/proton/matching/query_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,53 @@ TEST(QueryTest, global_filter_is_calculated_and_handled)
}
}

bool query_needs_ranking(const std::string& stack_dump)
{
Query query;
query.buildTree(stack_dump, "", ViewResolver(), plain_index_env);
return query.needs_ranking();

}

TEST(QueryTest, normal_term_doesnt_need_ranking)
{
QueryBuilder<ProtonNodeTypes> builder;
builder.addStringTerm("xyz", "f1", 1, Weight(1));
EXPECT_FALSE(query_needs_ranking(StackDumpCreator::create(*builder.build())));
}

TEST(QueryTest, weak_and_term_needs_ranking)
{
QueryBuilder<ProtonNodeTypes> builder;
builder.addWeakAnd(1, 10, "f1");
builder.addStringTerm("xyz", "f1", 1, Weight(1));
EXPECT_TRUE(query_needs_ranking(StackDumpCreator::create(*builder.build())));
}

TEST(QueryTest, weighted_set_term_needs_ranking)
{
QueryBuilder<ProtonNodeTypes> builder;
auto& ws = builder.addWeightedSetTerm(1, "f1", 1, Weight(1));
ws.addTerm("xyz", Weight(1));
EXPECT_TRUE(query_needs_ranking(StackDumpCreator::create(*builder.build())));
}

TEST(QueryTest, dot_product_term_needs_ranking)
{
QueryBuilder<ProtonNodeTypes> builder;
auto& dp = builder.addDotProduct(1, "f1", 1, Weight(1));
dp.addTerm("xyz", Weight(1));
EXPECT_TRUE(query_needs_ranking(StackDumpCreator::create(*builder.build())));
}

TEST(QueryTest, wand_term_needs_ranking)
{
QueryBuilder<ProtonNodeTypes> builder;
auto& wand = builder.addWandTerm(1, "f1", 1, Weight(1), 10, 0, 1.0);
wand.addTerm("xyz", Weight(1));
EXPECT_TRUE(query_needs_ranking(StackDumpCreator::create(*builder.build())));
}

} // namespace
} // namespace proton::matching

Expand Down
7 changes: 5 additions & 2 deletions searchcore/src/vespa/searchcore/proton/matching/matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ numThreads(size_t hits, size_t minHits) {

bool
willNeedRanking(const SearchRequest & request, const GroupingContext & groupingContext,
std::optional<search::feature_t> first_phase_rank_score_drop_limit)
std::optional<search::feature_t> first_phase_rank_score_drop_limit, bool query_needs_ranking)
{
if (query_needs_ranking) {
return true;
}
return (groupingContext.needRanking() || (request.maxhits != 0))
&& (request.sortSpec.empty() ||
(request.sortSpec.find("[rank]") != std::string::npos) ||
Expand Down Expand Up @@ -281,7 +284,7 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl
MatchParams params(searchContext.getDocIdLimit(), heapSize, arraySize, first_phase_rank_score_drop_limit,
second_phase_rank_score_drop_limit,
request.offset, request.maxhits, !_rankSetup->getSecondPhaseRank().empty(),
willNeedRanking(request, groupingContext, first_phase_rank_score_drop_limit));
willNeedRanking(request, groupingContext, first_phase_rank_score_drop_limit, mtf->query().needs_ranking()));

ResultProcessor rp(attrContext, metaStore, sessionMgr, groupingContext, sessionId,
request.sortSpec, params.offset, params.hits);
Expand Down
26 changes: 26 additions & 0 deletions searchcore/src/vespa/searchcore/proton/matching/query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <vespa/searchlib/common/geo_location_spec.h>
#include <vespa/searchlib/engine/trace.h>
#include <vespa/searchlib/parsequery/stackdumpiterator.h>
#include <vespa/searchlib/query/tree/templatetermvisitor.h>
#include <vespa/searchlib/queryeval/intermediate_blueprints.h>
#include <vespa/vespalib/util/issue.h>
#include <vespa/vespalib/util/thread_bundle.h>
Expand All @@ -32,6 +33,7 @@ using search::fef::MatchDataLayout;
using search::query::LocationTerm;
using search::query::Node;
using search::query::QueryTreeCreator;
using search::query::TemplateTermVisitor;
using search::query::Weight;
using search::queryeval::AndBlueprint;
using search::queryeval::AndNotBlueprint;
Expand Down Expand Up @@ -141,6 +143,27 @@ void exchange_location_nodes(const string &location_str,
}
}

/*
* WeakAnd, WeightedSetTerm, DotProduct and WandTerm query operators need ranking since
* doUnpack is used to updated threshold during query evaluation.
*/
class NeedsRankingVisitor : public TemplateTermVisitor<NeedsRankingVisitor, ProtonNodeTypes>
{
bool _needs_ranking;
public:
NeedsRankingVisitor()
: TemplateTermVisitor<NeedsRankingVisitor, ProtonNodeTypes>(),
_needs_ranking(false)
{
}
template <class TermNode> void visitTerm(TermNode&) { }
void visit(ProtonNodeTypes::WeakAnd&) override { _needs_ranking = true; }
void visitTerm(ProtonNodeTypes::WeightedSetTerm&) { _needs_ranking = true; }
void visitTerm(ProtonNodeTypes::DotProduct&) { _needs_ranking = true; }
void visitTerm(ProtonNodeTypes::WandTerm&) { _needs_ranking = true; }
bool needs_ranking() const noexcept { return _needs_ranking; }
};

} // namespace

Query::Query() = default;
Expand All @@ -160,6 +183,9 @@ Query::buildTree(std::string_view stack, const string &location,
_query_tree = UnpackingIteratorsOptimizer::optimize(std::move(_query_tree), bool(_whiteListBlueprint), always_mark_phrase_expensive);
ResolveViewVisitor resolve_visitor(resolver, indexEnv);
_query_tree->accept(resolve_visitor);
NeedsRankingVisitor need_ranking_visitor;
_query_tree->accept(need_ranking_visitor);
_needs_ranking = need_ranking_visitor.needs_ranking();
return true;
} else {
Issue::report("invalid query");
Expand Down
2 changes: 2 additions & 0 deletions searchcore/src/vespa/searchcore/proton/matching/query.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Query
Blueprint::UP _blueprint;
Blueprint::UP _whiteListBlueprint;
std::vector<GeoLocationSpec> _locations;
bool _needs_ranking = false;

public:
/** Convenience typedef. */
Expand Down Expand Up @@ -148,6 +149,7 @@ class Query
*/
Blueprint::HitEstimate estimate() const;
const Blueprint * peekRoot() const { return _blueprint.get(); }
bool needs_ranking() const noexcept { return _needs_ranking; }
};

}

0 comments on commit ff2bf65

Please sign in to comment.