From 20a8aff32647be79e92ff0524ee7de8522e57004 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 1 Jan 2025 19:57:59 +0100 Subject: [PATCH] Accept any value in pack_bits --- .../src/main/java/com/yahoo/tensor/Tensors.java | 14 ++++++-------- .../java/com/yahoo/tensor/TensorsTestCase.java | 10 +--------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensors.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensors.java index a7bd9b1b4c6..ef15932781e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensors.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensors.java @@ -48,9 +48,10 @@ public static Tensor toSparse(Tensor tensor, String ... dimensions) { } /** - * Converts any tensor containing only ones and zeroes into one where each consecutive 8 values in the - * dense dimension are packed into a single byte. As a consequence the output type of this is a tensor - * where the dense dimension is 1/8th as large. + * Converts any tensor into one where each consecutive 8 values in the + * dense dimension are packed into a single byte, + * by setting a bit to 1 when the tensor has a positive value and 0 otherwise. + * As a consequence the output type of this is a tensor where the dense dimension is 1/8th as large. * * @throws IllegalArgumentException if the tensor has the wrong type or contains any other value than 0 or 1 */ @@ -94,13 +95,10 @@ else if (tensor instanceof MixedTensor mixed) { } private static int packInto(int packedValue, double value, int bitPosition, long sourcePosition) { - if (value == 0.0) + if (value <= 0.0) return packedValue; - else if (value == 1.0) - return packedValue | ( 1 << ( 7 - bitPosition )); else - throw new IllegalArgumentException("The tensor to be packed can only contain 0 or 1 values, " + - "but has " + value + " at position " + sourcePosition); + return packedValue | ( 1 << ( 7 - bitPosition )); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java index 364f377f984..2819fb0b073 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java @@ -30,7 +30,7 @@ void testToSparse() { @Test void testPackBits() { assertPacked("tensor(x[2]):[-127,14]", "tensor(x[16]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]"); - assertPacked("tensor(x[2]):[-127,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1]"); + assertPacked("tensor(x[2]):[-127,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,2,3]"); assertPacked("tensor(x[1]):[-128]", "tensor(x[1]):[1]"); assertPacked("tensor(key{},x[2]):{a:[-127,14], b:[12, 7]}", "tensor(key{},x[16]):{a:[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]," + @@ -47,14 +47,6 @@ void testPackBits() { assertEquals("packBits requires a tensor with one dense dimensions, but got tensor(x[1],y[1])", e.getMessage()); } - try { - Tensors.packBits(Tensor.from("tensor(x[3]):[0, 1, 2]")); - fail("Expected exception"); - } - catch (IllegalArgumentException e) { - assertEquals("The tensor to be packed can only contain 0 or 1 values, but has 2.0 at position 2", - e.getMessage()); - } } void assertConvertedToSparse(String inputType, String outputType, String tensorValue, String ... dimensions) {