Skip to content

Commit

Permalink
Accept any value in pack_bits
Browse files Browse the repository at this point in the history
  • Loading branch information
bratseth committed Jan 1, 2025
1 parent f9380e6 commit 20a8aff
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
14 changes: 6 additions & 8 deletions vespajlib/src/main/java/com/yahoo/tensor/Tensors.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ public static Tensor toSparse(Tensor tensor, String ... dimensions) {
}

/**
* Converts any tensor containing only ones and zeroes into one where each consecutive 8 values in the
* dense dimension are packed into a single byte. As a consequence the output type of this is a tensor
* where the dense dimension is 1/8th as large.
* Converts any tensor into one where each consecutive 8 values in the
* dense dimension are packed into a single byte,
* by setting a bit to 1 when the tensor has a positive value and 0 otherwise.
* As a consequence the output type of this is a tensor where the dense dimension is 1/8th as large.
*
* @throws IllegalArgumentException if the tensor has the wrong type or contains any other value than 0 or 1
*/
Expand Down Expand Up @@ -94,13 +95,10 @@ else if (tensor instanceof MixedTensor mixed) {
}

private static int packInto(int packedValue, double value, int bitPosition, long sourcePosition) {
if (value == 0.0)
if (value <= 0.0)
return packedValue;
else if (value == 1.0)
return packedValue | ( 1 << ( 7 - bitPosition ));
else
throw new IllegalArgumentException("The tensor to be packed can only contain 0 or 1 values, " +
"but has " + value + " at position " + sourcePosition);
return packedValue | ( 1 << ( 7 - bitPosition ));
}

}
10 changes: 1 addition & 9 deletions vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void testToSparse() {
@Test
void testPackBits() {
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[16]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]");
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1]");
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,2,3]");
assertPacked("tensor<int8>(x[1]):[-128]", "tensor(x[1]):[1]");
assertPacked("tensor<int8>(key{},x[2]):{a:[-127,14], b:[12, 7]}",
"tensor(key{},x[16]):{a:[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]," +
Expand All @@ -47,14 +47,6 @@ void testPackBits() {
assertEquals("packBits requires a tensor with one dense dimensions, but got tensor(x[1],y[1])",
e.getMessage());
}
try {
Tensors.packBits(Tensor.from("tensor(x[3]):[0, 1, 2]"));
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("The tensor to be packed can only contain 0 or 1 values, but has 2.0 at position 2",
e.getMessage());
}
}

void assertConvertedToSparse(String inputType, String outputType, String tensorValue, String ... dimensions) {
Expand Down

0 comments on commit 20a8aff

Please sign in to comment.