From 357eb57eae7afeadcfd94bcf25e2552efe397097 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 2 Oct 2024 16:02:45 +0200 Subject: [PATCH 1/5] Refactor: Extract embedding test code --- .../EmbeddingScriptTestCase.java | 422 +++++++++++++++ .../EmbeddingScriptTester.java | 120 +++++ .../indexinglanguage/ScriptTestCase.java | 494 ------------------ 3 files changed, 542 insertions(+), 494 deletions(-) create mode 100644 indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java create mode 100644 indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java new file mode 100644 index 000000000000..cb9ddcd799b4 --- /dev/null +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java @@ -0,0 +1,422 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.indexinglanguage; + +import com.yahoo.document.ArrayDataType; +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.Array; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.simple.SimpleLinguistics; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext; +import com.yahoo.vespa.indexinglanguage.expressions.VerificationException; +import com.yahoo.vespa.indexinglanguage.parser.ParseException; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class EmbeddingScriptTestCase { + + @Test + public void testEmbed() throws ParseException { + // Test parsing without knowledge of any embedders + String exp = "input myText | embed emb1 | attribute 'myTensor'"; + Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + + Map embedder = Map.of( + "emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor") + ); + testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, + "input text", "[105, 110, 112, 117]"); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, + "input text", "[105, 110, 112, 117]"); + testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, + "input text", "[105, 110, 112, 117]"); + testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, + null, null); + + Map embedders = Map.of( + "emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor"), + "emb2", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", 1) + ); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, + "my input", "[109.0, 121.0, 32.0, 105.0]"); + testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, + "my input", "[110.0, 122.0, 33.0, 106.0]"); + + EmbeddingScriptTester.assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), + "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1, emb2"); + EmbeddingScriptTester.assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), + "Can't find embedder 'emb3'. Valid embedders are emb1, emb2"); + } + + private void testEmbedStatement(String expressionString, Map embedders, String input, String expected) { + try { + var expression = Expression.fromString(expressionString, new SimpleLinguistics(), embedders); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + expression.execute(context); + if (input == null) { + assertFalse(adapter.values.containsKey("myTensor")); + } + else { + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); + } + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testArrayEmbed() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensorArray")); + + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + + var tensorField = new Field("myTensorArray", new ArrayDataType(new TensorDataType(tensorType))); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new ArrayDataType(new TensorDataType(tensorType)), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("myTensorArray")); + var tensorArray = (Array)adapter.values.get("myTensorArray"); + assertEquals(Tensor.from(tensorType, "[102, 105, 114, 115]"), tensorArray.get(0).getTensor().get()); + assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get()); + } + + @Test + public void testArrayEmbedWithConcatenation() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); + var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + + var titleField = new Field("title", DataType.STRING); + adapter.createField(titleField); + adapter.setValue("title", new StringFieldValue("title1")); + + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "{ '0':[116.0, 105.0, 116.0, 108.0], 1:[116.0, 105.0, 116.0, 108.0]}"), + sparseTensor.getTensor().get()); + } + + /** Multiple paragraphs */ + @Test + public void testArrayEmbedTo2dMixedTensor() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); + var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"), + sparseTensor.getTensor().get()); + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + assertEquals("input myTextArray | embed emb1 passage | attribute mySparseTensor", expression.toString()); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("sec")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + // The two "passages" are [first, sec], the middle (d=1) token encodes those letters + assertEquals(Tensor.from(tensorType, + """ + { + {passage:0, token:0, d:0}: 101, + {passage:0, token:0, d:1}: 102, + {passage:0, token:0, d:2}: 103, + {passage:0, token:1, d:0}: 104, + {passage:0, token:1, d:1}: 105, + {passage:0, token:1, d:2}: 106, + {passage:0, token:2, d:0}: 113, + {passage:0, token:2, d:1}: 114, + {passage:0, token:2, d:2}: 115, + {passage:0, token:3, d:0}: 114, + {passage:0, token:3, d:1}: 115, + {passage:0, token:3, d:2}: 116, + {passage:0, token:4, d:0}: 115, + {passage:0, token:4, d:1}: 116, + {passage:0, token:4, d:2}: 117, + {passage:1, token:0, d:0}: 114, + {passage:1, token:0, d:1}: 115, + {passage:1, token:0, d:2}: 116, + {passage:1, token:1, d:0}: 100, + {passage:1, token:1, d:1}: 101, + {passage:1, token:1, d:2}: 102, + {passage:1, token:2, d:0}: 98, + {passage:1, token:2, d:1}: 99, + {passage:1, token:2, d:2}: 100 + } + """), + sparseTensor.getTensor().get()); + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...", + e.getMessage()); + } + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})", + e.getMessage()); + } + } + + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Test + public void testEmbedToSparseTensor() throws ParseException { + Embedder mappedEmbedder = new EmbeddingScriptTester.MockMappedEmbedder("myDocument.mySparseTensor", 0); + Map embedders = Map.of("emb1",mappedEmbedder); + + TensorType tensorType = TensorType.fromSpec("tensor(t{})"); + var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("text", DataType.STRING)); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var text = new StringFieldValue("abc"); + adapter.setValue("text", text); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(text); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), + sparseTensor.getTensor().get()); + assertEquals("Cached value always set by MockMappedEmbedder is present", + "myCachedValue", context.getCachedValue("myCacheKey")); + } + + /** Multiple paragraphs with sparse encoding (splade style) */ + @Test + public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); + var expression = Expression.fromString("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("my2DSparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("The dimension 'doh' given to embed is not a sparse dimension of the target type tensor(passage{},token{})", + e.getMessage()); + } + } + + /** Multiple paragraphs with sparse encoding (splade style) */ + @Test + @SuppressWarnings("OptionalGetWithoutIsPresent") + public void testArrayEmbedTo2MappedTensor() throws ParseException { + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); + var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'", + new SimpleLinguistics(), + embedders); + assertEquals("input myTextArray | embed emb1 passage | attribute my2DSparseTensor", expression.toString()); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + var tensorField = new Field("my2DSparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("abc")); + array.add(new StringFieldValue("cde")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("my2DSparseTensor")); + var sparse2DTensor = (TensorFieldValue)adapter.values.get("my2DSparseTensor"); + assertEquals(Tensor.from( + tensorType, + "tensor(passage{},token{}):" + + "{{passage:0,token:97}:97.0, " + + "{passage:0,token:98}:98.0, " + + "{passage:0,token:99}:99.0, " + + "{passage:1,token:100}:100.0, " + + "{passage:1,token:101}:101.0, " + + "{passage:1,token:99}:99.0}"), + sparse2DTensor.getTensor().get()); + } + +} diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java new file mode 100644 index 000000000000..2072e7c125a9 --- /dev/null +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java @@ -0,0 +1,120 @@ +package com.yahoo.vespa.indexinglanguage; + +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class EmbeddingScriptTester { + + public static void assertThrows(Runnable r, String expectedMessage) { + try { + r.run(); + fail(); + } catch (IllegalStateException e) { + assertEquals(expectedMessage, e.getMessage()); + } + } + + public static abstract class MockEmbedder implements Embedder { + + final String expectedDestination; + final int addition; + + public MockEmbedder(String expectedDestination, int addition) { + this.expectedDestination = expectedDestination; + this.addition = addition; + } + + @Override + public List embed(String text, Embedder.Context context) { + return null; + } + + void verifyDestination(Embedder.Context context) { + assertEquals(expectedDestination, context.getDestination()); + } + + } + + /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */ + public static class MockIndexedEmbedder extends MockEmbedder { + + public MockIndexedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockIndexedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + var b = Tensor.Builder.of(tensorType); + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + return b.build(); + } + + } + + /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */ + public static class MockMappedEmbedder extends MockEmbedder { + + public MockMappedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockMappedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + context.putCachedValue("myCacheKey", "myCachedValue"); + var b = Tensor.Builder.of(tensorType); + for (int i = 0; i < text.length(); i++) + b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition); + return b.build(); + } + + } + + /** + * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input + * char becomes an indexed dimension containing input-1, input, input+1. + */ + public static class MockMixedEmbedder extends MockEmbedder { + + public MockMixedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockMixedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + var b = Tensor.Builder.of(tensorType); + String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name(); + String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name(); + for (int i = 0; i < text.length(); i++) { + for (int j = 0; j < 3; j++) { + b.cell().label(mappedDimension, i) + .label(indexedDimension, j) + .value(text.charAt(i) + addition + j - 1); + } + } + return b.build(); + } + } + +} diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index dd0ec255c356..c4bd69663df7 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -174,498 +174,4 @@ public void testLongHash() throws ParseException { assertEquals(7678158186624760752L, adapter.values.get("myLong").getWrappedValue()); } - @Test - public void testEmbed() throws ParseException { - // Test parsing without knowledge of any embedders - String exp = "input myText | embed emb1 | attribute 'myTensor'"; - Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); - - Map embedder = Map.of( - "emb1", new MockIndexedEmbedder("myDocument.myTensor") - ); - testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); - testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); - testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); - testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, - null, null); - - Map embedders = Map.of( - "emb1", new MockIndexedEmbedder("myDocument.myTensor"), - "emb2", new MockIndexedEmbedder("myDocument.myTensor", 1) - ); - testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, - "my input", "[109.0, 121.0, 32.0, 105.0]"); - testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, - "my input", "[110.0, 122.0, 33.0, 106.0]"); - - assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), - "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1, emb2"); - assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), - "Can't find embedder 'emb3'. Valid embedders are emb1, emb2"); - } - - private void testEmbedStatement(String expressionString, Map embedders, String input, String expected) { - try { - var expression = Expression.fromString(expressionString, new SimpleLinguistics(), embedders); - TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myText", DataType.STRING)); - var tensorField = new Field("myTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - if (input != null) - adapter.setValue("myText", new StringFieldValue(input)); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); - - ExecutionContext context = new ExecutionContext(adapter); - expression.execute(context); - if (input == null) { - assertFalse(adapter.values.containsKey("myTensor")); - } - else { - assertTrue(adapter.values.containsKey("myTensor")); - assertEquals(Tensor.from(tensorType, expected), - ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); - } - } - catch (ParseException e) { - throw new IllegalArgumentException(e); - } - } - - @SuppressWarnings("unchecked") - @Test - public void testArrayEmbed() throws ParseException { - Map embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.myTensorArray")); - - TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - - var tensorField = new Field("myTensorArray", new ArrayDataType(new TensorDataType(tensorType))); - adapter.createField(tensorField); - - var array = new Array(new ArrayDataType(DataType.STRING)); - array.add(new StringFieldValue("first")); - array.add(new StringFieldValue("second")); - adapter.setValue("myTextArray", array); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(new ArrayDataType(new TensorDataType(tensorType)), expression.verify(verificationContext)); - - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(array); - expression.execute(context); - assertTrue(adapter.values.containsKey("myTensorArray")); - var tensorArray = (Array)adapter.values.get("myTensorArray"); - assertEquals(Tensor.from(tensorType, "[102, 105, 114, 115]"), tensorArray.get(0).getTensor().get()); - assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get()); - } - - @Test - public void testArrayEmbedWithConcatenation() throws ParseException { - Map embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); - var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - - var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - - var array = new Array(new ArrayDataType(DataType.STRING)); - array.add(new StringFieldValue("first")); - array.add(new StringFieldValue("second")); - adapter.setValue("myTextArray", array); - - var titleField = new Field("title", DataType.STRING); - adapter.createField(titleField); - adapter.setValue("title", new StringFieldValue("title1")); - - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); - - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(array); - expression.execute(context); - assertTrue(adapter.values.containsKey("mySparseTensor")); - var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); - assertEquals(Tensor.from(tensorType, "{ '0':[116.0, 105.0, 116.0, 108.0], 1:[116.0, 105.0, 116.0, 108.0]}"), - sparseTensor.getTensor().get()); - } - - /** Multiple paragraphs */ - @Test - public void testArrayEmbedTo2dMixedTensor() throws ParseException { - Map embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); - var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - - var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - - var array = new Array(new ArrayDataType(DataType.STRING)); - array.add(new StringFieldValue("first")); - array.add(new StringFieldValue("second")); - adapter.setValue("myTextArray", array); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); - - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(array); - expression.execute(context); - assertTrue(adapter.values.containsKey("mySparseTensor")); - var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); - assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"), - sparseTensor.getTensor().get()); - } - - /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ - @Test - public void testArrayEmbedTo3dMixedTensor() throws ParseException { - Map embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); - var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); - assertEquals("input myTextArray | embed emb1 passage | attribute mySparseTensor", expression.toString()); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - - var array = new Array(new ArrayDataType(DataType.STRING)); - array.add(new StringFieldValue("first")); - array.add(new StringFieldValue("sec")); - adapter.setValue("myTextArray", array); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); - - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(array); - expression.execute(context); - assertTrue(adapter.values.containsKey("mySparseTensor")); - var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); - // The two "passages" are [first, sec], the middle (d=1) token encodes those letters - assertEquals(Tensor.from(tensorType, - """ - { - {passage:0, token:0, d:0}: 101, - {passage:0, token:0, d:1}: 102, - {passage:0, token:0, d:2}: 103, - {passage:0, token:1, d:0}: 104, - {passage:0, token:1, d:1}: 105, - {passage:0, token:1, d:2}: 106, - {passage:0, token:2, d:0}: 113, - {passage:0, token:2, d:1}: 114, - {passage:0, token:2, d:2}: 115, - {passage:0, token:3, d:0}: 114, - {passage:0, token:3, d:1}: 115, - {passage:0, token:3, d:2}: 116, - {passage:0, token:4, d:0}: 115, - {passage:0, token:4, d:1}: 116, - {passage:0, token:4, d:2}: 117, - {passage:1, token:0, d:0}: 114, - {passage:1, token:0, d:1}: 115, - {passage:1, token:0, d:2}: 116, - {passage:1, token:1, d:0}: 100, - {passage:1, token:1, d:1}: 101, - {passage:1, token:1, d:2}: 102, - {passage:1, token:2, d:0}: 98, - {passage:1, token:2, d:1}: 99, - {passage:1, token:2, d:2}: 100 - } - """), - sparseTensor.getTensor().get()); - } - - /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ - @Test - public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException { - Map embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); - var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); - - try { - expression.verify(new VerificationContext(adapter)); - fail("Expected exception"); - } - catch (VerificationException e) { - assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...", - e.getMessage()); - } - } - - /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ - @Test - public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException { - Map embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); - var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); - - try { - expression.verify(new VerificationContext(adapter)); - fail("Expected exception"); - } - catch (VerificationException e) { - assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})", - e.getMessage()); - } - } - - @SuppressWarnings("OptionalGetWithoutIsPresent") - @Test - public void testEmbedToSparseTensor() throws ParseException { - Embedder mappedEmbedder = new MockMappedEmbedder("myDocument.mySparseTensor", 0); - Map embedders = Map.of("emb1",mappedEmbedder); - - TensorType tensorType = TensorType.fromSpec("tensor(t{})"); - var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("text", DataType.STRING)); - - var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - - var text = new StringFieldValue("abc"); - adapter.setValue("text", text); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); - - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(text); - expression.execute(context); - assertTrue(adapter.values.containsKey("mySparseTensor")); - var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); - assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), - sparseTensor.getTensor().get()); - assertEquals("Cached value always set by MockMappedEmbedder is present", - "myCachedValue", context.getCachedValue("myCacheKey")); - } - - /** Multiple paragraphs with sparse encoding (splade style) */ - @Test - public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws ParseException { - Map embedders = Map.of("emb1", new MockMappedEmbedder("myDocument.my2DSparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); - var expression = Expression.fromString("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'", - new SimpleLinguistics(), - embedders); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - adapter.createField(new Field("my2DSparseTensor", new TensorDataType(tensorType))); - - try { - expression.verify(new VerificationContext(adapter)); - fail("Expected exception"); - } - catch (VerificationException e) { - assertEquals("The dimension 'doh' given to embed is not a sparse dimension of the target type tensor(passage{},token{})", - e.getMessage()); - } - } - - /** Multiple paragraphs with sparse encoding (splade style) */ - @Test - @SuppressWarnings("OptionalGetWithoutIsPresent") - public void testArrayEmbedTo2MappedTensor() throws ParseException { - Map embedders = Map.of("emb1", new MockMappedEmbedder("myDocument.my2DSparseTensor")); - - TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); - var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'", - new SimpleLinguistics(), - embedders); - assertEquals("input myTextArray | embed emb1 passage | attribute my2DSparseTensor", expression.toString()); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); - var tensorField = new Field("my2DSparseTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - - var array = new Array(new ArrayDataType(DataType.STRING)); - array.add(new StringFieldValue("abc")); - array.add(new StringFieldValue("cde")); - adapter.setValue("myTextArray", array); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); - - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(array); - expression.execute(context); - assertTrue(adapter.values.containsKey("my2DSparseTensor")); - var sparse2DTensor = (TensorFieldValue)adapter.values.get("my2DSparseTensor"); - assertEquals(Tensor.from( - tensorType, - "tensor(passage{},token{}):" + - "{{passage:0,token:97}:97.0, " + - "{passage:0,token:98}:98.0, " + - "{passage:0,token:99}:99.0, " + - "{passage:1,token:100}:100.0, " + - "{passage:1,token:101}:101.0, " + - "{passage:1,token:99}:99.0}"), - sparse2DTensor.getTensor().get()); - } - - - private void assertThrows(Runnable r, String expectedMessage) { - try { - r.run(); - fail(); - } catch (IllegalStateException e) { - assertEquals(expectedMessage, e.getMessage()); - } - } - - private static abstract class MockEmbedder implements Embedder { - - final String expectedDestination; - final int addition; - - public MockEmbedder(String expectedDestination, int addition) { - this.expectedDestination = expectedDestination; - this.addition = addition; - } - - @Override - public List embed(String text, Embedder.Context context) { - return null; - } - - void verifyDestination(Embedder.Context context) { - assertEquals(expectedDestination, context.getDestination()); - } - - } - - /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */ - private static class MockIndexedEmbedder extends MockEmbedder { - - public MockIndexedEmbedder(String expectedDestination) { - this(expectedDestination, 0); - } - - public MockIndexedEmbedder(String expectedDestination, int addition) { - super(expectedDestination, addition); - } - - @Override - public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { - verifyDestination(context); - var b = Tensor.Builder.of(tensorType); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) - b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); - return b.build(); - } - - } - - /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */ - private static class MockMappedEmbedder extends MockEmbedder { - - public MockMappedEmbedder(String expectedDestination) { - this(expectedDestination, 0); - } - - public MockMappedEmbedder(String expectedDestination, int addition) { - super(expectedDestination, addition); - } - - @Override - public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { - verifyDestination(context); - context.putCachedValue("myCacheKey", "myCachedValue"); - var b = Tensor.Builder.of(tensorType); - for (int i = 0; i < text.length(); i++) - b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition); - return b.build(); - } - - } - - /** - * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input - * char becomes an indexed dimension containing input-1, input, input+1. - */ - private static class MockMixedEmbedder extends MockEmbedder { - - public MockMixedEmbedder(String expectedDestination) { - this(expectedDestination, 0); - } - - public MockMixedEmbedder(String expectedDestination, int addition) { - super(expectedDestination, addition); - } - - @Override - public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { - verifyDestination(context); - var b = Tensor.Builder.of(tensorType); - String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name(); - String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name(); - for (int i = 0; i < text.length(); i++) { - for (int j = 0; j < 3; j++) { - b.cell().label(mappedDimension, i) - .label(indexedDimension, j) - .value(text.charAt(i) + addition + j - 1); - } - } - return b.build(); - } - } - } From c402b9280616ee65fde7469d1649778a83e06354 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 2 Oct 2024 17:43:13 +0200 Subject: [PATCH 2/5] Formatting only --- .../EmbeddingScriptTestCase.java | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java index cb9ddcd799b4..74638141bd7d 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java @@ -42,27 +42,25 @@ public void testEmbed() throws ParseException { "emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor") ); testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); + "input text", "[105, 110, 112, 117]"); testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); + "input text", "[105, 110, 112, 117]"); testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); + "input text", "[105, 110, 112, 117]"); testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, - null, null); + null, null); - Map embedders = Map.of( - "emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor"), - "emb2", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", 1) - ); + Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor"), + "emb2", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", 1)); testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, - "my input", "[109.0, 121.0, 32.0, 105.0]"); + "my input", "[109.0, 121.0, 32.0, 105.0]"); testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, - "my input", "[110.0, 122.0, 33.0, 106.0]"); + "my input", "[110.0, 122.0, 33.0, 106.0]"); EmbeddingScriptTester.assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), - "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1, emb2"); + "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1, emb2"); EmbeddingScriptTester.assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), - "Can't find embedder 'emb3'. Valid embedders are emb1, emb2"); + "Can't find embedder 'emb3'. Valid embedders are emb1, emb2"); } private void testEmbedStatement(String expressionString, Map embedders, String input, String expected) { @@ -90,7 +88,7 @@ private void testEmbedStatement(String expressionString, Map e else { assertTrue(adapter.values.containsKey("myTensor")); assertEquals(Tensor.from(tensorType, expected), - ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); + ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); } } catch (ParseException e) { @@ -105,8 +103,7 @@ public void testArrayEmbed() throws ParseException { TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -139,8 +136,7 @@ public void testArrayEmbedWithConcatenation() throws ParseException { TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -169,7 +165,7 @@ public void testArrayEmbedWithConcatenation() throws ParseException { assertTrue(adapter.values.containsKey("mySparseTensor")); var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); assertEquals(Tensor.from(tensorType, "{ '0':[116.0, 105.0, 116.0, 108.0], 1:[116.0, 105.0, 116.0, 108.0]}"), - sparseTensor.getTensor().get()); + sparseTensor.getTensor().get()); } /** Multiple paragraphs */ @@ -179,8 +175,7 @@ public void testArrayEmbedTo2dMixedTensor() throws ParseException { TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -204,7 +199,7 @@ public void testArrayEmbedTo2dMixedTensor() throws ParseException { assertTrue(adapter.values.containsKey("mySparseTensor")); var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"), - sparseTensor.getTensor().get()); + sparseTensor.getTensor().get()); } /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ @@ -214,8 +209,7 @@ public void testArrayEmbedTo3dMixedTensor() throws ParseException { TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); assertEquals("input myTextArray | embed emb1 passage | attribute mySparseTensor", expression.toString()); SimpleTestAdapter adapter = new SimpleTestAdapter(); @@ -276,8 +270,7 @@ public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws Pars TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -289,7 +282,7 @@ public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws Pars } catch (VerificationException e) { assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...", - e.getMessage()); + e.getMessage()); } } @@ -300,8 +293,7 @@ public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseE TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -313,7 +305,7 @@ public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseE } catch (VerificationException e) { assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})", - e.getMessage()); + e.getMessage()); } } @@ -325,8 +317,7 @@ public void testEmbedToSparseTensor() throws ParseException { TensorType tensorType = TensorType.fromSpec("tensor(t{})"); var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("text", DataType.STRING)); @@ -348,9 +339,9 @@ public void testEmbedToSparseTensor() throws ParseException { assertTrue(adapter.values.containsKey("mySparseTensor")); var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), - sparseTensor.getTensor().get()); + sparseTensor.getTensor().get()); assertEquals("Cached value always set by MockMappedEmbedder is present", - "myCachedValue", context.getCachedValue("myCacheKey")); + "myCachedValue", context.getCachedValue("myCacheKey")); } /** Multiple paragraphs with sparse encoding (splade style) */ @@ -360,8 +351,7 @@ public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws Parse TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); var expression = Expression.fromString("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -373,7 +363,7 @@ public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws Parse } catch (VerificationException e) { assertEquals("The dimension 'doh' given to embed is not a sparse dimension of the target type tensor(passage{},token{})", - e.getMessage()); + e.getMessage()); } } @@ -385,8 +375,7 @@ public void testArrayEmbedTo2MappedTensor() throws ParseException { TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), embedders); assertEquals("input myTextArray | embed emb1 passage | attribute my2DSparseTensor", expression.toString()); SimpleTestAdapter adapter = new SimpleTestAdapter(); From 2a062138646c47d6581392c6d183ca6455065056 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 4 Oct 2024 14:07:40 +0200 Subject: [PATCH 3/5] Add 'binarize' expression --- document/abi-spec.json | 1 + .../java/com/yahoo/document/DataType.java | 2 +- .../com/yahoo/document/TensorDataType.java | 10 +- .../expressions/BinarizeExpression.java | 74 +++++++++ .../expressions/HashExpression.java | 9 +- .../src/main/javacc/IndexingParser.jj | 27 ++- .../EmbeddingScriptTestCase.java | 155 ++++++------------ .../EmbeddingScriptTester.java | 61 ++++++- .../provider/IndexingLangaugeCompletion.java | 1 + 9 files changed, 219 insertions(+), 121 deletions(-) create mode 100644 indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java diff --git a/document/abi-spec.json b/document/abi-spec.json index 5096039a9414..accbe41bd4ef 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -1040,6 +1040,7 @@ "public com.yahoo.tensor.TensorType getTensorType()", "public boolean equals(java.lang.Object)", "public int hashCode()", + "public static com.yahoo.document.TensorDataType any()", "public bridge synthetic com.yahoo.document.DataType clone()", "public bridge synthetic com.yahoo.vespa.objects.Identifiable clone()", "public bridge synthetic java.lang.Object clone()" diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index 6898e404b64a..64bd462fb34d 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -27,7 +27,7 @@ /** * Enumeration of the possible types of fields. Since arrays and weighted sets may be defined for any types, including - * themselves, this enumeration is open ended. + * themselves, this enumeration is open-ended. * * @author bratseth */ diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java index e50d660df5be..2aa863343cad 100644 --- a/document/src/main/java/com/yahoo/document/TensorDataType.java +++ b/document/src/main/java/com/yahoo/document/TensorDataType.java @@ -15,11 +15,13 @@ */ public class TensorDataType extends DataType { - private final TensorType tensorType; - // The global class identifier shared with C++. public static int classId = registerClass(Ids.document + 59, TensorDataType.class); + private static final TensorDataType anyTensorDataType = new TensorDataType(null); + + private final TensorType tensorType; + public TensorDataType(TensorType tensorType) { super(tensorType == null ? "tensor" : tensorType.toString(), DataType.tensorDataTypeCode); this.tensorType = tensorType; @@ -43,6 +45,7 @@ public Class getValueClass() { @Override public boolean isValueCompatible(FieldValue value) { if (value == null) return false; + if (tensorType == null) return true; // any if ( ! TensorFieldValue.class.isAssignableFrom(value.getClass())) return false; TensorFieldValue tensorValue = (TensorFieldValue)value; return tensorType.isConvertibleTo(tensorValue.getDataType().getTensorType()); @@ -65,4 +68,7 @@ public int hashCode() { return Objects.hash(super.hashCode(), tensorType); } + /** Returns the tensor data type representing any tensor. */ + public static TensorDataType any() { return anyTensorDataType; } + } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java new file mode 100644 index 000000000000..44c1c8215103 --- /dev/null +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java @@ -0,0 +1,74 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.indexinglanguage.expressions; + +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.tensor.Tensor; + +import java.util.Objects; +import java.util.Optional; + +/** + * Converts a vector of any input type into a binarized vector. + * + * @author bratseth + */ +public class BinarizeExpression extends Expression { + + private final double threshold; + + //private DataType targetType; + + /** The type this bth consumes and produces. */ + private DataType type; + + /** + * Creates a binarize expression. + * + * @param threshold the value which the tensor cell value must be larger than to be set to 1 and not 0. + */ + public BinarizeExpression(double threshold) { + super(TensorDataType.any()); + this.threshold = threshold; + } + + @Override + public void setStatementOutput(DocumentType documentType, Field field) { +// if (! (field.getDataType() instanceof TensorDataType)) +// throw new IllegalArgumentException("The 'binarize' function requires that the output type is a tensor, " + +// "but it is " + field.getDataType()); +// targetType = field.getDataType(); + } + + @Override + protected void doExecute(ExecutionContext context) { + Optional tensor = ((TensorFieldValue)context.getValue()).getTensor(); + if (tensor.isEmpty()) return; + context.setValue(new TensorFieldValue(tensor.get().map(v -> v > threshold ? 1 : 0))); + } + + @Override + protected void doVerify(VerificationContext context) { + type = context.getValueType(); + if (! (type instanceof TensorDataType)) + throw new IllegalArgumentException("The 'binarize' function requires a tensor, but got " + type); + } + + @Override + public DataType createdOutputType() { return type; } + + @Override + public String toString() { + return "binarize" + (threshold == 0 ? "" : " " + threshold); + } + + @Override + public int hashCode() { return Objects.hash(threshold, toString().hashCode()); } + + @Override + public boolean equals(Object o) { return o instanceof BinarizeExpression; } + +} diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java index 61cfeb3b5db8..eb3e316ba816 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java @@ -36,7 +36,7 @@ public void setStatementOutput(DocumentType documentType, Field field) { field.getName() + ": The hash function can only be used when the target field " + "is int or long or an array of int or long, not " + field.getDataType()); - targetType = primitiveTypeOf(field.getDataType()); + targetType = field.getDataType().getPrimitiveType(); } @Override @@ -68,7 +68,7 @@ protected void doVerify(VerificationContext context) { if ( ! canStoreHash(outputFieldType)) throw new VerificationException(this, "The type of the output field " + outputField + " is not int or long but " + outputFieldType); - targetType = primitiveTypeOf(outputFieldType); + targetType = outputFieldType.getPrimitiveType(); context.setValueType(createdOutputType()); } @@ -79,11 +79,6 @@ private boolean canStoreHash(DataType type) { return false; } - private static DataType primitiveTypeOf(DataType type) { - if (type instanceof ArrayDataType) return ((ArrayDataType)type).getNestedType(); - return type; - } - @Override public DataType createdOutputType() { return targetType; } diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index b1eb6f7aeac9..94143bd8f7ec 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -69,18 +69,18 @@ public class IndexingParser { return this; } - private static FieldValue parseDouble(String str) { + private static DoubleFieldValue parseDouble(String str) { return new DoubleFieldValue(new BigDecimal(str).doubleValue()); } - private static FieldValue parseFloat(String str) { + private static FloatFieldValue parseFloat(String str) { if (str.endsWith("f") || str.endsWith("F")) { str = str.substring(0, str.length() - 1); } return new FloatFieldValue(new BigDecimal(str).floatValue()); } - private static FieldValue parseInteger(String str) { + private static IntegerFieldValue parseInteger(String str) { if (str.startsWith("0x")) { return new IntegerFieldValue(new BigInteger(str.substring(2), 16).intValue()); } else { @@ -88,7 +88,7 @@ public class IndexingParser { } } - private static FieldValue parseLong(String str) { + private static LongFieldValue parseLong(String str) { if (str.endsWith("l") || str.endsWith("L")) { str = str.substring(0, str.length() - 1); } @@ -208,6 +208,7 @@ TOKEN : | | | + | | } @@ -299,11 +300,13 @@ Expression value() : ( val = attributeExp() | val = base64DecodeExp() | val = base64EncodeExp() | + val = binarizeExp() | val = busy_waitExp() | val = clearStateExp() | val = echoExp() | val = embedExp() | val = exactExp() | + val = executionValueExp() | val = flattenExp() | val = forEachExp() | val = getFieldExp() | @@ -317,6 +320,7 @@ Expression value() : val = indexExp() | val = inputExp() | val = joinExp() | + val = literalBoolExp() | val = lowerCaseExp() | val = ngramExp() | val = normalizeExp() | @@ -348,9 +352,7 @@ Expression value() : val = toWsetExp() | val = toBoolExp() | val = trimExp() | - val = literalBoolExp() | val = zcurveExp() | - val = executionValueExp() | ( val = statement() { val = new ParenthesisExpression(val); } ) ) { return val; } } @@ -790,6 +792,15 @@ Expression zcurveExp() : { } { return new ZCurveExpression(); } } +Expression binarizeExp() : +{ + NumericFieldValue threshold = new DoubleFieldValue(0); +} +{ + ( ( threshold = numericValue() )? ) + { return new BinarizeExpression(threshold.getNumber().doubleValue()); } +} + Expression executionValueExp() : { } { ( ) @@ -886,9 +897,9 @@ FieldValue fieldValue() : { return val; } } -FieldValue numericValue() : +NumericFieldValue numericValue() : { - FieldValue val; + NumericFieldValue val; String pre = ""; } { diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java index 74638141bd7d..0b0b00498860 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java @@ -10,20 +10,16 @@ import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.language.process.Embedder; -import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext; -import com.yahoo.vespa.indexinglanguage.expressions.Expression; import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext; import com.yahoo.vespa.indexinglanguage.expressions.VerificationException; -import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; import java.util.Map; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -33,77 +29,43 @@ public class EmbeddingScriptTestCase { @Test - public void testEmbed() throws ParseException { - // Test parsing without knowledge of any embedders - String exp = "input myText | embed emb1 | attribute 'myTensor'"; - Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); - - Map embedder = Map.of( - "emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor") - ); - testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); - testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); - testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, - "input text", "[105, 110, 112, 117]"); - testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, - null, null); - - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor"), - "emb2", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", 1)); - testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, - "my input", "[109.0, 121.0, 32.0, 105.0]"); - testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, - "my input", "[110.0, 122.0, 33.0, 106.0]"); - - EmbeddingScriptTester.assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), - "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1, emb2"); - EmbeddingScriptTester.assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), - "Can't find embedder 'emb3'. Valid embedders are emb1, emb2"); + public void testEmbed() { + // No embedders - parsing only + var tester = new EmbeddingScriptTester(Embedder.throwsOnUse.asMap()); + tester.expressionFrom("input myText | embed emb1 | attribute 'myTensor'"); + + // One embedder + tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor"))); + tester.testStatement("input myText | embed | attribute 'myTensor'", "input text", "[105, 110, 112, 117]"); + tester.testStatement("input myText | embed emb1 | attribute 'myTensor'", "input text", "[105, 110, 112, 117]"); + tester.testStatement("input myText | embed 'emb1' | attribute 'myTensor'", "input text", "[105, 110, 112, 117]"); + tester.testStatement("input myText | embed 'emb1' | attribute 'myTensor'", null, null); + + // Two embedders + tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor"), + "emb2", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", 1))); + tester.testStatement("input myText | embed emb1 | attribute 'myTensor'", "my input", "[109.0, 121.0, 32.0, 105.0]"); + tester.testStatement("input myText | embed emb2 | attribute 'myTensor'", "my input", "[110.0, 122.0, 33.0, 106.0]"); + tester.testStatementThrows("input myText | embed | attribute 'myTensor'", "input text", + "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1, emb2"); + tester.testStatementThrows("input myText | embed emb3 | attribute 'myTensor'", "input text", + "Can't find embedder 'emb3'. Valid embedders are emb1, emb2"); } - private void testEmbedStatement(String expressionString, Map embedders, String input, String expected) { - try { - var expression = Expression.fromString(expressionString, new SimpleLinguistics(), embedders); - TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myText", DataType.STRING)); - var tensorField = new Field("myTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - if (input != null) - adapter.setValue("myText", new StringFieldValue(input)); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); - - ExecutionContext context = new ExecutionContext(adapter); - expression.execute(context); - if (input == null) { - assertFalse(adapter.values.containsKey("myTensor")); - } - else { - assertTrue(adapter.values.containsKey("myTensor")); - assertEquals(Tensor.from(tensorType, expected), - ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); - } - } - catch (ParseException e) { - throw new IllegalArgumentException(e); - } + @Test + public void testEmbedAndBinarize() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", -111))); + tester.testStatement("input myText | embed | binarize | attribute 'myTensor'", "input text", "[0, 0, 1, 1]"); + tester.testStatement("input myText | embed | binarize 3.0 | attribute 'myTensor'", "input text", "[0, 0, 0, 1]"); } @SuppressWarnings("unchecked") @Test - public void testArrayEmbed() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensorArray")); + public void testArrayEmbed() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensorArray"))); TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | for_each { embed } | attribute 'myTensorArray'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -131,12 +93,11 @@ public void testArrayEmbed() throws ParseException { } @Test - public void testArrayEmbedWithConcatenation() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedWithConcatenation() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.mySparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); - var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -170,12 +131,11 @@ public void testArrayEmbedWithConcatenation() throws ParseException { /** Multiple paragraphs */ @Test - public void testArrayEmbedTo2dMixedTensor() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedTo2dMixedTensor() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.mySparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); - var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | embed | attribute 'mySparseTensor'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -204,12 +164,11 @@ public void testArrayEmbedTo2dMixedTensor() throws ParseException { /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ @Test - public void testArrayEmbedTo3dMixedTensor() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedTo3dMixedTensor() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); - var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'"); assertEquals("input myTextArray | embed emb1 passage | attribute mySparseTensor", expression.toString()); SimpleTestAdapter adapter = new SimpleTestAdapter(); @@ -263,14 +222,13 @@ public void testArrayEmbedTo3dMixedTensor() throws ParseException { sparseTensor.getTensor().get()); } - /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBERT style) */ @Test - public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); - var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | embed emb1 | attribute 'mySparseTensor'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -288,12 +246,11 @@ public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws Pars /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ @Test - public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMixedEmbedder("myDocument.mySparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); - var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | embed emb1 d | attribute 'mySparseTensor'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -311,13 +268,11 @@ public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseE @SuppressWarnings("OptionalGetWithoutIsPresent") @Test - public void testEmbedToSparseTensor() throws ParseException { - Embedder mappedEmbedder = new EmbeddingScriptTester.MockMappedEmbedder("myDocument.mySparseTensor", 0); - Map embedders = Map.of("emb1",mappedEmbedder); + public void testEmbedToSparseTensor() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.mySparseTensor", 0))); TensorType tensorType = TensorType.fromSpec("tensor(t{})"); - var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input text | embed | attribute 'mySparseTensor'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("text", DataType.STRING)); @@ -346,12 +301,11 @@ public void testEmbedToSparseTensor() throws ParseException { /** Multiple paragraphs with sparse encoding (splade style) */ @Test - public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor")); + public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); - var expression = Expression.fromString("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -370,12 +324,11 @@ public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws Parse /** Multiple paragraphs with sparse encoding (splade style) */ @Test @SuppressWarnings("OptionalGetWithoutIsPresent") - public void testArrayEmbedTo2MappedTensor() throws ParseException { - Map embedders = Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor")); + public void testArrayEmbedTo2MappedTensor() { + var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor"))); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); - var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'", - new SimpleLinguistics(), embedders); + var expression = tester.expressionFrom("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'"); assertEquals("input myTextArray | embed emb1 passage | attribute my2DSparseTensor", expression.toString()); SimpleTestAdapter adapter = new SimpleTestAdapter(); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java index 2072e7c125a9..b75834f1a163 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java @@ -1,25 +1,82 @@ package com.yahoo.vespa.indexinglanguage; +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.language.process.Embedder; +import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext; +import com.yahoo.vespa.indexinglanguage.parser.ParseException; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class EmbeddingScriptTester { - public static void assertThrows(Runnable r, String expectedMessage) { + private final Map embedders; + + public EmbeddingScriptTester(Map embedders) { + this.embedders = embedders; + } + + public void testStatement(String expressionString, String input, String expected) { + var expression = expressionFrom(expressionString); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + expression.execute(context); + if (input == null) { + assertFalse(adapter.values.containsKey("myTensor")); + } + else { + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); + } + } + + public void testStatementThrows(String expressionString, String input, String expectedMessage) { try { - r.run(); + testStatement(expressionString, input, null); fail(); } catch (IllegalStateException e) { assertEquals(expectedMessage, e.getMessage()); } } + public Expression expressionFrom(String string) { + try { + return Expression.fromString(string, new SimpleLinguistics(), embedders); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + public static abstract class MockEmbedder implements Embedder { final String expectedDestination; diff --git a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/lsp/completion/provider/IndexingLangaugeCompletion.java b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/lsp/completion/provider/IndexingLangaugeCompletion.java index de41be76e3ca..0d2d42b0e228 100644 --- a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/lsp/completion/provider/IndexingLangaugeCompletion.java +++ b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/lsp/completion/provider/IndexingLangaugeCompletion.java @@ -39,6 +39,7 @@ public List getCompletionItems(EventCompletionContext context) { CompletionUtils.withSortingPrefix("b", CompletionUtils.constructBasic("input")), CompletionUtils.withSortingPrefix("b", CompletionUtils.constructBasic("set_language")), CompletionUtils.withSortingPrefix("c", CompletionUtils.constructBasic("embed")), + CompletionUtils.withSortingPrefix("c", CompletionUtils.constructBasic("binarize")), CompletionUtils.withSortingPrefix("c", CompletionUtils.constructBasic("hash")), CompletionUtils.withSortingPrefix("c", CompletionUtils.constructBasic("to_array")), CompletionUtils.withSortingPrefix("c", CompletionUtils.constructBasic("to_byte")), From c7a6e1da3905add10d847d78384f1aad9c4a1708 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 4 Oct 2024 14:51:12 +0200 Subject: [PATCH 4/5] Parse 'binarize' --- .../src/main/javacc/IndexingParser.jj | 11 ++++---- .../ccc/indexinglanguage/IndexingParser.ccc | 27 +++++++++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index 94143bd8f7ec..09aa3829ab91 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -816,7 +816,8 @@ String identifier() : ( | | | - | + | + | | | | @@ -831,15 +832,15 @@ String identifier() : | | | - | - | - | + | + | + | | | | | | - | + | | | | diff --git a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc index 23d3ced8cf8f..a398b9d9199f 100644 --- a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc +++ b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc @@ -79,18 +79,18 @@ INJECT IndexingParser: return this; } - private static FieldValue parseDouble(String str) { + private static DoubleFieldValue parseDouble(String str) { return new DoubleFieldValue(new BigDecimal(str).doubleValue()); } - private static FieldValue parseFloat(String str) { + private static FloatFieldValue parseFloat(String str) { if (str.endsWith("f") || str.endsWith("F")) { str = str.substring(0, str.length() - 1); } return new FloatFieldValue(new BigDecimal(str).floatValue()); } - private static FieldValue parseInteger(String str) { + private static IntegerFieldValue parseInteger(String str) { if (str.startsWith("0x")) { return new IntegerFieldValue(new BigInteger(str.substring(2), 16).intValue()); } else { @@ -98,7 +98,7 @@ INJECT IndexingParser: } } - private static FieldValue parseLong(String str) { + private static LongFieldValue parseLong(String str) { if (str.endsWith("l") || str.endsWith("L")) { str = str.substring(0, str.length() - 1); } @@ -174,6 +174,7 @@ TOKEN : | | | + | | | | @@ -330,6 +331,7 @@ Expression value() : ( val = attributeExp() | val = base64DecodeExp() | val = base64EncodeExp() | + val = binarizeExp() | val = busy_waitExp() | val = clearStateExp() | val = echoExp() | @@ -407,6 +409,14 @@ Expression base64EncodeExp() : { } { return new Base64EncodeExpression(); } ; +Expression binarizeExp() : +{ + NumericFieldValue threshold = new DoubleFieldValue(0); +} + ( [ threshold = numericValue() ] ) + { return new BinarizeExpression(threshold.getNumber().doubleValue()); } +; + Expression busy_waitExp() : { } ( ) @@ -835,11 +845,12 @@ String identifierStr() : String val; } - ( val = stringLiteral() | + ( val = stringLiteral() | ( | | | - | + | + | | | | @@ -920,9 +931,9 @@ FieldValue fieldValue() : { return val; } ; -FieldValue numericValue() : +NumericFieldValue numericValue() : { - FieldValue val; + NumericFieldValue val; String pre = ""; } From 450a46693f35a6010c805ebb6ebb77f72260cd29 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 4 Oct 2024 15:25:04 +0200 Subject: [PATCH 5/5] Cleanup --- .../expressions/BinarizeExpression.java | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java index 44c1c8215103..491860cd788e 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/BinarizeExpression.java @@ -2,8 +2,6 @@ package com.yahoo.vespa.indexinglanguage.expressions; import com.yahoo.document.DataType; -import com.yahoo.document.DocumentType; -import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.tensor.Tensor; @@ -12,7 +10,7 @@ import java.util.Optional; /** - * Converts a vector of any input type into a binarized vector. + * Converts a tensor of any input type into a binarized tensor: Each value is replaced by either 0 or 1. * * @author bratseth */ @@ -20,9 +18,7 @@ public class BinarizeExpression extends Expression { private final double threshold; - //private DataType targetType; - - /** The type this bth consumes and produces. */ + /** The type this consumes and produces. */ private DataType type; /** @@ -35,14 +31,6 @@ public BinarizeExpression(double threshold) { this.threshold = threshold; } - @Override - public void setStatementOutput(DocumentType documentType, Field field) { -// if (! (field.getDataType() instanceof TensorDataType)) -// throw new IllegalArgumentException("The 'binarize' function requires that the output type is a tensor, " + -// "but it is " + field.getDataType()); -// targetType = field.getDataType(); - } - @Override protected void doExecute(ExecutionContext context) { Optional tensor = ((TensorFieldValue)context.getValue()).getTensor();