diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 5b8f7a3b01..846599f369 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -34,10 +34,8 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; -import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.InvalidJsonException; import com.jayway.jsonpath.JsonPath; -import com.jayway.jsonpath.Option; import com.jayway.jsonpath.PathNotFoundException; import com.networknt.schema.JsonSchema; import com.networknt.schema.JsonSchemaFactory; @@ -389,53 +387,73 @@ public static boolean pathExists(Object json, String path) { /** * Prepares nested structures in a JSON object based on the given field path. * - * This method ensures that all intermediate nested objects exist in the JSON object + * This method ensures that all intermediate nested objects and arrays exist in the JSON object * for a given field path. If any part of the path doesn't exist, it creates new empty objects - * (HashMaps) for those parts. + * (HashMaps) or arrays (ArrayLists) for those parts. * - * @param jsonObject The JSON object to be updated. - * @param fieldPath The full path of the field, potentially including nested structures. - * @return The updated JSON object with necessary nested structures in place. + * The method can handle complex paths including both object properties and array indices. + * For example, it can process paths like "foo.bar[1].baz[0].qux". * - * @throws IllegalArgumentException If there's an issue with JSON parsing or path manipulation. + * @param jsonObject The JSON object to be updated. If this is not a Map, a new Map will be created. + * @param fieldPath The full path of the field, potentially including nested structures and array indices. + * The path can optionally start with "$." which will be ignored if present. + * @return The updated JSON object with necessary nested structures in place. + * If the input was not a Map, returns the newly created Map structure. * - * @implNote This method uses JsonPath for JSON manipulation and StringUtils for path existence checks. - * It handles paths both with and without a leading "$." notation. - * Each non-existent intermediate object in the path is created as an empty HashMap. + * @throws IllegalArgumentException If the field path is null or not a valid JSON path. * - * @see JsonPath - * @see StringUtils */ public static Object prepareNestedStructures(Object jsonObject, String fieldPath) { - if (fieldPath == null) { - throw new IllegalArgumentException("the field path is null"); + throw new IllegalArgumentException("The field path is null"); + } + if (jsonObject == null) { + throw new IllegalArgumentException("The object is null"); } if (!isValidJSONPath(fieldPath)) { - throw new IllegalArgumentException("the field path is not a valid json path: " + fieldPath); + throw new IllegalArgumentException("The field path is not a valid JSON path: " + fieldPath); } + String path = fieldPath.startsWith("$.") ? fieldPath.substring(2) : fieldPath; - String[] pathParts = path.split("\\."); - Configuration suppressExceptionConfiguration = Configuration - .builder() - .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) - .build(); - StringBuilder currentPath = new StringBuilder("$"); - - for (int i = 0; i < pathParts.length - 1; i++) { - currentPath.append(".").append(pathParts[i]); - if (!StringUtils.pathExists(jsonObject, currentPath.toString())) { - try { - jsonObject = JsonPath - .using(suppressExceptionConfiguration) - .parse(jsonObject) - .set(currentPath.toString(), new java.util.HashMap<>()) - .json(); - } catch (Exception e) { - throw new IllegalArgumentException("Error creating nested structure for path: " + currentPath, e); + String[] pathParts = path.split("(? current = (jsonObject instanceof Map) ? (Map) jsonObject : new HashMap<>(); + + for (String part : pathParts) { + if (part.contains("[")) { + // Handle array notation + String[] arrayParts = part.split("\\["); + String key = arrayParts[0]; + int index = Integer.parseInt(arrayParts[1].replaceAll("\\]", "")); + + if (!current.containsKey(key)) { + current.put(key, new ArrayList<>()); + } + if (!(current.get(key) instanceof List)) { + return jsonObject; + } + List list = (List) current.get(key); + if (index >= list.size()) { + while (list.size() <= index) { + list.add(null); + } + list.set(index, new HashMap<>()); } + if (!(list.get(index) instanceof Map)) { + return jsonObject; + } + current = (Map) list.get(index); + } else { + // Handle object notation + if (!current.containsKey(part)) { + current.put(part, new HashMap<>()); + } else if (!(current.get(part) instanceof Map)) { + return jsonObject; + } + current = (Map) current.get(part); } } + return jsonObject; } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 7f9d0feb96..c56b4b4885 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -149,7 +149,7 @@ public void processRequestAsync( if (request.source() == null) { throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request."); } - + setRequestContextFromExt(request, requestContext); String queryString = request.source().toString(); rewriteQueryString(request, queryString, requestListener, requestContext); diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index e39b7f4b74..67df00d6ca 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -162,6 +162,8 @@ public void processResponseAsync( return; } + setRequestContextFromExt(request, responseContext); + // if many to one, run rewriteResponseDocuments if (!oneToOne) { // use MLInferenceSearchResponseProcessor to allow writing to extension diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 8a25b985f9..bd0eff429b 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder.PARAMETER_NAME; import java.io.IOException; import java.util.ArrayList; @@ -19,6 +20,7 @@ import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -33,6 +35,10 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; @@ -335,4 +341,26 @@ default List writeNewDotPathForNestedObject(Object json, String dotPath) default String convertToDotPath(String path) { return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", ""); } + + default void setRequestContextFromExt(SearchRequest request, PipelineProcessingContext requestContext) { + + List extBuilderList = request.source().ext(); + for (SearchExtBuilder ext : extBuilderList) { + if (ext instanceof MLInferenceRequestParametersExtBuilder) { + MLInferenceRequestParametersExtBuilder mlExtBuilder = (MLInferenceRequestParametersExtBuilder) ext; + Map mlParams = mlExtBuilder.getRequestParameters().getParams(); + mlParams + .forEach( + (key, value) -> requestContext + .setAttribute(String.format("ext.%s.%s", MLInferenceRequestParametersExtBuilder.NAME, key), value) + ); + } + if (ext instanceof GenerativeQAParamExtBuilder) { + GenerativeQAParamExtBuilder qaParamExtBuilder = (GenerativeQAParamExtBuilder) ext; + Map mlParams = (Map) qaParamExtBuilder.getParams(); + mlParams.forEach((key, value) -> requestContext.setAttribute(String.format("ext.%s.%s", PARAMETER_NAME, key), value)); + } + } + + } } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 771e882310..a8fb6a2b59 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -1169,6 +1169,109 @@ public void onFailure(Exception e) { } + /** + * Tests the successful rewriting of a complex nested array in query extension based on the model output. + * verify the pipelineConext is set from the extension + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "ext.ml_inference.question"; + String newQueryField = "ext.ml_inference.llm_response"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + // Test model return a complex nested array + Map nestedResponse = new HashMap<>(); + List> languageList = new ArrayList<>(); + languageList.add(Collections.singletonMap("eng", "0.95")); + languageList.add(Collections.singletonMap("es", "0.67")); + nestedResponse.put("language", languageList); + nestedResponse.put("type", "bert"); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", nestedResponse)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + + Map llmQuestion = new HashMap<>(); + llmQuestion.put("question", "what language is this text in?"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmQuestion); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder)); + + SearchRequest request = new SearchRequest().source(source); + + // Expecting new request with ml inference search extensions including the complex nested array + Map params = new HashMap<>(); + params.put("question", "what language is this text in?"); + params.put("llm_response", nestedResponse); + MLInferenceRequestParameters expectedRequestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder expectedMlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + expectedMlInferenceExtBuilder.setRequestParameters(expectedRequestParameters); + SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(expectedMlInferenceExtBuilder)); + SearchRequest expectRequest = new SearchRequest().source(expectedSource); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + assertEquals(expectRequest.toString(), newSearchRequest.toString()); + + // Additional checks for the complex nested array + MLInferenceRequestParametersExtBuilder actualExtBuilder = (MLInferenceRequestParametersExtBuilder) newSearchRequest + .source() + .ext() + .get(0); + MLInferenceRequestParameters actualParams = actualExtBuilder.getRequestParameters(); + Object actualResponse = actualParams.getParams().get("llm_response"); + + assertTrue(actualResponse instanceof Map); + Map actualNestedResponse = (Map) actualResponse; + + // Check the "language" field + assertTrue(actualNestedResponse.get("language") instanceof List); + List actualLanguageList = (List) actualNestedResponse.get("language"); + assertEquals(2, actualLanguageList.size()); + + Map engMap = (Map) actualLanguageList.get(0); + assertEquals("0.95", engMap.get("eng")); + + Map esMap = (Map) actualLanguageList.get(1); + assertEquals("0.67", esMap.get("es")); + + // Check the "type" field + assertEquals("bert", actualNestedResponse.get("type")); + verify(requestContext).setAttribute("ext.ml_inference.question", "what language is this text in?"); + verify(requestContext).setAttribute("ext.ml_inference.llm_response", nestedResponse); + + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + } + /** * Helper method to create an instance of the MLInferenceSearchRequestProcessor with the specified parameters. * diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 1ad9eee136..20c586be65 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -4003,6 +4003,105 @@ public void onFailure(Exception e) { } + @Test + public void testProcessResponseAsyncSetRequestContextFromExt() throws Exception { + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "ext.ml_inference.summary"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + SearchResponse response = getSearchResponse(5, true, documentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + Map role = new HashMap<>(); + role.put("role", "users"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(role); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder() + .query(incomingQuery) + .size(5) + .sort("text") + .ext(List.of(mlInferenceExtBuilder)); + SearchRequest request = new SearchRequest().source(source); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + Map newParams = new HashMap<>(); + newParams.put("llm_response", "answer"); + newParams.put("summary", "there is 1 value"); + assertEquals(responseAfterProcessor.getParams(), newParams); + verify(responseContext).setAttribute("ext.ml_inference.role", "users"); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + private static SearchRequest getSearchRequest() { QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).size(5).sort("text"); @@ -4443,4 +4542,5 @@ public void testWriteToExtensionAndOneToOne() throws Exception { } } + }