diff --git a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java index c20c382e3..c8360288c 100644 --- a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java +++ b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java @@ -88,9 +88,7 @@ private void registerFunctions() { this.expressionContext .getFunctionMapper() .mapFunction( - "fn", - "concat", - JstlFunctions.class.getMethod("concat", Object.class, Object.class)); + "fn", "concat", JstlFunctions.class.getMethod("concat", Object[].class)); this.expressionContext .getFunctionMapper() .mapFunction( @@ -127,6 +125,10 @@ private void registerFunctions() { "fn", "addAll", JstlFunctions.class.getMethod("addAll", Object.class, Object.class)); + this.expressionContext + .getFunctionMapper() + .mapFunction( + "fn", "listOf", JstlFunctions.class.getMethod("listOf", Object[].class)); this.expressionContext .getFunctionMapper() .mapFunction("fn", "emptyList", JstlFunctions.class.getMethod("emptyList")); @@ -153,6 +155,24 @@ private void registerFunctions() { this.expressionContext .getFunctionMapper() .mapFunction("fn", "emptyMap", JstlFunctions.class.getMethod("emptyMap")); + this.expressionContext + .getFunctionMapper() + .mapFunction( + "fn", + "mapPut", + JstlFunctions.class.getMethod( + "mapPut", Object.class, Object.class, Object.class)); + this.expressionContext + .getFunctionMapper() + .mapFunction("fn", "mapOf", JstlFunctions.class.getMethod("mapOf", Object[].class)); + + this.expressionContext + .getFunctionMapper() + .mapFunction( + "fn", + "mapRemove", + JstlFunctions.class.getMethod("mapRemove", Object.class, Object.class)); + this.expressionContext .getFunctionMapper() .mapFunction("fn", "toInt", JstlFunctions.class.getMethod("toInt", Object.class)); diff --git a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java index 038f0889c..cedf455d7 100644 --- a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java +++ b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java @@ -26,6 +26,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -164,6 +165,52 @@ public static Map emptyMap() { return Map.of(); } + public static Map mapOf(Object... field) { + Map result = new HashMap<>(); + for (int i = 0; i < field.length; i += 2) { + result.put(field[i].toString(), field[i + 1]); + } + return result; + } + + public static List listOf(Object... field) { + List result = new ArrayList<>(); + result.addAll(Arrays.asList(field)); + return result; + } + + public static Map mapPut(Object map, Object field, Object value) { + Map result = new HashMap<>(); + if (map != null) { + if (map instanceof Map m) { + result.putAll(m); + } else { + throw new IllegalArgumentException("mapPut doesn't allow a non-map value"); + } + } + if (field == null || field.toString().isEmpty()) { + throw new IllegalArgumentException("mapPut doesn't allow a null field"); + } + result.put(field.toString(), value); + return result; + } + + public static Map mapRemove(Object map, Object field) { + Map result = new HashMap<>(); + if (map != null) { + if (map instanceof Map m) { + result.putAll(m); + } else { + throw new IllegalArgumentException("mapPut doesn't allow a non-map value"); + } + } + if (field == null || field.toString().isEmpty()) { + throw new IllegalArgumentException("mapPut doesn't allow a null field"); + } + result.remove(field.toString()); + return result; + } + public static List mapToListOfStructs(Object object, String fields) { if (object == null) { throw new IllegalArgumentException("listOf doesn't allow a null value"); @@ -282,12 +329,18 @@ public static String trim(Object input) { return input == null ? null : toString(input).trim(); } - public static String concat(Object first, Object second) { - return toString(first) + toString(second); + public static String concat(Object... elements) { + StringBuilder sb = new StringBuilder(); + for (Object o : elements) { + if (o != null) { + sb.append(toString(o)); + } + } + return sb.toString(); } public static String concat3(Object first, Object second, Object third) { - return toString(first) + toString(second) + toString(third); + return concat(first, second, third); } public static Object coalesce(Object value, Object valueIfNull) { diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java index 5ea87c7a6..bc02a8d52 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java @@ -319,6 +319,9 @@ private Schema getAvroSchema(ComputeFieldType type, Object value) { case BYTES: schemaType = Schema.Type.BYTES; break; + case MAP: + schemaType = Schema.Type.MAP; + break; case ARRAY: schemaType = Schema.Type.ARRAY; break; @@ -342,6 +345,10 @@ private Schema getAvroSchema(ComputeFieldType type, Object value) { return Schema.createArray( Schema.createMap(Schema.create(Schema.Type.STRING))); } + if (schemaType == Schema.Type.MAP) { + // we don't know the element type of the array, so we can't create a schema + return Schema.createMap(Schema.create(Schema.Type.STRING)); + } // Handle logical types: // https://avro.apache.org/docs/1.10.2/spec.html#Logical+Types @@ -517,6 +524,9 @@ private ComputeFieldType getFieldType(Object value) { if (List.class.isAssignableFrom(value.getClass())) { return ComputeFieldType.ARRAY; } + if (Map.class.isAssignableFrom(value.getClass())) { + return ComputeFieldType.MAP; + } throw new UnsupportedOperationException("Got an unsupported type: " + value.getClass()); } } diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeFieldType.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeFieldType.java index 54c1285a1..4f9d70604 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeFieldType.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeFieldType.java @@ -39,5 +39,6 @@ public enum ComputeFieldType { DATETIME, BYTES, DECIMAL, - ARRAY + ARRAY, + MAP } diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java index bcd578564..d3e41252c 100644 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java @@ -34,7 +34,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import lombok.Getter; import lombok.extern.slf4j.Slf4j; @@ -176,15 +176,15 @@ public void initialise(Map agentConfiguration) { processor.start(configuration); } - private final AtomicReference> currentRecordStatus = - new AtomicReference<>(); + private final Map> currentRecordStatus = + new ConcurrentHashMap<>(); @Override public CompletableFuture upsert(Record record, Map context) { // we must handle one record at a time // so we block until the record is processed CompletableFuture handle = new CompletableFuture(); - currentRecordStatus.set(handle); + currentRecordStatus.put(record, handle); processor.put(List.of(new LangStreamSinkRecordAdapter(record))); return handle; } @@ -208,7 +208,8 @@ public String applicationName() { @Override protected void handleSuccess(AbstractSinkRecord abstractRecord) { Record record = ((LangStreamSinkRecordAdapter) abstractRecord).getRecord(); - currentRecordStatus.get().complete(null); + CompletableFuture remove = currentRecordStatus.remove(record); + remove.complete(null); } @Override @@ -242,12 +243,14 @@ protected void handleFailure( log.warn("Error decoding/mapping Kafka record {}: {}", record, e.getMessage()); } + CompletableFuture remove = currentRecordStatus.remove(record); + if (ignoreErrors == CassandraSinkConfig.IgnoreErrorsPolicy.NONE || (ignoreErrors == CassandraSinkConfig.IgnoreErrorsPolicy.DRIVER && !driverFailure)) { - currentRecordStatus.get().completeExceptionally(e); + remove.completeExceptionally(e); } else { - currentRecordStatus.get().complete(null); + remove.complete(null); } failCounter.run(); diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/CassandraVectorAssetQueryWriteIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/CassandraVectorAssetQueryWriteIT.java index 6f462b5b3..2c7122197 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/CassandraVectorAssetQueryWriteIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/CassandraVectorAssetQueryWriteIT.java @@ -21,6 +21,7 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.cql.ColumnDefinitions; import com.datastax.oss.driver.api.core.cql.ResultSet; import com.datastax.oss.driver.api.core.cql.Row; import com.datastax.oss.driver.api.core.data.CqlVector; @@ -32,6 +33,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.header.internals.RecordHeader; import org.junit.jupiter.api.Test; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; @@ -192,4 +194,191 @@ public void testCassandra() throws Exception { } } } + + @Test + public void testCassandraCassioSchema() throws Exception { + String tenant = "tenant"; + String[] expectedAgents = {"app-step1"}; + + Map application = + Map.of( + "configuration.yaml", + """ + configuration: + resources: + - type: "datasource" + name: "CassandraDatasource" + configuration: + service: "cassandra" + contact-points: "%s" + loadBalancing-localDc: "%s" + port: %d + """ + .formatted( + cassandra.getContactPoint().getHostString(), + cassandra.getLocalDatacenter(), + cassandra.getContactPoint().getPort()), + "pipeline.yaml", + """ + assets: + - name: "cassio-keyspace" + asset-type: "cassandra-keyspace" + creation-mode: create-if-not-exists + deletion-mode: delete + config: + keyspace: "cassio" + datasource: "CassandraDatasource" + create-statements: + - "CREATE KEYSPACE IF NOT EXISTS cassio WITH REPLICATION = {'class' : 'SimpleStrategy','replication_factor' : 1};" + delete-statements: + - "DROP KEYSPACE IF EXISTS cassio;" + - name: "documents-table" + asset-type: "cassandra-table" + creation-mode: create-if-not-exists + config: + table-name: "documents" + keyspace: "cassio" + datasource: "CassandraDatasource" + create-statements: + - | + CREATE TABLE IF NOT EXISTS cassio.documents ( + row_id text PRIMARY KEY, + attributes_blob text, + body_blob text, + metadata_s map , + vector vector + ); + - | + CREATE INDEX IF NOT EXISTS documents_metadata ON cassio.documents (ENTRIES(metadata_s)); + topics: + - name: "input-topic-cassio" + creation-mode: create-if-not-exists + pipeline: + - id: step1 + name: "Split into chunks" + type: "text-splitter" + input: "input-topic-cassio" + configuration: + chunk_size: 10 + chunk_overlap: 0 + keep_separator: true + length_function: "length" + - name: "Convert to structured data" + type: "document-to-json" + configuration: + text-field: text + copy-properties: true + - name: "Find old chunks" + type: "query" + configuration: + datasource: "CassandraDatasource" + when: "fn:toInt(properties.text_num_chunks) == (fn:toInt(properties.chunk_id) + 1)" + mode: "query" + query: "SELECT row_id, metadata_s['chunk_id'] as chunk_id FROM cassio.documents WHERE metadata_s['filename'] = ?" + output-field: "value.stale_chunks" + fields: + - "properties.filename" + - name: "Delete stale chunks" + type: "query" + configuration: + datasource: "CassandraDatasource" + when: "fn:toInt(properties.text_num_chunks) == (fn:toInt(properties.chunk_id) + 1)" + loop-over: "value.stale_chunks" + mode: "execute" + query: "DELETE FROM cassio.documents WHERE row_id = ?" + output-field: "value.delete-results" + fields: + - "record.row_id" + - type: compute + name: "Compute metadata" + configuration: + fields: + - name: "value.metadata_s" + expression: "fn:mapOf('filename', properties.filename, 'chunk_id', properties.chunk_id)" + - name: "value.row_id" + expression: "fn:uuid()" + - name: "value.vector" + expression: "fn:listOf(0,1,2,3,4)" + - name: "value.attributes_blob" + expression: "fn:str('')" + - type: "log-event" + name: "Log event" + - name: "Write a new record to Cassandra" + type: "vector-db-sink" + configuration: + datasource: "CassandraDatasource" + table-name: "documents" + keyspace: "cassio" + mapping: "row_id=value.row_id,attributes_blob=value.attributes_blob,body_blob=value.text,metadata_s=value.metadata_s,vector=value.vector" + """); + + try (ApplicationRuntime applicationRuntime = + deployApplication( + tenant, "app", application, buildInstanceYaml(), expectedAgents)) { + try (KafkaProducer producer = createProducer(); ) { + + String filename = "doc.txt"; + sendMessage( + "input-topic-cassio", + filename, + """ + This is some very long long long long long long long long long long long long text""", + List.of(new RecordHeader("filename", filename.getBytes())), + producer); + + executeAgentRunners(applicationRuntime); + + CqlSessionBuilder builder = new CqlSessionBuilder(); + builder.addContactPoint(cassandra.getContactPoint()); + builder.withLocalDatacenter(cassandra.getLocalDatacenter()); + + try (CqlSession cqlSession = builder.build(); ) { + ResultSet execute = cqlSession.execute("SELECT * FROM cassio.documents"); + List all = execute.all(); + all.forEach( + row -> { + log.info("row id {}", row.get("row_id", String.class)); + ColumnDefinitions columnDefinitions = row.getColumnDefinitions(); + for (int i = 0; i < columnDefinitions.size(); i++) { + log.info( + "column {} value {}", + columnDefinitions.get(i).getName(), + row.getObject(i)); + } + }); + assertEquals(9, all.size()); + } + + sendMessage( + "input-topic-cassio", + filename, + """ + Now the text is shorter""", + List.of(new RecordHeader("filename", filename.getBytes())), + producer); + + executeAgentRunners(applicationRuntime); + + try (CqlSession cqlSession = builder.build(); ) { + ResultSet execute = cqlSession.execute("SELECT * FROM cassio.documents"); + List all = execute.all(); + log.info("final records {}", all); + all.forEach( + row -> { + log.info("row id {}", row.get("row_id", String.class)); + ColumnDefinitions columnDefinitions = row.getColumnDefinitions(); + for (int i = 0; i < columnDefinitions.size(); i++) { + log.info( + "column {} value {}", + columnDefinitions.get(i).getName(), + row.getObject(i)); + } + }); + assertEquals(3, all.size()); + } + + applicationDeployer.cleanup(tenant, applicationRuntime.implementation()); + } + } + } }