-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR] Add fast path for lowering scatter (#1214)
**Context:** The semantics of [`mhlo.scatter`](https://www.tensorflow.org/mlir/hlo_ops#mhloscatter_mhloscatterop) and [`stablehlo.scatter`](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter) are the same. Currently we have a lowering from `mhlo.scatter` to upstream MLIR dialects (a mixture of func and scf dialects). The current implementation will lower the `mhlo.scatter` operation into a loop and move element by element values into the result tensor using `tensor.insert`. There are some special cases where it is possible to use [`tensor.insert_slice`](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorinsert_slice-tensorinsertsliceop) instead of [`tensor.insert`. ](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorinsert-tensorinsertop)The difference between `tensor.insert` and `tensor.insert_slice` is that `tensor.insert` inserts scalar elements into tensors while `tensor.insert_slice` may insert a tensor into a larger tensor. This has implications on performance. While both lowerings of scatter should be equivalent, `tensor.insert_slice` will be lowered to `memref.subview` and `memref.copy` which lowers to a single `memcpy`. This is the root cause of the performance issues described in #1153. **Description of the Change:** This PR adds an optimized lowering to `mhlo.scatter` to `tensor.insert_slice` in cases we can detect this lowering preserves the same semantics. This detection can be generalized in the future. It currently makes the following checks: 1. **`unique_indices` and `indices_are_sorted` are both true**: [The semantics state the following:](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter) > If indices_are_sorted is true then the implementation can assume that scatter_indices are sorted with respect to scatter_dims_to_operand_dims, otherwise the behavior is undefined. More formally, for all i1 < i2 from indices(result), full_start_index(i1) <= full_start_index(i2). > > If unique_indices is true then the implementation can assume that all result_index indices being scattered to are unique. If unique_indices is true but the indices being scattered to are not unique then the behavior is undefined. I believe these are undefined to keep the possibility of implementing `stablehlo.scatter` in parallel. 2. **Only one `%input` , `%result` and `%update` tensor**: The operands `%input` and `%update` are of variable length. To generalize the lowering to `tensor.insert_slice` one could run a pass earlier that will canonicalize the `scatter` operation into a series of `scatter` operations with a single `%input` , `%result` and `%update` tensor specified in the operation. 3. **Restricting `update_computation` to assignment of `%update_tensor`**. The `stablehlo.scatter` operation is more general than a simple assignment. It allows for the `%input` tensor to be updated with `%update` tensor values by running the function `update_computation` on the corresponding elements. We restrict `update_computation` to be the assignment to the `%update` values: I.e., ``` ({ ^bb0(%input_element: tensor<T>, %update_element: tensor<T>): mhlo.return %update_element : tensor<T> }) ``` This restriction may be relaxed by first computing a temporary tensor that will hold the result of the `update_computation` and replacing the uses of `update` with this new tensor and using the assignment function above. 4. **No batching**: Our current version of MLIR does not support the batching attributes in the operation. To generalize this more investigation is required. 5. **Single full slice**: This means that we are going to assign the whole `update` tensor to the `input` tensor and not just a subset. We could generalize this by using dynamic sizes when generating the `tensor.insert_slice` operation. 6. **rank(%scatter_indices) == 1** and **indexVectorDim == scatterIndicesTy.getRank() - 1**: This implies that the scatter indices are valid coordinates and do not need to be treated as tensors. To generalize this would imply looping over the number of valid indices depending on the shape of the scatter indices and generating a single `tensor.insert_slice` operation for each iteration. **Benefits:** Performance **Possible Drawbacks:** I would feel more comfortable having more time and upstreaming this to stableHLO. That way, we can get a review from the StableHLO team to make sure the semantics are correct since this is a somewhat complex operation. **Related GitHub Issues:** Fixes #1153 [sc-76025] --------- Co-authored-by: Romain Moyard <[email protected]> Co-authored-by: Haochen Wang <[email protected]>
- Loading branch information
1 parent
22a900f
commit af9f919
Showing
7 changed files
with
408 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.