Skip to content

Commit

Permalink
[MLIR] Add fast path for lowering scatter (#1214)
Browse files Browse the repository at this point in the history
**Context:** The semantics of
[`mhlo.scatter`](https://www.tensorflow.org/mlir/hlo_ops#mhloscatter_mhloscatterop)
and
[`stablehlo.scatter`](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter)
are the same. Currently we have a lowering from `mhlo.scatter` to
upstream MLIR dialects (a mixture of func and scf dialects). The current
implementation will lower the `mhlo.scatter` operation into a loop and
move element by element values into the result tensor using
`tensor.insert`.

There are some special cases where it is possible to use
[`tensor.insert_slice`](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorinsert_slice-tensorinsertsliceop)
instead of [`tensor.insert`.
](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorinsert-tensorinsertop)The
difference between `tensor.insert` and `tensor.insert_slice` is that
`tensor.insert` inserts scalar elements into tensors while
`tensor.insert_slice` may insert a tensor into a larger tensor.

This has implications on performance. While both lowerings of scatter
should be equivalent, `tensor.insert_slice` will be lowered to
`memref.subview` and `memref.copy` which lowers to a single `memcpy`.
This is the root cause of the performance issues described in #1153.

**Description of the Change:** This PR adds an optimized lowering to
`mhlo.scatter` to `tensor.insert_slice` in cases we can detect this
lowering preserves the same semantics.

This detection can be generalized in the future. It currently makes the
following checks:

1. **`unique_indices` and `indices_are_sorted` are both true**: [The
semantics state the
following:](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter)

> If indices_are_sorted is true then the implementation can assume that
scatter_indices are sorted with respect to scatter_dims_to_operand_dims,
otherwise the behavior is undefined. More formally, for all i1 < i2 from
indices(result), full_start_index(i1) <= full_start_index(i2).
>
> If unique_indices is true then the implementation can assume that all
result_index indices being scattered to are unique. If unique_indices is
true but the indices being scattered to are not unique then the behavior
is undefined.

I believe these are undefined to keep the possibility of implementing
`stablehlo.scatter` in parallel.

2. **Only one `%input` , `%result` and `%update` tensor**: The operands
`%input` and `%update` are of variable length. To generalize the
lowering to `tensor.insert_slice` one could run a pass earlier that will
canonicalize the `scatter` operation into a series of `scatter`
operations with a single `%input` , `%result` and `%update` tensor
specified in the operation.
3. **Restricting `update_computation` to assignment of
`%update_tensor`**. The `stablehlo.scatter` operation is more general
than a simple assignment. It allows for the `%input` tensor to be
updated with `%update` tensor values by running the function
`update_computation` on the corresponding elements. We restrict
`update_computation` to be the assignment to the `%update` values: I.e.,
```
({
      ^bb0(%input_element: tensor<T>, %update_element: tensor<T>):
        mhlo.return %update_element : tensor<T>
      })
```
This restriction may be relaxed by first computing a temporary tensor
that will hold the result of the `update_computation` and replacing the
uses of `update` with this new tensor and using the assignment function
above.
4. **No batching**: Our current version of MLIR does not support the
batching attributes in the operation. To generalize this more
investigation is required.
5. **Single full slice**: This means that we are going to assign the
whole `update` tensor to the `input` tensor and not just a subset. We
could generalize this by using dynamic sizes when generating the
`tensor.insert_slice` operation.
6. **rank(%scatter_indices) == 1** and **indexVectorDim ==
scatterIndicesTy.getRank() - 1**: This implies that the scatter indices
are valid coordinates and do not need to be treated as tensors. To
generalize this would imply looping over the number of valid indices
depending on the shape of the scatter indices and generating a single
`tensor.insert_slice` operation for each iteration.

**Benefits:** Performance

**Possible Drawbacks:** I would feel more comfortable having more time
and upstreaming this to stableHLO. That way, we can get a review from
the StableHLO team to make sure the semantics are correct since this is
a somewhat complex operation.

**Related GitHub Issues:** Fixes #1153

[sc-76025]

---------

Co-authored-by: Romain Moyard <[email protected]>
Co-authored-by: Haochen Wang <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent 22a900f commit af9f919
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 15 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@
- Registers the func dialect as a requirement for running the scatter lowering pass.
- Emits error if `%input`, `%update` and `%result` are not of length 1 instead of segfaulting.

* Fixes a performance issue with vmap with its root cause in the
lowering of the scatter operation.
[(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214)

<h3>Internal changes</h3>

* Remove deprecated pennylane code across the frontend.
Expand Down
16 changes: 6 additions & 10 deletions frontend/catalyst/api_extensions/function_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ def __call__(self, *args, **kwargs):
fn_args = tree_unflatten(args_tree, fn_args_flat)

# Run 'fn' one time to get output-shape
init_result = self.fn(*fn_args, **kwargs)
_, shape = jax.make_jaxpr(self.fn, return_shape=True)(*fn_args, **kwargs)
shapes, init_result_tree = tree_flatten(shape)
init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape in shapes]
init_result = tree_unflatten(init_result_tree, init_result_flat)

# Check the validity of the output w.r.t. out_axes
out_axes_deep_struct = tree_structure(self.out_axes, is_leaf=lambda x: x is None)
Expand All @@ -238,8 +241,6 @@ def __call__(self, *args, **kwargs):
f"{out_axes_deep_struct} axis specifiers and {init_result_deep_struct} results."
)

init_result_flat, init_result_tree = tree_flatten(init_result)

num_axes_out = len(init_result_flat)

if isinstance(self.out_axes, int):
Expand All @@ -255,16 +256,11 @@ def __call__(self, *args, **kwargs):
# in the flatten format with respect to the 'init_result' shape
batched_result_list = []
for j in range(num_axes_out):
out_shape = (
(batch_size,)
if not init_result_flat[j].shape
else (batch_size, *init_result_flat[j].shape)
)
out_shape = (batch_size, *init_result_flat[j].shape)
batched_result_list.append(jnp.zeros(shape=out_shape, dtype=init_result_flat[j].dtype))
batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])

