diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 065e0ec371..e3097e25db 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -1062,10 +1062,12 @@ public void loadExtensions(ExtensionLoader loader) { @Override public Map getProcessors(org.opensearch.ingest.Processor.Parameters parameters) { Map processors = new HashMap<>(); + NamedXContentRegistry contentRegistry = new NamedXContentRegistry(getNamedXContent()); + processors .put( MLInferenceIngestProcessor.TYPE, - new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry) + new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, contentRegistry) ); return Collections.unmodifiableMap(processors); } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index 3ff5d957f3..9a514c34a9 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -31,6 +31,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.MLResultDataType; @@ -138,6 +139,60 @@ public void testExecute_Exception() throws Exception { } + /** + * Models that use the parameters field need to have a valid NamedXContentRegistry object to create valid MLInputs. For example + *
+     * PUT   /_plugins/_ml/_predict/text_embedding/model_id
+     *  {
+     *     "parameters": {
+     *         "content_type" : "query"
+     *     },
+     *     "text_docs" : ["what day is it today?"],
+     *     "target_response" : ["sentence_embedding"]
+     *   }
+     * 
+ * These types of models like Local Asymmetric embedding models use the parameters field. + * And as such we need to test that having the contentRegistry throws an exception as it can not + * properly create a valid MLInput to perform prediction + * + * @implNote If you check the stack trace of the test you will see it tells you that it's a direct consequence of xContentRegistry being null + */ + public void testExecute_xContentRegistryNullWithLocalModel_throwsException() throws Exception { + // Set the registry to null and reset after exiting the test + xContentRegistry = null; + + String localModelInput = + "{ \"text_docs\": [\"What day is it today?\"],\"target_response\": [\"sentence_embedding\"], \"parameters\": { \"contentType\" : \"query\"} }"; + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "local_model_id", + null, + null, + null, + false, + FunctionName.TEXT_EMBEDDING.toString(), + false, + false, + false, + localModelInput + ); + try { + String npeMessage = + "Cannot invoke \"org.opensearch.ml.common.input.MLInput.setAlgorithm(org.opensearch.ml.common.FunctionName)\" because \"mlInput\" is null"; + + processor.execute(ingestDocument, handler); + verify(handler) + .accept( + isNull(), + argThat(exception -> exception instanceof NullPointerException && exception.getMessage().equals(npeMessage)) + ); + } catch (Exception e) { + assertEquals("this catch block should not get executed.", e.getMessage()); + } + // reset to mocked object + xContentRegistry = mock(NamedXContentRegistry.class); + } + /** * test nested object document with array of Map */ diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java index 54ab526dee..853bcfb0dd 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_HASH_VALUE; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestHelper.makeRequest; @@ -28,6 +29,7 @@ import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.DocumentContext; import com.jayway.jsonpath.JsonPath; public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase { @@ -434,6 +436,110 @@ public void testMLInferenceProcessorLocalModelObjectField() throws Exception { Assert.assertEquals(0.49191704, (Double) embedding2.get(0), 0.005); } + public void testMLInferenceIngestProcessor_simulatesIngestPipelineSuccessfully_withAsymmetricEmbeddingModelUsingPassageContentType() + throws Exception { + String taskId = registerModel(TestHelper.toJsonString(registerAsymmetricEmbeddingModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String asymmetricPipelineName = "asymmetric_embedding_pipeline"; + String createPipelineRequestBody = "{\n" + + " \"description\": \"ingest PASSAGE text and generate a embedding using an asymmetric model\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + "\n" + + " \"model_input\": \"{\\\"text_docs\\\":[\\\"${input_map.text_docs}\\\"],\\\"target_response\\\":[\\\"sentence_embedding\\\"],\\\"parameters\\\":{\\\"content_type\\\":\\\"passage\\\"}}\",\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"description\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + "\n" + + " " + + " \"fact_embedding\": \"$.inference_results[0].output[0].data\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + createPipelineProcessor(createPipelineRequestBody, asymmetricPipelineName); + String sampleDocuments = "{\n" + + " \"docs\": [\n" + + " {\n" + + " \"_index\": \"my-index\",\n" + + " \"_id\": \"1\",\n" + + " \"_source\": {\n" + + " \"title\": \"Central Park\",\n" + + " \"description\": \"A large public park in the heart of New York City, offering a wide range of recreational activities.\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"_index\": \"my-index\",\n" + + " \"_id\": \"2\",\n" + + " \"_source\": {\n" + + " \"title\": \"Empire State Building\",\n" + + " \"description\": \"An iconic skyscraper in New York City offering breathtaking views from its observation deck.\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}\n"; + + Map simulateResponseDocuments = simulateIngestPipeline(asymmetricPipelineName, sampleDocuments); + + DocumentContext documents = JsonPath.parse(simulateResponseDocuments); + + List centralParkFactEmbedding = documents.read("docs.[0].*._source.fact_embedding.*"); + assertEquals(768, centralParkFactEmbedding.size()); + Assert.assertEquals(0.5137818, (Double) centralParkFactEmbedding.get(0), 0.005); + + List empireStateBuildingFactEmbedding = documents.read("docs.[1].*._source.fact_embedding.*"); + assertEquals(768, empireStateBuildingFactEmbedding.size()); + Assert.assertEquals(0.4493039, (Double) empireStateBuildingFactEmbedding.get(0), 0.005); + } + + private MLRegisterModelInput registerAsymmetricEmbeddingModelInput() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .queryPrefix("query >>") + .passagePrefix("passage >> ") + .build(); + + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue(SENTENCE_TRANSFORMER_MODEL_HASH_VALUE) + .build(); + } + // TODO: add tests for other local model types such as sparse/cross encoders public void testMLInferenceProcessorLocalModelNestedField() throws Exception { @@ -560,6 +666,21 @@ protected void createPipelineProcessor(String requestBody, final String pipeline } + protected Map simulateIngestPipeline(String pipelineName, String sampleDocuments) throws IOException { + Response ingestionResponse = TestHelper + .makeRequest( + client(), + "POST", + "/_ingest/pipeline/" + pipelineName + "/_simulate", + null, + sampleDocuments, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, ingestionResponse.getStatusLine().getStatusCode()); + + return parseResponseToMap(ingestionResponse); + } + protected void createIndex(String indexName, String requestBody) throws Exception { Response response = makeRequest( client(), @@ -602,7 +723,7 @@ protected MLRegisterModelInput registerModelInput() throws IOException, Interrup .modelConfig(modelConfig) .url(SENTENCE_TRANSFORMER_MODEL_URL) .deployModel(false) - .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .hashValue(SENTENCE_TRANSFORMER_MODEL_HASH_VALUE) .build(); } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java index ab7acf38a0..563efa3b79 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java @@ -30,6 +30,7 @@ public class TestData { "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_huggingface.zip?raw=true"; public static final String SENTENCE_TRANSFORMER_MODEL_URL = "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true"; + public static final String SENTENCE_TRANSFORMER_MODEL_HASH_VALUE = "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021"; public static final String TIME_FIELD = "timestamp"; public static final String HUGGINGFACE_TRANSFORMER_MODEL_HASH_VALUE = "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021";