Skip to content

Commit

Permalink
Merge pull request #32522 from vespa-engine/bratseth/binarize
Browse files Browse the repository at this point in the history
Bratseth/binarize
  • Loading branch information
bratseth authored Oct 4, 2024
2 parents 0f5c5f5 + 450a466 commit 5ea7b94
Show file tree
Hide file tree
Showing 11 changed files with 660 additions and 525 deletions.
1 change: 1 addition & 0 deletions document/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@
"public com.yahoo.tensor.TensorType getTensorType()",
"public boolean equals(java.lang.Object)",
"public int hashCode()",
"public static com.yahoo.document.TensorDataType any()",
"public bridge synthetic com.yahoo.document.DataType clone()",
"public bridge synthetic com.yahoo.vespa.objects.Identifiable clone()",
"public bridge synthetic java.lang.Object clone()"
Expand Down
2 changes: 1 addition & 1 deletion document/src/main/java/com/yahoo/document/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

/**
* Enumeration of the possible types of fields. Since arrays and weighted sets may be defined for any types, including
* themselves, this enumeration is open ended.
* themselves, this enumeration is open-ended.
*
* @author bratseth
*/
Expand Down
10 changes: 8 additions & 2 deletions document/src/main/java/com/yahoo/document/TensorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
*/
public class TensorDataType extends DataType {

private final TensorType tensorType;

// The global class identifier shared with C++.
public static int classId = registerClass(Ids.document + 59, TensorDataType.class);

private static final TensorDataType anyTensorDataType = new TensorDataType(null);

private final TensorType tensorType;

public TensorDataType(TensorType tensorType) {
super(tensorType == null ? "tensor" : tensorType.toString(), DataType.tensorDataTypeCode);
this.tensorType = tensorType;
Expand All @@ -43,6 +45,7 @@ public Class<? extends TensorFieldValue> getValueClass() {
@Override
public boolean isValueCompatible(FieldValue value) {
if (value == null) return false;
if (tensorType == null) return true; // any
if ( ! TensorFieldValue.class.isAssignableFrom(value.getClass())) return false;
TensorFieldValue tensorValue = (TensorFieldValue)value;
return tensorType.isConvertibleTo(tensorValue.getDataType().getTensorType());
Expand All @@ -65,4 +68,7 @@ public int hashCode() {
return Objects.hash(super.hashCode(), tensorType);
}

/** Returns the tensor data type representing any tensor. */
public static TensorDataType any() { return anyTensorDataType; }

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.DataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.tensor.Tensor;

import java.util.Objects;
import java.util.Optional;

/**
* Converts a tensor of any input type into a binarized tensor: Each value is replaced by either 0 or 1.
*
* @author bratseth
*/
public class BinarizeExpression extends Expression {

private final double threshold;

/** The type this consumes and produces. */
private DataType type;

/**
* Creates a binarize expression.
*
* @param threshold the value which the tensor cell value must be larger than to be set to 1 and not 0.
*/
public BinarizeExpression(double threshold) {
super(TensorDataType.any());
this.threshold = threshold;
}

@Override
protected void doExecute(ExecutionContext context) {
Optional<Tensor> tensor = ((TensorFieldValue)context.getValue()).getTensor();
if (tensor.isEmpty()) return;
context.setValue(new TensorFieldValue(tensor.get().map(v -> v > threshold ? 1 : 0)));
}

@Override
protected void doVerify(VerificationContext context) {
type = context.getValueType();
if (! (type instanceof TensorDataType))
throw new IllegalArgumentException("The 'binarize' function requires a tensor, but got " + type);
}

@Override
public DataType createdOutputType() { return type; }

@Override
public String toString() {
return "binarize" + (threshold == 0 ? "" : " " + threshold);
}

@Override
public int hashCode() { return Objects.hash(threshold, toString().hashCode()); }

@Override
public boolean equals(Object o) { return o instanceof BinarizeExpression; }

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void setStatementOutput(DocumentType documentType, Field field) {
field.getName() +
": The hash function can only be used when the target field " +
"is int or long or an array of int or long, not " + field.getDataType());
targetType = primitiveTypeOf(field.getDataType());
targetType = field.getDataType().getPrimitiveType();
}

@Override
Expand Down Expand Up @@ -68,7 +68,7 @@ protected void doVerify(VerificationContext context) {
if ( ! canStoreHash(outputFieldType))
throw new VerificationException(this, "The type of the output field " + outputField +
" is not int or long but " + outputFieldType);
targetType = primitiveTypeOf(outputFieldType);
targetType = outputFieldType.getPrimitiveType();
context.setValueType(createdOutputType());
}

Expand All @@ -79,11 +79,6 @@ private boolean canStoreHash(DataType type) {
return false;
}

private static DataType primitiveTypeOf(DataType type) {
if (type instanceof ArrayDataType) return ((ArrayDataType)type).getNestedType();
return type;
}

@Override
public DataType createdOutputType() { return targetType; }

Expand Down
38 changes: 25 additions & 13 deletions indexinglanguage/src/main/javacc/IndexingParser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,26 @@ public class IndexingParser {
return this;
}

private static FieldValue parseDouble(String str) {
private static DoubleFieldValue parseDouble(String str) {
return new DoubleFieldValue(new BigDecimal(str).doubleValue());
}

private static FieldValue parseFloat(String str) {
private static FloatFieldValue parseFloat(String str) {
if (str.endsWith("f") || str.endsWith("F")) {
str = str.substring(0, str.length() - 1);
}
return new FloatFieldValue(new BigDecimal(str).floatValue());
}

private static FieldValue parseInteger(String str) {
private static IntegerFieldValue parseInteger(String str) {
if (str.startsWith("0x")) {
return new IntegerFieldValue(new BigInteger(str.substring(2), 16).intValue());
} else {
return new IntegerFieldValue(new BigInteger(str).intValue());
}
}

private static FieldValue parseLong(String str) {
private static LongFieldValue parseLong(String str) {
if (str.endsWith("l") || str.endsWith("L")) {
str = str.substring(0, str.length() - 1);
}
Expand Down Expand Up @@ -208,6 +208,7 @@ TOKEN :
<ZCURVE: "zcurve"> |
<TRUE: "true" > |
<FALSE: "false" > |
<BINARIZE: "binarize" > |
<UNDERSCORE: "_"> |
<IDENTIFIER: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_","-"])*>
}
Expand Down Expand Up @@ -299,11 +300,13 @@ Expression value() :
( val = attributeExp() |
val = base64DecodeExp() |
val = base64EncodeExp() |
val = binarizeExp() |
val = busy_waitExp() |
val = clearStateExp() |
val = echoExp() |
val = embedExp() |
val = exactExp() |
val = executionValueExp() |
val = flattenExp() |
val = forEachExp() |
val = getFieldExp() |
Expand All @@ -317,6 +320,7 @@ Expression value() :
val = indexExp() |
val = inputExp() |
val = joinExp() |
val = literalBoolExp() |
val = lowerCaseExp() |
val = ngramExp() |
val = normalizeExp() |
Expand Down Expand Up @@ -348,9 +352,7 @@ Expression value() :
val = toWsetExp() |
val = toBoolExp() |
val = trimExp() |
val = literalBoolExp() |
val = zcurveExp() |
val = executionValueExp() |
( <LPAREN> val = statement() <RPAREN> { val = new ParenthesisExpression(val); } ) )
{ return val; }
}
Expand Down Expand Up @@ -790,6 +792,15 @@ Expression zcurveExp() : { }
{ return new ZCurveExpression(); }
}

Expression binarizeExp() :
{
NumericFieldValue threshold = new DoubleFieldValue(0);
}
{
( <BINARIZE> ( threshold = numericValue() )? )
{ return new BinarizeExpression(threshold.getNumber().doubleValue()); }
}

Expression executionValueExp() : { }
{
( <UNDERSCORE> )
Expand All @@ -805,7 +816,8 @@ String identifier() :
( <ATTRIBUTE> |
<BASE64_DECODE> |
<BASE64_ENCODE> |
<BUSY_WAIT> |
<BINARIZE> |
<BUSY_WAIT> |
<CASE> |
<CASE_DEFAULT> |
<CLEAR_STATE> |
Expand All @@ -820,15 +832,15 @@ String identifier() :
<GET_VAR> |
<GUARD> |
<HASH> |
<HEX_DECODE> |
<HEX_ENCODE> |
<HOST_NAME> |
<HEX_DECODE> |
<HEX_ENCODE> |
<HOST_NAME> |
<IDENTIFIER> |
<IF> |
<INDEX> |
<INPUT> |
<JOIN> |
<LOWER_CASE> |
<LOWER_CASE> |
<MAX_LENGTH> |
<NGRAM> |
<NORMALIZE> |
Expand Down Expand Up @@ -886,9 +898,9 @@ FieldValue fieldValue() :
{ return val; }
}

FieldValue numericValue() :
NumericFieldValue numericValue() :
{
FieldValue val;
NumericFieldValue val;
String pre = "";
}
{
Expand Down
Loading

0 comments on commit 5ea7b94

Please sign in to comment.