Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes 3611: ML API procedures handle null values being passed in better #439

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ subprojects {

ext {
// NB: due to version.json generation by parsing this file, the next line must not have any if/then/else logic
neo4jVersion = "5.19.0"
neo4jVersion = "5.18.0"
// instead we apply the override logic here
neo4jVersionEffective = project.hasProperty("neo4jVersionOverride") ? project.getProperty("neo4jVersionOverride") : neo4jVersion
testContainersVersion = '1.18.3'
Expand Down
27 changes: 24 additions & 3 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE;
Expand Down Expand Up @@ -111,13 +113,26 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
"model": "text-embedding-ada-002",
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
*/
Stream<Object> resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data", apocConfig, urlAccessChecker);
return resultStream
Map<Boolean, List<String>> collect = texts.stream()
.collect(Collectors.groupingBy(Objects::nonNull));

List<String> nonNullTexts = collect.get(true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qui non servirebbe vedere se la collezione è nulla di modo da poter lanciare un eccezione prima di fare la richiesta?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sì in effetti è meglio, ho aggiunto il check con l'eccezione sopra.
Questa parte l'ho lasciata cosi perché è richiesto dalla issue:

But we should probably just filter those out in the embeddings call and don't forward them and return the result row with a null value for embedding.


Stream<Object> resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", nonNullTexts, "$.data", apocConfig, urlAccessChecker);
Stream<EmbeddingResult> embeddingResultStream = resultStream
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(m -> {
Long index = (Long) m.get("index");
return new EmbeddingResult(index, texts.get(index.intValue()), (List<Double>) m.get("embedding"));
return new EmbeddingResult(index, nonNullTexts.get(index.intValue()), (List<Double>) m.get("embedding"));
});

List<String> nullTexts = collect.getOrDefault(false, List.of());
Stream<EmbeddingResult> nullResultStream = nullTexts.stream()
.map(i -> {
// null text return index -1 to indicate that are not coming from `/embeddings` RestAPI
return new EmbeddingResult(-1, i, List.of());
});
return Stream.concat(embeddingResultStream, nullResultStream);
}


Expand All @@ -132,13 +147,19 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("api_ke
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }
}
*/
if (prompt == null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qui ugualmente non sarebbe meglio fallire prima di fare la richiesta?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cambiato

return Stream.of(new MapResult(null));
}
return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
}

@Procedure("apoc.ml.openai.chat")
@Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API")
public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (messages == null) {
return Stream.of(new MapResult(null));
}
return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
// https://platform.openai.com/docs/api-reference/chat/create
Expand Down
29 changes: 26 additions & 3 deletions extended/src/main/java/apoc/ml/VertexAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;


Expand Down Expand Up @@ -94,16 +96,30 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
]
}
*/

Map<Boolean, List<String>> collect = texts.stream()
.collect(Collectors.groupingBy(Objects::nonNull));

List<String> nonNullTexts = collect.get(true);
vga91 marked this conversation as resolved.
Show resolved Hide resolved

Object inputs = texts.stream().map(text -> Map.of("content", text)).toList();
Stream<Object> resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, List.of(), urlAccessChecker);
AtomicInteger ai = new AtomicInteger();
return resultStream
Stream<EmbeddingResult> embeddingResultStream = resultStream
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(m -> {
Map<String,Object> embeddings = (Map<String, Object>) ((Map)m).get("embeddings");
Map<String, Object> embeddings = (Map<String, Object>) ((Map) m).get("embeddings");
int index = ai.getAndIncrement();
return new EmbeddingResult(index, texts.get(index), (List<Double>) embeddings.get("values"));
return new EmbeddingResult(index, nonNullTexts.get(index), (List<Double>) embeddings.get("values"));
});

List<String> nullTexts = collect.getOrDefault(false, List.of());
Stream<EmbeddingResult> nullResultStream = nullTexts.stream()
.map(text -> {
// null text return index -1 to indicate that are not coming from `/embeddings` RestAPI
return new EmbeddingResult(-1, text, List.of());
});
return Stream.concat(embeddingResultStream, nullResultStream);
}


Expand Down Expand Up @@ -153,6 +169,10 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("access
]
}
*/
if (prompt == null) {
return Stream.of(new MapResult(null));
vga91 marked this conversation as resolved.
Show resolved Hide resolved
}

