From 4d83a5c520b29b422d7a00971b041559a3c644bd Mon Sep 17 00:00:00 2001 From: Jiahan Xie Date: Sun, 6 Oct 2024 13:52:03 -0400 Subject: [PATCH 1/2] support if op when its condition check is not combinational --- lib/Conversion/SCFToCalyx/SCFToCalyx.cpp | 151 ++++++++++++++++-- .../SCFToCalyx/convert_controlflow.mlir | 119 ++++++++++++++ 2 files changed, 256 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index 7411c3a4dd3b..5118d5c25d28 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -125,6 +125,20 @@ using Scheduleable = class IfLoweringStateInterface { public: + void setCondReg(scf::IfOp op, calyx::RegisterOp regOp) { + Operation *operation = op.getOperation(); + assert(condReg.count(operation) == 0 && + "A condition register was already set for this scf::IfOp!\n"); + condReg[operation] = regOp; + } + + calyx::RegisterOp getCondReg(scf::IfOp op) { + auto it = condReg.find(op.getOperation()); + if (it != condReg.end()) + return it->second; + return nullptr; + } + void setThenGroup(scf::IfOp op, calyx::GroupOp group) { Operation *operation = op.getOperation(); assert(thenGroup.count(operation) == 0 && @@ -172,6 +186,7 @@ class IfLoweringStateInterface { } private: + DenseMap condReg; DenseMap thenGroup; DenseMap elseGroup; DenseMap> resultRegs; @@ -240,6 +255,28 @@ class ForLoopLoweringStateInterface } }; +class PipeOpLoweringStateInterface { +public: + void setPipeResReg(Operation *op, calyx::RegisterOp reg) { + assert(isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op)); + assert(resultRegs.count(op) == 0 && + "A register was already set for this pipe operation!\n"); + resultRegs[op] = reg; + } + // Get the register for a specific pipe operation + calyx::RegisterOp getPipeResReg(Operation *op) { + auto it = resultRegs.find(op); + assert(it != resultRegs.end() && + "No register was set for this pipe operation!\n"); + return it->second; + } + +private: + DenseMap resultRegs; +}; + /// Handles the current state of lowering of a Calyx component. It is mainly /// used as a key/value store for recording information during partial lowering, /// which is required at later lowering passes. @@ -247,6 +284,7 @@ class ComponentLoweringState : public calyx::ComponentLoweringStateInterface, public WhileLoopLoweringStateInterface, public ForLoopLoweringStateInterface, public IfLoweringStateInterface, + public PipeOpLoweringStateInterface, public calyx::SchedulerInterface { public: ComponentLoweringState(calyx::ComponentOp component) @@ -339,7 +377,12 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { /// source operation TSrcOp. template LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, - TypeRange srcTypes, TypeRange dstTypes) const { + TypeRange srcTypes, TypeRange dstTypes, + calyx::RegisterOp srcReg = nullptr, + calyx::RegisterOp dstReg = nullptr) const { + assert((srcReg && dstReg) || (!srcReg && !dstReg)); + bool isSequential = srcReg && dstReg; + SmallVector types; llvm::append_range(types, srcTypes); llvm::append_range(types, dstTypes); @@ -365,26 +408,54 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { /// Create assignments to the inputs of the library op. auto group = createGroupForOp(rewriter, op); + + if (isSequential) { + auto groupOp = cast(group); + getState().addBlockScheduleable(op->getBlock(), + groupOp); + } + rewriter.setInsertionPointToEnd(group.getBodyBlock()); - for (auto dstOp : enumerate(opInputPorts)) - rewriter.create(op.getLoc(), dstOp.value(), - op->getOperand(dstOp.index())); + + for (auto dstOp : enumerate(opInputPorts)) { + if (isSequential) + rewriter.create(op.getLoc(), dstOp.value(), + srcReg.getOut()); + else + rewriter.create(op.getLoc(), dstOp.value(), + op->getOperand(dstOp.index())); + } /// Replace the result values of the source operator with the new operator. for (auto res : enumerate(opOutputPorts)) { getState().registerEvaluatingGroup(res.value(), group); - op->getResult(res.index()).replaceAllUsesWith(res.value()); + if (isSequential) + op->getResult(res.index()).replaceAllUsesWith(dstReg.getOut()); + else + op->getResult(res.index()).replaceAllUsesWith(res.value()); } + + if (isSequential) { + auto groupOp = cast(group); + buildAssignmentsForRegisterWrite( + rewriter, groupOp, + getState().getComponentOp(), dstReg, + calyxOp.getOut()); + } + return success(); } /// buildLibraryOp which provides in- and output types based on the operands /// and results of the op argument. template - LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const { + LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, + calyx::RegisterOp srcReg = nullptr, + calyx::RegisterOp dstReg = nullptr) const { return buildLibraryOp( - rewriter, op, op.getOperandTypes(), op->getResultTypes()); + rewriter, op, op.getOperandTypes(), op->getResultTypes(), srcReg, + dstReg); } /// Creates a group named by the basic block which the input op resides in. @@ -411,6 +482,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { auto reg = createRegister( op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(), getState().getUniqueName(opName)); + // Operation pipelines are not combinational, so a GroupOp is required. auto group = createGroupForOp(rewriter, op); OpBuilder builder(group->getRegion(0)); @@ -441,6 +513,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { getState().registerEvaluatingGroup( opPipe.getRight(), group); + getState().setPipeResReg(out.getDefiningOp(), reg); + return success(); } @@ -939,9 +1013,43 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, CmpIOp op) const { + auto isPipeLibOp = [](Value val) -> bool { + if (Operation *defOp = val.getDefiningOp()) { + return isa(defOp); + } + return false; + }; + switch (op.getPredicate()) { - case CmpIPredicate::eq: + case CmpIPredicate::eq: { + StringRef opName = op.getOperationName().split(".").second; + Type width = op.getResult().getType(); + auto condReg = createRegister( + op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(), + getState().getUniqueName(opName)); + + for (auto *user : op->getUsers()) { + if (auto ifOp = dyn_cast(user)) + getState().setCondReg(ifOp, condReg); + } + + bool isSequential = isPipeLibOp(op.getLhs()) || isPipeLibOp(op.getRhs()); + if (isSequential) { + calyx::RegisterOp pipeResReg; + if (isPipeLibOp(op.getLhs())) + pipeResReg = getState().getPipeResReg( + op.getLhs().getDefiningOp()); + else + pipeResReg = getState().getPipeResReg( + op.getRhs().getDefiningOp()); + + return buildLibraryOp( + rewriter, op, pipeResReg, condReg); + } return buildLibraryOp(rewriter, op); + } case CmpIPredicate::ne: return buildLibraryOp(rewriter, op); case CmpIPredicate::uge: @@ -1535,11 +1643,16 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { Location loc = ifOp->getLoc(); auto cond = ifOp.getCondition(); - auto condGroup = getState() - .getEvaluatingGroup(cond); - auto symbolAttr = FlatSymbolRefAttr::get( - StringAttr::get(getContext(), condGroup.getSymName())); + FlatSymbolRefAttr symbolAttr = nullptr; + auto condReg = getState().getCondReg(ifOp); + if (!condReg) { + auto condGroup = getState() + .getEvaluatingGroup(cond); + + symbolAttr = FlatSymbolRefAttr::get( + StringAttr::get(getContext(), condGroup.getSymName())); + } bool initElse = !ifOp.getElseRegion().empty(); auto ifCtrlOp = rewriter.create( @@ -1551,8 +1664,13 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { rewriter.create(ifOp.getThenRegion().getLoc()); auto *thenSeqOpBlock = thenSeqOp.getBodyBlock(); - rewriter.setInsertionPointToEnd(thenSeqOpBlock); + auto *thenBlock = &ifOp.getThenRegion().front(); + LogicalResult res = buildCFGControl(path, rewriter, thenSeqOpBlock, + /*preBlock=*/block, thenBlock); + if (res.failed()) + return res; + rewriter.setInsertionPointToEnd(thenSeqOpBlock); calyx::GroupOp thenGroup = getState().getThenGroup(ifOp); rewriter.create(thenGroup.getLoc(), @@ -1565,8 +1683,13 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { rewriter.create(ifOp.getElseRegion().getLoc()); auto *elseSeqOpBlock = elseSeqOp.getBodyBlock(); - rewriter.setInsertionPointToEnd(elseSeqOpBlock); + auto *elseBlock = &ifOp.getElseRegion().front(); + res = buildCFGControl(path, rewriter, elseSeqOpBlock, + /*preBlock=*/block, elseBlock); + if (res.failed()) + return res; + rewriter.setInsertionPointToEnd(elseSeqOpBlock); calyx::GroupOp elseGroup = getState().getElseGroup(ifOp); rewriter.create(elseGroup.getLoc(), diff --git a/test/Conversion/SCFToCalyx/convert_controlflow.mlir b/test/Conversion/SCFToCalyx/convert_controlflow.mlir index d4a87139f621..21618baa99f9 100644 --- a/test/Conversion/SCFToCalyx/convert_controlflow.mlir +++ b/test/Conversion/SCFToCalyx/convert_controlflow.mlir @@ -641,3 +641,122 @@ module { return %1 : i32 } } + +// ----- + +// Test if ops with sequential condition check. + +module { +// CHECK-LABEL: calyx.component @main( +// CHECK-SAME: %[[VAL_0:in0]]: i32, +// CHECK-SAME: %[[VAL_1:.*]]: i1 {clk}, +// CHECK-SAME: %[[VAL_2:.*]]: i1 {reset}, +// CHECK-SAME: %[[VAL_3:.*]]: i1 {go}) -> ( +// CHECK-SAME: %[[VAL_4:out0]]: i32, +// CHECK-SAME: %[[VAL_5:.*]]: i1 {done}) { +// CHECK: %[[VAL_6:.*]] = hw.constant true +// CHECK: %[[VAL_7:.*]] = hw.constant false +// CHECK: %[[VAL_8:.*]] = hw.constant 1 : i32 +// CHECK: %[[VAL_9:.*]] = hw.constant 2 : i32 +// CHECK: %[[VAL_10:.*]], %[[VAL_11:.*]] = calyx.std_slice @std_slice_1 : i32, i7 +// CHECK: %[[VAL_12:.*]], %[[VAL_13:.*]] = calyx.std_slice @std_slice_0 : i32, i7 +// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]], %[[VAL_16:.*]], %[[VAL_17:.*]], %[[VAL_18:.*]], %[[VAL_19:.*]] = calyx.register @load_1_reg : i32, i1, i1, i1, i32, i1 +// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]], %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]] = calyx.register @load_0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]] = calyx.std_eq @std_eq_0 : i32, i32, i1 +// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]], %[[VAL_31:.*]], %[[VAL_32:.*]], %[[VAL_33:.*]], %[[VAL_34:.*]] = calyx.register @cmpi_0_reg : i1, i1, i1, i1, i1, i1 +// CHECK: %[[VAL_35:.*]], %[[VAL_36:.*]], %[[VAL_37:.*]], %[[VAL_38:.*]], %[[VAL_39:.*]], %[[VAL_40:.*]] = calyx.register @remui_0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]], %[[VAL_43:.*]], %[[VAL_44:.*]], %[[VAL_45:.*]], %[[VAL_46:.*]], %[[VAL_47:.*]] = calyx.std_remu_pipe @std_remu_pipe_0 : i1, i1, i1, i32, i32, i32, i1 +// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]], %[[VAL_50:.*]], %[[VAL_51:.*]], %[[VAL_52:.*]], %[[VAL_53:.*]], %[[VAL_54:.*]], %[[VAL_55:.*]] = calyx.seq_mem @mem_0 <[120] x 32> [7] {external = true} : i7, i1, i1, i1, i1, i32, i32, i1 +// CHECK: %[[VAL_56:.*]], %[[VAL_57:.*]], %[[VAL_58:.*]], %[[VAL_59:.*]], %[[VAL_60:.*]], %[[VAL_61:.*]] = calyx.register @if_res_0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: %[[VAL_62:.*]], %[[VAL_63:.*]], %[[VAL_64:.*]], %[[VAL_65:.*]], %[[VAL_66:.*]], %[[VAL_67:.*]] = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: calyx.wires { +// CHECK: calyx.assign %[[VAL_4]] = %[[VAL_66]] : i32 +// CHECK: calyx.group @then_br_0 { +// CHECK: calyx.assign %[[VAL_56]] = %[[VAL_24]] : i32 +// CHECK: calyx.assign %[[VAL_57]] = %[[VAL_6]] : i1 +// CHECK: calyx.group_done %[[VAL_61]] : i1 +// CHECK: } +// CHECK: calyx.group @else_br_0 { +// CHECK: calyx.assign %[[VAL_56]] = %[[VAL_18]] : i32 +// CHECK: calyx.assign %[[VAL_57]] = %[[VAL_6]] : i1 +// CHECK: calyx.group_done %[[VAL_61]] : i1 +// CHECK: } +// CHECK: calyx.group @bb0_0 { +// CHECK: calyx.assign %[[VAL_44]] = %[[VAL_0]] : i32 +// CHECK: calyx.assign %[[VAL_45]] = %[[VAL_9]] : i32 +// CHECK: calyx.assign %[[VAL_35]] = %[[VAL_46]] : i32 +// CHECK: calyx.assign %[[VAL_36]] = %[[VAL_47]] : i1 +// CHECK: %[[VAL_68:.*]] = comb.xor %[[VAL_47]], %[[VAL_6]] : i1 +// CHECK: calyx.assign %[[VAL_43]] = %[[VAL_68]] ? %[[VAL_6]] : i1 +// CHECK: calyx.group_done %[[VAL_40]] : i1 +// CHECK: } +// CHECK: calyx.group @bb0_1 { +// CHECK: calyx.assign %[[VAL_26]] = %[[VAL_39]] : i32 +// CHECK: calyx.assign %[[VAL_27]] = %[[VAL_39]] : i32 +// CHECK: calyx.assign %[[VAL_29]] = %[[VAL_28]] : i1 +// CHECK: calyx.assign %[[VAL_30]] = %[[VAL_6]] : i1 +// CHECK: calyx.group_done %[[VAL_34]] : i1 +// CHECK: } +// CHECK: calyx.group @bb0_2 { +// CHECK: calyx.assign %[[VAL_10]] = %[[VAL_9]] : i32 +// CHECK: calyx.assign %[[VAL_48]] = %[[VAL_11]] : i7 +// CHECK: calyx.assign %[[VAL_51]] = %[[VAL_6]] : i1 +// CHECK: calyx.assign %[[VAL_52]] = %[[VAL_7]] : i1 +// CHECK: calyx.assign %[[VAL_20]] = %[[VAL_54]] : i32 +// CHECK: calyx.assign %[[VAL_21]] = %[[VAL_55]] : i1 +// CHECK: calyx.group_done %[[VAL_25]] : i1 +// CHECK: } +// CHECK: calyx.group @bb0_3 { +// CHECK: calyx.assign %[[VAL_12]] = %[[VAL_8]] : i32 +// CHECK: calyx.assign %[[VAL_48]] = %[[VAL_13]] : i7 +// CHECK: calyx.assign %[[VAL_51]] = %[[VAL_6]] : i1 +// CHECK: calyx.assign %[[VAL_52]] = %[[VAL_7]] : i1 +// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_54]] : i32 +// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_55]] : i1 +// CHECK: calyx.group_done %[[VAL_19]] : i1 +// CHECK: } +// CHECK: calyx.group @ret_assign_0 { +// CHECK: calyx.assign %[[VAL_62]] = %[[VAL_60]] : i32 +// CHECK: calyx.assign %[[VAL_63]] = %[[VAL_6]] : i1 +// CHECK: calyx.group_done %[[VAL_67]] : i1 +// CHECK: } +// CHECK: } +// CHECK: calyx.control { +// CHECK: calyx.seq { +// CHECK: calyx.enable @bb0_0 +// CHECK: calyx.enable @bb0_1 +// CHECK: calyx.if %[[VAL_33]] { +// CHECK: calyx.seq { +// CHECK: calyx.enable @bb0_2 +// CHECK: calyx.enable @then_br_0 +// CHECK: } +// CHECK: } else { +// CHECK: calyx.seq { +// CHECK: calyx.enable @bb0_3 +// CHECK: calyx.enable @else_br_0 +// CHECK: } +// CHECK: } +// CHECK: calyx.enable @ret_assign_0 +// CHECK: } +// CHECK: } +// CHECK: } {toplevel} + func.func @main(%arg0 : i32) -> i32 { + %1 = memref.alloc() : memref<120xi32> + %idx_one = arith.constant 1 : index + %two = arith.constant 2: i32 + %rem = arith.remui %arg0, %two : i32 + %cond = arith.cmpi eq, %arg0, %rem : i32 + + %res = scf.if %cond -> i32 { + %idx = arith.addi %idx_one, %idx_one : index + %then_res = memref.load %1[%idx] : memref<120xi32> + scf.yield %then_res : i32 + } else { + %idx = arith.muli %idx_one, %idx_one : index + %else_res = memref.load %1[%idx] : memref<120xi32> + scf.yield %else_res : i32 + } + + return %res : i32 + } +} From 7fc6c933c3a01def6f4718fec41c51eca138c3cb Mon Sep 17 00:00:00 2001 From: Jiahan Xie Date: Wed, 9 Oct 2024 18:06:10 -0400 Subject: [PATCH 2/2] small fix --- lib/Conversion/SCFToCalyx/SCFToCalyx.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index 5118d5c25d28..62e2717bbdba 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -1026,17 +1026,17 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, case CmpIPredicate::eq: { StringRef opName = op.getOperationName().split(".").second; Type width = op.getResult().getType(); - auto condReg = createRegister( - op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(), - getState().getUniqueName(opName)); - - for (auto *user : op->getUsers()) { - if (auto ifOp = dyn_cast(user)) - getState().setCondReg(ifOp, condReg); - } - bool isSequential = isPipeLibOp(op.getLhs()) || isPipeLibOp(op.getRhs()); if (isSequential) { + auto condReg = createRegister( + op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(), + getState().getUniqueName(opName)); + + for (auto *user : op->getUsers()) { + if (auto ifOp = dyn_cast(user)) + getState().setCondReg(ifOp, condReg); + } + calyx::RegisterOp pipeResReg; if (isPipeLibOp(op.getLhs())) pipeResReg = getState().getPipeResReg(