Skip to content

Commit

Permalink
Fixup op dominance around cloning index_cast op (Xilinx#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Jul 23, 2024
1 parent 140d0be commit 6a8c497
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,6 @@ specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder,
return failure();

// Fold for loops into channel op's wrap and stride fields
SmallVector<affine::AffineForOp> for_loops;
Operation *parent = memcpy_ops[0].getOperation();
while (parent != for_op.getOperation()) {
parent = parent->getParentOp();
if (auto for_op_in_nest = dyn_cast<affine::AffineForOp>(parent))
for_loops.push_back(for_op_in_nest);
}

auto memref = memcpy_ops[0]->getOperand(3);
auto memref_shape = xilinx::air::getTensorShape(memref.getType());
auto oper_begin = memcpy_ops[0].getOperands().begin();
Expand Down Expand Up @@ -850,14 +842,21 @@ specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder,
opers.insert(opers.end(), strides.begin(), strides.end());

// index_cast
IRMapping indexOperMap;
for (unsigned i = 0; i < opers.size(); i++) {
if (opers[i].getDefiningOp() &&
isa<arith::ConstantIndexOp>(opers[i].getDefiningOp())) {
opers[i] =
builder.clone(*opers[i].getDefiningOp(), indexOperMap)->getResult(0);
opers[i] = builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), opers[i]);
} else if (opers[i].getDefiningOp() &&
isa<arith::IndexCastOp>(opers[i].getDefiningOp())) {
opers[i] = builder.clone(*opers[i].getDefiningOp())->getResult(0);
auto castOp = dyn_cast<arith::IndexCastOp>(opers[i].getDefiningOp());
if (castOp.getOperand().getDefiningOp() &&
isa<arith::ConstantOp>(castOp.getOperand().getDefiningOp()))
builder.clone(*castOp.getOperand().getDefiningOp(), indexOperMap);
opers[i] = builder.clone(*castOp, indexOperMap)->getResult(0);
}
}
auto new_dma = builder.create<airrt::DmaMemcpyNdOp>(loc, tys, opers);
Expand Down

0 comments on commit 6a8c497

Please sign in to comment.