Object input = List.of(Map.of("prompt",prompt));
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
var resultStream = executeRequest(accessToken, project, configuration, "text-bison", input, parameterKeys, urlAccessChecker);
Expand Down Expand Up @@ -192,6 +212,9 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Strin
@Name(value = "context",defaultValue = "") String context,
@Name(value = "examples", defaultValue = "[]") List<Map<String, Map<String,String>>> examples
) throws Exception {
if (messages == null) {
return Stream.of(new MapResult(null));
}
Object inputs = List.of(Map.of("context",context, "examples",examples, "messages", messages));
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
return executeRequest(accessToken, project, configuration, "chat-bison", inputs, parameterKeys, urlAccessChecker)
Expand Down
6 changes: 6 additions & 0 deletions extended/src/main/java/apoc/ml/Watson.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class Watson {
@Procedure("apoc.ml.watson.chat")
@Description("apoc.ml.watson.chat(messages, accessToken, $configuration) - prompts the completion API")
public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> messages, @Name("accessToken") String accessToken, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (messages == null) {
return Stream.of(new MapResult(null));
}
String prompt = messages.stream()
.map(message -> {
Object role = message.get("role");
Expand All @@ -56,6 +59,9 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
@Procedure("apoc.ml.watson.completion")
@Description("apoc.ml.watson.completion(prompt, accessToken, $configuration) - prompts the completion API")
public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("accessToken") String accessToken, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (prompt == null) {
return Stream.of(new MapResult(null));
}
return executeRequest(prompt, accessToken, configuration);
}

Expand Down
9 changes: 7 additions & 2 deletions extended/src/main/java/apoc/ml/aws/Bedrock.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ public Stream<MapResult> custom(@Name(value = "body") Map<String, Object> body,
public Stream<MapResult> chatCompletion(
@Name("messages") List<Map<String, Object>> messages,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) {

if (messages == null) {
return Stream.of(new MapResult(null));
}
var config = new HashMap<>(configuration);
config.putIfAbsent(MODEL, ANTHROPIC_CLAUDE_V2);

Expand Down Expand Up @@ -94,7 +96,9 @@ private void transformOpenAiToBedrockRequestBody(Map<String, Object> message) {
@Description("apoc.ml.bedrock.completion(prompt, $conf) - prompts the completion API")
public Stream<MapResult> completion(@Name("prompt") String prompt,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) {

if (prompt == null) {
return Stream.of(new MapResult(null));
}
var config = new HashMap<>(configuration);
config.putIfAbsent(MODEL, JURASSIC_2_ULTRA);

Expand All @@ -109,6 +113,7 @@ public Stream<MapResult> completion(@Name("prompt") String prompt,
@Description("apoc.ml.bedrock.embedding([texts], $configuration) - returns the embeddings for a given text")
public Stream<Embedding> embedding(@Name(value = "texts") List<String> texts,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) {

var config = new HashMap<>(configuration);
config.putIfAbsent(MODEL, TITAN_EMBED_TEXT);

Expand Down
36 changes: 36 additions & 0 deletions extended/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
package apoc.ml;

import apoc.util.TestUtil;
import apoc.util.collection.Iterators;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.neo4j.test.rule.DbmsRule;
import org.neo4j.test.rule.ImpermanentDbmsRule;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import static apoc.ml.OpenAI.MODEL_CONF_KEY;
import static apoc.ml.OpenAITestResultUtils.*;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
import static java.util.Collections.emptyMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

public class OpenAIIT {

Expand Down Expand Up @@ -68,6 +74,19 @@ public void getEmbedding3LargeWithDimensionsRequestParameter() {
r -> assertEmbeddings(r, 256));
}

@Test
public void getEmbeddingNull() {
testResult(db, "CALL apoc.ml.openai.embedding([null, 'Some Text', null, 'Other Text'], $apiKey, $conf)", Map.of("apiKey",openaiKey, "conf", emptyMap()),
r -> {
Set<String> actual = Iterators.asSet(r.columnAs("text"));

Set<String> expected = new HashSet<>() {{
add(null); add(null); add("Some Text"); add("Other Text");
}};
assertEquals(expected, actual);
});
}

@Test
public void completion() {
testCall(db, COMPLETION_QUERY,
Expand Down Expand Up @@ -104,4 +123,21 @@ public void chatCompletion() {
}
*/
}

@Test
public void completionNull() {
testCall(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap()),
(row) -> assertNull(row.get("value"))
);
}

@Test
public void chatCompletionNull() {
testCall(db,
"CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap()),
(row) -> assertNull(row.get("value"))
);
}
}
37 changes: 37 additions & 0 deletions extended/src/test/java/apoc/ml/VertexAIIT.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apoc.ml;

import apoc.util.TestUtil;
import apoc.util.collection.Iterators;
import org.apache.commons.io.FileUtils;
import org.junit.Assume;
import org.junit.Before;
Expand All @@ -13,15 +14,20 @@
import java.io.IOException;
import java.util.Base64;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static apoc.ml.VertexAIHandler.ENDPOINT_CONF_KEY;
import static apoc.ml.VertexAIHandler.MODEL_CONF_KEY;
import static apoc.ml.VertexAIHandler.PREDICT_RESOURCE;
import static apoc.ml.VertexAIHandler.RESOURCE_CONF_KEY;
import static apoc.ml.VertexAIHandler.STREAM_RESOURCE;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
import static java.util.Collections.emptyMap;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
Expand Down Expand Up @@ -66,6 +72,20 @@ public void getEmbedding() {
});
}

@Test
public void getEmbeddingNull() {
testResult(db, "CALL apoc.ml.vertexai.embedding([null, 'Some Text', null, 'Other Text'], $apiKey, $project)",
parameters,
r -> {
Set<String> actual = Iterators.asSet(r.columnAs("text"));

Set<String> expected = new HashSet<>() {{
add(null); add(null); add("Some Text"); add("Other Text");
}};
assertEquals(expected, actual);
});
}

@Test
public void completion() {
testCall(db, "CALL apoc.ml.vertexai.completion('What color is the sky? Answer in one word: ', $apiKey, $project)", parameters,(row) -> {
Expand Down Expand Up @@ -209,4 +229,21 @@ private void assertCorrectResponse(Map<String, Object> row, String expected) {
assertTrue(stringRow.toLowerCase().contains(expected),
"Actual result is: " + stringRow);
}

@Test
public void completionNull() {
testCall(db, "CALL apoc.ml.vertexai.completion(null, $apiKey, $project)",
parameters,
(row) -> assertNull(row.get("value"))
);
}

@Test
public void chatCompletionNull() {
testCall(db,
"CALL apoc.ml.vertexai.chat(null, $apiKey, $project)",
parameters,
(row) -> assertNull(row.get("value"))
);
}
}
19 changes: 19 additions & 0 deletions extended/src/test/java/apoc/ml/WatsonIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import static apoc.ExtendedApocConfig.APOC_ML_WATSON_URL;
import static apoc.ExtendedApocConfig.APOC_ML_WATSON_PROJECT_ID;
import static apoc.util.TestUtil.testCall;
import static java.util.Collections.emptyMap;
import static org.junit.Assert.assertNull;
import static org.junit.Assume.assumeNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand Down Expand Up @@ -130,4 +132,21 @@ private static void commonAssertions(Map<String, Object> row, String text, long
assertEquals(inputTokenCount, result.get("input_token_count"));
assertEquals(stopReason, result.get("stop_reason"));
}

@Test
public void completionNull() {
testCall(db, "CALL apoc.ml.watson.completion(null, $apiKey, $conf)",
Map.of("apiKey", accessToken, "conf", emptyMap()),
(row) -> assertNull(row.get("value"))
);
}

@Test
public void chatCompletionNull() {
testCall(db,
"CALL apoc.ml.watson.chat(null, $apiKey, $conf)",
Map.of("apiKey", accessToken, "conf", emptyMap()),
(row) -> assertNull(row.get("value"))
);
}
}
19 changes: 19 additions & 0 deletions extended/src/test/java/apoc/ml/aws/BedrockIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import static apoc.ml.aws.BedrockUtil.*;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
import static java.util.Collections.emptyMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeNotNull;
Expand Down Expand Up @@ -291,4 +293,21 @@ private static void assertionsTitanEmbed(Map value) {
assertNotNull(value.get("inputTextTokenCount"));
assertNotNull(value.get("embedding"));
}

@Test
public void completionNull() {
testCall(db, "CALL apoc.ml.bedrock.completion(null, $conf)",
Map.of("conf", emptyMap()),
(row) -> assertNull(row.get("value"))
);
}

@Test
public void chatCompletionNull() {
testCall(db,
"CALL apoc.ml.bedrock.chat(null, $conf)",
Map.of("conf", emptyMap()),
(row) -> assertNull(row.get("value"))
);
}
}
Loading
Loading