From 5aeb6c0c376c7e0ea339655fc067d2c01b9a2bf9 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 8 Nov 2024 10:39:50 +0800 Subject: [PATCH] Rewrite `AffineForOp` transformation logic as `OpOperand` mutations (#764) * Rewrite as AffineForOp OpOperand mutations * Rewrite scf.for transformation as opoperand mutations --- .../Transform/AIRDependencyScheduleOpt.cpp | 83 ++++--------------- 1 file changed, 17 insertions(+), 66 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index e2604d8f6..59982f8d9 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -1633,50 +1633,20 @@ struct CanonicalizeAIRExecute : public OpRewritePattern { private: }; -affine::AffineForOp updateAffineForBounds(OpBuilder builder, IRMapping &remap, - affine::AffineForOp loop_op, int lb, - int ub, int step) { - affine::AffineForOp new_loop_op = builder.create( - builder.getUnknownLoc(), lb, ub, step); - remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar()); - // remap.map(old_apply.getResult(), new_loop_op.getInductionVar()); - auto insertionCheckpoint = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(new_loop_op.getBody()); - for (Operation &child_op : loop_op.getBody()->getOperations()) { - if (&child_op == loop_op.getBody()->getTerminator()) { /*Skip*/ - } else - builder.clone(child_op, remap); - } - builder.restoreInsertionPoint(insertionCheckpoint); - return new_loop_op; +void updateAffineForBounds(affine::AffineForOp loop_op, int lb, int ub, + int step) { + loop_op.setConstantLowerBound(lb); + loop_op.setConstantUpperBound(ub); + loop_op.setStep(step); } -scf::ForOp updateScfForBounds(OpBuilder builder, IRMapping &remap, - scf::ForOp loop_op, int lb, int ub, int step) { +void updateScfForBounds(OpBuilder builder, scf::ForOp loop_op, int lb, int ub, + int step) { auto loc = loop_op->getLoc(); - SmallVector deps = - loop_op.getOperands().drop_front(loop_op.getNumControlOperands()); - scf::ForOp new_loop_op = builder.create( - builder.getUnknownLoc(), builder.create(loc, lb), - builder.create(loc, ub), - builder.create(loc, step), deps); - remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar()); - for (unsigned i = 0; i < loop_op.getRegionIterArgs().size(); i++) - remap.map(loop_op.getRegionIterArgs()[i], - new_loop_op.getRegionIterArgs()[i]); - auto insertionCheckpoint = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(new_loop_op.getBody()); - for (Operation &child_op : loop_op.getBody()->getOperations()) { - if (&child_op == loop_op.getBody()->getTerminator()) { - if (!new_loop_op.getBody()->mightHaveTerminator()) - builder.clone(child_op, remap); - } else - builder.clone(child_op, remap); - } - for (unsigned i = 0; i < loop_op->getNumResults(); i++) - loop_op->getResult(i).replaceAllUsesWith(new_loop_op->getResult(i)); - builder.restoreInsertionPoint(insertionCheckpoint); - return new_loop_op; + builder.setInsertionPoint(loop_op); + loop_op.setLowerBound(builder.create(loc, lb)); + loop_op.setUpperBound(builder.create(loc, ub)); + loop_op.setStep(builder.create(loc, step)); } // Fold affine.apply op operating on loop induction variable into loop bounds. @@ -1725,7 +1695,6 @@ struct CanonicalizeAffineApplyOnLoopInductionVar int newStepInInt = llvm::divideCeilSigned(*new_ub - *new_lb, tripCount); IRMapping remap; if (auto exec = dyn_cast(apply->getParentOp())) { - rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); if (sfo.getNumRegionIterArgs()) exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); @@ -1738,13 +1707,10 @@ struct CanonicalizeAffineApplyOnLoopInductionVar } rewriter.eraseOp(exec); } else { - rewriter.setInsertionPoint(apply); apply.getResult().replaceAllUsesWith(sfo.getInductionVar()); rewriter.eraseOp(apply); } - rewriter.setInsertionPoint(sfo); - updateScfForBounds(rewriter, remap, sfo, *new_lb, *new_ub, newStepInInt); - rewriter.eraseOp(sfo); + updateScfForBounds(rewriter, sfo, *new_lb, *new_ub, newStepInInt); } else if (auto afo = dyn_cast(containingOp)) { if (!afo.hasConstantBounds()) return failure(); @@ -1760,12 +1726,9 @@ struct CanonicalizeAffineApplyOnLoopInductionVar assert(new_ub && new_lb); int newStepInInt = llvm::divideCeilSigned(*new_ub - *new_lb, tripCount); IRMapping remap; - rewriter.setInsertionPoint(afo); apply.getResult().replaceAllUsesWith(afo.getInductionVar()); rewriter.eraseOp(apply); - updateAffineForBounds(rewriter, remap, afo, *new_lb, *new_ub, - newStepInInt); - rewriter.eraseOp(afo); + updateAffineForBounds(afo, *new_lb, *new_ub, newStepInInt); } else return failure(); @@ -1830,7 +1793,6 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); IRMapping remap; if (auto exec = dyn_cast(op->getParentOp())) { - rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); if (sfo.getNumRegionIterArgs()) exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); @@ -1843,13 +1805,10 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar } rewriter.eraseOp(exec); } else { - rewriter.setInsertionPoint(op); op.getResult().replaceAllUsesWith(sfo.getInductionVar()); rewriter.eraseOp(op); } - rewriter.setInsertionPoint(sfo); - updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt); - rewriter.eraseOp(sfo); + updateScfForBounds(rewriter, sfo, new_lb, new_ub, newStepInInt); } else if (auto afo = dyn_cast(containingOp)) { if (!afo.hasConstantBounds()) return failure(); @@ -1858,11 +1817,9 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar int new_lb = afo.getConstantLowerBound() * muli_factor; int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); IRMapping remap; - rewriter.setInsertionPoint(afo); op.getResult().replaceAllUsesWith(afo.getInductionVar()); rewriter.eraseOp(op); - updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt); - rewriter.eraseOp(afo); + updateAffineForBounds(afo, new_lb, new_ub, newStepInInt); } else return failure(); @@ -1927,7 +1884,6 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); IRMapping remap; if (auto exec = dyn_cast(op->getParentOp())) { - rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); if (sfo.getNumRegionIterArgs()) exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); @@ -1940,13 +1896,10 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar } rewriter.eraseOp(exec); } else { - rewriter.setInsertionPoint(op); op.getResult().replaceAllUsesWith(sfo.getInductionVar()); rewriter.eraseOp(op); } - rewriter.setInsertionPoint(sfo); - updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt); - rewriter.eraseOp(sfo); + updateScfForBounds(rewriter, sfo, new_lb, new_ub, newStepInInt); } else if (auto afo = dyn_cast(containingOp)) { if (!afo.hasConstantBounds()) return failure(); @@ -1955,11 +1908,9 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar int new_lb = afo.getConstantLowerBound() + addi_operand; int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); IRMapping remap; - rewriter.setInsertionPoint(afo); op.getResult().replaceAllUsesWith(afo.getInductionVar()); rewriter.eraseOp(op); - updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt); - rewriter.eraseOp(afo); + updateAffineForBounds(afo, new_lb, new_ub, newStepInInt); } else return failure();