Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Bubble expand_shape for multi use producers #19718

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Groverkss
Copy link
Contributor

@Groverkss Groverkss commented Jan 16, 2025

This pr enables bubbling up expand_shape's through producers with multiple users. This allows propagation of expand_shape through a diamond like graph:

A = op1(in)
B = op2(A)
C = op2(A)
D = op3(B, C)
out = expand_shape D

This can get stuck at multiple expand_shape users

A = op1(in)
A1 = expand_shape A
A2 = expand_shape A
B = op2(A1)
C = op2(A2)
D = op3(B, C)

This PR enables the expand_shape to be propagated completly:

in_A = expand_shape A
A = op1(in_A)
B = op2(A)
C = op2(A)
D = op3(B, C)

This PR also adds a check when the diamond like expand_shape can cause cycles when expanding unit dimensions.

Copy link
Contributor

@IanWood1 IanWood1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me, but a similar PR (reverted at 7d21c5d) was reverted so let me quickly check that there are no regressions with llama performance

Edit: Didn't see any regressions ✅

@Groverkss
Copy link
Contributor Author

This makes sense to me, but a similar PR (reverted at 7d21c5d) was reverted so let me quickly check that there are no regressions with llama performance

Edit: Didn't see any regressions ✅

Yeah I think that one also allows reductions, which is why we were having regressions

static bool canCauseReshapingLoopByExpansion(Operation *producer,
Operation *consumer) {
bool isExpandingToUnitDims = false;
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need the same check for collapse_shape right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collapse_shape will eventually propagate up as an expand_shape, so it will stop the loop from ever happening.

The bigger reason why I only implemented it for expand_shape is that it's less common for an operation to have multiple results than an operation to have multiple inputs. So i'd rather have a conservative check for multiple results since that occurs less.

@Groverkss Groverkss requested a review from qedawkins January 17, 2025 16:07
Comment on lines +143 to +145
if (isExpandingToUnitDims && producer->getNumResults() > 1) {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to check number of users (my bad I think I suggested looking at the number of results). The example below just has 1 result but causes infinite looping:

  util.func public @test_no_infinite_loop_unit_dim_expansion(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>, %arg2: tensor<4xi64>) -> tensor<4xi64> {
    %c4 = arith.constant 4 : index
    %0 = tensor.empty() : tensor<4xi64>
    %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%0 : tensor<4xi64>) {
    ^bb0(%out: i64):
      %3 = linalg.index 0 : index
      %4 = arith.remsi %3, %c4 : index
      %extracted = tensor.extract %arg0[%4] : tensor<4xi64>
      linalg.yield %extracted : i64
    } -> tensor<4xi64>
    %expanded = tensor.expand_shape %1 [[0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64>
    %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "parallel"]} ins(%1, %expanded : tensor<4xi64>, tensor<4x1xi64>) outs(%0 : tensor<4xi64>) {
    ^bb0(%in: i64, %in_0: i64, %out: i64):
      %3 = linalg.index 1 : index
      %4 = arith.index_cast %3 : index to i64
      %5 = arith.addi %in, %in_0 : i64
      %6 = arith.addi %5, %4 : i64
      linalg.yield %6 : i64
    } -> tensor<4xi64>
    util.return %2 : tensor<4xi64>
  }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants