Skip to content

Commit

Permalink
[DAPHNE-#903] Add 'groupSum()' built-in function to DaphneDSL. (#921)
Browse files Browse the repository at this point in the history
This commit introduces a new `groupSum()` built-in function to DaphneDSL, enabling the creation of a `GroupOp` in DaphneIR and closes #903.

Changes Implemented
1. `groupSum()` Built-in Function in DaphneDSL:
   - Interface: `group(arg:frame, groupCols:str, ..., sumCol:str)`
   - Accepts:
     - A frame as input.
     - At least one column to group on.
     - A single column to compute the sum.
   - Aggregation Support:
     - Only supports `SUM` as the aggregation function.
3. **Test Cases**:
   - Added script-level tests to validate the functionality of the `group()` function in DaphneDSL.
  • Loading branch information
saminbassiri authored Dec 3, 2024
1 parent 7136e0d commit 44ff77b
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 0 deletions.
8 changes: 8 additions & 0 deletions doc/DaphneDSL/Builtins.md
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,14 @@ We will support set operations such as **`intersect`**, **`merge`**, and **`exce

We will support more variants of joins, including (left/right) outer joins, theta joins, anti-joins, etc.

### Grouping and aggregation

- **`groupSum`**`(arg:frame, grpColNames:str[, grpColNames, ...], sumColName:str)`

Groups the rows in the given frame `arg` by the specified columns `grpColNames` (at least one column) and calculates the per-group sum of the column denoted by `sumColName`.

*This built-in function is currently limited in terms of functionality (aggregation only on a single column, sum as the only aggregation function). It will be extended in the future. Meanwhile, consider using DAPHNE's `sql()` built-in function for more comprehensive grouping and aggregation support.*

### Frame label manipulation

- **`setColLabels`**`(arg:frame, labels:str, ...)`
Expand Down
39 changes: 39 additions & 0 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,45 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
.getResults();
}

// --------------------------------------------------------------------
// Grouping and aggregation
// --------------------------------------------------------------------

if (func == "groupSum") {
// Arbitrary number of columns to group on.
// A single column to calculate the sum on.
checkNumArgsMin(loc, func, numArgs, 3);
mlir::Value currentFrame = args[0];
mlir::Value aggCol = args[numArgs - 1];
std::vector<mlir::Value> groupName;
std::vector<mlir::Value> columnName;
std::vector<mlir::Type> colTypes;

// set aggregaton function to SUM
auto aggFunc = static_cast<mlir::Attribute>(
mlir::daphne::GroupEnumAttr::get(builder.getContext(), mlir::daphne::GroupEnum::SUM));
std::vector<mlir::Attribute> functionName;
functionName.push_back(aggFunc);

// get group columns
for (size_t i = 1; i < numArgs - 1; i++) {
groupName.push_back(args[i]);
}

// get agg column
columnName.push_back(aggCol);

// result column types
mlir::Type vt = utils.unknownType;
for (size_t i = 0; i < groupName.size() + columnName.size(); i++) {
colTypes.push_back(vt);
}

return static_cast<mlir::Value>(builder.create<GroupOp>(loc, FrameType::get(builder.getContext(), colTypes),
currentFrame, groupName, columnName,
builder.getArrayAttr(functionName)));
}

// ********************************************************************
// Frame label manipulation
// ********************************************************************
Expand Down
1 change: 1 addition & 0 deletions test/api/cli/operations/OperationsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ MAKE_TEST_CASE("createFrame", 1)
MAKE_TEST_CASE("ctable", 1)
MAKE_TEST_CASE("fill", 1)
MAKE_TEST_CASE("gemv", 1)
MAKE_TEST_CASE("groupSum", 1)
MAKE_TEST_CASE("idxMax", 1)
MAKE_TEST_CASE("idxMin", 1)
MAKE_TEST_CASE("innerJoin", 1)
Expand Down
8 changes: 8 additions & 0 deletions test/api/cli/operations/groupSum_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Grouping and sum aggregation (two grouping columns).

fr = createFrame([ 11, 11, 22, 33, 33 ], [ 1, 1, 2, 3, 4 ],
[ 100, 200, 300, 400, 500 ], "a", "b", "c");

res = groupSum(fr, "a", "b", "c");

print(res);
5 changes: 5 additions & 0 deletions test/api/cli/operations/groupSum_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Frame(4x3, [a:int64_t, b:int64_t, SUM(c):int64_t])
11 1 300
22 2 300
33 3 400
33 4 500

0 comments on commit 44ff77b

Please sign in to comment.