From 00328ecd414ff765c71eddc77b385b528a54fd5a Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 13:57:25 -0700 Subject: [PATCH] Fix bug where quantization framework does not work with training (#2100) (#2102) * Initial implementation Signed-off-by: Ryan Bogan * Modify integration test and fix bugs in jni Signed-off-by: Ryan Bogan * Fix unit test Signed-off-by: Ryan Bogan * Fix integration test after merge Signed-off-by: Ryan Bogan * Add changelog (release notes) Signed-off-by: Ryan Bogan * Add unit test Signed-off-by: Ryan Bogan * Remove entry for release notes Signed-off-by: Ryan Bogan * Add null checks Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan (cherry picked from commit 5d10d64c80f71f383d3e8690d8502d133859ffbf) Co-authored-by: Ryan Bogan --- jni/src/faiss_wrapper.cpp | 25 +++++--- .../knn/index/mapper/CompressionLevel.java | 2 +- .../index/memory/NativeMemoryAllocation.java | 5 ++ .../memory/NativeMemoryEntryContext.java | 8 ++- .../memory/NativeMemoryLoadStrategy.java | 4 ++ .../TrainingModelTransportAction.java | 15 ++++- .../training/FloatTrainingDataConsumer.java | 60 +++++++++++++++++-- .../memory/NativeMemoryEntryContextTests.java | 19 ++++-- .../memory/NativeMemoryLoadStrategyTests.java | 4 +- .../knn/integ/ModeAndCompressionIT.java | 43 ++++--------- .../FloatTrainingDataConsumerTests.java | 44 ++++++++++++-- 11 files changed, 168 insertions(+), 61 deletions(-) diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 227fcb477..45548e0f7 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -684,20 +684,29 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti } else { faiss::SearchParameters *searchParameters = nullptr; faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; std::unique_ptr idGrouper; std::vector idGrouperBitmap; - auto hnswReader = dynamic_cast(indexReader->index); + auto ivfReader = dynamic_cast(indexReader->index); // TODO currently, search parameter is not supported in binary index // To avoid test failure, we skip setting ef search when methodPramsJ is null temporary - if(hnswReader!= nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) { - // Query param efsearch supersedes ef_search provided during index setting. - hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); - if (parentIdsJ != nullptr) { - idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); - hnswParams.grp = idGrouper.get(); + if (ivfReader) { + int indexNprobe = ivfReader->nprobe; + ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe); + searchParameters = &ivfParams; + } else { + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader != nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) { + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } + searchParameters = &hnswParams; } - searchParameters = &hnswParams; } + try { indexReader->search(1, reinterpret_cast(rawQueryvector), kJ, dis.data(), ids.data(), searchParameters); } catch (...) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index 222e042b6..0709239cf 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -30,7 +30,7 @@ public enum CompressionLevel { x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)); // Internally, an empty string is easier to deal with them null. However, from the mapping, - // we do not want users to pass in the empty string and instead want null. So we make the conversion herex + // we do not want users to pass in the empty string and instead want null. So we make the conversion here public static final String[] NAMES_ARRAY = new String[] { NOT_CONFIGURED.getName(), x1.getName(), diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 02b480ed4..0bb8a556f 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -12,9 +12,11 @@ package org.opensearch.knn.index.memory; import lombok.Getter; +import lombok.Setter; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNEngine; @@ -252,6 +254,9 @@ class TrainingDataAllocation implements NativeMemoryAllocation { private volatile boolean closed; private long memoryAddress; private final int size; + @Getter + @Setter + private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY; // Implement reader/writer with semaphores to deal with passing lock conditions between threads private int readCount; diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 2dfc5fafb..dd219593d 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -11,8 +11,10 @@ package org.opensearch.knn.index.memory; +import lombok.Getter; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.VectorDataType; @@ -171,6 +173,8 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext listener) { + KNNMethodContext knnMethodContext = request.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = request.getKnnMethodConfigContext(); + QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY; + + if (knnMethodContext != null && request.getKnnMethodConfigContext() != null) { + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + quantizationConfig = knnLibraryIndexingContext.getQuantizationConfig(); + } NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( request.getTrainingDataSizeInKB(), @@ -54,7 +66,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener clusterService, request.getMaximumVectorCount(), request.getSearchSize(), - request.getVectorDataType() + request.getVectorDataType(), + quantizationConfig ); // Allocation representing size model will occupy in memory during training diff --git a/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java index d742a9184..292752945 100644 --- a/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java +++ b/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java @@ -13,10 +13,19 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.action.search.SearchResponse; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.knn.quantization.factory.QuantizerFactory; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.quantizer.Quantizer; import org.opensearch.search.SearchHit; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -25,6 +34,8 @@ */ public class FloatTrainingDataConsumer extends TrainingDataConsumer { + private final QuantizationConfig quantizationConfig; + /** * Constructor * @@ -32,16 +43,28 @@ public class FloatTrainingDataConsumer extends TrainingDataConsumer { */ public FloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { super(trainingDataAllocation); + this.quantizationConfig = trainingDataAllocation.getQuantizationConfig(); } @Override public void accept(List floats) { - trainingDataAllocation.setMemoryAddress( - JNIService.transferVectors( - trainingDataAllocation.getMemoryAddress(), - floats.stream().map(v -> ArrayUtils.toPrimitive((Float[]) v)).toArray(float[][]::new) - ) - ); + if (isValidFloatsAndQuantizationConfig(floats)) { + try { + List byteVectors = quantizeVectors(floats); + long memoryAddress = trainingDataAllocation.getMemoryAddress(); + memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); + trainingDataAllocation.setMemoryAddress(memoryAddress); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + trainingDataAllocation.setMemoryAddress( + JNIService.transferVectors( + trainingDataAllocation.getMemoryAddress(), + floats.stream().map(v -> ArrayUtils.toPrimitive((Float[]) v)).toArray(float[][]::new) + ) + ); + } } @Override @@ -64,4 +87,29 @@ public void processTrainingVectors(SearchResponse searchResponse, int vectorsToA accept(vectors); } + + private List quantizeVectors(List vectors) throws IOException { + List bytes = new ArrayList<>(); + ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(quantizationConfig.getQuantizationType()); + Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); + // Create training request + TrainingRequest trainingRequest = new TrainingRequest(vectors.size()) { + @Override + public float[] getVectorAtThePosition(int position) { + return ArrayUtils.toPrimitive((Float[]) vectors.get(position)); + } + }; + QuantizationState quantizationState = quantizer.train(trainingRequest); + BinaryQuantizationOutput binaryQuantizationOutput = new BinaryQuantizationOutput(quantizationConfig.getQuantizationType().getId()); + for (int i = 0; i < vectors.size(); i++) { + quantizer.quantize(ArrayUtils.toPrimitive((Float[]) vectors.get(i)), quantizationState, binaryQuantizationOutput); + bytes.add(binaryQuantizationOutput.getQuantizedVectorCopy()); + } + + return bytes; + } + + private boolean isValidFloatsAndQuantizationConfig(List floats) { + return floats != null && floats.isEmpty() == false && quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY; + } } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java index 385572cb4..1720da1ed 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java @@ -14,6 +14,7 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -124,7 +125,8 @@ public void testTrainingDataEntryContext_load() { null, 0, 0, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + QuantizationConfig.EMPTY ); NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( @@ -149,7 +151,8 @@ public void testTrainingDataEntryContext_getTrainIndexName() { null, 0, 0, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + QuantizationConfig.EMPTY ); assertEquals(trainIndexName, trainingDataEntryContext.getTrainIndexName()); @@ -165,7 +168,8 @@ public void testTrainingDataEntryContext_getTrainFieldName() { null, 0, 0, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + QuantizationConfig.EMPTY ); assertEquals(trainFieldName, trainingDataEntryContext.getTrainFieldName()); @@ -181,7 +185,8 @@ public void testTrainingDataEntryContext_getMaxVectorCount() { null, maxVectorCount, 0, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + QuantizationConfig.EMPTY ); assertEquals(maxVectorCount, trainingDataEntryContext.getMaxVectorCount()); @@ -197,7 +202,8 @@ public void testTrainingDataEntryContext_getSearchSize() { null, 0, searchSize, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + QuantizationConfig.EMPTY ); assertEquals(searchSize, trainingDataEntryContext.getSearchSize()); @@ -213,7 +219,8 @@ public void testTrainingDataEntryContext_getIndicesService() { clusterService, 0, 0, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + QuantizationConfig.EMPTY ); assertEquals(clusterService, trainingDataEntryContext.getClusterService()); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 29fbdb978..bdd8d7e45 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -18,6 +18,7 @@ import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQueryResult; @@ -180,7 +181,8 @@ public void testTrainingLoadStrategy_load() { null, 0, 0, - VectorDataType.FLOAT + VectorDataType.FLOAT, + QuantizationConfig.EMPTY ); // Load the allocation. Initially, the memory address should be 0. However, after the readlock is obtained, diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index ea9203196..59c435e2c 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -8,7 +8,6 @@ import lombok.SneakyThrows; import org.apache.http.util.EntityUtils; import org.junit.Assert; -import org.junit.Ignore; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -252,12 +251,14 @@ public void testTraining_whenInvalid_thenFail() { // Training isnt currently supported for mode and compression because quantization framework does not quantize // the training vectors. So, commenting out for now. - @Ignore @SneakyThrows public void testTraining_whenValid_thenSucceed() { setupTrainingIndex(); XContentBuilder builder; for (String compressionLevel : CompressionLevel.NAMES_ARRAY) { + if (compressionLevel.equals("4x")) { + continue; + } String indexName = INDEX_NAME + compressionLevel; String modelId = indexName; builder = XContentFactory.jsonBuilder() @@ -287,38 +288,13 @@ public void testTraining_whenValid_thenSucceed() { compressionLevel, Mode.NOT_CONFIGURED.getName() ); + deleteKNNIndex(indexName); } - - for (String compressionLevel : CompressionLevel.NAMES_ARRAY) { - for (String mode : Mode.NAMES_ARRAY) { - String indexName = INDEX_NAME + compressionLevel + "_" + mode; - String modelId = indexName; - builder = XContentFactory.jsonBuilder() - .startObject() - .field(TRAIN_INDEX_PARAMETER, TRAINING_INDEX_NAME) - .field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME) - .field(KNNConstants.DIMENSION, DIMENSION) - .field(MODEL_DESCRIPTION, "") - .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) - .field(MODE_PARAMETER, mode) - .endObject(); - validateTraining(modelId, builder); - builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", "knn_vector") - .field("model_id", modelId) - .endObject() - .endObject() - .endObject(); - String mapping = builder.toString(); - validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NLIST_DEFAULT, compressionLevel, mode); - } - } - for (String mode : Mode.NAMES_ARRAY) { + if (mode == null) { + continue; + } + mode = mode.toLowerCase(); String indexName = INDEX_NAME + mode; String modelId = indexName; builder = XContentFactory.jsonBuilder() @@ -348,8 +324,8 @@ public void testTraining_whenValid_thenSucceed() { CompressionLevel.NOT_CONFIGURED.getName(), mode ); + deleteKNNIndex(indexName); } - } @SneakyThrows @@ -459,6 +435,7 @@ private void validateSearch( String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity()); List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); assertEquals(NUM_DOCS, exactSearchKnnResults.size()); + if (CompressionLevel.x4.getName().equals(compressionLevelString) == false && Mode.ON_DISK.getName().equals(mode)) { Assert.assertEquals(exactSearchKnnResults, knnResults); } diff --git a/src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java b/src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java index 27e02b46b..6e0410853 100644 --- a/src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java +++ b/src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java @@ -13,7 +13,9 @@ import org.mockito.ArgumentCaptor; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.util.ArrayList; import java.util.Arrays; @@ -29,12 +31,46 @@ public void testAccept() { // Mock the training data allocation int dimension = 128; - NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = mock(NativeMemoryAllocation.TrainingDataAllocation.class); // new - // NativeMemoryAllocation.TrainingDataAllocation(0, - // numVectors*dimension* - // Float.BYTES); + NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = mock(NativeMemoryAllocation.TrainingDataAllocation.class); + when(trainingDataAllocation.getMemoryAddress()).thenReturn(0L); + when(trainingDataAllocation.getQuantizationConfig()).thenReturn(QuantizationConfig.EMPTY); + + // Capture argument passed to set pointer + ArgumentCaptor valueCapture = ArgumentCaptor.forClass(Long.class); + + FloatTrainingDataConsumer floatTrainingDataConsumer = new FloatTrainingDataConsumer(trainingDataAllocation); + + List vectorSet1 = new ArrayList<>(3); + for (int i = 0; i < 3; i++) { + Float[] vector = new Float[dimension]; + Arrays.fill(vector, (float) i); + vectorSet1.add(vector); + } + + // Transfer vectors + floatTrainingDataConsumer.accept(vectorSet1); + + // Ensure that the pointer captured has been updated + verify(trainingDataAllocation).setMemoryAddress(valueCapture.capture()); + when(trainingDataAllocation.getMemoryAddress()).thenReturn(valueCapture.getValue()); + + assertNotEquals(0, trainingDataAllocation.getMemoryAddress()); + } + + public void testAccept_withQuantizationConfig() { + + // Mock the training data allocation + int dimension = 128; + NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = mock(NativeMemoryAllocation.TrainingDataAllocation.class); + + when(trainingDataAllocation.getMemoryAddress()).thenReturn(0L); + + QuantizationConfig quantizationConfig = mock(QuantizationConfig.class); + when(quantizationConfig.getQuantizationType()).thenReturn(ScalarQuantizationType.ONE_BIT); + when(trainingDataAllocation.getQuantizationConfig()).thenReturn(QuantizationConfig.EMPTY); + // Capture argument passed to set pointer ArgumentCaptor valueCapture = ArgumentCaptor.forClass(Long.class);