# Apply mapping batched_args[1:] ---> fn(args)
@for_loop(1, batch_size, 1)
@for_loop(0, batch_size, 1)
def loop_fn(i, batched_result_list):
fn_args_flat = args_flat
for loc in batch_loc:
Expand Down
1 change: 0 additions & 1 deletion frontend/test/lit/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def circuit():
# CHECK: mitigation.zne @one_shot_wrapper(%c) folding( global) numFolds(%2 : tensor<2xi64>) : (tensor<5xi1>) -> tensor<2xf64>

# CHECK: func.func private @one_shot_wrapper(%arg0: tensor<5xi1>) -> tensor<f64>
# CHECK: catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<f64>
# CHECK: scf.for
# CHECK: catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<f64>
print(mcm_method_with_zne.mlir)
1 change: 1 addition & 0 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def ScatterLoweringPass : Pass<"scatter-lowering"> {
"mlir::func::FuncDialect",
"index::IndexDialect",
"mhlo::MhloDialect",
"tensor::TensorDialect",
"scf::SCFDialect"
];

Expand Down
229 changes: 229 additions & 0 deletions mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern<mhlo::ScatterOp>

mlir::LogicalResult onlyOneInputUpdateAndResult(mhlo::ScatterOp op) const
{
// Semantics of scatter:
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
// Assumption 1: only one input, one update, and one result
// * size(inputs) == 1
// * size(updates) == 1
Expand All @@ -50,9 +52,236 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern<mhlo::ScatterOp>
return op.getResults().size() == 1 ? success() : failure();
}

