Skip to content

Commit

Permalink
Rewrite AffineForOp transformation logic as OpOperand mutations (X…
Browse files Browse the repository at this point in the history
…ilinx#764)

* Rewrite as AffineForOp OpOperand mutations

* Rewrite scf.for transformation as opoperand mutations
  • Loading branch information
erwei-xilinx authored Nov 8, 2024
1 parent 00414a8 commit 5aeb6c0
Showing 1 changed file with 17 additions and 66 deletions.
83 changes: 17 additions & 66 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1633,50 +1633,20 @@ struct CanonicalizeAIRExecute : public OpRewritePattern<air::ExecuteOp> {
private:
};

affine::AffineForOp updateAffineForBounds(OpBuilder builder, IRMapping &remap,
affine::AffineForOp loop_op, int lb,
int ub, int step) {
affine::AffineForOp new_loop_op = builder.create<affine::AffineForOp>(
builder.getUnknownLoc(), lb, ub, step);
remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar());
// remap.map(old_apply.getResult(), new_loop_op.getInductionVar());
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(new_loop_op.getBody());
for (Operation &child_op : loop_op.getBody()->getOperations()) {
if (&child_op == loop_op.getBody()->getTerminator()) { /*Skip*/
} else
builder.clone(child_op, remap);
}
builder.restoreInsertionPoint(insertionCheckpoint);
return new_loop_op;
void updateAffineForBounds(affine::AffineForOp loop_op, int lb, int ub,
int step) {
loop_op.setConstantLowerBound(lb);
loop_op.setConstantUpperBound(ub);
loop_op.setStep(step);
}

scf::ForOp updateScfForBounds(OpBuilder builder, IRMapping &remap,
scf::ForOp loop_op, int lb, int ub, int step) {
void updateScfForBounds(OpBuilder builder, scf::ForOp loop_op, int lb, int ub,
int step) {
auto loc = loop_op->getLoc();
SmallVector<Value, 1> deps =
loop_op.getOperands().drop_front(loop_op.getNumControlOperands());
scf::ForOp new_loop_op = builder.create<scf::ForOp>(
builder.getUnknownLoc(), builder.create<arith::ConstantIndexOp>(loc, lb),
builder.create<arith::ConstantIndexOp>(loc, ub),
builder.create<arith::ConstantIndexOp>(loc, step), deps);
remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar());
for (unsigned i = 0; i < loop_op.getRegionIterArgs().size(); i++)
remap.map(loop_op.getRegionIterArgs()[i],
new_loop_op.getRegionIterArgs()[i]);
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(new_loop_op.getBody());
for (Operation &child_op : loop_op.getBody()->getOperations()) {
if (&child_op == loop_op.getBody()->getTerminator()) {
if (!new_loop_op.getBody()->mightHaveTerminator())
builder.clone(child_op, remap);
} else
builder.clone(child_op, remap);
}
for (unsigned i = 0; i < loop_op->getNumResults(); i++)
loop_op->getResult(i).replaceAllUsesWith(new_loop_op->getResult(i));
builder.restoreInsertionPoint(insertionCheckpoint);
return new_loop_op;
builder.setInsertionPoint(loop_op);
loop_op.setLowerBound(builder.create<arith::ConstantIndexOp>(loc, lb));
loop_op.setUpperBound(builder.create<arith::ConstantIndexOp>(loc, ub));
loop_op.setStep(builder.create<arith::ConstantIndexOp>(loc, step));
}

// Fold affine.apply op operating on loop induction variable into loop bounds.
Expand Down Expand Up @@ -1725,7 +1695,6 @@ struct CanonicalizeAffineApplyOnLoopInductionVar
int newStepInInt = llvm::divideCeilSigned(*new_ub - *new_lb, tripCount);
IRMapping remap;
if (auto exec = dyn_cast<air::ExecuteOp>(apply->getParentOp())) {
rewriter.setInsertionPoint(exec);
exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar());
if (sfo.getNumRegionIterArgs())
exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]);
Expand All @@ -1738,13 +1707,10 @@ struct CanonicalizeAffineApplyOnLoopInductionVar
}
rewriter.eraseOp(exec);
} else {
rewriter.setInsertionPoint(apply);
apply.getResult().replaceAllUsesWith(sfo.getInductionVar());
rewriter.eraseOp(apply);
}
rewriter.setInsertionPoint(sfo);
updateScfForBounds(rewriter, remap, sfo, *new_lb, *new_ub, newStepInInt);
rewriter.eraseOp(sfo);
updateScfForBounds(rewriter, sfo, *new_lb, *new_ub, newStepInInt);
} else if (auto afo = dyn_cast<affine::AffineForOp>(containingOp)) {
if (!afo.hasConstantBounds())
return failure();
Expand All @@ -1760,12 +1726,9 @@ struct CanonicalizeAffineApplyOnLoopInductionVar
assert(new_ub && new_lb);
int newStepInInt = llvm::divideCeilSigned(*new_ub - *new_lb, tripCount);
IRMapping remap;
rewriter.setInsertionPoint(afo);
apply.getResult().replaceAllUsesWith(afo.getInductionVar());
rewriter.eraseOp(apply);
updateAffineForBounds(rewriter, remap, afo, *new_lb, *new_ub,
newStepInInt);
rewriter.eraseOp(afo);
updateAffineForBounds(afo, *new_lb, *new_ub, newStepInInt);
} else
return failure();

