Replies: 4 comments 5 replies
-
I am generally opposed to any remaining MIGraphX MLIR dialect after conversion to TOSA. 1 simpler change that would not affect bufferization, etc. would be to add the Please add an example of the MIGraphX IR input to the problem statement so we can explore all alternatives. |
Beta Was this translation helpful? Give feedback.
-
Results can also have attributes.
This is certainly preferred. |
Beta Was this translation helpful? Give feedback.
-
Sorry, I think this was discussed when Im out -- hence missed this. A design choice to consider : I think it is fair to use 'tensor' dialect alongside TOSA. If so we should be able to do a more direct lowering to https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorextract_slice-tensorextractsliceop that preserves the semantics you are after. It would be then supported by https://github.com/ROCmSoftwarePlatform/rocMLIR-internal/issues/870. There is also https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorinsert_slice-tensorinsertsliceop
This is going to break the clone-based tests. |
Beta Was this translation helpful? Give feedback.
-
Closing because we've implemented this |
Beta Was this translation helpful? Give feedback.
-
Note: doing this depends on getting #1140 , the general perf key ticket, done, because it'll be really ugly if we don't do that first
The problem
Currently, the MLIR
migraphx
dialect maps MIGraphX's shapes (which are{sizes,...}, {strides, ...}, type
to upstream'stensor
, which throws away stride information.This hasn't been much of a problem historically, as MIGraphX would insert the appropriate
transpose
operations to reorder the shape to be in increasing-stride order ... on each input.However, MIGraphX IR expects operation, like convolution, to preserve the strides. So, if MIGraphX wants an NHWC convolution (which is represented in IR with
shape = {N, C, H, W}, stride = {CHW, 1, CW, C}
), it expects the output layout to match the input layout.However, because our representation of MIGraphX IR doesn't include stride information, we unconditionally produce an NCHW output.
This problem isn't limited to convolutions, though, and would affect any MIGraphX kernel request that has an output with a shape that isn't in "stride order".
This is, as far as I'm concerned, a historical mistake we made where we didn't design the MIGraphX dialect to properly represent the MIGraphX IR.
Proposed solution
While a short-term fix (inserting more
migraphx.transpose
ops onto outputs) has been identified, this entire decision of usingtensor<>
for MIGraphX'sshape
is likely to have irritating long-term impacts.Therefore, to substantially simplify our code and make these edge cases much easier to handle, I propose the following changes (after #1140)
#migraphx.shape<L1xL2...xLk, S1xS2...xSk, T>
All operations in the
migraphx
dialect will no longer operate on thetensor
type, but will instead use themigraphx.shaped<Sizes, Strides, Type>
type. This will be the type MIGraphX generates when it translates its IR into an MLIR module.This will also allow us to ensure we're correctly preserving every detail of MIGraphX's broadcasting semantics.
Function arguments
The one awkward thing here is function arguments, since those'll need to go through a generally unmodified MLIR bufferization flow. There, I propose a solution that'll also have useful effects for the underlying codegen:
That is, the arguments to the function become the underlying
float *, half *, ...
that you'll be passing in, and then, at the beginning of the function, we put the shape information back, and, similarly, the "returned"float *
and friends are represented as the underlying memory and then we write it in the expected pattern.During lowering
During MIGraphX to Tosa conversion, we use MLIR's type conversion system to insert the appropriate conversions.
Specifically,
mlir.arg_view
andmlir.result_view
become the relevantrock.transform[Embed{}]
to produce a tensor whose shape is what Tosa expects (the logical shape) but which writes in the underlying strided form. (In a Rock-less lowering, this would bememref.reinterpret_cast
once someone defined a way to sneak it throughtensor
)tensor<LxT>
on the way to Tosa@sjw36 @pfultz2 @kahmed10 @manupak
Beta Was this translation helpful? Give feedback.
All reactions