Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
GenAI: add support for compute-ai-embeddings function (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli authored Jun 30, 2023
1 parent bacd547 commit 24b5a65
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.datastax.oss.sga.api.runtime;

import com.datastax.oss.sga.api.model.ApplicationInstance;
import com.datastax.oss.sga.api.model.Connection;
import com.datastax.oss.sga.api.model.Module;

Expand All @@ -23,6 +24,8 @@
*/
public interface PhysicalApplicationInstance {

ApplicationInstance getApplicationInstance();

ConnectionImplementation getConnectionImplementation(Module module, Connection connection);

AgentImplementation getAgentImplementation(Module module, String id);
Expand Down
9 changes: 9 additions & 0 deletions examples/applications/app4/configuration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
configuration:
resources:
- type: "open-ai-configuration"
name: "OpenAI Azure configuration"
id: "openai-configuration"
configuration:
url: "https://put-here-you-api-server"
access-key: "put-here-you-api-key"
provider": "azure"
20 changes: 20 additions & 0 deletions examples/applications/app4/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module: "module-1"
id: "pipeline-1"
topics:
- name: "input-topic"
creation-mode: create-if-not-exists
schema:
type: avro
schema: '{"type":"record","namespace":"examples","name":"Product","fields":[{"name":"id","type":"string"},{"name":"name","type":"string"},{"name":"description","type":"string"}]}}'
- name: "output-topic"
creation-mode: create-if-not-exists
pipeline:
- name: "compute-embeddings"
id: "step1"
type: "compute-ai-embeddings"
input: "input-topic"
output: "output-topic"
configuration:
model: "text-embedding-ada-002"
embeddings-field: "value.embeddings"
text: "{{ value.name }} {{ value.description }}"
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public PulsarPhysicalApplicationInstance createImplementation(ApplicationInstanc
String tenant = (String) streamingCluster.configuration().getOrDefault("defaultTenant", "public");
String namespace = (String) streamingCluster.configuration().getOrDefault("defaultNamespace", "default");

PulsarPhysicalApplicationInstance result = new PulsarPhysicalApplicationInstance(tenant, namespace);
PulsarPhysicalApplicationInstance result = new PulsarPhysicalApplicationInstance(applicationInstance, tenant, namespace);

detectTopics(applicationInstance, result, tenant, namespace);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.datastax.oss.sga.pulsar;

import com.datastax.oss.sga.api.model.AgentConfiguration;
import com.datastax.oss.sga.api.model.ApplicationInstance;
import com.datastax.oss.sga.api.model.Connection;
import com.datastax.oss.sga.api.model.Module;
import com.datastax.oss.sga.api.model.SchemaDefinition;
Expand All @@ -21,6 +22,7 @@ public class PulsarPhysicalApplicationInstance implements PhysicalApplicationIns
private final Map<PulsarName, PulsarTopic> topics = new HashMap<>();
private final Map<String, AgentImplementation> agents = new HashMap<>();

private final ApplicationInstance applicationInstance;
private final String defaultTenant;
private final String defaultNamespace;

Expand All @@ -39,6 +41,11 @@ public PulsarTopic registerTopic(String tenant, String namespace, String name, S
return pulsarTopic;
}

@Override
public ApplicationInstance getApplicationInstance() {
return applicationInstance;
}

@Override
public ConnectionImplementation getConnectionImplementation(Module module, Connection connection) {
Connection.Connectable endpoint = connection.endpoint();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.datastax.oss.sga.pulsar.agents.ai;

import java.util.List;
import java.util.Map;

public class ComputeEmbeddingsAgentProvider extends GenAIToolKitFunctionAgentProvider {

public ComputeEmbeddingsAgentProvider() {
super("compute-ai-embeddings");
}

@Override
protected void generateSteps(Map<String, Object> originalConfiguration, List<Map<String, Object>> steps) {
Map<String, Object> step = Map.of(
"type", "compute-ai-embeddings",
"model", originalConfiguration.getOrDefault("model", "text-embedding-ada-002"),
"embeddings-field", originalConfiguration.getOrDefault("embeddings-field", "embeddings"),
"text", originalConfiguration.getOrDefault("text", "")
);
steps.add(step);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.datastax.oss.sga.pulsar.agents.ai;

import com.datastax.oss.sga.api.model.AgentConfiguration;
import com.datastax.oss.sga.api.model.ApplicationInstance;
import com.datastax.oss.sga.api.model.Module;
import com.datastax.oss.sga.api.model.Resource;
import com.datastax.oss.sga.api.runtime.ClusterRuntime;
import com.datastax.oss.sga.api.runtime.ConnectionImplementation;
import com.datastax.oss.sga.api.runtime.PhysicalApplicationInstance;
import com.datastax.oss.sga.pulsar.PulsarClusterRuntime;
import com.datastax.oss.sga.pulsar.PulsarTopic;
import com.datastax.oss.sga.pulsar.agents.AbstractPulsarFunctionAgentProvider;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class GenAIToolKitFunctionAgentProvider extends AbstractPulsarFunctionAgentProvider {

public GenAIToolKitFunctionAgentProvider(String stepType) {
super(List.of(stepType), List.of(PulsarClusterRuntime.CLUSTER_TYPE));
}

@Override
protected String getFunctionType(AgentConfiguration agentConfiguration) {
// https://github.com/datastax/pulsar-transformations/tree/master/pulsar-ai-tools
return "ai-tools";
}

@Override
protected String getFunctionClassname(AgentConfiguration agentConfiguration) {
return null;
}

protected void generateSteps(Map<String, Object> originalConfiguration, List<Map<String, Object>> steps) {
}

private void generateOpenAIConfiguration(ApplicationInstance applicationInstance, Map<String, Object> configuration) {
Resource resource = applicationInstance.getResources().values().stream()
.filter(r -> r.type().equals("open-ai-configuration"))
.findFirst().orElse(null);
if (resource != null) {
String url = (String) resource.configuration().get("url");
String accessKey = (String) resource.configuration().get("access-key");
String provider = (String) resource.configuration().get("provider");
Map<String, Object> openaiConfiguration = new HashMap<>();
if (url != null) {
openaiConfiguration.put("url", url);
}
if (accessKey != null) {
openaiConfiguration.put("access-key", accessKey);
}
if (provider != null) {
openaiConfiguration.put("provider", provider);
}
configuration.put("openai", openaiConfiguration);
}
}

@Override
protected Map<String, Object> computeAgentConfiguration(AgentConfiguration agentConfiguration, Module module,
PhysicalApplicationInstance physicalApplicationInstance,
ClusterRuntime clusterRuntime) {
Map<String, Object> originalConfiguration = super.computeAgentConfiguration(agentConfiguration, module, physicalApplicationInstance, clusterRuntime);
Map<String, Object> configuration = new HashMap<>();

generateOpenAIConfiguration(physicalApplicationInstance.getApplicationInstance(), configuration);

List<Map<String, Object>> steps = new ArrayList<>();
configuration.put("steps", steps);
generateSteps(originalConfiguration, steps);
return configuration;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ com.datastax.oss.sga.pulsar.agents.GenericPulsarSourceAgentProvider

# Functions
com.datastax.oss.sga.pulsar.agents.GenericPulsarFunctionAgentProvider

# Streaming AI
com.datastax.oss.sga.pulsar.agents.ai.ComputeEmbeddingsAgentProvider
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -351,7 +353,6 @@ public void testMapGenericPulsarFunctionsChain() throws Exception {


@Test
@Disabled
public void testOpenAIComputeEmbeddingFunction() throws Exception {
ApplicationInstance applicationInstance = ModelBuilder
.buildApplicationInstance(Map.of("instance.yaml",
Expand All @@ -364,6 +365,17 @@ public void testOpenAIComputeEmbeddingFunction() throws Exception {
defaultTenant: "public"
defaultNamespace: "default"
""",
"configuration.yaml",
"""
configuration:
resources:
- name: open-ai
type: open-ai-configuration
configuration:
url: "http://something"
access-key: "xxcxcxc"
provider: "azure"
""",
"module.yaml", """
module: "module-1"
id: "pipeline-1"
Expand Down Expand Up @@ -401,6 +413,24 @@ public void testOpenAIComputeEmbeddingFunction() throws Exception {

AgentImplementation agentImplementation = implementation.getAgentImplementation(module, "step1");
assertNotNull(agentImplementation);
AbstractAgentProvider.DefaultAgentImplementation step = (AbstractAgentProvider.DefaultAgentImplementation) agentImplementation;
Map<String, Object> configuration = step.getConfiguration();
log.info("Configuration: {}", configuration);
Map<String, Object> openAIConfiguration = (Map<String, Object>) configuration.get("openai");
log.info("openAIConfiguration: {}", openAIConfiguration);
assertEquals("http://something", openAIConfiguration.get("url"));
assertEquals("xxcxcxc", openAIConfiguration.get("access-key"));
assertEquals("azure", openAIConfiguration.get("provider"));


List<Map<String, Object>> steps = (List<Map<String, Object>>) configuration.get("steps");
assertEquals(1, steps.size());
Map<String, Object> step1 = steps.get(0);
assertEquals("text-embedding-ada-002", step1.get("model"));
assertEquals("value.embeddings", step1.get("embeddings-field"));
assertEquals("{{ value.name }} {{ value.description }}", step1.get("text"));



}

Expand Down

0 comments on commit 24b5a65

Please sign in to comment.