Skip to content

Commit

Permalink
[improvement][headless]Expression replacement logic supports more com…
Browse files Browse the repository at this point in the history
…plex sql.
  • Loading branch information
jerryjzhang committed Jan 5, 2025
1 parent 6fcfdc1 commit 4e653c1
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ public static String replaceExpression(String expr, Map<String, String> replace)
return expr;
}

public static String replaceSqlByExpression(String sql, Map<String, String> replace) {
public static String replaceSqlByExpression(String tableName, String sql,
Map<String, String> replace) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
List<PlainSelect> plainSelectList = new ArrayList<>();
if (selectStatement instanceof PlainSelect) {
Expand All @@ -636,9 +637,8 @@ public static String replaceSqlByExpression(String sql, Map<String, String> repl
selectStatement.getWithItemsList().forEach(withItem -> {
plainSelectList.add(withItem.getSelect().getPlainSelect());
});
} else {
plainSelectList.add((PlainSelect) selectStatement);
}
plainSelectList.add((PlainSelect) selectStatement);
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
Expand Down Expand Up @@ -672,9 +672,12 @@ public static String replaceSqlByExpression(String sql, Map<String, String> repl

List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect plainSelect : plainSelects) {
replacePlainSelectByExpr(plainSelect, replace);
if (SqlSelectHelper.hasAggregateFunction(plainSelect)) {
SqlSelectHelper.addMissingGroupby(plainSelect);
Table table = (Table) plainSelect.getFromItem();
if (table.getName().equals(tableName)) {
replacePlainSelectByExpr(plainSelect, replace);
if (SqlSelectHelper.hasAggregateFunction(plainSelect)) {
SqlSelectHelper.addMissingGroupby(plainSelect);
}
}
}
return selectStatement.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.OntologyQuery;
Expand Down Expand Up @@ -40,7 +41,9 @@ public void parse(QueryStatement queryStatement) throws Exception {

Map<String, String> bizName2Expr = getDimensionExpressions(semanticSchema, ontologyQuery);
if (!CollectionUtils.isEmpty(bizName2Expr)) {
String sql = SqlReplaceHelper.replaceSqlByExpression(sqlQuery.getSql(), bizName2Expr);
String sql = SqlReplaceHelper.replaceSqlByExpression(
Constants.TABLE_PREFIX + queryStatement.getDataSetId(), sqlQuery.getSql(),
bizName2Expr);
sqlQuery.setSql(sql);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
Expand Down Expand Up @@ -39,7 +40,9 @@ public void parse(QueryStatement queryStatement) throws Exception {

Map<String, String> bizName2Expr = getMetricExpressions(semanticSchema, ontologyQuery);
if (!CollectionUtils.isEmpty(bizName2Expr)) {
String sql = SqlReplaceHelper.replaceSqlByExpression(sqlQuery.getSql(), bizName2Expr);
String sql = SqlReplaceHelper.replaceSqlByExpression(
Constants.TABLE_PREFIX + queryStatement.getDataSetId(), sqlQuery.getSql(),
bizName2Expr);
sqlQuery.setSql(sql);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ public SemanticSchemaResp buildSemanticSchema(SchemaFilterReq schemaFilterReq) {
DataSetSchemaResp dataSetSchemaResp =
fetchDataSetSchema(schemaFilterReq.getDataSetId());
BeanUtils.copyProperties(dataSetSchemaResp, semanticSchemaResp);
semanticSchemaResp.setDataSetResp(dataSetSchemaResp);
List<Long> modelIds = dataSetSchemaResp.getAllModels();
MetaFilter metaFilter = new MetaFilter();
metaFilter.setIds(modelIds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,18 @@ public void testSql_with() throws Exception {
executeSql(explain.getQuerySQL());
}

@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testSql_subquery() throws Exception {
String sql = new String(
Files.readAllBytes(
Paths.get(ClassLoader.getSystemResource("sql/testSubquery.sql").toURI())),
StandardCharsets.UTF_8);
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
executeSql(explain.getQuerySQL());
}

}
27 changes: 27 additions & 0 deletions launchers/standalone/src/test/resources/sql/testSubquery.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
WITH
_average_stay_duration_ AS (
SELECT
AVG(停留时长) AS _avg_duration_
FROM
超音数数据集
)
SELECT
用户名,
SUM(停留时长) AS _total_stay_duration_
FROM
超音数数据集
GROUP BY
用户名
HAVING
SUM(停留时长) > (
SELECT
_avg_duration_ * 1.5
FROM
_average_stay_duration_
)
OR SUM(停留时长) < (
SELECT
_avg_duration_ * 0.5
FROM
_average_stay_duration_
)

0 comments on commit 4e653c1

Please sign in to comment.