Skip to content

Commit

Permalink
Integrate Lucene Vector field with native engines to use KNNVectorFor…
Browse files Browse the repository at this point in the history
…mat during segment creation (#1945)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Aug 12, 2024
1 parent 2cd57e8 commit 5a5351f
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 41 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
### Infrastructure
### Documentation
### Maintenance
Expand All @@ -32,4 +33,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931)
* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877)
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
33 changes: 32 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ public class KNNSettings {
public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit";
public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold";
public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled";
/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled";

/**
* Default setting values
Expand Down Expand Up @@ -255,6 +261,17 @@ public class KNNSettings {
NodeScope
);

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static final Setting<Boolean> KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting(
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED,
false,
NodeScope
);

/**
* Dynamic settings
*/
Expand Down Expand Up @@ -379,6 +396,10 @@ private Setting<?> getSetting(String key) {
return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING;
}

if (KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED.equals(key)) {
return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -397,7 +418,8 @@ public List<Setting<?>> getSettings() {
MODEL_CACHE_SIZE_LIMIT_SETTING,
ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING,
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING,
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down Expand Up @@ -443,6 +465,15 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

/**
* TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the
* code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added
* for native engines.
*/
public static boolean getIsLuceneVectorFormatEnabled() {
return KNNSettings.state().getSettingValue(KNNSettings.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED);
}

public void initialize(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.mapper;

import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -57,8 +58,11 @@ private FlatVectorFieldMapper(
Version indexCreatedVersion
) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null);
// setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created.
this.useLuceneBasedVectorField = false;
this.perDimensionValidator = selectPerDimensionValidator(vectorDataType);
this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);
this.fieldType.setDocValuesType(DocValuesType.BINARY);
this.fieldType.freeze();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.IndexOptions;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
Expand Down Expand Up @@ -456,6 +457,7 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
protected boolean hasDocValues;
protected VectorDataType vectorDataType;
protected ModelDao modelDao;
protected boolean useLuceneBasedVectorField;

// We need to ensure that the original KNNMethodContext as parsed is stored to initialize the
// Builder for serialization. So, we need to store it here. This is mainly to ensure that the legacy field mapper
Expand Down Expand Up @@ -497,16 +499,29 @@ protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getKnnMappingConfig().getDimension(), fieldType().getVectorDataType());
}

private Field createVectorField(float[] vectorValue) {
if (useLuceneBasedVectorField) {
return new KnnFloatVectorField(name(), vectorValue, fieldType);
}
return new VectorField(name(), vectorValue, fieldType);
}

private Field createVectorField(byte[] vectorValue) {
if (useLuceneBasedVectorField) {
return new KnnByteVectorField(name(), vectorValue, fieldType);
}
return new VectorField(name(), vectorValue, fieldType);
}

