Skip to content

Commit

Permalink
Revert "Arnej/unify cell type conversion"
Browse files Browse the repository at this point in the history
  • Loading branch information
baldersheim authored Mar 12, 2023
1 parent b06d77b commit ce5d891
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,25 +180,12 @@ static TensorType toVespaType(ValueInfo valueInfo) {
}

static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
// NOTE:
// should match best_cell_type in onnx_wrapper.cpp
switch (onnxType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return TensorType.Value.INT8;

case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
return TensorType.Value.BFLOAT16;

case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return TensorType.Value.FLOAT;

case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
return TensorType.Value.DOUBLE;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
}
return TensorType.Value.DOUBLE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,21 @@ static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
tensor.getDimsList());
}

private static TensorType.Value toValueType(Onnx.TensorProto.DataType onnxType) {
// NOTE:
// should match best_cell_type in onnx_wrapper.cpp
switch (onnxType) {
case BOOL: // Imperfect conversion fallthrough
case INT8:
return TensorType.Value.INT8;
case BFLOAT16:
return TensorType.Value.BFLOAT16;
case UINT8: // Imperfect conversion fallthrough
case INT16: // Imperfect conversion fallthrough
case UINT16: // Imperfect conversion fallthrough
case FLOAT:
return TensorType.Value.FLOAT;
case INT32: // Imperfect conversion fallthrough
case INT64: // Imperfect conversion fallthrough
case UINT32: // Imperfect conversion fallthrough
case UINT64: // Imperfect conversion fallthrough
case DOUBLE:
return TensorType.Value.DOUBLE;
default: throw new IllegalArgumentException("A ONNX tensor with data type " + onnxType +
private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
switch (dataType) {
case FLOAT: return TensorType.Value.FLOAT;
case DOUBLE: return TensorType.Value.DOUBLE;
// Imperfect conversion, for now:
case BOOL: return TensorType.Value.FLOAT;
case INT8: return TensorType.Value.FLOAT;
case INT16: return TensorType.Value.FLOAT;
case INT32: return TensorType.Value.FLOAT;
case INT64: return TensorType.Value.FLOAT;
case UINT8: return TensorType.Value.FLOAT;
case UINT16: return TensorType.Value.FLOAT;
case UINT32: return TensorType.Value.FLOAT;
case UINT64: return TensorType.Value.FLOAT;
default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
" cannot be converted to a Vespa tensor type");
}
}
Expand Down
6 changes: 1 addition & 5 deletions model-integration/src/main/protobuf/onnx.proto
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,6 @@ message TensorProto {
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
}

Expand Down Expand Up @@ -465,4 +461,4 @@ message OperatorSetIdProto {
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
}
}

0 comments on commit ce5d891

Please sign in to comment.