Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport: Support empty string for fields in text embedding processor (#1041) #1046

Open
wants to merge 1 commit into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey)));
}
}
return mapWithProcessorKeys;
Expand All @@ -333,19 +333,34 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
List<Object> listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
List<Object> listOfStrings = list.stream().map(x -> {
Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey());
return normalizeSourceValue(nestedSourceValue);
}).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
} else {
Object parentValue = sourceAndMetadataMap.get(parentKey);
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
treeRes.put(key, normalizeSourceValue(parentValue));
}
}

private boolean isBlankString(Object object) {
return object instanceof String && StringUtils.isBlank((String) object);
}

private Object normalizeSourceValue(Object value) {
if (isBlankString(value)) {
return null;
}
return value;
}

/**
* Process the nested key, such as "a.b.c" to "a", "b.c"
* @param nestedFieldMapEntry
Expand Down Expand Up @@ -376,7 +391,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
indexName,
clusterService,
environment,
false
true
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,6 @@ public void test_batchExecute_emptyInput() {
verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecuteWithEmpty_allFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
Consumer resultHandler = mock(Consumer.class);
processor.batchExecute(wrapperList, resultHandler);
ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class);
verify(resultHandler).accept(captor.capture());
assertEquals(docCount, captor.getValue().size());
for (int i = 0; i < docCount; ++i) {
assertNotNull(captor.getValue().get(i).getException());
assertEquals(
"list type field [key1] has empty string, cannot process it",
captor.getValue().get(i).getException().getMessage()
);
assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument());
}
verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecuteWithNull_allFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
Expand All @@ -107,27 +85,6 @@ public void test_batchExecuteWithNull_allFailedValidation() {
verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecute_partialFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Consumer resultHandler = mock(Consumer.class);
processor.batchExecute(wrapperList, resultHandler);
ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class);
verify(resultHandler).accept(captor.capture());
assertEquals(docCount, captor.getValue().size());
assertNotNull(captor.getValue().get(0).getException());
assertNull(captor.getValue().get(1).getException());
for (int i = 0; i < docCount; ++i) {
assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument());
}
ArgumentCaptor<List<String>> inferenceTextCaptor = ArgumentCaptor.forClass(List.class);
verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any());
assertEquals(2, inferenceTextCaptor.getValue().size());
}

public void test_batchExecute_happyCase() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -240,31 +241,6 @@ public void testExecute_withListTypeInput_successful() {
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", " ");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig();

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() {
List<String> list1 = ImmutableList.of("", "test2", "test3");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", list1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig();

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNonStringValue_throwIllegalArgumentException() {
List<Integer> list2 = ImmutableList.of(1, 2, 3);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down Expand Up @@ -549,20 +525,6 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() {
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, String> map2 = ImmutableMap.of("test3", " ");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", map1);
sourceAndMetadata.put("key2", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig();
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() {
Map<String, Object> ret = createMaxDepthLimitExceedMap(() -> 1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down Expand Up @@ -785,6 +747,79 @@ public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_s
assertNotNull(nestedObj.get(1).get("vectorField"));
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withPlainString_EmptyString_skipped() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
Map<String, Object> sourceAndMetadata = ingestDocument.getSourceAndMetadata();
sourceAndMetadata.put("oriKey1", StringUtils.EMPTY);

TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = processor.buildMapWithTargetKeys(ingestDocument);
List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(6, 100, 0.0f, 1.0f);
processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList);

/** IngestDocument
* "oriKey1": "",
* "oriKey2": "oriValue2",
* "oriKey3": "oriValue3",
* "oriKey4": "oriValue4",
* "oriKey5": "oriValue5",
* "oriKey6": [
* "oriValue6",
* "oriValue7"
* ]
*
*/
assertEquals(11, sourceAndMetadata.size());
assertFalse(sourceAndMetadata.containsKey("oriKey1_knn"));
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedField_EmptyString_skipped() {
Map<String, Object> config = createNestedMapConfiguration();
IngestDocument ingestDocument = createNestedMapIngestDocument();
Map<String, Object> favorites = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get("favorites");
Map<String, Object> favorite = (Map<String, Object>) favorites.get("favorite");
favorite.put("movie", StringUtils.EMPTY);

TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = processor.buildMapWithTargetKeys(ingestDocument);
List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f);
processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata());

/**
* "favorites": {
* "favorite": {
* "movie": "",
* "actor": "Charlie Chaplin",
* "games" : {
* "adventure": {
* "action": "overwatch",
* "rpg": "elden ring"
* }
* }
* }
* }
*/
Map<String, Object> favoritesMap = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get("favorites");
assertNotNull(favoritesMap);
Map<String, Object> favoriteMap = (Map<String, Object>) favoritesMap.get("favorite");
assertNotNull(favoriteMap);

Map<String, Object> favoriteGames = (Map<String, Object>) favoriteMap.get("games");
assertNotNull(favoriteGames);
Map<String, Object> adventure = (Map<String, Object>) favoriteGames.get("adventure");
List<Float> adventureKnnVector = (List<Float>) adventure.get("with_action_knn");
assertNotNull(adventureKnnVector);
assertEquals(100, adventureKnnVector.size());
for (float vector : adventureKnnVector) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}

assertFalse(favoriteMap.containsKey("favorite_movie_knn"));
}

public void test_updateDocument_appendVectorFieldsToDocument_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
Expand Down
Loading