mlir::LogicalResult isAssignment(mhlo::ScatterOp op) const
{
// From:
// C23: update_computation has type
// (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ...,
// tensor<EN-1>) , where is_promotable(element_type(inputs[i]), Ei)
//
// On the description of the schedule:
// updated_values = update_computation(results...[result_index], updates_converted)
//
// It follows that:
// We are guaranteed that the update_computation
// function only has two parameters and one result.
// One parameter that corresponds to the
// result at the result_index
// and the single updates_converted_values
// This means that if the only operation inside the update_computation
// function is returning the second argument, then we are just assigning the update
// value to the result.
Region &region = op.getUpdateComputation();
Block &block = region.front();
bool oneOperation = block.begin() == --block.end();
if (!oneOperation) {
return failure();
}

mhlo::ReturnOp returnOp = dyn_cast<mhlo::ReturnOp>(block.getTerminator());
if (!returnOp) {
return failure();
}

return returnOp.getResults().front() == block.getArgument(1) ? success() : failure();
}

mlir::LogicalResult noBatching(mhlo::ScatterOp op) const
{
// Ok, now that we know it is an assignment, we need to worry about
// where exactly are we assigning and what are we assigning.
// First let's worry about the what we are assigning.
// It needs to be a proper slice. No preprocessing of anyway.
// What kind of preprocessing exists?
// * Batching for input
// * Batching for indices
//
// From:
// (C13) 0 <= input_batching_dims < rank(inputs[0])).
// (C17) size(input_batching_dims) == size(scatter_indices_batching_dims)
// Implies:
// If there is no input_batching_dims and no scatter_indices_batching
// TODO: This will always be success until we update our version of mlir-hlo.
// It looks we are using an old version where getInputBatchingDims was not yet available.
// See here:
// https://github.com/tensorflow/mlir-hlo/commit/5ac7c579c52ef02b13c29886a98672c2ade7c9b0
return success();
// Until then, keep this code commented:
// auto scatterDimNumbers = op.getScatterDimensionNumbers();
// return scatterDimNumbers.getInputBatchingDims().empty() ? success() : failure();
}

mlir::LogicalResult singleFullSlices(mhlo::ScatterOp op) const
{
// From:
// More formally, for all update_index in index_space(updates[0]):
// * update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
// * update_scatter_index = update_index[update_scatter_dims...]
// we want update_scatter_index to be empty. This would mean that:
// scatter_indices points to a location in the input tensor and the corresponding
// update value is a full window that is inserted at that location.
// So we have a single update
auto update = op.getUpdates().front();
// And we need to make sure that all of its axes are in the update_window_dims.
// From:
// (C7) is_unique(update_window_dims) and is_sorted(update_window_dims)
// Implies
auto updateTy = cast<RankedTensorType>(update.getType());
auto scatterDimNumbers = op.getScatterDimensionNumbers();
size_t rank = updateTy.getRank();
return rank == scatterDimNumbers.getUpdateWindowDims().size() ? success() : failure();
}

mlir::LogicalResult canBeDoneWithSingleTensorInsertSlice(mhlo::ScatterOp op) const
{
return cast<RankedTensorType>(op.getScatterIndices().getType()).getRank() == 1 ? success()
: failure();
}