Expand Down Expand Up @@ -1830,7 +1793,6 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
if (auto exec = dyn_cast<air::ExecuteOp>(op->getParentOp())) {
rewriter.setInsertionPoint(exec);
exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar());
if (sfo.getNumRegionIterArgs())
exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]);
Expand All @@ -1843,13 +1805,10 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar
}
rewriter.eraseOp(exec);
} else {
rewriter.setInsertionPoint(op);
op.getResult().replaceAllUsesWith(sfo.getInductionVar());
rewriter.eraseOp(op);
}
rewriter.setInsertionPoint(sfo);
updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(sfo);
updateScfForBounds(rewriter, sfo, new_lb, new_ub, newStepInInt);
} else if (auto afo = dyn_cast<affine::AffineForOp>(containingOp)) {
if (!afo.hasConstantBounds())
return failure();
Expand All @@ -1858,11 +1817,9 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar
int new_lb = afo.getConstantLowerBound() * muli_factor;
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
rewriter.setInsertionPoint(afo);
op.getResult().replaceAllUsesWith(afo.getInductionVar());
rewriter.eraseOp(op);
updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(afo);
updateAffineForBounds(afo, new_lb, new_ub, newStepInInt);
} else
return failure();

Expand Down Expand Up @@ -1927,7 +1884,6 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
if (auto exec = dyn_cast<air::ExecuteOp>(op->getParentOp())) {
rewriter.setInsertionPoint(exec);
exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar());
if (sfo.getNumRegionIterArgs())
exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]);
Expand All @@ -1940,13 +1896,10 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar
}
rewriter.eraseOp(exec);
} else {
rewriter.setInsertionPoint(op);
op.getResult().replaceAllUsesWith(sfo.getInductionVar());
rewriter.eraseOp(op);
}
rewriter.setInsertionPoint(sfo);
updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(sfo);
updateScfForBounds(rewriter, sfo, new_lb, new_ub, newStepInInt);
} else if (auto afo = dyn_cast<affine::AffineForOp>(containingOp)) {
if (!afo.hasConstantBounds())
return failure();
Expand All @@ -1955,11 +1908,9 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar
int new_lb = afo.getConstantLowerBound() + addi_operand;
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
rewriter.setInsertionPoint(afo);
op.getResult().replaceAllUsesWith(afo.getInductionVar());
rewriter.eraseOp(op);
updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(afo);
updateAffineForBounds(afo, new_lb, new_ub, newStepInInt);
} else
return failure();

Expand Down

0 comments on commit 5aeb6c0

Please sign in to comment.