Skip to content

Commit

Permalink
Implement query
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli committed Nov 22, 2023
1 parent f22e181 commit d186bd0
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import com.dtsx.astra.sdk.AstraDB;
import io.stargate.sdk.json.CollectionClient;
import io.stargate.sdk.json.domain.Filter;
import io.stargate.sdk.json.domain.JsonResult;
import io.stargate.sdk.json.domain.SelectQuery;
import io.stargate.sdk.json.domain.SelectQueryBuilder;
Expand Down Expand Up @@ -60,60 +59,76 @@ public List<Map<String, Object>> fetchData(String query, List<Object> params) {
.map(v -> v == null ? "null" : v.getClass().toString())
.collect(Collectors.joining(",")));
}
Map<String, Object> queryMap =
InterpolationUtils.buildObjectFromJson(query, Map.class, params);
if (queryMap.isEmpty()) {
throw new UnsupportedOperationException("Query is empty");
}
String collectionName = (String) queryMap.get("collection-name");
if (collectionName == null) {
throw new UnsupportedOperationException("collection-name is not defined");
}
CollectionClient collection = this.getAstraDB().collection(collectionName);
List<JsonResult> result;
try {
Map<String, Object> queryMap =
InterpolationUtils.buildObjectFromJson(query, Map.class, params);
if (queryMap.isEmpty()) {
throw new UnsupportedOperationException("Query is empty");
}
String collectionName = (String) queryMap.remove("collection-name");
if (collectionName == null) {
throw new UnsupportedOperationException("collection-name is not defined");
}
CollectionClient collection = this.getAstraDB().collection(collectionName);
List<JsonResult> result;

float[] vector = JstlFunctions.toArrayOfFloat(queryMap.remove("vector"));
Integer max = (Integer) queryMap.remove("max");
float[] vector = JstlFunctions.toArrayOfFloat(queryMap.remove("vector"));
Integer max = (Integer) queryMap.remove("max");

if (max == null) {
max = Integer.MAX_VALUE;
}
if (vector != null) {
Filter filter = new Filter();
queryMap.forEach((k, v) -> filter.where(k).isEqualsTo(v));
log.info(
"doing similarity search with filter {} max {} and vector {}",
filter,
max,
vector);
result = collection.similaritySearch(vector, filter, max);
} else {
SelectQueryBuilder selectQueryBuilder =
SelectQuery.builder().includeSimilarity().select("*");
queryMap.forEach((k, v) -> selectQueryBuilder.where(k).isEqualsTo(v));
boolean includeSimilarity = vector != null;
Object includeSimilarityParam = queryMap.remove("include-similarity");
if (includeSimilarityParam != null) {
includeSimilarity = Boolean.parseBoolean(includeSimilarityParam.toString());
}
Map<String, Object> filterMap = (Map<String, Object>) queryMap.remove("filter");
SelectQueryBuilder selectQueryBuilder = SelectQuery.builder();
Object selectClause = queryMap.remove("select");
if (selectClause != null) {
if (selectClause instanceof List list) {
String[] arrayOfStrings = ((List<String>) list).toArray(new String[0]);
selectQueryBuilder.select(arrayOfStrings);
} else {
throw new IllegalArgumentException(
"select clause must be a list of strings, but found: " + selectClause);
}
}
if (includeSimilarity) {
selectQueryBuilder.includeSimilarity();
}
if (vector != null) {
selectQueryBuilder.orderByAnn(vector);
}
if (filterMap != null) {
selectQueryBuilder.withJsonFilter(JstlFunctions.toJson(filterMap));
}
if (max != null) {
selectQueryBuilder.limit(max);
}

SelectQuery selectQuery = selectQueryBuilder.build();
log.info("doing query {}", selectQuery);
log.info("doing query {}", JstlFunctions.toJson(selectQuery));

result = collection.query(selectQuery).toList();
}

return result.stream()
.map(
m -> {
Map<String, Object> r = new HashMap<>();
if (m.getData() != null) {
r.putAll(m.getData());
}
if (m.getSimilarity() != null) {
r.put("similarity", m.getSimilarity());
}
if (m.getVector() != null) {
r.put("vector", JstlFunctions.toListOfFloat(m.getVector()));
}
return r;
})
.collect(Collectors.toList());
return result.stream()
.map(
m -> {
Map<String, Object> r = new HashMap<>();
if (m.getData() != null) {
r.putAll(m.getData());
}
if (m.getSimilarity() != null) {
r.put("similarity", m.getSimilarity());
}
if (m.getVector() != null) {
r.put("vector", JstlFunctions.toListOfFloat(m.getVector()));
}
return r;
})
.collect(Collectors.toList());
} catch (Exception err) {
throw new RuntimeException(err);
}
}

@Override
Expand Down
Loading

0 comments on commit d186bd0

Please sign in to comment.