Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test pack bits combined with a natively binarizing embedder #33062

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ public IndexingProcessor(DocumentTypeManager documentTypeManager,
IlscriptsConfig ilscriptsConfig,
Linguistics linguistics,
ComponentRegistry<Embedder> embedders) {
this(documentTypeManager, new ScriptManager(documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders)));
}

public IndexingProcessor(DocumentTypeManager documentTypeManager,
ScriptManager scriptManager) {
this.documentTypeManager = documentTypeManager;
scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders));
this.scriptManager = scriptManager;
adapterFactory = new SimpleAdapterFactory(new ExpressionSelector());
}

Expand Down Expand Up @@ -132,7 +137,7 @@ private void processRemove(DocumentRemove input, List<DocumentOperation> out) {
out.add(input);
}

private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
private static Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
var map = embedders.allComponentsById().entrySet().stream()
.collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
if (map.size() > 1) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.docprocs.indexing;

import com.yahoo.component.AbstractComponent;
import com.yahoo.document.DataType;
import com.yahoo.document.Document;
import com.yahoo.document.DocumentOperation;
import com.yahoo.document.DocumentPut;
import com.yahoo.document.DocumentType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.DocumentUpdate;
import com.yahoo.document.PositionDataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.update.AssignValueUpdate;
import com.yahoo.document.update.FieldUpdate;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.Tensors;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import org.junit.Test;

import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
Expand Down Expand Up @@ -225,4 +239,61 @@ public void requireThatIndexerForwardsUpdatesOfUnknownType() {
assertSame(input, output);
}

@Test
public void testEmbedBinarizeAndPack() {
var documentTypes = new DocumentTypeManager();
var test = new DocumentType("test");
test.addField("myText", DataType.STRING);
test.addField("embedding", new TensorDataType(TensorType.fromSpec("tensor<int8>(x[16])")));
documentTypes.register(test);

IlscriptsConfig.Builder config = new IlscriptsConfig.Builder();
config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("test")
.content("input myText | embed | binarize | pack_bits | attribute embedding")
.docfield("myText"));
var scripts = new ScriptManager(documentTypes, new IlscriptsConfig(config), null, Map.of("test", new TestEmbedder()));
assertNotNull(scripts.getScript(documentTypes.getDocumentType("test")));

var tester = new IndexingProcessorTester(documentTypes, scripts);
DocumentUpdate input = new DocumentUpdate(test, "id:ns:test::");
input.addFieldUpdate(FieldUpdate.createAssign(test.getField("myText"), new StringFieldValue("my text")));
DocumentUpdate output = (DocumentUpdate)tester.process(input);
FieldUpdate embeddingUpdate = output.getFieldUpdate("embedding");
AssignValueUpdate valueUpdate = (AssignValueUpdate)embeddingUpdate.getValueUpdate(0);
assertEquals(Tensor.from("tensor<int8>(x[16]):[-110, 73, 36, -110, 73, 36, -110, 73, 36, -110, 73, 36, -110, 73, 36, -110]"),
valueUpdate.getValue().getWrappedValue());
}

/** An ebedder which also does its own quantization, similar to HuggingFaceEmbedder. */
static class TestEmbedder extends AbstractComponent implements Embedder {

@Override
public List<Integer> embed(String s, Context context) {
throw new UnsupportedOperationException();
}

@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
if (tensorType.dimensions().size() != 1)
throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension.");
if (!tensorType.dimensions().get(0).isIndexed())
throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed.");
boolean binarize = tensorType.valueType() == TensorType.Value.INT8;
long size = tensorType.dimensions().get(0).size().get();
if (binarize)
size = size * 8;
var embeddedType = new TensorType.Builder().indexed(tensorType.dimensions().get(0).name(), size).build();
var resultBuilder = Tensor.Builder.of(embeddedType);
for (int i = 0; i < size; i++) {
int v = ((i % 3) == 0) ? 1 : 0;
resultBuilder.cell(v, i);
}
var result = resultBuilder.build();
if (binarize)
result = Tensors.packBits(result);
return result;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public IndexingProcessorTester(String configDir) {
indexer = newProcessor("dir:" + configDir);
}

public IndexingProcessorTester(DocumentTypeManager documentTypes, ScriptManager scripts) {
indexer = newProcessor(documentTypes, scripts);
}

public DocumentType getDocumentType(String name) {
return indexer.getDocumentTypeManager().getDocumentType(name);
}
Expand Down Expand Up @@ -70,4 +74,8 @@ private static IndexingProcessor newProcessor(String configId) {
new ComponentRegistry<>());
}

