Skip to content

Commit

Permalink
Revert "Bubble expand shapes through AttentionOps (#18074)"
Browse files Browse the repository at this point in the history
This reverts commit 8dd1db3.
  • Loading branch information
nirvedhmeshram committed Aug 22, 2024
1 parent 5170872 commit 018e137
Show file tree
Hide file tree
Showing 21 changed files with 48 additions and 778 deletions.
20 changes: 10 additions & 10 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,13 @@ jobs:
run: |
source ${VENV_DIR}/bin/activate
pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \
--goldentime-rocm-e2e-ms 1616.0 \
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-e2e-ms 1450.0 \
--goldentime-rocm-unet-ms 370.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1551 \
--goldentime-rocm-vae-ms 315.0 \
--goldendispatch-rocm-unet 1691 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-vae 247 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand All @@ -338,13 +338,13 @@ jobs:
run: |
source ${VENV_DIR}/bin/activate
pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \
--goldentime-rocm-e2e-ms 372.0 \
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-e2e-ms 325.0 \
--goldentime-rocm-unet-ms 77.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1551 \
--goldentime-rocm-vae-ms 74.0 \
--goldendispatch-rocm-unet 1691 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-vae 247 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -84,9 +82,6 @@ void BubbleUpExpandShapesPass::runOnOperation() {
};
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
bubbleUpExpansionControlFn);
LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
bubbleExpandShapePatterns, bubbleUpExpansionControlFn);

// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,11 +744,6 @@ isFusableWithProducer(OpOperand &operand,
return true;
}

// Don't fuse attention with it's producer
if (isa<LinalgExt::AttentionOp>(consumer)) {
return false;
}

if (isPackLikeOp(consumer)) {
return TypeSwitch<Operation *, bool>(producer)
.Case<tensor::PadOp>([&](auto padOp) { return true; })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -42,36 +41,23 @@ namespace {
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//

// If possible, interchange indexing maps to make input maps all identity.
struct ElementwiseOpInterchangePattern
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
genericOp.getNumDpsInputs() == 0)
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1)
return failure();

// All input maps must be equal and non-identity. All maps, including
// output, must be be permutations. Permutation maps are checked by
// isElementwise but may be removed.
AffineMap inputMap = genericOp.getIndexingMapsArray().front();
auto *initOperand = genericOp.getDpsInitOperand(0);
if (inputMap.isIdentity() || !inputMap.isPermutation() ||
!genericOp.getMatchingIndexingMap(initOperand).isPermutation()) {
AffineMap indexingMap = genericOp.getIndexingMapsArray().back();
if (indexingMap.isIdentity())
return failure();
}
for (auto *operand : genericOp.getDpsInputOperands()) {
if (genericOp.getMatchingIndexingMap(operand) != inputMap) {
return failure();
}
}

// Make all inputs identity.
ArrayRef<AffineExpr> exprs = inputMap.getResults();
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});

return linalg::interchangeGenericOp(rewriter, genericOp, perm);
}
};
Expand Down Expand Up @@ -224,7 +210,6 @@ struct FusionPreprocessingPass
// operand shapes.
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"annotate_dispatches.mlir",
"attention_fuse_by_expansion.mlir",
"capture_dispatch_dynamic_dims.mlir",
"capture_scf_for_dynamic_dims.mlir",
"cleanup_tensor_shapes.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ iree_lit_test_suite(
lit
SRCS
"annotate_dispatches.mlir"
"attention_fuse_by_expansion.mlir"
"capture_dispatch_dynamic_dims.mlir"
"capture_scf_for_dynamic_dims.mlir"
"cleanup_tensor_shapes.mlir"
Expand Down

This file was deleted.

Loading

0 comments on commit 018e137

Please sign in to comment.