/**
* Function returns a list of fields to be indexed when the vector is float type.
*
* @param array array of floats
* @param fieldType {@link FieldType}
* @return {@link List} of {@link Field}
*/
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType) {
protected List<Field> getFieldsForFloatVector(final float[] array) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
fields.add(createVectorField(array));
if (this.stored) {
fields.add(createStoredFieldForFloatVector(name(), array));
}
Expand All @@ -517,12 +532,11 @@ protected List<Field> getFieldsForFloatVector(final float[] array, final FieldTy
* Function returns a list of fields to be indexed when the vector is byte type.
*
* @param array array of bytes
* @param fieldType {@link FieldType}
* @return {@link List} of {@link Field}
*/
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType) {
protected List<Field> getFieldsForByteVector(final byte[] array) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
fields.add(createVectorField(array));
if (this.stored) {
fields.add(createStoredFieldForByteVector(name(), array));
}
Expand Down Expand Up @@ -561,24 +575,14 @@ protected void validatePreparse() {
protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
validatePreparse();

if (VectorDataType.BINARY == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType);

if (bytesArrayOptional.isEmpty()) {
return;
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForByteVector(array, fieldType));
} else if (VectorDataType.BYTE == vectorDataType) {
if (VectorDataType.BINARY == vectorDataType || VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType);

if (bytesArrayOptional.isEmpty()) {
return;
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForByteVector(array, fieldType));
context.doc().addAll(getFieldsForByteVector(array));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

Expand All @@ -587,7 +591,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForFloatVector(array, fieldType));
context.doc().addAll(getFieldsForFloatVector(array));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down Expand Up @@ -714,7 +718,6 @@ public static class Defaults {
static {
FIELD_TYPE.setTokenized(false);
FIELD_TYPE.setIndexOptions(IndexOptions.NONE);
FIELD_TYPE.setDocValuesType(DocValuesType.BINARY);
FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type
FIELD_TYPE.freeze();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,22 @@ static void validateIfKNNPluginEnabled() {
}
}

/**
* Prerequisite: Index should a knn index which is validated via index settings index.knn setting. This function
* assumes that caller has already validated that index is a KNN index.
* We will use LuceneKNNVectorsFormat when these below condition satisfy:
* <ol>
* <li>Index is created with Version of opensearch >= 2.17</li>
* <li>Cluster setting is enabled to use Lucene KNNVectors format. This condition is temporary condition and will be
* removed before release.</li>
* </ol>
* @param indexCreatedVersion {@link Version}
* @return true if vector field should use KNNVectorsFormat
*/
static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) {
return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNSettings.getIsLuceneVectorFormatEnabled();
}

private static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) {
String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey());
if (spaceType == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
Expand Down Expand Up @@ -112,9 +112,9 @@ private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final Create
}

@Override
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType) {
protected List<Field> getFieldsForFloatVector(final float[] array) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType));
fieldsToBeAdded.add(new KnnFloatVectorField(name(), array, fieldType));

if (hasDocValues && vectorFieldType != null) {
fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType));
Expand All @@ -127,7 +127,7 @@ protected List<Field> getFieldsForFloatVector(final float[] array, final FieldTy
}

@Override
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType) {
protected List<Field> getFieldsForByteVector(final byte[] array) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
package org.opensearch.knn.index.mapper;

import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.VectorEncoding;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -99,6 +102,7 @@ private MethodFieldMapper(
indexVerision,
originalKNNMethodContext
);
this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion);
KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig();
KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext()
.orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));
Expand All @@ -118,6 +122,22 @@ private MethodFieldMapper(
throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe));
}

if (useLuceneBasedVectorField) {
int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY
? annConfig.getDimension() / 8
: annConfig.getDimension();
final VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT
? VectorEncoding.FLOAT32
: VectorEncoding.BYTE;
fieldType.setVectorAttributes(
adjustedDimension,
encoding,
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);
} else {
fieldType.setDocValuesType(DocValuesType.BINARY);
}

this.fieldType.freeze();
initValidatorsAndProcessors(knnMethodContext);
knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
package org.opensearch.knn.index.mapper;

import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.VectorEncoding;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.indices.ModelDao;
Expand Down Expand Up @@ -102,7 +105,7 @@ private ModelFieldMapper(

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);
this.fieldType.putAttribute(MODEL_ID, modelId);
this.fieldType.freeze();
this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion);
}

@Override
Expand Down Expand Up @@ -193,6 +196,21 @@ private void initPerDimensionProcessor() {
protected void parseCreateField(ParseContext context) throws IOException {
validatePreparse();
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
if (useLuceneBasedVectorField) {
int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY
? modelMetadata.getDimension()
: modelMetadata.getDimension() / 8;
final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT
? VectorEncoding.FLOAT32
: VectorEncoding.BYTE;
fieldType.setVectorAttributes(
adjustedDimension,
encoding,
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);
} else {
fieldType.setDocValuesType(DocValuesType.BINARY);
}
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType());
}

Expand Down
Loading

0 comments on commit 5a5351f

Please sign in to comment.