diff --git a/CHANGELOG.md b/CHANGELOG.md index 44f387533a..f7b8ed8ebc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.16...2.x) ### Features +* Integrate Lucene Vector field with native engines, to use KNNVectorFormat during segment creation. [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Enhancements ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) @@ -31,4 +32,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) -* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) \ No newline at end of file +* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 33c7ff410b..4ced38b38e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -82,6 +82,12 @@ public class KNNSettings { public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit"; public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold"; public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled"; + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled"; /** * Default setting values @@ -255,6 +261,17 @@ public class KNNSettings { NodeScope ); + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final Setting KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting( + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, + false, + NodeScope + ); + /** * Dynamic settings */ @@ -379,6 +396,10 @@ private Setting getSetting(String key) { return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; } + if (KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED.equals(key)) { + return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -397,7 +418,8 @@ public List> getSettings() { MODEL_CACHE_SIZE_LIMIT_SETTING, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, KNN_FAISS_AVX2_DISABLED_SETTING, - KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING + KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -443,6 +465,15 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static boolean getIsLuceneVectorFormatEnabled() { + return KNNSettings.state().getSettingValue(KNNSettings.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED); + } + public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 3b94876454..2ce6a2a067 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -17,7 +17,8 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.IndexOptions; import org.opensearch.Version; import org.opensearch.common.Explicit; @@ -172,10 +173,13 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected Version indexCreatedVersion; - public Builder(String name, ModelDao modelDao, Version indexCreatedVersion) { + protected boolean isIndexKNN; + + public Builder(String name, ModelDao modelDao, Version indexCreatedVersion, boolean isIndexKNN) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; + this.isIndexKNN = isIndexKNN; } /** @@ -187,12 +191,13 @@ public Builder(String name, ModelDao modelDao, Version indexCreatedVersion) { * @param m m value of field * @param efConstruction efConstruction value of field */ - public Builder(String name, String spaceType, String m, String efConstruction, Version indexCreatedVersion) { + public Builder(String name, String spaceType, String m, String efConstruction, Version indexCreatedVersion, boolean isIndexKNN) { super(name); this.spaceType = spaceType; this.m = m; this.efConstruction = efConstruction; this.indexCreatedVersion = indexCreatedVersion; + this.isIndexKNN = isIndexKNN; } @Override @@ -253,6 +258,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { .hasDocValues(hasDocValues.get()) .vectorDataType(vectorDataType.getValue()) .knnMethodContext(knnMethodContext) + .isIndexKNN(isIndexKNN) .build(); return new LuceneFieldMapper(createLuceneFieldMapperInput); } @@ -265,7 +271,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - knnMethodContext + knnMethodContext, + isIndexKNN ); } @@ -286,7 +293,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { hasDocValues.get(), modelDao, modelIdAsString, - indexCreatedVersion + indexCreatedVersion, + isIndexKNN ); } @@ -325,7 +333,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { spaceType, m, efConstruction, - indexCreatedVersion + indexCreatedVersion, + isIndexKNN ); } @@ -430,7 +439,12 @@ public TypeParser(Supplier modelDaoSupplier) { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated()); + Builder builder = new KNNVectorFieldMapper.Builder( + name, + modelDaoSupplier.get(), + parserContext.indexVersionCreated(), + parserContext.getSettings().getAsBoolean(KNN_INDEX, false) + ); builder.parse(name, parserContext, node); // All parse(String name, Map node, ParserCont // subclass (if it is unique). protected KNNMethodContext knnMethod; protected String modelId; + protected boolean isIndexKNN; public KNNVectorFieldMapper( String simpleName, @@ -473,9 +488,11 @@ public KNNVectorFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - Version indexCreatedVersion + Version indexCreatedVersion, + boolean isIndexKNN ) { super(simpleName, mappedFieldType, multiFields, copyTo); + this.isIndexKNN = isIndexKNN; this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; @@ -516,12 +533,11 @@ private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMet * Function returns a list of fields to be indexed when the vector is float type. * * @param array array of floats - * @param fieldType {@link FieldType} * @return {@link List} of {@link Field} */ - protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + protected List getFieldsForFloatVector(final float[] array) { final List fields = new ArrayList<>(); - fields.add(new VectorField(name(), array, fieldType)); + fields.add(createVectorField(array)); if (this.stored) { fields.add(createStoredFieldForFloatVector(name(), array)); } @@ -532,18 +548,31 @@ protected List getFieldsForFloatVector(final float[] array, final FieldTy * Function returns a list of fields to be indexed when the vector is byte type. * * @param array array of bytes - * @param fieldType {@link FieldType} * @return {@link List} of {@link Field} */ - protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + protected List getFieldsForByteVector(final byte[] byteArray) { final List fields = new ArrayList<>(); - fields.add(new VectorField(name(), array, fieldType)); + fields.add(createVectorField(byteArray)); if (this.stored) { - fields.add(createStoredFieldForByteVector(name(), array)); + fields.add(createStoredFieldForByteVector(name(), byteArray)); } return fields; } + private Field createVectorField(float[] vectorValue) { + if (KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion, isIndexKNN)) { + return new KnnFloatVectorField(name(), vectorValue, fieldType); + } + return new VectorField(name(), vectorValue, fieldType); + } + + private Field createVectorField(byte[] vectorValue) { + if (KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion, isIndexKNN)) { + return new KnnByteVectorField(name(), vectorValue, fieldType); + } + return new VectorField(name(), vectorValue, fieldType); + } + protected void parseCreateField( ParseContext context, int dimension, @@ -564,7 +593,7 @@ protected void parseCreateField( } final byte[] array = bytesArrayOptional.get(); spaceType.validateVector(array); - context.doc().addAll(getFieldsForByteVector(array, fieldType)); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -573,7 +602,7 @@ protected void parseCreateField( } final byte[] array = bytesArrayOptional.get(); spaceType.validateVector(array); - context.doc().addAll(getFieldsForByteVector(array, fieldType)); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); @@ -582,7 +611,7 @@ protected void parseCreateField( } final float[] array = floatsArrayOptional.get(); spaceType.validateVector(array); - context.doc().addAll(getFieldsForFloatVector(array, fieldType)); + context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) @@ -746,7 +775,7 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this); + return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion, isIndexKNN).init(this); } @Override @@ -783,7 +812,6 @@ public static class Defaults { static { FIELD_TYPE.setTokenized(false); FIELD_TYPE.setIndexOptions(IndexOptions.NONE); - FIELD_TYPE.setDocValuesType(DocValuesType.BINARY); FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type FIELD_TYPE.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 2adbbb6953..099d5d8ac1 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -17,7 +17,9 @@ import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.util.BytesRef; +import org.opensearch.Version; import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; @@ -245,6 +247,22 @@ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFiel return VectorDataType.BINARY == knnVectorFieldType.getVectorDataType() ? expectedDimensions / 8 : expectedDimensions; } + /** + * We will use LuceneKNNVectorsFormat when these below condition satisfy: + *
    + *
  1. Index is created with Version of opensearch >= 2.17
  2. + *
  3. index.knn setting is marked as true
  4. + *
  5. Cluster setting is enabled to use Lucene KNNVectors format. This condition is temporary condition and will be + * removed before release.
  6. + *
