Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binary index support for Lucene engine #2292

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public float compare(byte[] v1, byte[] v2) {

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
// For binary vectors using Lucene engine we instead implement a custom BinaryVectorScorer
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
}
};
Expand Down
22 changes: 12 additions & 10 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

/**
* Enum contains data_type of vectors
* Lucene supports byte and float data type
* Lucene supports binary, byte and float data type
* NMSLib supports only float data type
* Faiss supports binary and float data type
*/
Expand All @@ -39,8 +39,10 @@ public enum VectorDataType {
BINARY("binary") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new IllegalStateException("Unsupported method");
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
// For binary vectors using Lucene engine we instead implement a custom BinaryVectorScorer so the VectorSimilarityFunction will
// not be used.
return KnnByteVectorField.createFieldType(dimension / Byte.SIZE, VectorSimilarityFunction.EUCLIDEAN);
}

@Override
Expand Down Expand Up @@ -68,8 +70,8 @@ public void freeNativeMemory(long memoryAddress) {
BYTE("byte") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction());
}

@Override
Expand Down Expand Up @@ -97,8 +99,8 @@ public void freeNativeMemory(long memoryAddress) {
FLOAT("float") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction());
}

@Override
Expand Down Expand Up @@ -129,11 +131,11 @@ public void freeNativeMemory(long memoryAddress) {
* Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and
* VectorSimilarityFunction.
*
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @param dimension Dimension of the vector
* @param knnVectorSimilarityFunction KNNVectorSimilarityFunction for a given spaceType
* @return FieldType
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);
public abstract FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction);

/**
* Deserializes float vector from BytesRef.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
}
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth,
knnMethodContext.getSpaceType()
);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.opensearch.knn.index.KNNVectorSimilarityFunction;

import java.io.IOException;

/**
* A FlatVectorsScorer to be used for scoring binary vectors. Meant to be used with {@link KNN9120BinaryVectorScorer}
*/
public class KNN9120BinaryVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] queryVector
) throws IOException {
throw new IllegalArgumentException("binary vectors do not support float[] targets");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] queryVector
) throws IOException {
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

static class BinaryRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final byte[] queryVector;

BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
this.queryVector = query;
this.vectorValues = vectorValues;
}

@Override
public float score(int node) throws IOException {
return KNNVectorSimilarityFunction.HAMMING.compare(queryVector, vectorValues.vectorValue(node));
}

@Override
public int maxOrd() {
return vectorValues.size();
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return vectorValues.getAcceptOrds(acceptDocs);
}
}

static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;

public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] queryVector = vectorValues1.vectorValue(ord);
return new BinaryRandomVectorScorer(vectorValues2, queryVector);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinaryRandomVectorScorerSupplier(vectorValues.copy());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
import java.util.concurrent.ExecutorService;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import static org.opensearch.knn.index.engine.KNNEngine.getMaxDimensionByEngine;

/**
* Custom KnnVectorsFormat implementation to support binary vectors. This class is mostly identical to
* {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat}, however we use the custom {@link KNN9120BinaryVectorScorer}
* to perform hamming bit scoring.
*/
public final class KNN9120HnswBinaryVectorsFormat extends KnnVectorsFormat {

private final int maxConn;
private final int beamWidth;
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN9120BinaryVectorScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

private static final String NAME = "KNN990HnswBinaryVectorsFormat";

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat()}
*/
public KNN9120HnswBinaryVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int)}
*/
public KNN9120HnswBinaryVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, 1, null);
}

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int, int, java.util.concurrent.ExecutorService)}
*/
public KNN9120HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
this.maxConn,
this.beamWidth,
flatVectorsFormat.fieldsWriter(state),
this.numMergeWorkers,
this.mergeExec
);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return getMaxDimensionByEngine(KNNEngine.LUCENE);
}

@Override
public String toString() {
return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn="
+ this.maxConn
+ ", beamWidth="
+ this.beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {
private static final int NUM_MERGE_WORKERS = 1;

public KNN9120PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> {
// There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that
// changes in the future.
if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) {
return new KNN9120HnswBinaryVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
);
} else {
return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth());
}
},
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
NUM_MERGE_WORKERS,
knnScalarQuantizedVectorsFormatParams.getBits(),
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
null
)
);
}

@Override
/**
* This method returns the maximum dimension allowed from KNNEngine for Lucene codec
*
* @param fieldName Name of the field, ignored
* @return Maximum constant dimension set by KNNEngine
*/
public int getMaxDimensions(String fieldName) {
return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
Expand Down
Loading
Loading