Skip to content

Commit

Permalink
Implement execute statement and connect to the planner
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli committed Nov 22, 2023
1 parent d186bd0 commit 9d18f7b
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,37 @@
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import com.dtsx.astra.sdk.AstraDB;
import io.stargate.sdk.json.CollectionClient;
import io.stargate.sdk.json.domain.DeleteQuery;
import io.stargate.sdk.json.domain.DeleteQueryBuilder;
import io.stargate.sdk.json.domain.JsonDocument;
import io.stargate.sdk.json.domain.JsonResult;
import io.stargate.sdk.json.domain.JsonResultUpdate;
import io.stargate.sdk.json.domain.SelectQuery;
import io.stargate.sdk.json.domain.SelectQueryBuilder;
import io.stargate.sdk.json.domain.UpdateQuery;
import io.stargate.sdk.json.domain.UpdateQueryBuilder;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraVectorDBDataSource implements QueryStepDataSource {

static final Field update;

static {
try {
update = UpdateQueryBuilder.class.getDeclaredField("update");
update.setAccessible(true);
} catch (Throwable t) {
throw new RuntimeException(t);
}
}

AstraDB astraDB;

@Override
Expand Down Expand Up @@ -73,7 +92,7 @@ public List<Map<String, Object>> fetchData(String query, List<Object> params) {
List<JsonResult> result;

float[] vector = JstlFunctions.toArrayOfFloat(queryMap.remove("vector"));
Integer max = (Integer) queryMap.remove("max");
Integer limit = (Integer) queryMap.remove("limit");

boolean includeSimilarity = vector != null;
Object includeSimilarityParam = queryMap.remove("include-similarity");
Expand Down Expand Up @@ -101,8 +120,8 @@ public List<Map<String, Object>> fetchData(String query, List<Object> params) {
if (filterMap != null) {
selectQueryBuilder.withJsonFilter(JstlFunctions.toJson(filterMap));
}
if (max != null) {
selectQueryBuilder.limit(max);
if (limit != null) {
selectQueryBuilder.limit(limit);
}

SelectQuery selectQuery = selectQueryBuilder.build();
Expand Down Expand Up @@ -143,7 +162,102 @@ public Map<String, Object> executeStatement(
.map(v -> v == null ? "null" : v.getClass().toString())
.collect(Collectors.joining(",")));
}
throw new UnsupportedOperationException();
try {
Map<String, Object> queryMap =
InterpolationUtils.buildObjectFromJson(query, Map.class, params);
if (queryMap.isEmpty()) {
throw new UnsupportedOperationException("Query is empty");
}
String collectionName = (String) queryMap.remove("collection-name");
if (collectionName == null) {
throw new UnsupportedOperationException("collection-name is not defined");
}
CollectionClient collection = this.getAstraDB().collection(collectionName);

String action = (String) queryMap.remove("action");

switch (action) {
case "findOneAndUpdate":
{
Map<String, Object> filterMap =
(Map<String, Object>) queryMap.remove("filter");
UpdateQueryBuilder builder = UpdateQuery.builder();
if (filterMap != null) {
builder.withJsonFilter(JstlFunctions.toJson(filterMap));
}
String returnDocument = (String) queryMap.remove("return-document");
if (returnDocument != null) {
builder.withReturnDocument(
UpdateQueryBuilder.ReturnDocument.valueOf(returnDocument));
}
Map<String, Object> updateMap =
(Map<String, Object>) queryMap.remove("update");
if (updateMap != null) {
update.set(builder, updateMap);
}

UpdateQuery updateQuery = builder.build();
log.info(
"doing findOneAndUpdate with UpdateQuery {}",
JstlFunctions.toJson(updateQuery));
JsonResultUpdate oneAndUpdate = collection.findOneAndUpdate(updateQuery);
return Map.of("count", oneAndUpdate.getUpdateStatus().getModifiedCount());
}
case "deleteOne":
{
Map<String, Object> filterMap =
(Map<String, Object>) queryMap.remove("filter");
DeleteQueryBuilder builder = DeleteQuery.builder();
if (filterMap != null) {
builder.withJsonFilter(JstlFunctions.toJson(filterMap));
}
DeleteQuery delete = builder.build();
log.info(
"doing deleteOne with DeleteQuery {}",
JstlFunctions.toJson(delete));
int count = collection.deleteOne(delete);
return Map.of("count", count);
}
case "insertOne":
{
Map<String, Object> documentData =
(Map<String, Object>) queryMap.remove("document");
JsonDocument document = new JsonDocument();
for (Map.Entry<String, Object> entry : documentData.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
switch (key) {
case "id":
document.id(value.toString());
break;
case "vector":
document.vector(JstlFunctions.toArrayOfFloat(value));
break;
case "data":
document.data(value);
break;
default:
document.put(key, value);
break;
}
}
if (document.getId() == null) {
document.setId(UUID.randomUUID().toString());
}

log.info(
"doing insertOne with JsonDocument {}",
JstlFunctions.toJson(document));
String id = collection.insertOne(document);
return Map.of("id", id);
}
default:
throw new UnsupportedOperationException("Unsupported action: " + action);
}

} catch (Exception err) {
throw new RuntimeException(err);
}
}

public AstraDB getAstraDB() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
@Slf4j
public class AstraVectorDBTest {

private static final String TOKEN = "";
private static final String ENDPOINT = "";
private static final String TOKEN =
"AstraCS:HQKZyFwTNcNQFPhsLHPHlyYq:0fd08e29b7e7c590e947ac8fa9a4d6d785a4661a8eb1b3c011e2a0d19c2ecd7c";
private static final String ENDPOINT =
"https://18bdf302-901f-4245-af09-061ebdb480d2-us-east1.apps.astra.datastax.com";

@Test
void testWriteAndRead() throws Exception {
Expand Down Expand Up @@ -305,6 +307,84 @@ void testWriteAndRead() throws Exception {
assertEquals(0, result.size());
});

Map<String, Object> executeInsertRes =
executeStatement(
datasource,
"""
{
"action": "insertOne",
"collection-name": "%s",
"document": {
"id": "some-id",
"name": ?,
"vector": ?,
"text": "Some text"
}
}
"""
.formatted(collectionName),
List.of("some", vector));
assertEquals("some-id", executeInsertRes.get("id"));

assertContents(
datasource,
queryWithFilterOnName,
List.of("some"),
result -> {
assertEquals(1, result.size());
assertEquals("Some text", result.get(0).get("text"));
});

executeStatement(
datasource,
"""
{
"action": "findOneAndUpdate",
"collection-name": "%s",
"filter": {
"_id": ?
},
"update": {
"$set": {
"text": ?
}
}
}
"""
.formatted(collectionName),
List.of("some-id", "new value"));

assertContents(
datasource,
queryWithFilterOnName,
List.of("some"),
result -> {
assertEquals(1, result.size());
assertEquals("new value", result.get(0).get("text"));
});

executeStatement(
datasource,
"""
{
"action": "deleteOne",
"collection-name": "%s",
"filter": {
"_id": ?
}
}
"""
.formatted(collectionName),
List.of("some-id"));

assertContents(
datasource,
queryWithFilterOnName,
List.of("some"),
result -> {
assertEquals(0, result.size());
});

// CLEANUP
assertTrue(tableManager.assetExists());
tableManager.deleteAssetIfExists();
Expand All @@ -325,4 +405,13 @@ private static List<Map<String, Object>> assertContents(
;
return results;
}

private static Map<String, Object> executeStatement(
QueryStepDataSource datasource, String query, List<Object> params) {
log.info("Query: {}", query);
log.info("Params: {}", params);
Map<String, Object> results = datasource.executeStatement(query, null, params);
log.info("Result: {}", results);
return results;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.impl.assets;

import ai.langstream.api.doc.AssetConfig;
import ai.langstream.api.doc.ConfigProperty;
import ai.langstream.api.model.AssetDefinition;
import ai.langstream.api.util.ConfigurationUtils;
import ai.langstream.impl.common.AbstractAssetProvider;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Map;
import java.util.Set;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraVectorDBAssetsProvider extends AbstractAssetProvider {

protected static final String ASTRA_COLLECTION = "astra-collection";

public AstraVectorDBAssetsProvider() {
super(Set.of(ASTRA_COLLECTION));
}

@Override
protected Class getAssetConfigModelClass(String type) {
return switch (type) {
case ASTRA_COLLECTION -> AstraCollectionConfig.class;
default -> throw new IllegalArgumentException("Unknown asset type " + type);
};
}

@Override
protected void validateAsset(AssetDefinition assetDefinition, Map<String, Object> asset) {
Map<String, Object> configuration = ConfigurationUtils.getMap("config", null, asset);
final Map<String, Object> datasource =
ConfigurationUtils.getMap("datasource", Map.of(), configuration);
final Map<String, Object> datasourceConfiguration =
ConfigurationUtils.getMap("configuration", Map.of(), datasource);
switch (assetDefinition.getAssetType()) {
default -> {}
}
}

@Override
protected boolean lookupResource(String fieldName) {
return "datasource".equals(fieldName);
}

@AssetConfig(
name = "Astra Collection",
description =
"""
Manage a DataStax Astra Collection.
""")
@Data
public static class AstraCollectionConfig {

@ConfigProperty(
description =
"""
Reference to a datasource id configured in the application.
""",
required = true)
private String datasource;

@ConfigProperty(
description =
"""
Name of the collection to create.
""",
required = true)
@JsonProperty("collection-name")
private String collectionName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,22 @@ protected void validateAsset(AssetDefinition assetDefinition, Map<String, Object
}
case ASTRA_KEYSPACE -> {
if (!datasourceConfiguration.containsKey("secureBundle")
&& !datasourceConfiguration.containsKey("database")) {
&& !datasourceConfiguration.containsKey("database")
&& !datasourceConfiguration.containsKey("database-id")) {
throw new IllegalArgumentException(
"Use cassandra-keyspace for a standard Cassandra service (not AstraDB)");
}
// are we are using the AstraDB SDK we need also the AstraCS token and
// the name of the database
requiredNonEmptyField(datasourceConfiguration, "token", describe(assetDefinition));
requiredNonEmptyField(
datasourceConfiguration, "database", describe(assetDefinition));
if (!datasourceConfiguration.containsKey("database")) {
requiredNonEmptyField(
datasourceConfiguration, "database-id", describe(assetDefinition));
}
if (!datasourceConfiguration.containsKey("database-id")) {
requiredNonEmptyField(
datasourceConfiguration, "database", describe(assetDefinition));
}
}
default -> {}
}
Expand Down
Loading

0 comments on commit 9d18f7b

Please sign in to comment.