diff --git a/doc/tutorial/sqlTutorial.md b/doc/tutorial/sqlTutorial.md index 19d2e4c51..d6f2002dd 100644 --- a/doc/tutorial/sqlTutorial.md +++ b/doc/tutorial/sqlTutorial.md @@ -63,6 +63,7 @@ Other features we do and don't support right now can be found below. * Having Clauses * Order By Clauses * As +* Distinct ### Not Yet Supported Features @@ -71,7 +72,6 @@ Other features we do and don't support right now can be found below. * All Set Operations (Union, Except, Intersect) * Recursive SQL Queries * Limit -* Distinct ## Examples diff --git a/src/api/python/operator/nodes/frame.py b/src/api/python/operator/nodes/frame.py index 1c37f9526..0cbbcb4af 100644 --- a/src/api/python/operator/nodes/frame.py +++ b/src/api/python/operator/nodes/frame.py @@ -71,7 +71,7 @@ def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], named_inpu "schema": [ { "label": self._pd_dataframe.columns[i], - "valueType": self.getDType(self._pd_dataframe.dtypes[i]) + "valueType": self.getDType(self._pd_dataframe.dtypes.iloc[i]) } for i in range(self._pd_dataframe.shape[1]) ] diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index 4454aaec8..42633a965 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -215,6 +215,18 @@ namespace // of the variadic pack ops. Should be changed when reworking the lowering to kernels. if(llvm::dyn_cast(op) && idx >= operandTypes.size()) { callee << "__char_variadic__size_t"; + auto cvpOp = rewriter.create( + loc, + daphne::VariadicPackType::get( + rewriter.getContext(), + daphne::StringType::get(rewriter.getContext()) + ), + rewriter.getI64IntegerAttr(0) + ); + newOperands.push_back(cvpOp); + newOperands.push_back(rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)) + ); continue; } else { callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandTypes[idx], generalizeInputTypes); diff --git a/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp b/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp index 7aa07e05e..88f4bc9f7 100644 --- a/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp @@ -200,7 +200,15 @@ void daphne::GroupOp::inferFrameLabels() { std::vector aggFuncNames; for(Value t: getKeyCol()){ //Adopting keyCol Labels - newLabels->push_back(CompilerUtils::constantOrThrow(t)); + std::string keyLabel = CompilerUtils::constantOrThrow(t); + if(keyLabel == "*") { + daphne::FrameType arg = getFrame().getType().dyn_cast(); + for (std::string frameLabel : *arg.getLabels()) { + newLabels->push_back(frameLabel); + } + } else { + newLabels->push_back(keyLabel); + } } for(Value t: getAggCol()){ diff --git a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp index db88cc37e..1ba6f9123 100644 --- a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp @@ -235,7 +235,17 @@ std::vector daphne::GroupOp::inferTypes() { for(Value t : getKeyCol()){ //Key Types getting adopted for the new Frame - newColumnTypes.push_back(getFrameColumnTypeByLabel(arg, t)); + std::string labelStr = CompilerUtils::constantOrThrow( + t, "the specified label must be a constant of string type" + ); + if(labelStr == "*") { + auto allTypes = arg.getColumnTypes(); + for (Type type: allTypes) { + newColumnTypes.push_back(type); + } + } else { + newColumnTypes.push_back(getFrameColumnTypeByLabel(arg, t)); + } } // Values get collected in a easier to use Datastructure diff --git a/src/parser/sql/SQLGrammar.g4 b/src/parser/sql/SQLGrammar.g4 index b3bc5f8b6..023c3b7be 100644 --- a/src/parser/sql/SQLGrammar.g4 +++ b/src/parser/sql/SQLGrammar.g4 @@ -33,7 +33,7 @@ query: select ';'?; select: - SQL_SELECT selectExpr (',' selectExpr)* + SQL_SELECT distinctExpr? selectExpr (',' selectExpr)* SQL_FROM tableExpr whereClause? groupByClause? @@ -52,6 +52,9 @@ selectExpr: tableExpr: fromExpr joinExpr*; +distinctExpr: + SQL_DISTINCT; + fromExpr: var=tableReference #tableIdentifierExpr | lhs=tableReference ',' rhs=fromExpr #cartesianExpr diff --git a/src/parser/sql/SQLVisitor.cpp b/src/parser/sql/SQLVisitor.cpp index b494b8a13..86d8ef748 100644 --- a/src/parser/sql/SQLVisitor.cpp +++ b/src/parser/sql/SQLVisitor.cpp @@ -18,6 +18,7 @@ #include #include #include "antlr4-runtime.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpDefinition.h" #include @@ -54,6 +55,14 @@ void toggleBit(int64_t& flag, int64_t position){ setBit(flag, position, !isBitSet(flag, position)); } +/** + * @brief Creates a lower cast version of a string + */ +std::string toLower(std::string str){ + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + return str; +} + // **************************************************************************** // Member Helper functions // **************************************************************************** @@ -504,6 +513,10 @@ antlrcpp::Any SQLVisitor::visitSelect( throw std::runtime_error(err_msg.str()); } } + currentFrame = res; + if(ctx->distinctExpr()) { + res = utils.valueOrError(visit(ctx->distinctExpr())); + } return res; } @@ -566,6 +579,28 @@ antlrcpp::Any SQLVisitor::visitTableExpr( return currentFrame; } +//distinctExpr +antlrcpp::Any SQLVisitor::visitDistinctExpr( + SQLGrammarParser::DistinctExprContext *ctx +) +{ + if (isBitSet(sqlFlag, (int64_t)SQLBit::group) // If group is active + != columnName.size()) // EXCLUSIVE OR there is an aggregation + return currentFrame; // due to earlier grouping XOR aggregation the result is already distinct + + mlir::Location loc = utils.getLoc(ctx->start); + mlir::Value starLiteral = createStringConstant("*"); + std::vector cols{starLiteral}; + std::vector aggs; + std::vector functions; + mlir::Type vt = utils.unknownType; + std::vector colTypes{vt}; + mlir::Type resType = + mlir::daphne::FrameType::get(builder.getContext(), colTypes); + return static_cast(builder.create( + loc, resType, currentFrame, cols, aggs, builder.getArrayAttr(functions))); +} + //fromExpr antlrcpp::Any SQLVisitor::visitTableIdentifierExpr( SQLGrammarParser::TableIdentifierExprContext *ctx @@ -867,8 +902,9 @@ antlrcpp::Any SQLVisitor::visitIdentifierExpr( SQLGrammarParser::IdentifierExprContext * ctx) { if( isBitSet(sqlFlag, (int64_t)SQLBit::group) //If group is active - && !isBitSet(sqlFlag, (int64_t)SQLBit::agg) //AND there isn't an aggreagtion - && grouped[ctx->selectIdent()->getText()] == 0) //AND the label is not in group expr + && !isBitSet(sqlFlag, (int64_t)SQLBit::agg) //AND there isn't an aggregation + && grouped.count(ctx->selectIdent()->getText()) == 0 //AND the label is not in group expr + && grouped.count("*") == 0) //AND there is no * in group expr { std::stringstream err_msg; err_msg << "Error during a generalExpr. \"" @@ -911,7 +947,7 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr( mlir::Type resTypeCol = col.getType().dyn_cast().getElementType(); - const std::string &func = ctx->func->getText(); + std::string func = toLower(ctx->func->getText()); mlir::Value result; if(func == "count"){ @@ -959,7 +995,7 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr( ); } - std::string newColumnNameAppended = getEnumLabelExt(ctx->func->getText()) + "(" + newColumnName + ")"; + std::string newColumnNameAppended = getEnumLabelExt(func) + "(" + newColumnName + ")"; return utils.castIf(utils.matrixOf(result), result); @@ -975,7 +1011,8 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr( //create Column pre Group for in group Aggregation if(!isBitSet(sqlFlag, (int64_t)SQLBit::codegen)){ columnName.push_back(createStringConstant(newColumnName)); - functionName.push_back(getGroupEnum(ctx->func->getText())); + const std::string &func = toLower(ctx->func->getText()); + functionName.push_back(getGroupEnum(func)); setBit(sqlFlag, (int64_t)SQLBit::agg, 1); setBit(sqlFlag, (int64_t)SQLBit::codegen, 1); @@ -987,7 +1024,8 @@ antlrcpp::Any SQLVisitor::visitGroupAggExpr( currentFrame = addMatrixToCurrentFrame(matrix, newColumnName); return nullptr; }else{ //Get Column after Group - std::string newColumnNameAppended = getEnumLabelExt(ctx->func->getText()) + "(" + newColumnName + ")"; + const std::string &func = toLower(ctx->func->getText()); + std::string newColumnNameAppended = getEnumLabelExt(func) + "(" + newColumnName + ")"; mlir::Value colname = utils.valueOrError(createStringConstant(newColumnNameAppended)); return extractMatrixFromFrame(currentFrame, colname); //returns Matrix } diff --git a/src/parser/sql/SQLVisitor.h b/src/parser/sql/SQLVisitor.h index e84c67e5a..1a85fed2f 100644 --- a/src/parser/sql/SQLVisitor.h +++ b/src/parser/sql/SQLVisitor.h @@ -183,6 +183,9 @@ class SQLVisitor : public SQLGrammarVisitor { //tableExpr antlrcpp::Any visitTableExpr(SQLGrammarParser::TableExprContext * ctx) override; +//distinctExpr + antlrcpp::Any visitDistinctExpr(SQLGrammarParser::DistinctExprContext * ctx) override; + //fromExpr antlrcpp::Any visitTableIdentifierExpr(SQLGrammarParser::TableIdentifierExprContext *ctx) override; diff --git a/src/runtime/local/kernels/Group.h b/src/runtime/local/kernels/Group.h index a2b05941a..d7499c403 100644 --- a/src/runtime/local/kernels/Group.h +++ b/src/runtime/local/kernels/Group.h @@ -125,14 +125,40 @@ template <> struct Group { if (arg == nullptr || (keyCols == nullptr && numKeyCols != 0) || (aggCols == nullptr && numAggCols != 0) || (aggFuncs == nullptr && numAggFuncs != 0)) { throw std::runtime_error("group-kernel called with invalid arguments"); } - + + // check if labels contain * + std::vector starLabels; + const std::string * argLabels = arg->getLabels(); + const size_t numColsArg = arg->getNumCols(); + std::vector aggColsVec; + for (size_t m = 0; m < numAggCols; m++) { + aggColsVec.push_back(aggCols[m]); + } + for (size_t i = 0; i < numKeyCols; i++) { + if (strcmp(keyCols[i], "*") == 0) { + for (size_t m = 0; m < numColsArg; m++) { + // check that we do not include columns in the result that are used for aggregations and would lead to duplicates + if(std::find(aggColsVec.begin(), aggColsVec.end(), argLabels[m]) == aggColsVec.end()) { + starLabels.push_back(argLabels[m]); + } + } + // we assume that other key columns are included in the * + // operator, otherwise they would not be in the argument frame + // and throw a error later on + numColsRes = starLabels.size() + numAggCols; + } + } + + // convert labels to indices auto idxs = std::shared_ptr(new size_t[numColsRes]); - bool * ascending = new bool[numKeyCols]; - for (size_t i = 0; i < numKeyCols; i++) { - idxs[i] = arg->getColumnIdx(keyCols[i]); - ascending[i] = true; - } + numKeyCols = starLabels.size()? starLabels.size() : numKeyCols; + bool * ascending = new bool[starLabels.size()]; + for (size_t i = 0; i < numKeyCols; ++i) { + idxs[i] = starLabels.size() ? arg->getColumnIdx(starLabels[i]) + : arg->getColumnIdx(keyCols[i]); + ascending[i] = true; + } for (size_t i = numKeyCols; i < numColsRes; i++) { idxs[i] = arg->getColumnIdx(aggCols[i-numKeyCols]); } @@ -166,10 +192,16 @@ template <> struct Group { // create the result frame std::string * labels = new std::string[numColsRes]; ValueTypeCode * schema = new ValueTypeCode[numColsRes]; - - for (size_t i = 0; i < numKeyCols; i++) { - labels[i] = keyCols[i]; - schema[i] = ordered->getColumnType(idxs[i]); + if (starLabels.size()) { + for (size_t i = 0; i < numKeyCols; i++) { + labels[i] = starLabels[i]; + schema[i] = ordered->getColumnType(idxs[i]); + } + } else { + for (size_t i = 0; i < numKeyCols; i++) { + labels[i] = keyCols[i]; + schema[i] = ordered->getColumnType(idxs[i]); + } } using mlir::daphne::GroupEnum; for (size_t i = numKeyCols; i < numColsRes; i++) { @@ -199,4 +231,4 @@ template <> struct Group { } }; -#endif //SRC_RUNTIME_LOCAL_KERNELS_GROUP_H \ No newline at end of file +#endif //SRC_RUNTIME_LOCAL_KERNELS_GROUP_H diff --git a/test/api/cli/sql/SQLTest.cpp b/test/api/cli/sql/SQLTest.cpp index f9041caca..ed97f468e 100644 --- a/test/api/cli/sql/SQLTest.cpp +++ b/test/api/cli/sql/SQLTest.cpp @@ -103,5 +103,6 @@ MAKE_TEST_CASE("agg_sum", 1) MAKE_TEST_CASE("reuseString", 2) +MAKE_TEST_CASE("distinct", 5) // TODO Use the scripts testing failure cases. diff --git a/test/api/cli/sql/distinct_1.daphne b/test/api/cli/sql/distinct_1.daphne new file mode 100644 index 000000000..00d963b58 --- /dev/null +++ b/test/api/cli/sql/distinct_1.daphne @@ -0,0 +1,12 @@ +# DISTINCT over one column. + +f = createFrame( + [ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2], + [ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1], + "a", "b"); + +registerView("f", f); + +res = sql("SELECT DISTINCT f.a FROM f;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/distinct_1.txt b/test/api/cli/sql/distinct_1.txt new file mode 100644 index 000000000..77b81f60b --- /dev/null +++ b/test/api/cli/sql/distinct_1.txt @@ -0,0 +1,7 @@ +Frame(6x1, [f.a:int64_t]) +0 +1 +2 +3 +6 +8 diff --git a/test/api/cli/sql/distinct_2.daphne b/test/api/cli/sql/distinct_2.daphne new file mode 100644 index 000000000..0588588e8 --- /dev/null +++ b/test/api/cli/sql/distinct_2.daphne @@ -0,0 +1,12 @@ +# DISTINCT over multiple columns. + +f = createFrame( + [ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2], + [ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1], + "a", "b"); + +registerView("f", f); + +res = sql("SELECT DISTINCT f.a, f.b FROM f;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/distinct_2.txt b/test/api/cli/sql/distinct_2.txt new file mode 100644 index 000000000..edbf17b4d --- /dev/null +++ b/test/api/cli/sql/distinct_2.txt @@ -0,0 +1,9 @@ +Frame(8x2, [f.a:int64_t, f.b:int64_t]) +0 1 +1 2 +2 1 +2 3 +3 3 +3 4 +6 5 +8 1 diff --git a/test/api/cli/sql/distinct_3.daphne b/test/api/cli/sql/distinct_3.daphne new file mode 100644 index 000000000..a9cc66b49 --- /dev/null +++ b/test/api/cli/sql/distinct_3.daphne @@ -0,0 +1,12 @@ +# DISTINCT over multiple columns with GROUP BY. + +f = createFrame( + [ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2], + [ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1], + "a", "b"); + +registerView("f", f); + +res = sql("SELECT DISTINCT f.a, f.b FROM f GROUP BY f.a, f.b;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/distinct_3.txt b/test/api/cli/sql/distinct_3.txt new file mode 100644 index 000000000..edbf17b4d --- /dev/null +++ b/test/api/cli/sql/distinct_3.txt @@ -0,0 +1,9 @@ +Frame(8x2, [f.a:int64_t, f.b:int64_t]) +0 1 +1 2 +2 1 +2 3 +3 3 +3 4 +6 5 +8 1 diff --git a/test/api/cli/sql/distinct_4.daphne b/test/api/cli/sql/distinct_4.daphne new file mode 100644 index 000000000..3b9cc44ad --- /dev/null +++ b/test/api/cli/sql/distinct_4.daphne @@ -0,0 +1,11 @@ +# DISTINCT with aggregation. + +f = createFrame( + [ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2], + "a"); + +registerView("f", f); + +res = sql("SELECT DISTINCT SUM(f.a) FROM f;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/distinct_4.txt b/test/api/cli/sql/distinct_4.txt new file mode 100644 index 000000000..bdd61d119 --- /dev/null +++ b/test/api/cli/sql/distinct_4.txt @@ -0,0 +1,2 @@ +Frame(1x1, [SUM(f.a):int64_t]) +30 diff --git a/test/api/cli/sql/distinct_5.daphne b/test/api/cli/sql/distinct_5.daphne new file mode 100644 index 000000000..96ba754e4 --- /dev/null +++ b/test/api/cli/sql/distinct_5.daphne @@ -0,0 +1,12 @@ +# DISTINCT over one column with a group by aggregation. + +f = createFrame( + [ 0, 1, 2, 2, 3, 3, 6, 3, 8, 2], + [ 1, 2, 3, 3, 3, 4, 5, 4, 1, 1], + "a", "b"); + +registerView("f", f); + +res = sql("SELECT DISTINCT SUM(f.a) FROM f GROUP by f.a;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/distinct_5.txt b/test/api/cli/sql/distinct_5.txt new file mode 100644 index 000000000..525e73dd7 --- /dev/null +++ b/test/api/cli/sql/distinct_5.txt @@ -0,0 +1,6 @@ +Frame(5x1, [SUM(f.a):int64_t]) +0 +1 +6 +8 +9