From 9d18f7b1b98e051ab7b7a5fed176835de42eb83e Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Wed, 22 Nov 2023 15:00:06 +0100 Subject: [PATCH] Implement execute statement and connect to the planner --- .../vector/astra/AstraVectorDBDataSource.java | 122 +++++++++++++++++- .../datasource/impl/AstraVectorDBTest.java | 93 ++++++++++++- .../assets/AstraVectorDBAssetsProvider.java | 89 +++++++++++++ .../impl/assets/CassandraAssetsProvider.java | 13 +- .../resources/DataSourceResourceProvider.java | 6 +- .../datasource/AstraDatasourceConfig.java | 3 +- .../AstraVectorDBDatasourceConfig.java | 62 +++++++++ ...i.langstream.api.runtime.AssetNodeProvider | 3 +- 8 files changed, 379 insertions(+), 12 deletions(-) create mode 100644 langstream-core/src/main/java/ai/langstream/impl/assets/AstraVectorDBAssetsProvider.java create mode 100644 langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraVectorDBDatasourceConfig.java diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java index 74f79cc37..5ca6e69cb 100644 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java @@ -21,18 +21,37 @@ import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource; import com.dtsx.astra.sdk.AstraDB; import io.stargate.sdk.json.CollectionClient; +import io.stargate.sdk.json.domain.DeleteQuery; +import io.stargate.sdk.json.domain.DeleteQueryBuilder; +import io.stargate.sdk.json.domain.JsonDocument; import io.stargate.sdk.json.domain.JsonResult; +import io.stargate.sdk.json.domain.JsonResultUpdate; import io.stargate.sdk.json.domain.SelectQuery; import io.stargate.sdk.json.domain.SelectQueryBuilder; +import io.stargate.sdk.json.domain.UpdateQuery; +import io.stargate.sdk.json.domain.UpdateQueryBuilder; +import java.lang.reflect.Field; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; @Slf4j public class AstraVectorDBDataSource implements QueryStepDataSource { + static final Field update; + + static { + try { + update = UpdateQueryBuilder.class.getDeclaredField("update"); + update.setAccessible(true); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + AstraDB astraDB; @Override @@ -73,7 +92,7 @@ public List> fetchData(String query, List params) { List result; float[] vector = JstlFunctions.toArrayOfFloat(queryMap.remove("vector")); - Integer max = (Integer) queryMap.remove("max"); + Integer limit = (Integer) queryMap.remove("limit"); boolean includeSimilarity = vector != null; Object includeSimilarityParam = queryMap.remove("include-similarity"); @@ -101,8 +120,8 @@ public List> fetchData(String query, List params) { if (filterMap != null) { selectQueryBuilder.withJsonFilter(JstlFunctions.toJson(filterMap)); } - if (max != null) { - selectQueryBuilder.limit(max); + if (limit != null) { + selectQueryBuilder.limit(limit); } SelectQuery selectQuery = selectQueryBuilder.build(); @@ -143,7 +162,102 @@ public Map executeStatement( .map(v -> v == null ? "null" : v.getClass().toString()) .collect(Collectors.joining(","))); } - throw new UnsupportedOperationException(); + try { + Map queryMap = + InterpolationUtils.buildObjectFromJson(query, Map.class, params); + if (queryMap.isEmpty()) { + throw new UnsupportedOperationException("Query is empty"); + } + String collectionName = (String) queryMap.remove("collection-name"); + if (collectionName == null) { + throw new UnsupportedOperationException("collection-name is not defined"); + } + CollectionClient collection = this.getAstraDB().collection(collectionName); + + String action = (String) queryMap.remove("action"); + + switch (action) { + case "findOneAndUpdate": + { + Map filterMap = + (Map) queryMap.remove("filter"); + UpdateQueryBuilder builder = UpdateQuery.builder(); + if (filterMap != null) { + builder.withJsonFilter(JstlFunctions.toJson(filterMap)); + } + String returnDocument = (String) queryMap.remove("return-document"); + if (returnDocument != null) { + builder.withReturnDocument( + UpdateQueryBuilder.ReturnDocument.valueOf(returnDocument)); + } + Map updateMap = + (Map) queryMap.remove("update"); + if (updateMap != null) { + update.set(builder, updateMap); + } + + UpdateQuery updateQuery = builder.build(); + log.info( + "doing findOneAndUpdate with UpdateQuery {}", + JstlFunctions.toJson(updateQuery)); + JsonResultUpdate oneAndUpdate = collection.findOneAndUpdate(updateQuery); + return Map.of("count", oneAndUpdate.getUpdateStatus().getModifiedCount()); + } + case "deleteOne": + { + Map filterMap = + (Map) queryMap.remove("filter"); + DeleteQueryBuilder builder = DeleteQuery.builder(); + if (filterMap != null) { + builder.withJsonFilter(JstlFunctions.toJson(filterMap)); + } + DeleteQuery delete = builder.build(); + log.info( + "doing deleteOne with DeleteQuery {}", + JstlFunctions.toJson(delete)); + int count = collection.deleteOne(delete); + return Map.of("count", count); + } + case "insertOne": + { + Map documentData = + (Map) queryMap.remove("document"); + JsonDocument document = new JsonDocument(); + for (Map.Entry entry : documentData.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + switch (key) { + case "id": + document.id(value.toString()); + break; + case "vector": + document.vector(JstlFunctions.toArrayOfFloat(value)); + break; + case "data": + document.data(value); + break; + default: + document.put(key, value); + break; + } + } + if (document.getId() == null) { + document.setId(UUID.randomUUID().toString()); + } + + log.info( + "doing insertOne with JsonDocument {}", + JstlFunctions.toJson(document)); + String id = collection.insertOne(document); + return Map.of("id", id); + } + default: + throw new UnsupportedOperationException("Unsupported action: " + action); + } + + } catch (Exception err) { + throw new RuntimeException(err); + } } public AstraDB getAstraDB() { diff --git a/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java b/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java index e6b212393..4b9e2b8cf 100644 --- a/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java +++ b/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java @@ -42,8 +42,10 @@ @Slf4j public class AstraVectorDBTest { - private static final String TOKEN = ""; - private static final String ENDPOINT = ""; + private static final String TOKEN = + "AstraCS:HQKZyFwTNcNQFPhsLHPHlyYq:0fd08e29b7e7c590e947ac8fa9a4d6d785a4661a8eb1b3c011e2a0d19c2ecd7c"; + private static final String ENDPOINT = + "https://18bdf302-901f-4245-af09-061ebdb480d2-us-east1.apps.astra.datastax.com"; @Test void testWriteAndRead() throws Exception { @@ -305,6 +307,84 @@ void testWriteAndRead() throws Exception { assertEquals(0, result.size()); }); + Map executeInsertRes = + executeStatement( + datasource, + """ + { + "action": "insertOne", + "collection-name": "%s", + "document": { + "id": "some-id", + "name": ?, + "vector": ?, + "text": "Some text" + } + } + """ + .formatted(collectionName), + List.of("some", vector)); + assertEquals("some-id", executeInsertRes.get("id")); + + assertContents( + datasource, + queryWithFilterOnName, + List.of("some"), + result -> { + assertEquals(1, result.size()); + assertEquals("Some text", result.get(0).get("text")); + }); + + executeStatement( + datasource, + """ + { + "action": "findOneAndUpdate", + "collection-name": "%s", + "filter": { + "_id": ? + }, + "update": { + "$set": { + "text": ? + } + } + } + """ + .formatted(collectionName), + List.of("some-id", "new value")); + + assertContents( + datasource, + queryWithFilterOnName, + List.of("some"), + result -> { + assertEquals(1, result.size()); + assertEquals("new value", result.get(0).get("text")); + }); + + executeStatement( + datasource, + """ + { + "action": "deleteOne", + "collection-name": "%s", + "filter": { + "_id": ? + } + } + """ + .formatted(collectionName), + List.of("some-id")); + + assertContents( + datasource, + queryWithFilterOnName, + List.of("some"), + result -> { + assertEquals(0, result.size()); + }); + // CLEANUP assertTrue(tableManager.assetExists()); tableManager.deleteAssetIfExists(); @@ -325,4 +405,13 @@ private static List> assertContents( ; return results; } + + private static Map executeStatement( + QueryStepDataSource datasource, String query, List params) { + log.info("Query: {}", query); + log.info("Params: {}", params); + Map results = datasource.executeStatement(query, null, params); + log.info("Result: {}", results); + return results; + } } diff --git a/langstream-core/src/main/java/ai/langstream/impl/assets/AstraVectorDBAssetsProvider.java b/langstream-core/src/main/java/ai/langstream/impl/assets/AstraVectorDBAssetsProvider.java new file mode 100644 index 000000000..ee9499a9a --- /dev/null +++ b/langstream-core/src/main/java/ai/langstream/impl/assets/AstraVectorDBAssetsProvider.java @@ -0,0 +1,89 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.impl.assets; + +import ai.langstream.api.doc.AssetConfig; +import ai.langstream.api.doc.ConfigProperty; +import ai.langstream.api.model.AssetDefinition; +import ai.langstream.api.util.ConfigurationUtils; +import ai.langstream.impl.common.AbstractAssetProvider; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; +import java.util.Set; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class AstraVectorDBAssetsProvider extends AbstractAssetProvider { + + protected static final String ASTRA_COLLECTION = "astra-collection"; + + public AstraVectorDBAssetsProvider() { + super(Set.of(ASTRA_COLLECTION)); + } + + @Override + protected Class getAssetConfigModelClass(String type) { + return switch (type) { + case ASTRA_COLLECTION -> AstraCollectionConfig.class; + default -> throw new IllegalArgumentException("Unknown asset type " + type); + }; + } + + @Override + protected void validateAsset(AssetDefinition assetDefinition, Map asset) { + Map configuration = ConfigurationUtils.getMap("config", null, asset); + final Map datasource = + ConfigurationUtils.getMap("datasource", Map.of(), configuration); + final Map datasourceConfiguration = + ConfigurationUtils.getMap("configuration", Map.of(), datasource); + switch (assetDefinition.getAssetType()) { + default -> {} + } + } + + @Override + protected boolean lookupResource(String fieldName) { + return "datasource".equals(fieldName); + } + + @AssetConfig( + name = "Astra Collection", + description = + """ + Manage a DataStax Astra Collection. + """) + @Data + public static class AstraCollectionConfig { + + @ConfigProperty( + description = + """ + Reference to a datasource id configured in the application. + """, + required = true) + private String datasource; + + @ConfigProperty( + description = + """ + Name of the collection to create. + """, + required = true) + @JsonProperty("collection-name") + private String collectionName; + } +} diff --git a/langstream-core/src/main/java/ai/langstream/impl/assets/CassandraAssetsProvider.java b/langstream-core/src/main/java/ai/langstream/impl/assets/CassandraAssetsProvider.java index de34e8195..0681e2e1c 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/assets/CassandraAssetsProvider.java +++ b/langstream-core/src/main/java/ai/langstream/impl/assets/CassandraAssetsProvider.java @@ -71,15 +71,22 @@ protected void validateAsset(AssetDefinition assetDefinition, Map { if (!datasourceConfiguration.containsKey("secureBundle") - && !datasourceConfiguration.containsKey("database")) { + && !datasourceConfiguration.containsKey("database") + && !datasourceConfiguration.containsKey("database-id")) { throw new IllegalArgumentException( "Use cassandra-keyspace for a standard Cassandra service (not AstraDB)"); } // are we are using the AstraDB SDK we need also the AstraCS token and // the name of the database requiredNonEmptyField(datasourceConfiguration, "token", describe(assetDefinition)); - requiredNonEmptyField( - datasourceConfiguration, "database", describe(assetDefinition)); + if (!datasourceConfiguration.containsKey("database")) { + requiredNonEmptyField( + datasourceConfiguration, "database-id", describe(assetDefinition)); + } + if (!datasourceConfiguration.containsKey("database-id")) { + requiredNonEmptyField( + datasourceConfiguration, "database", describe(assetDefinition)); + } } default -> {} } diff --git a/langstream-core/src/main/java/ai/langstream/impl/resources/DataSourceResourceProvider.java b/langstream-core/src/main/java/ai/langstream/impl/resources/DataSourceResourceProvider.java index 20d9aa633..7091cdfbf 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/resources/DataSourceResourceProvider.java +++ b/langstream-core/src/main/java/ai/langstream/impl/resources/DataSourceResourceProvider.java @@ -16,6 +16,7 @@ package ai.langstream.impl.resources; import ai.langstream.impl.resources.datasource.AstraDatasourceConfig; +import ai.langstream.impl.resources.datasource.AstraVectorDBDatasourceConfig; import ai.langstream.impl.resources.datasource.CassandraDatasourceConfig; import ai.langstream.impl.resources.datasource.JDBCDatasourceConfig; import ai.langstream.impl.resources.datasource.OpenSearchDatasourceConfig; @@ -24,6 +25,8 @@ public class DataSourceResourceProvider extends BaseDataSourceResourceProvider { protected static final String SERVICE_ASTRA = "astra"; + + protected static final String SERVICE_ASTRA_VECTOR_DB = "astra-vector-db"; protected static final String SERVICE_CASSANDRA = "cassandra"; protected static final String SERVICE_JDBC = "jdbc"; protected static final String SERVICE_OPENSEARCH = "opensearch"; @@ -35,6 +38,7 @@ public DataSourceResourceProvider() { SERVICE_ASTRA, AstraDatasourceConfig.CONFIG, SERVICE_CASSANDRA, CassandraDatasourceConfig.CONFIG, SERVICE_JDBC, JDBCDatasourceConfig.CONFIG, - SERVICE_OPENSEARCH, OpenSearchDatasourceConfig.CONFIG)); + SERVICE_OPENSEARCH, OpenSearchDatasourceConfig.CONFIG, + SERVICE_ASTRA_VECTOR_DB, AstraVectorDBDatasourceConfig.CONFIG)); } } diff --git a/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java b/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java index 93b2bf318..f96e258de 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java +++ b/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java @@ -48,7 +48,8 @@ public void validate(Resource resource) { ConfigurationUtils.getString("secureBundle", "", configuration); if (secureBundle.isEmpty()) { if (configuration.get("token") == null - || configuration.get("database") == null) { + || (configuration.get("database") == null + && configuration.get("database-id") == null)) { throw new IllegalArgumentException( ClassConfigValidator.formatErrString( new ClassConfigValidator.ResourceEntityRef(resource), diff --git a/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraVectorDBDatasourceConfig.java b/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraVectorDBDatasourceConfig.java new file mode 100644 index 000000000..7c1089fca --- /dev/null +++ b/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraVectorDBDatasourceConfig.java @@ -0,0 +1,62 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.impl.resources.datasource; + +import ai.langstream.api.doc.ConfigProperty; +import ai.langstream.api.doc.ResourceConfig; +import ai.langstream.api.model.Resource; +import ai.langstream.impl.resources.BaseDataSourceResourceProvider; +import ai.langstream.impl.uti.ClassConfigValidator; +import lombok.Data; + +@Data +@ResourceConfig( + name = "Astra Vector DB", + description = "Connect to DataStax Astra Vector DB service.") +public class AstraVectorDBDatasourceConfig extends BaseDatasourceConfig { + + public static final BaseDataSourceResourceProvider.DatasourceConfig CONFIG = + new BaseDataSourceResourceProvider.DatasourceConfig() { + + @Override + public Class getResourceConfigModelClass() { + return AstraVectorDBDatasourceConfig.class; + } + + @Override + public void validate(Resource resource) { + ClassConfigValidator.validateResourceModelFromClass( + resource, + AstraVectorDBDatasourceConfig.class, + resource.configuration(), + false); + } + }; + + @ConfigProperty( + description = + """ + API Endpoint. + """) + private String endpoint; + + @ConfigProperty( + description = + """ + Astra Token (AstraCS:xxx) for connecting to the database. + """) + private String token; +} diff --git a/langstream-core/src/main/resources/META-INF/services/ai.langstream.api.runtime.AssetNodeProvider b/langstream-core/src/main/resources/META-INF/services/ai.langstream.api.runtime.AssetNodeProvider index f47198f79..3a5d6610a 100644 --- a/langstream-core/src/main/resources/META-INF/services/ai.langstream.api.runtime.AssetNodeProvider +++ b/langstream-core/src/main/resources/META-INF/services/ai.langstream.api.runtime.AssetNodeProvider @@ -2,4 +2,5 @@ ai.langstream.impl.assets.CassandraAssetsProvider ai.langstream.impl.assets.MilvusAssetsProvider ai.langstream.impl.assets.JdbcAssetsProvider ai.langstream.impl.assets.SolrAssetsProvider -ai.langstream.impl.assets.OpenSearchAssetsProvider \ No newline at end of file +ai.langstream.impl.assets.OpenSearchAssetsProvider +ai.langstream.impl.assets.AstraVectorDBAssetsProvider \ No newline at end of file