diff --git a/build.gradle b/build.gradle index 1026ce243f..29d2422bed 100644 --- a/build.gradle +++ b/build.gradle @@ -131,7 +131,7 @@ subprojects { ext { // NB: due to version.json generation by parsing this file, the next line must not have any if/then/else logic - neo4jVersion = "5.19.0" + neo4jVersion = "5.18.0" // instead we apply the override logic here neo4jVersionEffective = project.hasProperty("neo4jVersionOverride") ? project.getProperty("neo4jVersionOverride") : neo4jVersion testContainersVersion = '1.18.3' diff --git a/extended/src/main/java/apoc/ml/MLUtil.java b/extended/src/main/java/apoc/ml/MLUtil.java new file mode 100644 index 0000000000..d996e07216 --- /dev/null +++ b/extended/src/main/java/apoc/ml/MLUtil.java @@ -0,0 +1,5 @@ +package apoc.ml; + +public class MLUtil { + public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input"; +} diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index a28aae5e17..c08c3ef8e0 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -16,10 +16,13 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; import java.util.stream.Stream; import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE; import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY; +import static apoc.ml.MLUtil.ERROR_NULL_INPUT; @Extended @@ -111,13 +114,30 @@ public Stream getEmbedding(@Name("texts") List texts, @ "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ - Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data", apocConfig, urlAccessChecker); - return resultStream + if (texts == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } + + Map> collect = texts.stream() + .collect(Collectors.groupingBy(Objects::nonNull)); + + List nonNullTexts = collect.get(true); + + Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", nonNullTexts, "$.data", apocConfig, urlAccessChecker); + Stream embeddingResultStream = resultStream .flatMap(v -> ((List>) v).stream()) .map(m -> { Long index = (Long) m.get("index"); - return new EmbeddingResult(index, texts.get(index.intValue()), (List) m.get("embedding")); + return new EmbeddingResult(index, nonNullTexts.get(index.intValue()), (List) m.get("embedding")); }); + + List nullTexts = collect.getOrDefault(false, List.of()); + Stream nullResultStream = nullTexts.stream() + .map(i -> { + // null text return index -1 to indicate that are not coming from `/embeddings` RestAPI + return new EmbeddingResult(-1, i, List.of()); + }); + return Stream.concat(embeddingResultStream, nullResultStream); } @@ -132,6 +152,9 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } } */ + if (prompt == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker) .map(v -> (Map)v).map(MapResult::new); } @@ -139,6 +162,9 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke @Procedure("apoc.ml.openai.chat") @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API") public Stream chatCompletion(@Name("messages") List> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + if (messages == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$", apocConfig, urlAccessChecker) .map(v -> (Map)v).map(MapResult::new); // https://platform.openai.com/docs/api-reference/chat/create diff --git a/extended/src/main/java/apoc/ml/VertexAI.java b/extended/src/main/java/apoc/ml/VertexAI.java index c2c50fbbe4..86b6a45a56 100644 --- a/extended/src/main/java/apoc/ml/VertexAI.java +++ b/extended/src/main/java/apoc/ml/VertexAI.java @@ -20,9 +20,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import java.util.stream.Stream; +import static apoc.ml.MLUtil.ERROR_NULL_INPUT; + @Extended public class VertexAI { @@ -94,16 +98,34 @@ public Stream getEmbedding(@Name("texts") List texts, @ ] } */ + + if (texts == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } + + Map> collect = texts.stream() + .collect(Collectors.groupingBy(Objects::nonNull)); + + List nonNullTexts = collect.get(true); + Object inputs = texts.stream().map(text -> Map.of("content", text)).toList(); Stream resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, List.of(), urlAccessChecker); AtomicInteger ai = new AtomicInteger(); - return resultStream + Stream embeddingResultStream = resultStream .flatMap(v -> ((List>) v).stream()) .map(m -> { - Map embeddings = (Map) ((Map)m).get("embeddings"); + Map embeddings = (Map) ((Map) m).get("embeddings"); int index = ai.getAndIncrement(); - return new EmbeddingResult(index, texts.get(index), (List) embeddings.get("values")); + return new EmbeddingResult(index, nonNullTexts.get(index), (List) embeddings.get("values")); + }); + + List nullTexts = collect.getOrDefault(false, List.of()); + Stream nullResultStream = nullTexts.stream() + .map(text -> { + // null text return index -1 to indicate that are not coming from `/embeddings` RestAPI + return new EmbeddingResult(-1, text, List.of()); }); + return Stream.concat(embeddingResultStream, nullResultStream); } @@ -153,6 +175,10 @@ public Stream completion(@Name("prompt") String prompt, @Name("access ] } */ + if (prompt == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } + Object input = List.of(Map.of("prompt",prompt)); var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens"); var resultStream = executeRequest(accessToken, project, configuration, "text-bison", input, parameterKeys, urlAccessChecker); @@ -192,6 +218,9 @@ public Stream chatCompletion(@Name("messages") List>> examples ) throws Exception { + if (messages == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } Object inputs = List.of(Map.of("context",context, "examples",examples, "messages", messages)); var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens"); return executeRequest(accessToken, project, configuration, "chat-bison", inputs, parameterKeys, urlAccessChecker) diff --git a/extended/src/main/java/apoc/ml/Watson.java b/extended/src/main/java/apoc/ml/Watson.java index 2b6a369623..75e55e57ff 100644 --- a/extended/src/main/java/apoc/ml/Watson.java +++ b/extended/src/main/java/apoc/ml/Watson.java @@ -39,6 +39,9 @@ public class Watson { @Procedure("apoc.ml.watson.chat") @Description("apoc.ml.watson.chat(messages, accessToken, $configuration) - prompts the completion API") public Stream chatCompletion(@Name("messages") List> messages, @Name("accessToken") String accessToken, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + if (messages == null) { + return Stream.of(new MapResult(null)); + } String prompt = messages.stream() .map(message -> { Object role = message.get("role"); @@ -56,6 +59,9 @@ public Stream chatCompletion(@Name("messages") List completion(@Name("prompt") String prompt, @Name("accessToken") String accessToken, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + if (prompt == null) { + return Stream.of(new MapResult(null)); + } return executeRequest(prompt, accessToken, configuration); } diff --git a/extended/src/main/java/apoc/ml/aws/Bedrock.java b/extended/src/main/java/apoc/ml/aws/Bedrock.java index 463c928f07..18fc96eaab 100644 --- a/extended/src/main/java/apoc/ml/aws/Bedrock.java +++ b/extended/src/main/java/apoc/ml/aws/Bedrock.java @@ -17,6 +17,7 @@ import org.neo4j.procedure.Procedure; import org.apache.commons.lang3.StringUtils; +import static apoc.ml.MLUtil.ERROR_NULL_INPUT; import static apoc.ml.aws.AWSConfig.JSON_PATH; import static apoc.ml.aws.BedrockInvokeConfig.MODEL; import static apoc.util.JsonUtil.OBJECT_MAPPER; @@ -60,7 +61,9 @@ public Stream custom(@Name(value = "body") Map body, public Stream chatCompletion( @Name("messages") List> messages, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - + if (messages == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } var config = new HashMap<>(configuration); config.putIfAbsent(MODEL, ANTHROPIC_CLAUDE_V2); @@ -94,7 +97,9 @@ private void transformOpenAiToBedrockRequestBody(Map message) { @Description("apoc.ml.bedrock.completion(prompt, $conf) - prompts the completion API") public Stream completion(@Name("prompt") String prompt, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - + if (prompt == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } var config = new HashMap<>(configuration); config.putIfAbsent(MODEL, JURASSIC_2_ULTRA); @@ -109,6 +114,9 @@ public Stream completion(@Name("prompt") String prompt, @Description("apoc.ml.bedrock.embedding([texts], $configuration) - returns the embeddings for a given text") public Stream embedding(@Name(value = "texts") List texts, @Name(value = "configuration", defaultValue = "{}") Map configuration) { + if (texts == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } var config = new HashMap<>(configuration); config.putIfAbsent(MODEL, TITAN_EMBED_TEXT); @@ -127,6 +135,9 @@ public Stream embedding(@Name(value = "texts") List texts, @Procedure("apoc.ml.bedrock.image") public Stream image(@Name(value = "body") Map body, @Name(value = "configuration", defaultValue = "{}") Map configuration) { + if (body == null) { + throw new RuntimeException(ERROR_NULL_INPUT); + } configuration.putIfAbsent(MODEL, STABILITY_STABLE_DIFFUSION_XL); configuration.putIfAbsent(JSON_PATH, "$.artifacts[0]"); diff --git a/extended/src/test/java/apoc/ml/MLTestUtil.java b/extended/src/test/java/apoc/ml/MLTestUtil.java new file mode 100644 index 0000000000..27e6946b36 --- /dev/null +++ b/extended/src/test/java/apoc/ml/MLTestUtil.java @@ -0,0 +1,25 @@ +package apoc.ml; + +import org.neo4j.graphdb.GraphDatabaseService; + +import java.util.Map; + +import static apoc.ml.MLUtil.ERROR_NULL_INPUT; +import static apoc.util.TestUtil.testCall; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class MLTestUtil { + public static void assertNullInputFails(GraphDatabaseService db, String query, Map params) { + try { + testCall(db, query, params, + (row) -> fail("Should fail due to null input") + ); + } catch (RuntimeException e) { + String message = e.getMessage(); + assertTrue("Current error message is: " + message, + message.contains(ERROR_NULL_INPUT) + ); + } + } +} diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index c4092afff5..a45eecf7ce 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -1,6 +1,7 @@ package apoc.ml; import apoc.util.TestUtil; +import apoc.util.collection.Iterators; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -8,12 +9,17 @@ import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; +import java.util.HashSet; import java.util.Map; +import java.util.Set; +import static apoc.ml.MLTestUtil.assertNullInputFails; import static apoc.ml.OpenAI.MODEL_CONF_KEY; import static apoc.ml.OpenAITestResultUtils.*; import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; import static java.util.Collections.emptyMap; +import static org.junit.Assert.assertEquals; public class OpenAIIT { @@ -68,6 +74,19 @@ public void getEmbedding3LargeWithDimensionsRequestParameter() { r -> assertEmbeddings(r, 256)); } + @Test + public void getEmbeddingNull() { + testResult(db, "CALL apoc.ml.openai.embedding([null, 'Some Text', null, 'Other Text'], $apiKey, $conf)", Map.of("apiKey",openaiKey, "conf", emptyMap()), + r -> { + Set actual = Iterators.asSet(r.columnAs("text")); + + Set expected = new HashSet<>() {{ + add(null); add(null); add("Some Text"); add("Other Text"); + }}; + assertEquals(expected, actual); + }); + } + @Test public void completion() { testCall(db, COMPLETION_QUERY, @@ -104,4 +123,25 @@ public void chatCompletion() { } */ } + + @Test + public void embeddingsNull() { + assertNullInputFails(db, "CALL apoc.ml.openai.embedding(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void completionNull() { + assertNullInputFails(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void chatCompletionNull() { + assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } } \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/VertexAIIT.java b/extended/src/test/java/apoc/ml/VertexAIIT.java index 103ccddd05..04d102369a 100644 --- a/extended/src/test/java/apoc/ml/VertexAIIT.java +++ b/extended/src/test/java/apoc/ml/VertexAIIT.java @@ -1,6 +1,7 @@ package apoc.ml; import apoc.util.TestUtil; +import apoc.util.collection.Iterators; import org.apache.commons.io.FileUtils; import org.junit.Assume; import org.junit.Before; @@ -13,15 +14,21 @@ import java.io.IOException; import java.util.Base64; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import static apoc.ml.MLTestUtil.assertNullInputFails; import static apoc.ml.VertexAIHandler.ENDPOINT_CONF_KEY; import static apoc.ml.VertexAIHandler.MODEL_CONF_KEY; import static apoc.ml.VertexAIHandler.PREDICT_RESOURCE; import static apoc.ml.VertexAIHandler.RESOURCE_CONF_KEY; import static apoc.ml.VertexAIHandler.STREAM_RESOURCE; import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static java.util.Collections.emptyMap; +import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -66,6 +73,20 @@ public void getEmbedding() { }); } + @Test + public void getEmbeddingNull() { + testResult(db, "CALL apoc.ml.vertexai.embedding([null, 'Some Text', null, 'Other Text'], $apiKey, $project)", + parameters, + r -> { + Set actual = Iterators.asSet(r.columnAs("text")); + + Set expected = new HashSet<>() {{ + add(null); add(null); add("Some Text"); add("Other Text"); + }}; + assertEquals(expected, actual); + }); + } + @Test public void completion() { testCall(db, "CALL apoc.ml.vertexai.completion('What color is the sky? Answer in one word: ', $apiKey, $project)", parameters,(row) -> { @@ -209,4 +230,25 @@ private void assertCorrectResponse(Map row, String expected) { assertTrue(stringRow.toLowerCase().contains(expected), "Actual result is: " + stringRow); } + + @Test + public void embeddingsNull() { + assertNullInputFails(db, "CALL apoc.ml.vertexai.embedding(null, $apiKey, $project)", + parameters + ); + } + + @Test + public void completionNull() { + assertNullInputFails(db, "CALL apoc.ml.vertexai.completion(null, $apiKey, $project)", + parameters + ); + } + + @Test + public void chatCompletionNull() { + assertNullInputFails(db, "CALL apoc.ml.vertexai.chat(null, $apiKey, $project)", + parameters + ); + } } \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/WatsonIT.java b/extended/src/test/java/apoc/ml/WatsonIT.java index 190611ff57..729091f1fd 100644 --- a/extended/src/test/java/apoc/ml/WatsonIT.java +++ b/extended/src/test/java/apoc/ml/WatsonIT.java @@ -13,7 +13,9 @@ import static apoc.ExtendedApocConfig.APOC_ML_WATSON_URL; import static apoc.ExtendedApocConfig.APOC_ML_WATSON_PROJECT_ID; +import static apoc.ml.MLTestUtil.assertNullInputFails; import static apoc.util.TestUtil.testCall; +import static java.util.Collections.emptyMap; import static org.junit.Assume.assumeNotNull; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -130,4 +132,18 @@ private static void commonAssertions(Map row, String text, long assertEquals(inputTokenCount, result.get("input_token_count")); assertEquals(stopReason, result.get("stop_reason")); } + + @Test + public void completionNull() { + assertNullInputFails(db, "CALL apoc.ml.watson.completion(null, $apiKey, $conf)", + Map.of("apiKey", accessToken, "conf", emptyMap()) + ); + } + + @Test + public void chatCompletionNull() { + assertNullInputFails(db, "CALL apoc.ml.watson.chat(null, $apiKey, $conf)", + Map.of("apiKey", accessToken, "conf", emptyMap()) + ); + } } diff --git a/extended/src/test/java/apoc/ml/aws/BedrockIT.java b/extended/src/test/java/apoc/ml/aws/BedrockIT.java index 31a8603e96..23906e532b 100644 --- a/extended/src/test/java/apoc/ml/aws/BedrockIT.java +++ b/extended/src/test/java/apoc/ml/aws/BedrockIT.java @@ -17,6 +17,7 @@ import static apoc.ApocConfig.apocConfig; import static apoc.ExtendedApocConfig.APOC_AWS_KEY_ID; import static apoc.ExtendedApocConfig.APOC_AWS_SECRET_KEY; +import static apoc.ml.MLTestUtil.assertNullInputFails; import static apoc.ml.aws.AWSConfig.KEY_ID; import static apoc.ml.aws.AWSConfig.METHOD_KEY; import static apoc.ml.aws.AWSConfig.SECRET_KEY; @@ -27,8 +28,10 @@ import static apoc.ml.aws.BedrockUtil.*; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; +import static java.util.Collections.emptyMap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeNotNull; @@ -291,4 +294,32 @@ private static void assertionsTitanEmbed(Map value) { assertNotNull(value.get("inputTextTokenCount")); assertNotNull(value.get("embedding")); } + + @Test + public void embeddingNull() { + assertNullInputFails(db, "CALL apoc.ml.bedrock.embedding(null)", + emptyMap() + ); + } + + @Test + public void completionNull() { + assertNullInputFails(db, "CALL apoc.ml.bedrock.completion(null)", + emptyMap() + ); + } + + @Test + public void chatCompletionNull() { + assertNullInputFails(db, "CALL apoc.ml.bedrock.chat(null)", + emptyMap() + ); + } + + @Test + public void imageNull() { + assertNullInputFails(db, "CALL apoc.ml.bedrock.image(null)", + emptyMap() + ); + } } diff --git a/extended/src/test/java/apoc/ml/sagemaker/SageMakerIT.java b/extended/src/test/java/apoc/ml/sagemaker/SageMakerIT.java index 3c97d663fa..4c2e3ea49a 100644 --- a/extended/src/test/java/apoc/ml/sagemaker/SageMakerIT.java +++ b/extended/src/test/java/apoc/ml/sagemaker/SageMakerIT.java @@ -22,8 +22,11 @@ import static apoc.ml.aws.AWSConfig.HEADERS_KEY; import static apoc.ml.aws.AWSConfig.REGION_KEY; import static apoc.ml.aws.SageMakerConfig.ENDPOINT_NAME_KEY; +import static apoc.util.TestUtil.testCall; +import static java.util.Collections.emptyMap; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeNotNull; @@ -160,4 +163,21 @@ private void assertEventually(Callable booleanCallable) { Assert.assertEventually(booleanCallable, val -> val, 30, TimeUnit.SECONDS); } + @Test + public void completionNull() { + testCall(db, "CALL apoc.ml.sagemaker.completion(null, $conf)", + Map.of("conf", emptyMap()), + (row) -> assertNull(row.get("value")) + ); + } + + @Test + public void chatCompletionNull() { + testCall(db, + "CALL apoc.ml.sagemaker.chat(null, $conf)", + Map.of("conf", emptyMap()), + (row) -> assertNull(row.get("value")) + ); + } + }