-
Notifications
You must be signed in to change notification settings - Fork 641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DispatchCreation] Bubble expand_shape for multi use producers #19718
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me, but a similar PR (reverted at 7d21c5d) was reverted so let me quickly check that there are no regressions with llama performance
Edit: Didn't see any regressions ✅
Yeah I think that one also allows reductions, which is why we were having regressions |
static bool canCauseReshapingLoopByExpansion(Operation *producer, | ||
Operation *consumer) { | ||
bool isExpandingToUnitDims = false; | ||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need the same check for collapse_shape right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The collapse_shape will eventually propagate up as an expand_shape, so it will stop the loop from ever happening.
The bigger reason why I only implemented it for expand_shape is that it's less common for an operation to have multiple results than an operation to have multiple inputs. So i'd rather have a conservative check for multiple results since that occurs less.
if (isExpandingToUnitDims && producer->getNumResults() > 1) { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to check number of users (my bad I think I suggested looking at the number of results). The example below just has 1 result but causes infinite looping:
util.func public @test_no_infinite_loop_unit_dim_expansion(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>, %arg2: tensor<4xi64>) -> tensor<4xi64> {
%c4 = arith.constant 4 : index
%0 = tensor.empty() : tensor<4xi64>
%1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%0 : tensor<4xi64>) {
^bb0(%out: i64):
%3 = linalg.index 0 : index
%4 = arith.remsi %3, %c4 : index
%extracted = tensor.extract %arg0[%4] : tensor<4xi64>
linalg.yield %extracted : i64
} -> tensor<4xi64>
%expanded = tensor.expand_shape %1 [[0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64>
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "parallel"]} ins(%1, %expanded : tensor<4xi64>, tensor<4x1xi64>) outs(%0 : tensor<4xi64>) {
^bb0(%in: i64, %in_0: i64, %out: i64):
%3 = linalg.index 1 : index
%4 = arith.index_cast %3 : index to i64
%5 = arith.addi %in, %in_0 : i64
%6 = arith.addi %5, %4 : i64
linalg.yield %6 : i64
} -> tensor<4xi64>
util.return %2 : tensor<4xi64>
}
This pr enables bubbling up expand_shape's through producers with multiple users. This allows propagation of expand_shape through a diamond like graph:
This can get stuck at multiple expand_shape users
This PR enables the expand_shape to be propagated completly:
This PR also adds a check when the diamond like expand_shape can cause cycles when expanding unit dimensions.