Skip to content

Commit

Permalink
Merge pull request #32747 from vespa-engine/bratseth/indexing-type-in…
Browse files Browse the repository at this point in the history
…ference

Bratseth/indexing type inference
  • Loading branch information
bratseth authored Nov 3, 2024
2 parents b9be442 + 0ad806f commit 9b7e68b
Show file tree
Hide file tree
Showing 117 changed files with 1,103 additions and 384 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void process(boolean validate, boolean documentsOnly) {
converter.convert(exp); // TODO: stop doing this explicitly when visiting a script does not branch
}
} catch (VerificationException e) {
fail(schema, field, "For expression '" + e.getExpression() + "': " + Exceptions.toMessageString(e));
fail(schema, field, Exceptions.toMessageString(e));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void requireThatExtraFieldInputImplicitThrows() throws ParseException {
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'indexing_extra_field_input_implicit', field 'foo': " +
"For expression '{ tokenize normalize stem:\"BEST\" | index foo; }': Expected string input, but no input is specified",
"Invalid expression '{ tokenize normalize stem:\"BEST\" | index foo; }': Expected string input, but no input is specified",
Exceptions.toMessageString(e));
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ void testNoInputInDerivedField() throws ParseException {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'test', field 'derived1': For expression '{ attribute derived1; }': " +
assertEquals("For schema 'test', field 'derived1': Invalid expression '{ attribute derived1; }': " +
"Expected any input, but no input is specified",
Exceptions.toMessageString(e));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void requireThatOutputConflictThrows() throws ParseException {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'indexing_output_confict', field 'bar': For expression 'index bar': Attempting " +
assertEquals("For schema 'indexing_output_confict', field 'bar': Invalid expression 'index bar': Attempting " +
"to assign conflicting values to field 'bar'",
Exceptions.toMessageString(e));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void testAttributeChanged() throws ParseException {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'indexing_attribute_changed', field 'foo': For expression 'attribute foo': " +
assertEquals("For schema 'indexing_attribute_changed', field 'foo': Invalid expression 'attribute foo': " +
"Attempting to assign conflicting values to field 'foo'",
Exceptions.toMessageString(e));
}
Expand Down Expand Up @@ -79,7 +79,7 @@ void testIndexChanged() throws ParseException {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'indexing_index_changed', field 'foo': For expression 'index foo': " +
assertEquals("For schema 'indexing_index_changed', field 'foo': Invalid expression 'index foo': " +
"Attempting to assign conflicting values to field 'foo'",
Exceptions.toMessageString(e));
}
Expand Down Expand Up @@ -123,7 +123,7 @@ void testSummaryChanged() throws ParseException {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'indexing_summary_fail', field 'foo': For expression 'summary foo': Attempting " +
assertEquals("For schema 'indexing_summary_fail', field 'foo': Invalid expression 'summary foo': Attempting " +
"to assign conflicting values to field 'foo'",
Exceptions.toMessageString(e));
}
Expand Down Expand Up @@ -185,7 +185,7 @@ void requireThatMultilineOutputConflictThrows() throws ParseException {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("For schema 'indexing_multiline_output_confict', field 'cox': For expression 'index cox': " +
assertEquals("For schema 'indexing_multiline_output_confict', field 'cox': Invalid expression 'index cox': " +
"Attempting to assign conflicting values to field 'cox'",
Exceptions.toMessageString(e));
}
Expand Down
3 changes: 2 additions & 1 deletion document/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@
"public com.yahoo.document.datatypes.FieldValue createFieldValue(java.lang.Object)",
"public abstract java.lang.Class getValueClass()",
"public abstract boolean isValueCompatible(com.yahoo.document.datatypes.FieldValue)",
"public final boolean isAssignableFrom(com.yahoo.document.DataType)",
"public boolean isAssignableFrom(com.yahoo.document.DataType)",
"public boolean isAssignableTo(com.yahoo.document.DataType)",
"public static com.yahoo.document.ArrayDataType getArray(com.yahoo.document.DataType)",
"public static com.yahoo.document.MapDataType getMap(com.yahoo.document.DataType, com.yahoo.document.DataType)",
"public static com.yahoo.document.WeightedSetDataType getWeightedSet(com.yahoo.document.DataType)",
Expand Down
7 changes: 6 additions & 1 deletion document/src/main/java/com/yahoo/document/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,16 @@ public FieldValue createFieldValue(Object arg) {

public abstract boolean isValueCompatible(FieldValue value);

public final boolean isAssignableFrom(DataType dataType) {
public boolean isAssignableFrom(DataType dataType) {
// TODO: Reverse this so that isValueCompatible() uses this instead.
return isValueCompatible(dataType.createFieldValue());
}

/** The reverse of isAssignableFrom */
public boolean isAssignableTo(DataType dataType) {
return dataType.isAssignableFrom(this);
}

/**
* Returns an array datatype, where the array elements are of the given type
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ public class NumericDataType extends PrimitiveDataType {

// The global class identifier shared with C++.
public static int classId = registerClass(Ids.document + 52, NumericDataType.class);

/**
* Creates a datatype
*
* @param name the name of the type
* @param code the code (id) of the type
* @param type the field value used for this type
*/
protected NumericDataType(java.lang.String name, int code, Class<? extends FieldValue> type, Factory factory) {
protected NumericDataType(String name, int code, Class<? extends FieldValue> type, Factory factory) {
super(name, code, type, factory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,21 @@ private AnyDataType() {
}

@Override
public FieldValue createFieldValue() { throw new UnsupportedOperationException(); }
public boolean isAssignableFrom(DataType other) { return true; }

@Override
public Class<?> getValueClass() { throw new UnsupportedOperationException(); }
public boolean isAssignableTo(DataType other) {
return other instanceof AnyDataType;
}

@Override
public boolean isValueCompatible(FieldValue value) { return true; }

@Override
public FieldValue createFieldValue() { throw new UnsupportedOperationException(); }

@Override
public Class<?> getValueClass() { throw new UnsupportedOperationException(); }


}
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,48 @@ public ArithmeticExpression convertChildren(ExpressionConverter converter) {
return new ArithmeticExpression(converter.convert(left), op, converter.convert(right));
}

@Override
public DataType setInputType(DataType inputType, VerificationContext context) {
super.setInputType(inputType, context);
DataType leftOutput = left.setInputType(inputType, context);
DataType rightOutput = right.setInputType(inputType, context);
return resultingType(leftOutput, rightOutput);
}

@Override
public DataType setOutputType(DataType outputType, VerificationContext context) {
super.setOutputType(outputType, context);
DataType leftInput = left.setOutputType(outputType, context);
DataType rightInput = right.setOutputType(outputType, context);
if (leftInput == rightInput) // TODO: Generalize
return leftInput;
else
return getInputType(context);
}

@Override
protected void doVerify(VerificationContext context) {
DataType input = context.getCurrentType();
context.setCurrentType(evaluate(context.setCurrentType(input).verify(left).getCurrentType(),
context.setCurrentType(input).verify(right).getCurrentType()));
context.setCurrentType(resultingType(context.setCurrentType(input).verify(left).getCurrentType(),
context.setCurrentType(input).verify(right).getCurrentType()));
}

private DataType resultingType(DataType left, DataType right) {
if (left == null || right == null)
return null;
if (!(left instanceof NumericDataType))
throw new VerificationException(this, "The first argument must be a number, but has type " + left.getName());
if (!(right instanceof NumericDataType))
throw new VerificationException(this, "The second argument must be a number, but has type " + right.getName());

if (left == DataType.FLOAT || left == DataType.DOUBLE || right == DataType.FLOAT || right == DataType.DOUBLE) {
if (left == DataType.DOUBLE || right == DataType.DOUBLE)
return DataType.DOUBLE;
return DataType.FLOAT;
}
if (left == DataType.LONG || right == DataType.LONG)
return DataType.LONG;
return DataType.INT;
}

@Override
Expand All @@ -80,15 +117,15 @@ protected void doExecute(ExecutionContext context) {
context.setCurrentValue(input).execute(right).getCurrentValue()));
}

private static DataType requiredInputType(Expression lhs, Expression rhs) {
DataType lhsType = lhs.requiredInputType();
DataType rhsType = rhs.requiredInputType();
if (lhsType == null) return rhsType;
if (rhsType == null) return lhsType;
if (!lhsType.equals(rhsType))
private static DataType requiredInputType(Expression left, Expression right) {
DataType leftType = left.requiredInputType();
DataType rightType = right.requiredInputType();
if (leftType == null) return rightType;
if (rightType == null) return leftType;
if (!leftType.equals(rightType))
throw new VerificationException(ArithmeticExpression.class, "Operands require conflicting input types, " +
lhsType.getName() + " vs " + rhsType.getName());
return lhsType;
leftType.getName() + " vs " + rightType.getName());
return leftType;
}

@Override
Expand Down Expand Up @@ -116,23 +153,6 @@ public int hashCode() {
return getClass().hashCode() + left.hashCode() + op.hashCode() + right.hashCode();
}

private DataType evaluate(DataType lhs, DataType rhs) {
if (lhs == null || rhs == null)
throw new VerificationException(this, "Attempting to perform arithmetic on a null value");
if (!(lhs instanceof NumericDataType) || !(rhs instanceof NumericDataType))
throw new VerificationException(this, "Attempting to perform unsupported arithmetic: [" +
lhs.getName() + "] " + op + " [" + rhs.getName() + "]");

if (lhs == DataType.FLOAT || lhs == DataType.DOUBLE || rhs == DataType.FLOAT || rhs == DataType.DOUBLE) {
if (lhs == DataType.DOUBLE || rhs == DataType.DOUBLE)
return DataType.DOUBLE;
return DataType.FLOAT;
}
if (lhs == DataType.LONG || rhs == DataType.LONG)
return DataType.LONG;
return DataType.INT;
}

private FieldValue evaluate(FieldValue lhs, FieldValue rhs) {
if (lhs == null || rhs == null) return null;
if (!(lhs instanceof NumericFieldValue) || !(rhs instanceof NumericFieldValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public DataType setInputType(DataType inputType, VerificationContext context) {

@Override
public DataType setOutputType(DataType outputType, VerificationContext context) {
super.setOutputType(outputType, DataType.LONG, context);
super.setOutputType(DataType.LONG, outputType, null, context);
return DataType.STRING;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public DataType setInputType(DataType inputType, VerificationContext context) {

@Override
public DataType setOutputType(DataType outputType, VerificationContext context) {
super.setOutputType(outputType, DataType.STRING, context);
super.setOutputType(DataType.STRING, outputType, null, context);
return DataType.LONG;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ public DataType setInputType(DataType inputType, VerificationContext context) {

@Override
public DataType setOutputType(DataType outputType, VerificationContext context) {
return super.setOutputType(outputType, TensorDataType.any(), context);
return super.setOutputType(null, outputType, TensorDataType.any(), context);
}

@Override
protected void doVerify(VerificationContext context) {
type = context.getCurrentType();
if (! (type instanceof TensorDataType))
throw new IllegalArgumentException("The 'binarize' function requires a tensor, but got " + type);
throw new VerificationException(this, "Require a tensor, but got " + type);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ private static double nihlakanta(int i) {
@Override public String toString() { return "busy_wait"; }
@Override public boolean equals(Object obj) { return obj instanceof BusyWaitExpression; }
@Override public int hashCode() { return getClass().hashCode(); }

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.yahoo.document.ArrayDataType;
import com.yahoo.document.CollectionDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.WeightedSetDataType;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.FieldValue;
Expand All @@ -12,10 +13,12 @@
import com.yahoo.vespa.indexinglanguage.ExpressionConverter;

import java.util.*;
import java.util.List;

/**
* @author Simon Thoresen Hult
*/
// TODO: Support Map in addition to Array and Wighted Set (doc just says "collection type")
public final class CatExpression extends ExpressionList<Expression> {

public CatExpression(Expression... expressions) {
Expand All @@ -31,6 +34,32 @@ public CatExpression convertChildren(ExpressionConverter converter) {
return new CatExpression(convertChildList(converter));
}

@Override
public DataType setInputType(DataType inputType, VerificationContext context) {
super.setInputType(inputType, context);

List<DataType> outputTypes = new ArrayList<>(expressions().size());
for (var expression : expressions())
outputTypes.add(expression.setInputType(inputType, context));
DataType outputType = resolveOutputType(outputTypes);
return outputType != null ? outputType : getOutputType(context);
}

@Override
public DataType setOutputType(DataType outputType, VerificationContext context) {
if (outputType != DataType.STRING && ! (outputType instanceof CollectionDataType))
throw new VerificationException(this, "Required to produce " + outputType.getName() +
", but this produces a string or collection");
super.setOutputType(outputType, context);
for (var expression : expressions())
expression.setOutputType(AnyDataType.instance, context); // Any output is handled by converting to string

if (outputType instanceof CollectionDataType)
return outputType;
else
return getInputType(context); // Cannot infer input type since we take the string value
}

@Override
protected void doVerify(VerificationContext context) {
DataType input = context.getCurrentType();
Expand All @@ -52,8 +81,8 @@ protected void doExecute(ExecutionContext context) {
context.fillVariableTypes(verificationContext);
List<FieldValue> values = new LinkedList<>();
List<DataType> types = new LinkedList<>();
for (Expression exp : this) {
FieldValue val = context.setCurrentValue(input).execute(exp).getCurrentValue();
for (Expression expression : this) {
FieldValue val = context.setCurrentValue(input).execute(expression).getCurrentValue();
values.add(val);

DataType type;
Expand Down Expand Up @@ -110,6 +139,7 @@ public boolean equals(Object obj) {
private static DataType resolveOutputType(List<DataType> types) {
DataType resolved = null;
for (DataType type : types) {
if (type == null) return null;
if (!(type instanceof CollectionDataType)) return DataType.STRING;

if (resolved == null)
Expand Down
Loading

0 comments on commit 9b7e68b

Please sign in to comment.