Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLParser: Introduce DISTINCT clause #564

Merged
merged 7 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/tutorial/sqlTutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Other features we do and don't support right now can be found below.
* Having Clauses
* Order By Clauses
* As
* Distinct

### Not Yet Supported Features

Expand All @@ -71,7 +72,6 @@ Other features we do and don't support right now can be found below.
* All Set Operations (Union, Except, Intersect)
* Recursive SQL Queries
* Limit
* Distinct

## Examples

Expand Down
2 changes: 1 addition & 1 deletion src/api/python/operator/nodes/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], named_inpu
"schema": [
{
"label": self._pd_dataframe.columns[i],
"valueType": self.getDType(self._pd_dataframe.dtypes[i])
"valueType": self.getDType(self._pd_dataframe.dtypes.iloc[i])
}
for i in range(self._pd_dataframe.shape[1])
]
Expand Down
12 changes: 12 additions & 0 deletions src/compiler/lowering/RewriteToCallKernelOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ namespace
// of the variadic pack ops. Should be changed when reworking the lowering to kernels.
if(llvm::dyn_cast<daphne::GroupOp>(op) && idx >= operandTypes.size()) {
callee << "__char_variadic__size_t";
auto cvpOp = rewriter.create<daphne::CreateVariadicPackOp>(
loc,
daphne::VariadicPackType::get(
rewriter.getContext(),
daphne::StringType::get(rewriter.getContext())
),
rewriter.getI64IntegerAttr(0)
);
newOperands.push_back(cvpOp);
newOperands.push_back(rewriter.create<daphne::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0))
);
continue;
} else {
callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandTypes[idx], generalizeInputTypes);
Expand Down
10 changes: 9 additions & 1 deletion src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,15 @@ void daphne::GroupOp::inferFrameLabels() {
std::vector<std::string> aggFuncNames;

for(Value t: getKeyCol()){ //Adopting keyCol Labels
newLabels->push_back(CompilerUtils::constantOrThrow<std::string>(t));
std::string keyLabel = CompilerUtils::constantOrThrow<std::string>(t);
if(keyLabel == "*") {
daphne::FrameType arg = getFrame().getType().dyn_cast<daphne::FrameType>();
for (std::string frameLabel : *arg.getLabels()) {
newLabels->push_back(frameLabel);
}
} else {
newLabels->push_back(keyLabel);
}
}

for(Value t: getAggCol()){
Expand Down
12 changes: 11 additions & 1 deletion src/ir/daphneir/DaphneInferTypesOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,17 @@ std::vector<Type> daphne::GroupOp::inferTypes() {

for(Value t : getKeyCol()){
//Key Types getting adopted for the new Frame
newColumnTypes.push_back(getFrameColumnTypeByLabel(arg, t));
std::string labelStr = CompilerUtils::constantOrThrow<std::string>(
t, "the specified label must be a constant of string type"
);
if(labelStr == "*") {
auto allTypes = arg.getColumnTypes();
for (Type type: allTypes) {
newColumnTypes.push_back(type);
}
} else {
newColumnTypes.push_back(getFrameColumnTypeByLabel(arg, t));
}
}

// Values get collected in a easier to use Datastructure
Expand Down
5 changes: 4 additions & 1 deletion src/parser/sql/SQLGrammar.g4
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ query:
select ';'?;

select:
SQL_SELECT selectExpr (',' selectExpr)*
SQL_SELECT distinctExpr? selectExpr (',' selectExpr)*
SQL_FROM tableExpr
whereClause?
groupByClause?
Expand All @@ -52,6 +52,9 @@ selectExpr:
tableExpr:
fromExpr joinExpr*;

distinctExpr:
SQL_DISTINCT;

fromExpr:
var=tableReference #tableIdentifierExpr
| lhs=tableReference ',' rhs=fromExpr #cartesianExpr
Expand Down
50 changes: 44 additions & 6 deletions src/parser/sql/SQLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ir/daphneir/Daphne.h>
#include <parser/sql/SQLVisitor.h>
#include "antlr4-runtime.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpDefinition.h"

#include <stdexcept>
Expand Down Expand Up @@ -54,6 +55,14 @@ void toggleBit(int64_t& flag, int64_t position){
setBit(flag, position, !isBitSet(flag, position));
}

/**
* @brief Creates a lower cast version of a string
*/
std::string toLower(std::string str){
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
return str;
}

// ****************************************************************************
// Member Helper functions
// ****************************************************************************
Expand Down Expand Up @@ -504,6 +513,10 @@ antlrcpp::Any SQLVisitor::visitSelect(
throw std::runtime_error(err_msg.str());
}
}
currentFrame = res;
if(ctx->distinctExpr()) {
res = utils.valueOrError(visit(ctx->distinctExpr()));
}
return res;
}

Expand Down Expand Up @@ -566,6 +579,28 @@ antlrcpp::Any SQLVisitor::visitTableExpr(
return currentFrame;
}

//distinctExpr
antlrcpp::Any SQLVisitor::visitDistinctExpr(
SQLGrammarParser::DistinctExprContext *ctx
)
{
if (isBitSet(sqlFlag, (int64_t)SQLBit::group) // If group is active
!= columnName.size()) // EXCLUSIVE OR there is an aggregation
return currentFrame; // due to earlier grouping XOR aggregation the result is already distinct

mlir::Location loc = utils.getLoc(ctx->start);
mlir::Value starLiteral = createStringConstant("*");
std::vector<mlir::Value> cols{starLiteral};
std::vector<mlir::Value> aggs;
std::vector<mlir::Attribute> functions;
mlir::Type vt = utils.unknownType;
std::vector<mlir::Type> colTypes{vt};
mlir::Type resType =
mlir::daphne::FrameType::get(builder.getContext(), colTypes);
return static_cast<mlir::Value>(builder.create<mlir::daphne::GroupOp>(
loc, resType, currentFrame, cols, aggs, builder.getArrayAttr(functions)));
}

//fromExpr
antlrcpp::Any SQLVisitor::visitTableIdentifierExpr(
SQLGrammarParser::TableIdentifierExprContext *ctx
Expand Down Expand Up @@ -867,8 +902,9 @@ antlrcpp::Any SQLVisitor::visitIdentifierExpr(
SQLGrammarParser::IdentifierExprContext * ctx)
{
if( isBitSet(sqlFlag, (int64_t)SQLBit::group) //If group is active
&& !isBitSet(sqlFlag, (int64_t)SQLBit::agg) //AND there isn't an aggreagtion
&& grouped[ctx->selectIdent()->getText()] == 0) //AND the label is not in group expr
&& !isBitSet(sqlFlag, (int64_t)SQLBit::agg) //AND there isn't an aggregation
&& grouped.count(ctx->selectIdent()->getText()) == 0 //AND the label is not in group expr
&& grouped.count("*") == 0) //AND there is no * in group expr
{
std::stringstream err_msg;
err_msg << "Error during a generalExpr. \""
Expand Down Expand Up @@ -911,7 +947,7 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr(

mlir::Type resTypeCol = col.getType().dyn_cast<mlir::daphne::MatrixType>().getElementType();

const std::string &func = ctx->func->getText();
std::string func = toLower(ctx->func->getText());

mlir::Value result;
if(func == "count"){
Expand Down Expand Up @@ -959,7 +995,7 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr(
);
}

std::string newColumnNameAppended = getEnumLabelExt(ctx->func->getText()) + "(" + newColumnName + ")";
std::string newColumnNameAppended = getEnumLabelExt(func) + "(" + newColumnName + ")";

return utils.castIf(utils.matrixOf(result), result);

Expand All @@ -975,7 +1011,8 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr(
//create Column pre Group for in group Aggregation
if(!isBitSet(sqlFlag, (int64_t)SQLBit::codegen)){
columnName.push_back(createStringConstant(newColumnName));
functionName.push_back(getGroupEnum(ctx->func->getText()));
const std::string &func = toLower(ctx->func->getText());
functionName.push_back(getGroupEnum(func));

setBit(sqlFlag, (int64_t)SQLBit::agg, 1);
setBit(sqlFlag, (int64_t)SQLBit::codegen, 1);
Expand All @@ -987,7 +1024,8 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr(
currentFrame = addMatrixToCurrentFrame(matrix, newColumnName);
return nullptr;
}else{ //Get Column after Group
std::string newColumnNameAppended = getEnumLabelExt(ctx->func->getText()) + "(" + newColumnName + ")";
const std::string &func = toLower(ctx->func->getText());
std::string newColumnNameAppended = getEnumLabelExt(func) + "(" + newColumnName + ")";
mlir::Value colname = utils.valueOrError(createStringConstant(newColumnNameAppended));
return extractMatrixFromFrame(currentFrame, colname); //returns Matrix
}
Expand Down
3 changes: 3 additions & 0 deletions src/parser/sql/SQLVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ class SQLVisitor : public SQLGrammarVisitor {
//tableExpr
antlrcpp::Any visitTableExpr(SQLGrammarParser::TableExprContext * ctx) override;

//distinctExpr
antlrcpp::Any visitDistinctExpr(SQLGrammarParser::DistinctExprContext * ctx) override;

//fromExpr
antlrcpp::Any visitTableIdentifierExpr(SQLGrammarParser::TableIdentifierExprContext *ctx) override;

Expand Down
54 changes: 43 additions & 11 deletions src/runtime/local/kernels/Group.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,40 @@ template <> struct Group<Frame> {
if (arg == nullptr || (keyCols == nullptr && numKeyCols != 0) || (aggCols == nullptr && numAggCols != 0) || (aggFuncs == nullptr && numAggFuncs != 0)) {
throw std::runtime_error("group-kernel called with invalid arguments");
}


// check if labels contain *
std::vector<std::string> starLabels;
const std::string * argLabels = arg->getLabels();
const size_t numColsArg = arg->getNumCols();
std::vector<std::string> aggColsVec;
for (size_t m = 0; m < numAggCols; m++) {
aggColsVec.push_back(aggCols[m]);
}
for (size_t i = 0; i < numKeyCols; i++) {
if (strcmp(keyCols[i], "*") == 0) {
for (size_t m = 0; m < numColsArg; m++) {
// check that we do not include columns in the result that are used for aggregations and would lead to duplicates
if(std::find(aggColsVec.begin(), aggColsVec.end(), argLabels[m]) == aggColsVec.end()) {
starLabels.push_back(argLabels[m]);
}
}
// we assume that other key columns are included in the *
// operator, otherwise they would not be in the argument frame
// and throw a error later on
numColsRes = starLabels.size() + numAggCols;
}
}


// convert labels to indices
auto idxs = std::shared_ptr<size_t[]>(new size_t[numColsRes]);
bool * ascending = new bool[numKeyCols];
for (size_t i = 0; i < numKeyCols; i++) {
idxs[i] = arg->getColumnIdx(keyCols[i]);
ascending[i] = true;
}
numKeyCols = starLabels.size()? starLabels.size() : numKeyCols;
bool * ascending = new bool[starLabels.size()];
for (size_t i = 0; i < numKeyCols; ++i) {
idxs[i] = starLabels.size() ? arg->getColumnIdx(starLabels[i])
: arg->getColumnIdx(keyCols[i]);
ascending[i] = true;
}
for (size_t i = numKeyCols; i < numColsRes; i++) {
idxs[i] = arg->getColumnIdx(aggCols[i-numKeyCols]);
}
Expand Down Expand Up @@ -166,10 +192,16 @@ template <> struct Group<Frame> {
// create the result frame
std::string * labels = new std::string[numColsRes];
ValueTypeCode * schema = new ValueTypeCode[numColsRes];

for (size_t i = 0; i < numKeyCols; i++) {
labels[i] = keyCols[i];
schema[i] = ordered->getColumnType(idxs[i]);
if (starLabels.size()) {
for (size_t i = 0; i < numKeyCols; i++) {
labels[i] = starLabels[i];
schema[i] = ordered->getColumnType(idxs[i]);
}
} else {
for (size_t i = 0; i < numKeyCols; i++) {
labels[i] = keyCols[i];
schema[i] = ordered->getColumnType(idxs[i]);
}
}
using mlir::daphne::GroupEnum;
for (size_t i = numKeyCols; i < numColsRes; i++) {
Expand Down Expand Up @@ -199,4 +231,4 @@ template <> struct Group<Frame> {
}
};

#endif //SRC_RUNTIME_LOCAL_KERNELS_GROUP_H
#endif //SRC_RUNTIME_LOCAL_KERNELS_GROUP_H
1 change: 1 addition & 0 deletions test/api/cli/sql/SQLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ MAKE_TEST_CASE("agg_sum", 1)

MAKE_TEST_CASE("reuseString", 2)

MAKE_TEST_CASE("distinct", 5)

// TODO Use the scripts testing failure cases.
12 changes: 12 additions & 0 deletions test/api/cli/sql/distinct_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# DISTINCT over one column.

f = createFrame(
[ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2],
[ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1],
"a", "b");

registerView("f", f);

res = sql("SELECT DISTINCT f.a FROM f;");

print(res);
7 changes: 7 additions & 0 deletions test/api/cli/sql/distinct_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Frame(6x1, [f.a:int64_t])
0
1
2
3
6
8
12 changes: 12 additions & 0 deletions test/api/cli/sql/distinct_2.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# DISTINCT over multiple columns.

f = createFrame(
[ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2],
[ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1],
"a", "b");

registerView("f", f);

res = sql("SELECT DISTINCT f.a, f.b FROM f;");

print(res);
9 changes: 9 additions & 0 deletions test/api/cli/sql/distinct_2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Frame(8x2, [f.a:int64_t, f.b:int64_t])
0 1
1 2
2 1
2 3
3 3
3 4
6 5
8 1
12 changes: 12 additions & 0 deletions test/api/cli/sql/distinct_3.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# DISTINCT over multiple columns with GROUP BY.

f = createFrame(
[ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2],
[ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1],
"a", "b");

registerView("f", f);

res = sql("SELECT DISTINCT f.a, f.b FROM f GROUP BY f.a, f.b;");

print(res);
9 changes: 9 additions & 0 deletions test/api/cli/sql/distinct_3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Frame(8x2, [f.a:int64_t, f.b:int64_t])
0 1
1 2
2 1
2 3
3 3
3 4
6 5
8 1
11 changes: 11 additions & 0 deletions test/api/cli/sql/distinct_4.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# DISTINCT with aggregation.

f = createFrame(
[ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2],
"a");

registerView("f", f);

res = sql("SELECT DISTINCT SUM(f.a) FROM f;");

print(res);
2 changes: 2 additions & 0 deletions test/api/cli/sql/distinct_4.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Frame(1x1, [SUM(f.a):int64_t])
30
Loading
Loading