private static IndexingProcessor newProcessor(DocumentTypeManager documentTypes, ScriptManager scripts) {
return new IndexingProcessor(documentTypes, scripts);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected void doExecute(ExecutionContext context) {

/** Returns the type this requires when producing the given output type. */
private TensorType inputType(TensorType givenType) {
var builder = new TensorType.Builder(TensorType.Value.INT8); // Any larger value type is also permissible
var builder = new TensorType.Builder(TensorType.Value.DOUBLE); // Any value type is permissible
for (var d : givenType.dimensions())
builder.dimension(d.size().isPresent() ? d.withSize(d.size().get() * 8) : d);
return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void testEmbedAndBinarize() {
@Test
public void testEmbedBinarizeAndPack_bits() {
var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", -111)));
tester.testStatement("input myText | embed | binarize | pack_bits | attribute 'myTensor'", "input text", "tensor<int8>(x[2])", "[58, 192]");
tester.testStatement("input myText | embed | binarize | pack_bits | attribute 'myTensor'", "input text", "tensor<int8>(x[2])", "[58, -64]");
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ public void deconstruct() {
tokenizer.close();
}

@SuppressWarnings("unchecked")
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
if (tensorType.dimensions().size() != 1) {
Expand Down Expand Up @@ -213,6 +212,7 @@ private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddin
/**
* Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
*/
// TODO: Call Tensors.packBits instead. It is more general and faster.
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);
BitSet bitSet = new BitSet(8);
Expand Down
2 changes: 1 addition & 1 deletion vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ private static String cellToString(Map.Entry<TensorAddress, Double> cell, Tensor
int hashCode();

/**
* Implement here to make this work across implementations.
* Implemented here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
*/
static boolean equals(Tensor a, Tensor b) {
Expand Down
19 changes: 9 additions & 10 deletions vespajlib/src/main/java/com/yahoo/tensor/Tensors.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.yahoo.api.annotations.Beta;

import java.util.Arrays;
import java.util.BitSet;
import java.util.Iterator;

/**
Expand Down Expand Up @@ -47,9 +48,10 @@ public static Tensor toSparse(Tensor tensor, String ... dimensions) {
}

/**
* Converts any tensor containing only ones and zeroes into one where each consecutive 8 values in the
* dense dimension are packed into a single byte. As a consequence the output type of this is a tensor
* where the dense dimension is 1/8th as large.
* Converts any tensor into one where each consecutive 8 values in the
* dense dimension are packed into a single byte,
* by setting a bit to 1 when the tensor has a positive value and 0 otherwise.
* As a consequence the output type of this is a tensor where the dense dimension is 1/8th as large.
*
* @throws IllegalArgumentException if the tensor has the wrong type or contains any other value than 0 or 1
*/
Expand All @@ -71,7 +73,7 @@ public static Tensor packBits(Tensor tensor) {
int packedValue = 0;
for (int j = 0; j < 8 && i < indexed.size(); j++)
packedValue = packInto(packedValue, indexed.get(i), j, i++);
builder.cell(packedValue, packedIndex);
builder.cell((byte)packedValue, packedIndex);
}
}
else if (tensor instanceof MixedTensor mixed) {
Expand All @@ -81,7 +83,7 @@ else if (tensor instanceof MixedTensor mixed) {
int packedValue = 0;
for (int j = 0; j < 8 && i < denseSubspace.cells.length; j++)
packedValue = packInto(packedValue, denseSubspace.cells[i], j, i++);
builder.cell(packedAddress, packedValue);
builder.cell(packedAddress, (byte)packedValue);
}
}
}
Expand All @@ -93,13 +95,10 @@ else if (tensor instanceof MixedTensor mixed) {
}

private static int packInto(int packedValue, double value, int bitPosition, long sourcePosition) {
if (value == 0.0)
if (value <= 0.0)
return packedValue;
else if (value == 1.0)
return packedValue | ( 1 << ( 7 - bitPosition ));
else
throw new IllegalArgumentException("The tensor to be packed can only contain 0 or 1 values, " +
"but has " + value + " at position " + sourcePosition);
return packedValue | ( 1 << ( 7 - bitPosition ));
}

}
22 changes: 8 additions & 14 deletions vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import org.junit.jupiter.api.Test;

import java.util.BitSet;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

Expand All @@ -27,15 +29,15 @@ void testToSparse() {

@Test
void testPackBits() {
assertPacked("tensor<int8>(x[2]):[129,14]", "tensor(x[16]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]");
assertPacked("tensor<int8>(x[2]):[129,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1]");
assertPacked("tensor<int8>(x[1]):[128]", "tensor(x[1]):[1]");
assertPacked("tensor<int8>(key{},x[2]):{a:[129,14], b:[12, 7]}",
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[16]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]");
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,2,3]");
assertPacked("tensor<int8>(x[1]):[-128]", "tensor(x[1]):[1]");
assertPacked("tensor<int8>(key{},x[2]):{a:[-127,14], b:[12, 7]}",
"tensor(key{},x[16]):{a:[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]," +
" b:[0,0,0,0,1,1,0,0, 0,0,0,0,0,1,1,1]}");
assertPacked("tensor<int8>(key{},x[1]):{a:[160],b:[32]}",
assertPacked("tensor<int8>(key{},x[1]):{a:[-96],b:[32]}",
"tensor(key{},x[3]):{a:[1,0,1],b:[0,0,1]}");
assertPacked("tensor<int8>(key{},x[1]):{a:[128]}", "tensor(key{}, x[1]):{a:[1]}");
assertPacked("tensor<int8>(key{},x[1]):{a:[-128]}", "tensor(key{}, x[1]):{a:[1]}");

try {
Tensors.packBits(Tensor.from("tensor(x[1],y[1]):[1]"));
Expand All @@ -45,14 +47,6 @@ void testPackBits() {
assertEquals("packBits requires a tensor with one dense dimensions, but got tensor(x[1],y[1])",
e.getMessage());
}
try {
Tensors.packBits(Tensor.from("tensor(x[3]):[0, 1, 2]"));
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("The tensor to be packed can only contain 0 or 1 values, but has 2.0 at position 2",
e.getMessage());
}
}

void assertConvertedToSparse(String inputType, String outputType, String tensorValue, String ... dimensions) {
Expand Down
Loading