+ * @param indexCreatedVersion {@link Version} + * @param isIndexKNN boolean + * @return true if vector field should use KNNVectorsFormat + */ + static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion, final boolean isIndexKNN) { + return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && isIndexKNN == true && KNNSettings.getIsLuceneVectorFormatEnabled(); + } + private static boolean isModelBasedIndex(int expectedDimensions) { return expectedDimensions == -1; } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java index cf5ec933a1..742dd39334 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java @@ -7,11 +7,14 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.IndexHyperParametersUtil; import org.opensearch.knn.index.engine.KNNEngine; @@ -51,10 +54,10 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { String spaceType, String m, String efConstruction, - Version indexCreatedVersion + Version indexCreatedVersion, + boolean isIndexKNN ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, isIndexKNN); this.spaceType = spaceType; this.m = m; this.efConstruction = efConstruction; @@ -69,14 +72,31 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { this.fieldType.putAttribute(HNSW_ALGO_M, m); this.fieldType.putAttribute(HNSW_ALGO_EF_CONSTRUCTION, efConstruction); + if (KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion, isIndexKNN)) { + int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY ? dimension : dimension / 8; + VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT ? VectorEncoding.FLOAT32 : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } + this.fieldType.freeze(); } @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction, this.indexCreatedVersion).init( - this - ); + return new KNNVectorFieldMapper.Builder( + simpleName(), + this.spaceType, + this.m, + this.efConstruction, + this.indexCreatedVersion, + this.isIndexKNN + ).init(this); } static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index c82afb9e72..ff83dd56d4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -14,7 +14,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -44,7 +44,8 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getKnnMethodContext().getMethodComponentContext().getIndexVersion() + input.getKnnMethodContext().getMethodComponentContext().getIndexVersion(), + input.isIndexKNN() ); vectorDataType = input.getVectorDataType(); @@ -76,9 +77,9 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { } @Override - protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + protected List getFieldsForFloatVector(final float[] array) { final List fieldsToBeAdded = new ArrayList<>(); - fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType)); + fieldsToBeAdded.add(new KnnFloatVectorField(name(), array, fieldType)); if (hasDocValues && vectorFieldType != null) { fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); @@ -91,7 +92,7 @@ protected List getFieldsForFloatVector(final float[] array, final FieldTy } @Override - protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + protected List getFieldsForByteVector(final byte[] array) { final List fieldsToBeAdded = new ArrayList<>(); fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType)); @@ -129,5 +130,6 @@ static class CreateLuceneFieldMapperInput { VectorDataType vectorDataType; @NonNull KNNMethodContext knnMethodContext; + boolean isIndexKNN; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index b15ab14894..787ce2fc73 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -6,8 +6,12 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.KNNEngine; @@ -33,7 +37,8 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - KNNMethodContext knnMethodContext + KNNMethodContext knnMethodContext, + boolean isIndexKNN ) { super( @@ -44,7 +49,8 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { ignoreMalformed, stored, hasDocValues, - knnMethodContext.getMethodComponentContext().getIndexVersion() + knnMethodContext.getMethodComponentContext().getIndexVersion(), + isIndexKNN ); this.knnMethod = knnMethodContext; @@ -62,9 +68,19 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { Map libParams = knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString()); } catch (IOException ioe) { - throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); + throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe), ioe); + } + if (KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion, isIndexKNN)) { + int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY ? dimension / 8 : dimension; + VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT ? VectorEncoding.FLOAT32 : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); } - this.fieldType.freeze(); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index adaaef28e6..8d3030e20b 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -6,9 +6,13 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -32,16 +36,16 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { boolean hasDocValues, ModelDao modelDao, String modelId, - Version indexCreatedVersion + Version indexCreatedVersion, + boolean isIndexKNN ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, isIndexKNN); this.modelId = modelId; this.modelDao = modelDao; this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); - this.fieldType.freeze(); } @Override @@ -61,7 +65,20 @@ protected void parseCreateField(ParseContext context) throws IOException { ) ); } + this.dimension = modelMetadata.getDimension(); + this.vectorDataType = modelMetadata.getVectorDataType(); + if (KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion, isIndexKNN)) { + int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY ? dimension : dimension / 8; + final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT ? VectorEncoding.FLOAT32 : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } parseCreateField( context, modelMetadata.getDimension(), diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index a0b9b32d0e..20e6e6cbae 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Query; @@ -84,11 +86,10 @@ * Test used for testing Codecs */ public class KNNCodecTestCase extends KNNTestCase { - - private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); private static FieldType sampleFieldType; static { sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + sampleFieldType.setDocValuesType(DocValuesType.BINARY); sampleFieldType.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); @@ -240,6 +241,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio iwc.setCodec(codec); FieldType fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + fieldType.setDocValuesType(DocValuesType.BINARY); fieldType.putAttribute(KNNConstants.MODEL_ID, modelId); fieldType.freeze(); @@ -326,9 +328,9 @@ public void testKnnVectorIndex( /** * Add doc with field "test_vector_one" */ - final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); + final FieldType luceneFieldType = KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); float[] array = { 1.0f, 3.0f, 4.0f }; - KnnVectorField vectorField = new KnnVectorField(FIELD_NAME_ONE, array, luceneFieldType); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME_ONE, array, luceneFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); doc.add(vectorField); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index c95568be22..9354a679cd 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -8,10 +8,11 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; +import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.Explicit; @@ -103,7 +104,7 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, true); assertEquals(7, builder.getParameters().size()); List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); @@ -114,7 +115,7 @@ public void testBuilder_getParameters() { public void testBuilder_build_fromKnnMethodContext() { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -151,7 +152,7 @@ public void testBuilder_build_fromKnnMethodContext() { public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -191,7 +192,7 @@ public void testBuilder_build_fromModel() { public void testBuilder_build_fromLegacy() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); int m = 17; int efConstruction = 17; @@ -214,7 +215,7 @@ public void testBuilder_build_fromLegacy() { public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -742,6 +743,129 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } + @SneakyThrows + public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() { + MockedStatic utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class); + for (VectorDataType dataType : VectorDataType.values()) { + final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + methodComponentContext.setIndexVersion(CURRENT); + SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT; + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext); + + final KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, + knnMethodContext, + dataType + ); + + ParseContext.Document document = new ParseContext.Document(); + ContentPath contentPath = new ContentPath(); + ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any(), Mockito.anyBoolean())) + .thenReturn(true); + MethodFieldMapper methodFieldMapper = Mockito.spy( + new MethodFieldMapper( + TEST_FIELD_NAME, + knnVectorFieldType, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + knnMethodContext, + true + ) + ); + + if (dataType == VectorDataType.BINARY) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper) + .getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType); + } else if (dataType == VectorDataType.BYTE) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType); + } else { + doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper) + .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); + } + + doNothing().when(methodFieldMapper).validateIfCircuitBreakerIsNotTriggered(); + doNothing().when(methodFieldMapper).validateIfKNNPluginEnabled(); + + methodFieldMapper.parseCreateField( + parseContext, + dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, + methodFieldMapper.fieldType().spaceType, + methodFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + dataType + ); + + List fields = document.getFields(); + assertEquals(1, fields.size()); + IndexableField field1 = fields.get(0); + if (dataType == VectorDataType.FLOAT) { + assertTrue(field1 instanceof KnnFloatVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32); + } else { + assertTrue(field1 instanceof KnnByteVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE); + } + + assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION); + assertEquals( + field1.fieldType().vectorSimilarityFunction(), + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any(), Mockito.anyBoolean())) + .thenReturn(false); + + document = new ParseContext.Document(); + contentPath = new ContentPath(); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + methodFieldMapper = Mockito.spy( + new MethodFieldMapper( + TEST_FIELD_NAME, + knnVectorFieldType, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + knnMethodContext, + true + ) + ); + + if (dataType == VectorDataType.FLOAT) { + doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper) + .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); + } else { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper) + .getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType); + } + + methodFieldMapper.parseCreateField( + parseContext, + dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, + methodFieldMapper.fieldType().spaceType, + methodFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + dataType + ); + fields = document.getFields(); + assertEquals(1, fields.size()); + field1 = fields.get(0); + assertTrue(field1 instanceof VectorField); + } + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + utilMockedStatic.close(); + } + @SneakyThrows public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField @@ -774,15 +898,15 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { IndexableField field2 = fields.get(1); VectorField vectorField; - KnnVectorField knnVectorField; + KnnFloatVectorField knnVectorField; if (field1 instanceof VectorField) { - assertTrue(field2 instanceof KnnVectorField); + assertTrue(field2 instanceof KnnFloatVectorField); vectorField = (VectorField) field1; - knnVectorField = (KnnVectorField) field2; + knnVectorField = (KnnFloatVectorField) field2; } else { - assertTrue(field1 instanceof KnnVectorField); + assertTrue(field1 instanceof KnnFloatVectorField); assertTrue(field2 instanceof VectorField); - knnVectorField = (KnnVectorField) field1; + knnVectorField = (KnnFloatVectorField) field1; vectorField = (VectorField) field2; } @@ -816,8 +940,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { fields = document.getFields(); assertEquals(1, fields.size()); IndexableField field = fields.get(0); - assertTrue(field instanceof KnnVectorField); - knnVectorField = (KnnVectorField) field; + assertTrue(field instanceof KnnFloatVectorField); + knnVectorField = (KnnFloatVectorField) field; assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); } @@ -970,7 +1094,7 @@ private void testBuilderWithBinaryDataType( String expectedErrMsg ) { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); @@ -996,7 +1120,7 @@ private void testBuilderWithBinaryDataType( public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); @@ -1022,7 +1146,7 @@ public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1037,7 +1161,7 @@ public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, true); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 31da12d669..37139f3195 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -13,8 +13,13 @@ import org.apache.lucene.document.StoredField; import org.apache.lucene.util.BytesRef; +import org.junit.Assert; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -148,6 +153,26 @@ public void testValidateVectorDataType_whenFloat_thenValid() { validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); } + public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { + final KNNSettings knnSettings = mock(KNNSettings.class); + final MockedStatic mockedStatic = Mockito.mockStatic(KNNSettings.class); + mockedStatic.when(KNNSettings::state).thenReturn(knnSettings); + + mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(false); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_16_0, true)); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0, true)); + + mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(true); + Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0, true)); + Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0, true)); + + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0, false)); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0, false)); + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + mockedStatic.close(); + } + private void validateValidateVectorDataType( final KNNEngine knnEngine, final String methodName, diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java index dcd2557405..b99da2d705 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java @@ -30,7 +30,8 @@ public void testMethodFieldMapper_whenVectorDataTypeIsGiven_thenSetItInFieldType KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, true, true, - KNNMethodContext.getDefault() + KNNMethodContext.getDefault(), + true ); assertEquals(VectorDataType.BINARY, mappers.fieldType().vectorDataType); }