Skip to content

Commit

Permalink
Add TensorReshapeOp canon & test
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Oct 9, 2024
1 parent 0f28d44 commit e3a633f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
74 changes: 74 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,78 @@ struct ResolveShapedDim : public OpRewritePattern<tensor::DimOp> {
}
};

template <typename OpTy>
struct EraseReshapesAroundOp : public OpRewritePattern<Flow::TensorReshapeOp> {
using OpRewritePattern<Flow::TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Flow::TensorReshapeOp lastReshapeOp,
PatternRewriter &rewriter) const override {
auto expandOp = lastReshapeOp.getSource().getDefiningOp<OpTy>();
if (!expandOp) {
return failure();
}

auto firstReshapeOp =
expandOp.getSrc().template getDefiningOp<Flow::TensorReshapeOp>();
if (!firstReshapeOp) {
return failure();
}

if constexpr (std::is_same_v<OpTy, tensor::ExpandShapeOp>) {
for (auto dim : expandOp.getOutputShape()) {
APInt d;
if (!matchPattern(dim, m_ConstantInt(&d))) {
return failure();
}
}
}

if (firstReshapeOp.getSourceDims().size() != 0 ||
lastReshapeOp.getResultDims().size() != 0) {
return failure();
}

auto isStaticToDynamicCast = [](ShapedType from, ShapedType to,
OperandRange dynamicSizes) {
auto it = dynamicSizes.begin();
for (auto [fromSize, toSize] :
llvm::zip_equal(from.getShape(), to.getShape())) {
if (ShapedType::isDynamic(fromSize)) {
return false;
}

auto resolvedSize = toSize;
if (ShapedType::isDynamic(resolvedSize)) {
APInt d;
if (!matchPattern(*it++, m_ConstantInt(&d))) {
return false;
}
resolvedSize = d.getSExtValue();
}
if (fromSize != resolvedSize) {
return false;
}
}

return true;
};
if (!isStaticToDynamicCast(firstReshapeOp.getSource().getType(),
firstReshapeOp.getType(),
firstReshapeOp.getResultDims()) ||
!isStaticToDynamicCast(lastReshapeOp.getType(),
lastReshapeOp.getSource().getType(),
lastReshapeOp.getSourceDims())) {
return failure();
}

auto newExpand = rewriter.create<OpTy>(
expandOp.getLoc(), lastReshapeOp.getType(), firstReshapeOp.getSource(),
expandOp.getReassociationIndices());
rewriter.replaceOp(lastReshapeOp, newExpand);

return success();
}
};

} // namespace

void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand All @@ -1029,6 +1101,8 @@ void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<FlattenTensorCastLikeChain<TensorReshapeOp>>(context);
results.insert<ResolveShapedRank>(context);
results.insert<ResolveShapedDim>(context);
results.insert<EraseReshapesAroundOp<tensor::ExpandShapeOp>>(context);
results.insert<EraseReshapesAroundOp<tensor::CollapseShapeOp>>(context);
}

void TensorBitCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,35 @@ util.func public @innermost_unit_dim(%4: !flow.dispatch.tensor<readonly:tensor<3
// CHECK-SAME: %[[DYNAMIC_DIM:[a-zA-Z0-9]+]]: index)
// CHECK: flow.dispatch.tensor.load
// CHECK-SAME: sizes = [1, 1, 16, %[[DYNAMIC_DIM]], 1]

// -----

util.func public @canonicalizeReshapeExpand(%arg0: tensor<4x1x8192xf16>) -> tensor<4x1x256x32xf16> {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = flow.tensor.reshape %arg0: tensor<4x1x8192xf16> -> tensor<?x?x8192xf16>{%c4, %c1}
%expanded_0 = tensor.expand_shape %0 [[0], [1], [2, 3]] output_shape [%c4, %c1, 256, 32] : tensor<?x?x8192xf16> into tensor<?x?x256x32xf16>
%1 = flow.tensor.reshape %expanded_0 : tensor<?x?x256x32xf16>{%c4, %c1} -> tensor<4x1x256x32xf16>
util.return %1 : tensor<4x1x256x32xf16>
}

// CHECK-LABEL: util.func public @canonicalizeReshapeExpand
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x1x8192xf16>
// CHECK: %[[VAL0:.+]] = tensor.expand_shape
// CHECK: util.return %[[VAL0]]

// -----

util.func public @canonicalizeReshapeCollapse(%arg0: tensor<4x1x256x32xf16>) -> tensor<4x1x8192xf16> {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = flow.tensor.reshape %arg0: tensor<4x1x256x32xf16> -> tensor<?x?x256x32xf16>{%c4, %c1}
%expanded_0 = tensor.collapse_shape %0 [[0], [1], [2, 3]] : tensor<?x?x256x32xf16> into tensor<?x?x8192xf16>
%1 = flow.tensor.reshape %expanded_0 : tensor<?x?x8192xf16>{%c4, %c1} -> tensor<4x1x8192xf16>
util.return %1 : tensor<4x1x8192xf16>
}

// CHECK-LABEL: util.func public @canonicalizeReshapeCollapse
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x1x256x32xf16>
// CHECK: %[[VAL0:.+]] = tensor.collapse_shape
// CHECK: util.return %[[VAL0]]

0 comments on commit e3a633f

Please sign in to comment.