mlir::LogicalResult lowerToTensorInsertSlice(mhlo::ScatterOp op,
mlir::PatternRewriter &rewriter) const
{
// mhlo::ScatterOp is exactly the same as stablehlo::ScatterOp
// See https://www.tensorflow.org/mlir/hlo_ops#mhloscatter_mhloscatterop
// and https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
//
// From https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter:
//
// Semantics
//
// Produces results tensors which are equal to inputs tensors
// except that several slices specified by scatter_indices
// are updated with the values updates using update_computation.
//
// These simple semantics are obscured a bit by too many other details.
//
// Let's make some simplifying assumptions

// Add checks for supported cases (assumptions: no update windows dim, unique indices and
// sorted indices)
if (!op.getUniqueIndices() || !op.getIndicesAreSorted()) {
op.emitError() << "Indices are not unique and/or not sorted, unique boolean: "
<< op.getUniqueIndices()
<< ", sorted boolean :" << op.getIndicesAreSorted();
return failure();
}

// size(%result) == size(%update) == size(%input) == 1
if (failed(this->onlyOneInputUpdateAndResult(op))) {
return failure();
}
auto input = op.getInputs().front();
auto update = op.getUpdates().front();
auto scatterIndices = op.getScatterIndices();

// update_function =
// ^bb0(%arg0: T, %arg1: T):
// stablehlo.return %arg1 : T
// })
if (failed(this->isAssignment(op))) {
return failure();
}

// input_batching_dims = []
// scatter_indices_batching_dims = []
if (failed(this->noBatching(op))) {
return failure();
}

// rank(%update) == size(update_window_dims)
// => we are inserting the whole %update into a dimension of %input
if (failed(this->singleFullSlices(op))) {
return failure();
}

// Now, where are we going to insert this full slice?
// scatter_indices is typed as tensor of integer type
// So, normally I would need a loop around the scatter_indices.
// But let's assume that scatter_indices is a tensor of rank 1
// If this is not true, we would need to create a loop?
// rank(%scatter_indices) == 1
if (failed(this->canBeDoneWithSingleTensorInsertSlice(op))) {
return failure();
}

auto inputTy = cast<RankedTensorType>(input.getType());
auto updateTy = cast<RankedTensorType>(update.getType());
auto inputShape = inputTy.getShape();
auto updateShape = updateTy.getShape();
auto scatterIndicesTy = cast<RankedTensorType>(scatterIndices.getType());
// (C24) shape(%result) == shape(%input)

auto scatterDimNumbers = op.getScatterDimensionNumbers();
auto insertedWindowDims = scatterDimNumbers.getInsertedWindowDims();
auto scatterDimsToOperandDims = scatterDimNumbers.getScatterDimsToOperandDims();
auto indexVectorDim = scatterDimNumbers.getIndexVectorDim();

if (indexVectorDim != scatterIndicesTy.getRank() - 1) {
// TODO: I think if indexVectorDim > 0
// implies a loop of insert_slices.
return failure();
}
// Because we said before
// rank(%scatter_indices) == 1
// => indexVectorDim = 0

SmallVector<Value> dynOffsets, dynSizes, dynStrides;
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
Location loc = op.getLoc();

// TODO: upstream to mlir-hlo and stablehlo
for (size_t i = 0, inputDim = 0, updateDim = 0; i < inputShape.size(); i++) {
if (llvm::is_contained(insertedWindowDims, i)) {
int scatterDimIndex = scatterDimsToOperandDims[inputDim];
Value scatterDimVal = rewriter.create<index::ConstantOp>(loc, scatterDimIndex);
auto extractOp =
rewriter.create<tensor::ExtractOp>(loc, scatterIndices, scatterDimVal)
.getResult();
auto indexCastOp =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), extractOp)
.getResult();
dynOffsets.push_back(indexCastOp);
staticOffsets.push_back(ShapedType::kDynamic);
staticSizes.push_back(1);
}
else if (updateDim == inputDim) {
int scatterDimIndex = scatterDimsToOperandDims[inputDim];
Value scatterDimVal = rewriter.create<index::ConstantOp>(loc, scatterDimIndex);
auto extractOp =
rewriter.create<tensor::ExtractOp>(loc, scatterIndices, scatterDimVal)
.getResult();
auto indexCastOp =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), extractOp)
.getResult();
dynOffsets.push_back(indexCastOp);
staticOffsets.push_back(ShapedType::kDynamic);
staticSizes.push_back(updateShape[updateDim]);
updateDim++;
}
else {
staticOffsets.push_back(0);
staticSizes.push_back(updateShape[updateDim]);
updateDim++;
}
inputDim++;
staticStrides.push_back(1);
}

rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(op, update, input, dynOffsets, dynSizes,
dynStrides, staticOffsets, staticSizes,
staticStrides);

return success();
}

mlir::LogicalResult matchAndRewrite(mhlo::ScatterOp op,
mlir::PatternRewriter &rewriter) const override
{
// FastPath
if (!failed(this->lowerToTensorInsertSlice(op, rewriter))) {
return success();
}

if (failed(onlyOneInputUpdateAndResult(op))) {
// Otherwise it will segfault.
op.emitError() << "Only one input, update, and result";
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/scatter_lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down
Loading

0 comments on commit af9f919

Please sign in to comment.