-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1437 from vespa-engine/bratseth/reranker
Add reranker example app
- Loading branch information
Showing
14 changed files
with
721 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
*.iml | ||
.idea/ | ||
target/ | ||
src/main/application/security/ | ||
*.pem |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> | ||
|
||
![Vespa logo](https://vespa.ai/assets/vespa-logo-color.png) | ||
|
||
# Reranker sample application | ||
|
||
A stateless application which reranks results obtained from another Vespa application. | ||
While this does not result in good performance and is not recommended for production, | ||
it is useful when you want to quickly do ranking experiments without rewriting application data. | ||
|
||
## Usage | ||
|
||
1. Make sure the application to rerank has a | ||
[token endpoint](https://cloud.vespa.ai/en/security/guide#application-key). | ||
2. `vespa clone examples/reranker` | ||
3. Add the endpoint and any defaults to the reranker config in `src/main/application/services.xml` | ||
(parameters can also be passed in the request). | ||
4. Add the model(s) to use for reranking to the `models` directory. | ||
5. `mvn install && vespa deploy` | ||
6. Issue queries. All request parameters including the token header will be passed through to the application to be reranked. | ||
|
||
Example requests: | ||
|
||
Minimal: | ||
|
||
vespa query "select * from sources * where album contains 'to'" --header "Authorization: Bearer [your token]" | ||
|
||
Passing all reranking parameters: | ||
|
||
vespa query "select * from sources * where album contains 'to'" --header "Authorization: Bearer [your token]" rerank.model=xgboost_model_example rerank.hits=100 profile=firstPhase |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
<?xml version="1.0"?> | ||
<!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 | ||
http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<groupId>ai.vespa.examples</groupId> | ||
<artifactId>reranker</artifactId> <!-- Note: When changing this, also change bundle names in services.xml --> | ||
<version>1.0.0</version> | ||
<packaging>container-plugin</packaging> | ||
<parent> | ||
<groupId>com.yahoo.vespa</groupId> | ||
<artifactId>cloud-tenant-base</artifactId> | ||
<version>[8,9)</version> <!-- Use the latest Vespa release on each build --> | ||
<relativePath/> | ||
</parent> | ||
<properties> | ||
<bundle-plugin.failOnWarnings>true</bundle-plugin.failOnWarnings> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
<test.hide>true</test.hide> | ||
</properties> | ||
<dependencies> | ||
<dependency> | ||
<groupId>com.yahoo.vespa</groupId> | ||
<artifactId>container</artifactId> | ||
<scope>provided</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.httpcomponents.client5</groupId> | ||
<artifactId>httpclient5</artifactId> | ||
<version>5.3.1</version> | ||
<exclusions> | ||
<exclusion> | ||
<groupId>org.slf4j</groupId> | ||
<artifactId>slf4j-api</artifactId> | ||
</exclusion> | ||
</exclusions> | ||
</dependency> | ||
<dependency> | ||
<groupId>com.fasterxml.jackson.core</groupId> | ||
<artifactId>jackson-databind</artifactId> | ||
<scope>provided</scope> | ||
</dependency> | ||
</dependencies> | ||
</project> |
19 changes: 19 additions & 0 deletions
19
examples/reranker/src/main/application/models/xgboost_model_example.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[ | ||
{ "nodeid": 0, "depth": 0, "split": "fieldMatch(album).proximity", "split_condition": 0.75, "yes": 1, "no": 2, "missing": 2, "children": [ | ||
{ "nodeid": 1, "depth": 1, "split": "fieldMatch(album).completeness", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 3, "children": [ | ||
{ "nodeid": 3, "leaf": 0.9 }, | ||
{ "nodeid": 4, "leaf": 0.8 } | ||
]}, | ||
{ "nodeid": 2, "depth": 1, "split": "fieldMatch(artist).proximity", "split_condition": 0.5, "yes": 5, "no": 6, "missing": 6, "children": [ | ||
{ "nodeid": 5, "leaf": 0.7 }, | ||
{ "nodeid": 6, "leaf": 0.6 } | ||
]} | ||
]}, | ||
{ "nodeid": 0, "depth": 0, "split": "fieldMatch(album).proximity", "split_condition": 0.25, "yes": 1, "no": 2, "missing": 1, "children": [ | ||
{ "nodeid": 1, "depth": 1, "split": "fieldMatch(artist).completeness", "split_condition": 0.125, "yes": 3, "no": 4, "missing": 4, "children": [ | ||
{ "nodeid": 3, "leaf": 0.5 }, | ||
{ "nodeid": 4, "leaf": 0.4 } | ||
]}, | ||
{ "nodeid": 2, "leaf": 0.3 } | ||
]} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
<?xml version="1.0" encoding="utf-8" ?> | ||
<!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> | ||
<services version="1.0" xmlns:deploy="vespa" xmlns:preprocess="properties"> | ||
|
||
<container id="default" version="1.0"> | ||
|
||
<config name="ai.vespa.example.reranker.reranker"> | ||
<endpoint>https://f237494d.ae82d729.z.vespa-app.cloud/</endpoint> | ||
<rerank> | ||
<hits>100</hits> | ||
<profile>firstPhase</profile> | ||
<model>xgboost_model_example</model> | ||
</rerank> | ||
</config> | ||
|
||
<model-evaluation/> | ||
|
||
<search> | ||
<chain id="default" inherits="native"> | ||
<searcher id="ai.vespa.example.reranker.RerankingSearcher" bundle="reranker"/> | ||
<searcher id="ai.vespa.example.reranker.VespaSearcher" bundle="reranker"/> | ||
</chain> | ||
</search> | ||
|
||
<nodes count="1"/> | ||
</container> | ||
|
||
</services> |
77 changes: 77 additions & 0 deletions
77
examples/reranker/src/main/java/ai/vespa/example/reranker/RerankingSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
package ai.vespa.example.reranker; | ||
|
||
import ai.vespa.models.evaluation.FunctionEvaluator; | ||
import ai.vespa.models.evaluation.ModelsEvaluator; | ||
import com.yahoo.processing.request.CompoundName; | ||
import com.yahoo.search.Query; | ||
import com.yahoo.search.Result; | ||
import com.yahoo.search.Searcher; | ||
import com.yahoo.search.result.FeatureData; | ||
import com.yahoo.search.result.Hit; | ||
import com.yahoo.search.searchchain.Execution; | ||
import com.yahoo.tensor.Tensor; | ||
|
||
import java.util.Iterator; | ||
|
||
/** | ||
* A searcher which can rerank results from another Vespa application. | ||
* | ||
* @author bratseth | ||
*/ | ||
public class RerankingSearcher extends Searcher { | ||
|
||
public static final CompoundName rerankHitsParameter = new CompoundName("rerank.hits"); | ||
public static final CompoundName rerankModelParameter = new CompoundName("rerank.model"); | ||
|
||
private final ModelsEvaluator modelsEvaluator; | ||
|
||
private final int defaultRerankHits; | ||
private final String defaultRerankProfile; | ||
private final String defaultRerankModel; | ||
|
||
public RerankingSearcher(RerankerConfig config, ModelsEvaluator modelsEvaluator) { | ||
this.modelsEvaluator = modelsEvaluator; | ||
|
||
this.defaultRerankHits = config.rerank().hits(); | ||
this.defaultRerankProfile = config.rerank().profile(); | ||
this.defaultRerankModel = config.rerank().model(); | ||
} | ||
|
||
@Override | ||
public Result search(Query query, Execution execution) { | ||
query.setHits(Math.max(query.getHits(), query.properties().getInteger(rerankHitsParameter, defaultRerankHits))); | ||
if (query.getRanking().getProfile().equals("default")) | ||
query.getRanking().setProfile(defaultRerankProfile); | ||
|
||
Result result = execution.search(query); | ||
rerank(result, query.properties().getString(rerankModelParameter, defaultRerankModel)); | ||
return result; | ||
} | ||
|
||
private void rerank(Result result, String rerankModel) { | ||
for (Iterator<Hit> i = result.hits().unorderedDeepIterator(); i.hasNext(); ) { | ||
Hit hit = i.next(); | ||
if ( ! hit.isAuxiliary()) | ||
rerank(hit, rerankModel); | ||
} | ||
} | ||
|
||
private void rerank(Hit hit, String rerankModel) { | ||
FunctionEvaluator evaluator = modelsEvaluator.evaluatorOf(rerankModel); | ||
|
||
FeatureData features = (FeatureData)hit.getField("summaryfeatures"); | ||
if (features == null) | ||
throw new IllegalArgumentException("Missing 'summaryfeatures' field in " + hit + | ||
". Use a rank profile with a 'summary-features' block, using '" + | ||
hit.getQuery().getRanking().getProfile() + "'"); | ||
for (String featureName : features.featureNames()) { | ||
if (featureName.equals("vespa.summaryFeatures.cached")) continue; | ||
if (evaluator.context().arguments().contains(featureName)) | ||
evaluator.bind(featureName, features.getTensor(featureName)); | ||
} | ||
Tensor result = evaluator.evaluate(); | ||
hit.setRelevance(result.asDouble()); | ||
} | ||
|
||
} |
90 changes: 90 additions & 0 deletions
90
examples/reranker/src/main/java/ai/vespa/example/reranker/ResultReader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
package ai.vespa.example.reranker; | ||
|
||
import com.fasterxml.jackson.core.JsonFactory; | ||
import com.fasterxml.jackson.core.JsonParser; | ||
import com.fasterxml.jackson.databind.JsonNode; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import com.yahoo.search.Query; | ||
import com.yahoo.search.Result; | ||
import com.yahoo.search.result.ErrorMessage; | ||
import com.yahoo.search.result.FeatureData; | ||
import com.yahoo.search.result.Hit; | ||
import com.yahoo.tensor.Tensor; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Iterator; | ||
import java.util.Map; | ||
|
||
/** | ||
* Converts a JSON result from a Vespa backend to Hits in a Result. | ||
* | ||
* @author bratseth | ||
*/ | ||
class ResultReader { | ||
|
||
void read(String resultJson, Result result) { | ||
// Create ObjectMapper instance | ||
ObjectMapper objectMapper = new ObjectMapper(); | ||
JsonFactory factory = new JsonFactory(); | ||
|
||
try (JsonParser parser = factory.createParser(resultJson)) { | ||
// Read the tree structure from the JSON | ||
JsonNode jsonRoot = objectMapper.readTree(parser); | ||
JsonNode rootNode = jsonRoot.get("root"); | ||
if (rootNode == null) | ||
throw new IllegalArgumentException("Expected a 'root' object in the JSON, got: " + jsonRoot); | ||
|
||
if (rootNode.get("fields") != null && rootNode.get("fields").get("totalCount") != null) | ||
result.setTotalHitCount(rootNode.get("fields").get("totalCount").asInt()); | ||
|
||
if (rootNode.get("errors") != null) | ||
rootNode.get("errors").forEach(hit -> result.hits().addError(readError(hit))); | ||
if (rootNode.get("children") != null) | ||
rootNode.get("children").forEach(hit -> result.hits().add(readHit(hit, result.getQuery()))); | ||
} catch (IOException e) { | ||
throw new IllegalArgumentException("Could not read result JSON", e); | ||
} | ||
} | ||
|
||
ErrorMessage readError(JsonNode errorObject) { | ||
return new ErrorMessage(errorObject.get("code").asInt(), | ||
errorObject.get("summary").asText(), | ||
errorObject.get("message") != null ? errorObject.get("message").asText() : null); | ||
} | ||
|
||
Hit readHit(JsonNode hitObject, Query query) { | ||
Hit hit = new Hit(hitObject.get("id").asText(), hitObject.get("relevance").asDouble(), query); | ||
// TODO: Source | ||
for (Iterator<Map.Entry<String, JsonNode>> i = hitObject.get("fields").fields(); i.hasNext(); ) { | ||
var fieldEntry = i.next(); | ||
if ("matchfeatures".equals(fieldEntry.getKey())) | ||
hit.setField("matchfeatures", readFeatureData(fieldEntry.getValue())); | ||
if ("summaryfeatures".equals(fieldEntry.getKey())) | ||
hit.setField("summaryfeatures", readFeatureData(fieldEntry.getValue())); | ||
else | ||
hit.setField(fieldEntry.getKey(), toValue(fieldEntry.getValue())); | ||
} | ||
return hit; | ||
} | ||
|
||
FeatureData readFeatureData(JsonNode featureDataObject) { | ||
Map<String, Tensor> features = new HashMap<>(); | ||
for (Iterator<Map.Entry<String, JsonNode>> i = featureDataObject.fields(); i.hasNext(); ) { | ||
var fieldEntry = i.next(); | ||
features.put(fieldEntry.getKey(), Tensor.from(fieldEntry.getValue().asDouble())); // TODO: Parse tensors | ||
} | ||
return new FeatureData(features); | ||
} | ||
|
||
public Object toValue(JsonNode fieldValue) { | ||
return switch (fieldValue.getNodeType()) { | ||
case NUMBER -> fieldValue.asDouble(); | ||
case STRING -> fieldValue.asText(); | ||
case BOOLEAN -> fieldValue.asBoolean(); | ||
default -> fieldValue.asText(); | ||
}; | ||
} | ||
|
||
} |
83 changes: 83 additions & 0 deletions
83
examples/reranker/src/main/java/ai/vespa/example/reranker/VespaClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
package ai.vespa.example.reranker; | ||
|
||
import com.yahoo.container.jdisc.HttpRequest; | ||
import org.apache.hc.client5.http.classic.methods.HttpGet; | ||
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; | ||
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; | ||
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; | ||
import org.apache.hc.core5.http.ClassicHttpResponse; | ||
import org.apache.hc.core5.http.ParseException; | ||
import org.apache.hc.core5.http.io.HttpClientResponseHandler; | ||
import org.apache.hc.core5.http.io.entity.EntityUtils; | ||
import org.apache.hc.core5.net.URIBuilder; | ||
|
||
import java.io.IOException; | ||
import java.net.URISyntaxException; | ||
import java.util.Map; | ||
|
||
/** | ||
* A client which can talk to a Vespa applications *token* endpoint. | ||
* This is multithread safe. | ||
* | ||
* @author bratseth | ||
*/ | ||
class VespaClient { | ||
|
||
private final String tokenEndpoint; | ||
private final CloseableHttpClient httpClient; | ||
|
||
public VespaClient(String tokenEndpoint) { | ||
this.tokenEndpoint = tokenEndpoint; | ||
this.httpClient = HttpClientBuilder.create() | ||
.setConnectionManager(PoolingHttpClientConnectionManagerBuilder | ||
.create() | ||
.build()) | ||
.setUserAgent("vespa") | ||
.disableCookieManagement() | ||
.disableAutomaticRetries() | ||
.disableAuthCaching() | ||
.build(); | ||
} | ||
|
||
public Response search(HttpRequest request, Map<String, Object> overridingProperties) throws IOException { | ||
try { | ||
String authorizationHeader = request.getHeader("Authorization"); | ||
if (authorizationHeader == null || !authorizationHeader.startsWith("Bearer ")) | ||
throw new IllegalArgumentException("Request must have an 'Authorization' header with the value " + | ||
"'Bearer $your_token'"); | ||
// String tokenHc = "vespa_cloud_dNpDIa7RkNntm0AkvKWNlA0cFydFa4W3GlV6HOGQTuf"; | ||
// String authorizationHeader = "Bearer " + authorizationHeader; | ||
var uriBuilder = new URIBuilder(tokenEndpoint); | ||
uriBuilder.setPath("/search/"); | ||
for (var property : request.propertyMap().entrySet()) | ||
uriBuilder.addParameter(property.getKey(), property.getValue()); | ||
for (var property : overridingProperties.entrySet()) | ||
uriBuilder.addParameter(property.getKey(), property.getValue().toString()); | ||
var get = new HttpGet(uriBuilder.build()); | ||
get.addHeader("Authorization", authorizationHeader); | ||
return httpClient.execute(get, new ResponseHandler()); | ||
} | ||
catch (URISyntaxException e) { | ||
throw new IllegalStateException(e); | ||
} | ||
} | ||
|
||
public record Response(int statusCode, String responseBody) {} | ||
|
||
// Custom ResponseHandler to handle the response | ||
public static class ResponseHandler implements HttpClientResponseHandler<Response> { | ||
|
||
@Override | ||
public Response handleResponse(ClassicHttpResponse response) { | ||
String responseBody; | ||
try { | ||
responseBody = EntityUtils.toString(response.getEntity()); | ||
} catch (IOException | ParseException e) { | ||
throw new IllegalStateException(e); | ||
} | ||
return new Response(response.getCode(), responseBody); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.