From cccd1c2d8947405b1e398449b1ce60a60df11f65 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 16 Jan 2025 17:17:52 +0000 Subject: [PATCH 1/2] [DispatchCreation] Bubble expand_shape for multi use producers --- .../DispatchCreation/BubbleUpExpandShapes.cpp | 10 ++- .../test/bubble_up_expand_shapes.mlir | 73 +++++++++++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 48ee4d4d80c1..085ebe412ef8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -134,11 +134,9 @@ void BubbleUpExpandShapesPass::runOnOperation() { return false; } - // Do not fuse producer generic op if it has more than one user - // or any reduction iterators. + // Do not fuse producer generic op if it has any reduction iterators. if (auto producerGenericOp = dyn_cast(producer)) { - return producerGenericOp->hasOneUse() && - llvm::all_of(producerGenericOp.getIteratorTypesArray(), + return llvm::all_of(producerGenericOp.getIteratorTypesArray(), linalg::isParallelIterator); } @@ -206,6 +204,10 @@ void BubbleUpExpandShapesPass::runOnOperation() { bubbleExpandShapePatterns.insert(context); tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, context); + tensor::DimOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); + tensor::EmptyOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); GreedyRewriteConfig rewriteConfig; rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir index d654df337520..f0b6c34042e6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir @@ -69,3 +69,76 @@ util.func public @attention_v_reshape_propagation(%arg0: index, // CHECK-SAME: ins(%[[ATTENTION]] // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] // CHECK: return %[[COLLAPSE]] + +// ----- + +#elementwise_trait = { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] +} + +// This test could actually fuse into 1 by using elementwise fusion. We could +// in reality, use all reductions with expansion on outer parallel loops also. +// Elementwise operations are just easier to write. +util.func public @diamond_propagate_expand_shape(%input : tensor) + -> tensor<2x?x?xf16> { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + + %c1 = arith.constant 1.0 : f16 + %dim = tensor.dim %input, %c0 : tensor + %empty = tensor.empty(%dim, %dim) : tensor + + %A = linalg.generic #elementwise_trait + ins(%input : tensor) outs(%empty : tensor) { + ^bb0(%in : f16, %out : f16): + %add = arith.addf %in, %c1 : f16 + linalg.yield %add : f16 + } -> tensor + + %B = linalg.generic #elementwise_trait + ins(%A : tensor) outs(%empty : tensor) { + ^bb0(%in : f16, %out : f16): + %add = arith.addf %in, %c1 : f16 + linalg.yield %add : f16 + } -> tensor + + %C = linalg.generic #elementwise_trait + ins(%A : tensor) outs(%empty : tensor) { + ^bb0(%in : f16, %out : f16): + %add = arith.addf %in, %c1 : f16 + linalg.yield %add : f16 + } -> tensor + + // The canonical form would be to pass both inputs as ins, but for a consise + // test, we pass it as outs so we can reuse the elementwise_trait. + %D = linalg.generic #elementwise_trait + ins(%B : tensor) outs(%C : tensor) { + ^bb0(%in : f16, %out : f16): + %add = arith.addf %in, %out : f16 + linalg.yield %add : f16 + } -> tensor + + %dimA = arith.divui %dim, %c2 : index + %out = tensor.expand_shape %D [[0, 1], [2]] output_shape [2, %dimA, %dim] : + tensor into tensor<2x?x?xf16> + + util.return %out : tensor<2x?x?xf16> +} + +// Check that there is only 1 expand_shape at top +// CHECK-LABEL: diamond_propagate_expand_shape +// CHECK: tensor.expand_shape +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.generic +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.generic +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.generic +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.generic +// CHECK-NOT: tensor.expand_shape +// CHECK: util.return From 42ed10c0c8dd0a9036424a2aad4d0fbbb5faa2e5 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 17 Jan 2025 15:38:00 +0000 Subject: [PATCH 2/2] Add check for looping behavior --- .../DispatchCreation/BubbleUpExpandShapes.cpp | 35 ++++++++++++++++ .../test/bubble_up_expand_shapes.mlir | 42 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 085ebe412ef8..5158ea0804a8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -117,6 +117,37 @@ struct BubbleExpandThroughExtract final } // namespace +/// If the domain of the operation is being expanded by unit dimensions, check +/// if it's possible to have an infinite loop where the unit dim expansion keeps +/// on propagating infinitely. +static bool canCauseReshapingLoopByExpansion(Operation *producer, + Operation *consumer) { + bool isExpandingToUnitDims = false; + if (auto expandShapeOp = dyn_cast(consumer)) { + // If the expand_shape is only expanding unit dimensions and the producer + // has multiple results, there is a possibility of an infinite loop. + ArrayRef outputShape = expandShapeOp.getStaticOutputShape(); + for (auto [idx, indices] : + llvm::enumerate(expandShapeOp.getReassociationIndices())) { + if (indices.size() == 1) { + continue; + } + // Check if the output shape at any of the reassociation indices is 1. + for (int64_t ind : indices) { + if (outputShape[ind] == 1) { + isExpandingToUnitDims = true; + } + } + } + + if (isExpandingToUnitDims && producer->getNumResults() > 1) { + return true; + } + } + + return false; +} + void BubbleUpExpandShapesPass::runOnOperation() { MLIRContext *context = &getContext(); @@ -129,6 +160,10 @@ void BubbleUpExpandShapesPass::runOnOperation() { return false; } + if (canCauseReshapingLoopByExpansion(producer, consumer)) { + return false; + } + // Do not fuse by expand if consumer is dequant. if (IREE::LinalgExt::isBitExtendOp(consumer)) { return false; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir index f0b6c34042e6..6f6e5024eb4f 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir @@ -142,3 +142,45 @@ util.func public @diamond_propagate_expand_shape(%input : tensor) // CHECK: linalg.generic // CHECK-NOT: tensor.expand_shape // CHECK: util.return + +// ----- + +// Check if unit dim expansion in a cyclic expansion like graph could cause +// infinite behavior. +util.func public @test_no_infinite_loop_unit_dim_expansion(%arg0 : tensor<4xi64>, %arg1 : tensor<4xi64>, %arg3 : tensor<4xi64>) -> (tensor<4xi64>) { + %c2_i64 = arith.constant 2 : i64 + %cst = arith.constant dense<[2, 1]> : tensor<2xi64> + %c4 = arith.constant 4 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %__hoisted_tensor_4xi64 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = tensor.empty() : tensor<4xi64> + %9 = tensor.empty() : tensor<4xi64> + %10:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%9, %1 : tensor<4xi64>, tensor<4xi64>) { + ^bb0(%out: i64, %out_0: i64): + %16 = linalg.index 0 : index + %17 = arith.remsi %16, %c4 : index + %extracted = tensor.extract %arg0[%17] : tensor<4xi64> + %extracted_1 = tensor.extract %arg1[%17] : tensor<4xi64> + linalg.yield %extracted, %extracted_1 : i64, i64 + } -> (tensor<4xi64>, tensor<4xi64>) + %expanded = tensor.expand_shape %10#0 [[0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64> + %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "parallel"]} + ins(%10#1, %expanded: tensor<4xi64>, tensor<4x1xi64>) outs(%1 : tensor<4xi64>) { + ^bb0(%in: i64, %in0: i64, %out: i64): + %idx = linalg.index 1 : index + %cast = arith.index_cast %idx : index to i64 + %add = arith.addi %in, %in0: i64 + %add1 = arith.addi %add, %cast: i64 + linalg.yield %add1 : i64 + } -> tensor<4xi64> + + util.return %11 : tensor<4xi64> +} + +// CHECK-LABEL: test_no_infinite_loop_unit_dim_expansion +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.generic +// CHECK: tensor.expand_shape +// CHECK: linalg.generic +// CHECK-NOT: tensor.expand_shape