Skip to content

Commit

Permalink
Fix remote model with embedding input issue (#3289)
Browse files Browse the repository at this point in the history
* Fix remote model with embedding input issue

Signed-off-by: b4sjoo <[email protected]>

* Add UT

Signed-off-by: b4sjoo <[email protected]>

* Spotless

Signed-off-by: b4sjoo <[email protected]>

* Add UT for both embedding and remote cases for all remote embedding schema

Signed-off-by: b4sjoo <[email protected]>

* Remove hardcoded test schema

Signed-off-by: b4sjoo <[email protected]>

---------

Signed-off-by: b4sjoo <[email protected]>
  • Loading branch information
b4sjoo committed Jan 2, 2025
1 parent 65cf1eb commit a791a04
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ public class ModelInterfaceUtils {
+ " \"texts\"\n"
+ " ]\n"
+ " }\n"
+ " },\n"
+ " \"required\": [\n"
+ " \"parameters\"\n"
+ " ]\n"
+ " }\n"
+ "}";

private static final String TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n"
Expand All @@ -72,10 +69,7 @@ public class ModelInterfaceUtils {
+ " \"inputText\"\n"
+ " ]\n"
+ " }\n"
+ " },\n"
+ " \"required\": [\n"
+ " \"parameters\"\n"
+ " ]\n"
+ " }\n"
+ "}";

private static final String TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n"
Expand All @@ -92,10 +86,7 @@ public class ModelInterfaceUtils {
+ " }\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"required\": [\n"
+ " \"parameters\"\n"
+ " ]\n"
+ " }\n"
+ "}";

private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT = "{\n"
Expand Down
50 changes: 50 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
package org.opensearch.ml.utils;

import static java.util.Collections.emptyMap;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;

import java.io.IOException;
Expand Down Expand Up @@ -66,6 +69,53 @@ public void testValidateSchema() throws IOException {
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testValidateEmbeddingInputWithGeneralEmbeddingRemoteSchema() throws IOException {
String schema = BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE.get("input");
String json = "{\"text_docs\":[ \"today is sunny\", \"today is sunny\"]}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testValidateRemoteInputWithGeneralEmbeddingRemoteSchema() throws IOException {
String schema = BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE.get("input");
String json = "{\"parameters\": {\"texts\": [\"Hello\",\"world\"]}}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testValidateEmbeddingInputWithTitanTextRemoteSchema() throws IOException {
String schema = BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE.get("input");
String json = "{\"text_docs\":[ \"today is sunny\", \"today is sunny\"]}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testValidateRemoteInputWithTitanTextRemoteSchema() throws IOException {
String schema = BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE.get("input");
String json = "{\"parameters\": {\"inputText\": \"Say this is a test\"}}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testValidateEmbeddingInputWithTitanMultiModalRemoteSchema() throws IOException {
String schema = BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE.get("input");
String json = "{\"text_docs\":[ \"today is sunny\", \"today is sunny\"]}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testValidateRemoteInputWithTitanMultiModalRemoteSchema() throws IOException {
String schema = BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE.get("input");
String json = "{\n"
+ " \"parameters\": {\n"
+ " \"inputText\": \"Say this is a test\",\n"
+ " \"inputImage\": \"/9jk=\"\n"
+ " }\n"
+ "}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueNoParameters() throws IOException {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true}";
Expand Down

0 comments on commit a791a04

Please sign in to comment.