Skip to content

Commit

Permalink
Merge pull request #1437 from vespa-engine/bratseth/reranker
Browse files Browse the repository at this point in the history
Add reranker example app
  • Loading branch information
bratseth authored Jul 12, 2024
2 parents b61a4d1 + 2352823 commit 7f15ad8
Show file tree
Hide file tree
Showing 14 changed files with 721 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/reranker/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.iml
.idea/
target/
src/main/application/security/
*.pem
30 changes: 30 additions & 0 deletions examples/reranker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->

![Vespa logo](https://vespa.ai/assets/vespa-logo-color.png)

# Reranker sample application

A stateless application which reranks results obtained from another Vespa application.
While this does not result in good performance and is not recommended for production,
it is useful when you want to quickly do ranking experiments without rewriting application data.

## Usage

1. Make sure the application to rerank has a
[token endpoint](https://cloud.vespa.ai/en/security/guide#application-key).
2. `vespa clone examples/reranker`
3. Add the endpoint and any defaults to the reranker config in `src/main/application/services.xml`
(parameters can also be passed in the request).
4. Add the model(s) to use for reranking to the `models` directory.
5. `mvn install && vespa deploy`
6. Issue queries. All request parameters including the token header will be passed through to the application to be reranked.

Example requests:

Minimal:

vespa query "select * from sources * where album contains 'to'" --header "Authorization: Bearer [your token]"

Passing all reranking parameters:

vespa query "select * from sources * where album contains 'to'" --header "Authorization: Bearer [your token]" rerank.model=xgboost_model_example rerank.hits=100 profile=firstPhase
46 changes: 46 additions & 0 deletions examples/reranker/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<?xml version="1.0"?>
<!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ai.vespa.examples</groupId>
<artifactId>reranker</artifactId> <!-- Note: When changing this, also change bundle names in services.xml -->
<version>1.0.0</version>
<packaging>container-plugin</packaging>
<parent>
<groupId>com.yahoo.vespa</groupId>
<artifactId>cloud-tenant-base</artifactId>
<version>[8,9)</version> <!-- Use the latest Vespa release on each build -->
<relativePath/>
</parent>
<properties>
<bundle-plugin.failOnWarnings>true</bundle-plugin.failOnWarnings>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<test.hide>true</test.hide>
</properties>
<dependencies>
<dependency>
<groupId>com.yahoo.vespa</groupId>
<artifactId>container</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents.client5</groupId>
<artifactId>httpclient5</artifactId>
<version>5.3.1</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<scope>provided</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[
{ "nodeid": 0, "depth": 0, "split": "fieldMatch(album).proximity", "split_condition": 0.75, "yes": 1, "no": 2, "missing": 2, "children": [
{ "nodeid": 1, "depth": 1, "split": "fieldMatch(album).completeness", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 3, "children": [
{ "nodeid": 3, "leaf": 0.9 },
{ "nodeid": 4, "leaf": 0.8 }
]},
{ "nodeid": 2, "depth": 1, "split": "fieldMatch(artist).proximity", "split_condition": 0.5, "yes": 5, "no": 6, "missing": 6, "children": [
{ "nodeid": 5, "leaf": 0.7 },
{ "nodeid": 6, "leaf": 0.6 }
]}
]},
{ "nodeid": 0, "depth": 0, "split": "fieldMatch(album).proximity", "split_condition": 0.25, "yes": 1, "no": 2, "missing": 1, "children": [
{ "nodeid": 1, "depth": 1, "split": "fieldMatch(artist).completeness", "split_condition": 0.125, "yes": 3, "no": 4, "missing": 4, "children": [
{ "nodeid": 3, "leaf": 0.5 },
{ "nodeid": 4, "leaf": 0.4 }
]},
{ "nodeid": 2, "leaf": 0.3 }
]}
]
28 changes: 28 additions & 0 deletions examples/reranker/src/main/application/services.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8" ?>
<!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<services version="1.0" xmlns:deploy="vespa" xmlns:preprocess="properties">

<container id="default" version="1.0">

<config name="ai.vespa.example.reranker.reranker">
<endpoint>https://f237494d.ae82d729.z.vespa-app.cloud/</endpoint>
<rerank>
<hits>100</hits>
<profile>firstPhase</profile>
<model>xgboost_model_example</model>
</rerank>
</config>

<model-evaluation/>

<search>
<chain id="default" inherits="native">
<searcher id="ai.vespa.example.reranker.RerankingSearcher" bundle="reranker"/>
<searcher id="ai.vespa.example.reranker.VespaSearcher" bundle="reranker"/>
</chain>
</search>

<nodes count="1"/>
</container>

</services>
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.example.reranker;

import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.tensor.Tensor;

import java.util.Iterator;

/**
* A searcher which can rerank results from another Vespa application.
*
* @author bratseth
*/
public class RerankingSearcher extends Searcher {

public static final CompoundName rerankHitsParameter = new CompoundName("rerank.hits");
public static final CompoundName rerankModelParameter = new CompoundName("rerank.model");

private final ModelsEvaluator modelsEvaluator;

private final int defaultRerankHits;
private final String defaultRerankProfile;
private final String defaultRerankModel;

public RerankingSearcher(RerankerConfig config, ModelsEvaluator modelsEvaluator) {
this.modelsEvaluator = modelsEvaluator;

this.defaultRerankHits = config.rerank().hits();
this.defaultRerankProfile = config.rerank().profile();
this.defaultRerankModel = config.rerank().model();
}

@Override
public Result search(Query query, Execution execution) {
query.setHits(Math.max(query.getHits(), query.properties().getInteger(rerankHitsParameter, defaultRerankHits)));
if (query.getRanking().getProfile().equals("default"))
query.getRanking().setProfile(defaultRerankProfile);

Result result = execution.search(query);
rerank(result, query.properties().getString(rerankModelParameter, defaultRerankModel));
return result;
}

private void rerank(Result result, String rerankModel) {
for (Iterator<Hit> i = result.hits().unorderedDeepIterator(); i.hasNext(); ) {
Hit hit = i.next();
if ( ! hit.isAuxiliary())
rerank(hit, rerankModel);
}
}

private void rerank(Hit hit, String rerankModel) {
FunctionEvaluator evaluator = modelsEvaluator.evaluatorOf(rerankModel);

FeatureData features = (FeatureData)hit.getField("summaryfeatures");
if (features == null)
throw new IllegalArgumentException("Missing 'summaryfeatures' field in " + hit +
". Use a rank profile with a 'summary-features' block, using '" +
hit.getQuery().getRanking().getProfile() + "'");
for (String featureName : features.featureNames()) {
if (featureName.equals("vespa.summaryFeatures.cached")) continue;
if (evaluator.context().arguments().contains(featureName))
evaluator.bind(featureName, features.getTensor(featureName));
}
Tensor result = evaluator.evaluate();
hit.setRelevance(result.asDouble());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.example.reranker;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.tensor.Tensor;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
* Converts a JSON result from a Vespa backend to Hits in a Result.
*
* @author bratseth
*/
class ResultReader {

void read(String resultJson, Result result) {
// Create ObjectMapper instance
ObjectMapper objectMapper = new ObjectMapper();
JsonFactory factory = new JsonFactory();

try (JsonParser parser = factory.createParser(resultJson)) {
// Read the tree structure from the JSON
JsonNode jsonRoot = objectMapper.readTree(parser);
JsonNode rootNode = jsonRoot.get("root");
if (rootNode == null)
throw new IllegalArgumentException("Expected a 'root' object in the JSON, got: " + jsonRoot);

if (rootNode.get("fields") != null && rootNode.get("fields").get("totalCount") != null)
result.setTotalHitCount(rootNode.get("fields").get("totalCount").asInt());

if (rootNode.get("errors") != null)
rootNode.get("errors").forEach(hit -> result.hits().addError(readError(hit)));
if (rootNode.get("children") != null)
rootNode.get("children").forEach(hit -> result.hits().add(readHit(hit, result.getQuery())));
} catch (IOException e) {
throw new IllegalArgumentException("Could not read result JSON", e);
}
}

ErrorMessage readError(JsonNode errorObject) {
return new ErrorMessage(errorObject.get("code").asInt(),
errorObject.get("summary").asText(),
errorObject.get("message") != null ? errorObject.get("message").asText() : null);
}

Hit readHit(JsonNode hitObject, Query query) {
Hit hit = new Hit(hitObject.get("id").asText(), hitObject.get("relevance").asDouble(), query);
// TODO: Source
for (Iterator<Map.Entry<String, JsonNode>> i = hitObject.get("fields").fields(); i.hasNext(); ) {
var fieldEntry = i.next();
if ("matchfeatures".equals(fieldEntry.getKey()))
hit.setField("matchfeatures", readFeatureData(fieldEntry.getValue()));
if ("summaryfeatures".equals(fieldEntry.getKey()))
hit.setField("summaryfeatures", readFeatureData(fieldEntry.getValue()));
else
hit.setField(fieldEntry.getKey(), toValue(fieldEntry.getValue()));
}
return hit;
}

FeatureData readFeatureData(JsonNode featureDataObject) {
Map<String, Tensor> features = new HashMap<>();
for (Iterator<Map.Entry<String, JsonNode>> i = featureDataObject.fields(); i.hasNext(); ) {
var fieldEntry = i.next();
features.put(fieldEntry.getKey(), Tensor.from(fieldEntry.getValue().asDouble())); // TODO: Parse tensors
}
return new FeatureData(features);
}

public Object toValue(JsonNode fieldValue) {
return switch (fieldValue.getNodeType()) {
case NUMBER -> fieldValue.asDouble();
case STRING -> fieldValue.asText();
case BOOLEAN -> fieldValue.asBoolean();
default -> fieldValue.asText();
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.example.reranker;

import com.yahoo.container.jdisc.HttpRequest;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder;
import org.apache.hc.core5.http.ClassicHttpResponse;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.HttpClientResponseHandler;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.net.URIBuilder;

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Map;

/**
* A client which can talk to a Vespa applications *token* endpoint.
* This is multithread safe.
*
* @author bratseth
*/
class VespaClient {

private final String tokenEndpoint;
private final CloseableHttpClient httpClient;

public VespaClient(String tokenEndpoint) {
this.tokenEndpoint = tokenEndpoint;
this.httpClient = HttpClientBuilder.create()
.setConnectionManager(PoolingHttpClientConnectionManagerBuilder
.create()
.build())
.setUserAgent("vespa")
.disableCookieManagement()
.disableAutomaticRetries()
.disableAuthCaching()
.build();
}

public Response search(HttpRequest request, Map<String, Object> overridingProperties) throws IOException {
try {
String authorizationHeader = request.getHeader("Authorization");
if (authorizationHeader == null || !authorizationHeader.startsWith("Bearer "))
throw new IllegalArgumentException("Request must have an 'Authorization' header with the value " +
"'Bearer $your_token'");
// String tokenHc = "vespa_cloud_dNpDIa7RkNntm0AkvKWNlA0cFydFa4W3GlV6HOGQTuf";
// String authorizationHeader = "Bearer " + authorizationHeader;
var uriBuilder = new URIBuilder(tokenEndpoint);
uriBuilder.setPath("/search/");
for (var property : request.propertyMap().entrySet())
uriBuilder.addParameter(property.getKey(), property.getValue());
for (var property : overridingProperties.entrySet())
uriBuilder.addParameter(property.getKey(), property.getValue().toString());
var get = new HttpGet(uriBuilder.build());
get.addHeader("Authorization", authorizationHeader);
return httpClient.execute(get, new ResponseHandler());
}
catch (URISyntaxException e) {
throw new IllegalStateException(e);
}
}

public record Response(int statusCode, String responseBody) {}

// Custom ResponseHandler to handle the response
public static class ResponseHandler implements HttpClientResponseHandler<Response> {

@Override
public Response handleResponse(ClassicHttpResponse response) {
String responseBody;
try {
responseBody = EntityUtils.toString(response.getEntity());
} catch (IOException | ParseException e) {
throw new IllegalStateException(e);
}
return new Response(response.getCode(), responseBody);
}
}

}
Loading

0 comments on commit 7f15ad8

Please sign in to comment.