diff --git a/.github/ci_script/file_guard.py b/.github/ci_script/file_guard.py index 574284a..52d7ee8 100644 --- a/.github/ci_script/file_guard.py +++ b/.github/ci_script/file_guard.py @@ -3,9 +3,10 @@ import os import argparse + def file_guard(guard_status_file, guard_log_file): # where stores the last position that pointer pointed to. - where= 0 + where = 0 while True: file = open(guard_log_file, "r") file.seek(where) @@ -28,11 +29,18 @@ def file_guard(guard_status_file, guard_log_file): exit(-1) # sleep for a while time.sleep(2) + + if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Monitor a log file and echo lines, check status to stop.") - parser.add_argument('guard_status_file', type=str, help='The path to the status file.') - parser.add_argument('guard_log_file', type=str, help='The path to the log file.') + parser = argparse.ArgumentParser( + description="Monitor a log file and echo lines, check status to stop.") + parser.add_argument('guard_status_file', + type=str, + help='The path to the status file.') + parser.add_argument('guard_log_file', + type=str, + help='The path to the log file.') args = parser.parse_args() - + file_guard(args.guard_status_file, args.guard_log_file) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index f759bb4..a3c8c1e 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -5,7 +5,6 @@ add_llvm_executable(triton-linalg-opt triton-linalg-opt.cpp PARTIAL_SOURCES_INTE llvm_update_compile_flags(triton-linalg-opt) target_link_libraries(triton-linalg-opt PRIVATE - ArithTransforms AuxiliaryTransforms LinalgExtTransforms TritonLinalgAnalysis diff --git a/bin/RegisterTritonLinalgDialects.h b/bin/RegisterTritonLinalgDialects.h index 4990b8d..7bfe49a 100644 --- a/bin/RegisterTritonLinalgDialects.h +++ b/bin/RegisterTritonLinalgDialects.h @@ -4,11 +4,11 @@ #include "triton-linalg/Dialect/Auxiliary/Transforms/AuxOpTilingInterface.h" #include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "triton-linalg/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.h" +#include "triton-linalg/Dialect/MathExt/IR/MathExt.h" #include "triton-linalg/Dialect/Triton/Transforms/InferAxisInfoInterfaceImpl.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton-linalg/Conversion/Passes.h" -#include "triton-linalg/Dialect/Arith/Transforms/Passes.h" #include "triton-linalg/Dialect/Triton/Transforms/Passes.h" inline void registerTritonLinalgDialects(mlir::DialectRegistry ®istry) { @@ -17,6 +17,7 @@ inline void registerTritonLinalgDialects(mlir::DialectRegistry ®istry) { // TritonLinalg. registry.insert(); registry.insert(); + registry.insert(); mlir::triton::aux::registerTilingInterfaceExternalModels(registry); mlir::triton::linalg_ext::registerTilingInterfaceExternalModels(registry); @@ -26,7 +27,6 @@ inline void registerTritonLinalgDialects(mlir::DialectRegistry ®istry) { } inline void registerTritonLinalgPasses() { - ::mlir::triton::arith_ext::registerArithExtPasses(); ::mlir::triton::registerTritonLinalgConversionPasses(); ::mlir::triton::registerTritonTransformsExtendPasses(); } diff --git a/include/triton-linalg/Analysis/AxisInfoAnalysis.h b/include/triton-linalg/Analysis/AxisInfoAnalysis.h index 9f7754b..7206aad 100644 --- a/include/triton-linalg/Analysis/AxisInfoAnalysis.h +++ b/include/triton-linalg/Analysis/AxisInfoAnalysis.h @@ -30,6 +30,12 @@ namespace triton { class AxisInfoLattice : public mlir::dataflow::Lattice { public: using Lattice::Lattice; + ChangeResult join(const AxisInfoExt &rhs); + bool isInitialized() { return initialized; } + +private: + bool initialized = false; + using mlir::dataflow::Lattice::join; }; //===--------------------------------------------------------------------===// diff --git a/include/triton-linalg/CMakeLists.txt b/include/triton-linalg/CMakeLists.txt index 0cfbd52..629c08a 100644 --- a/include/triton-linalg/CMakeLists.txt +++ b/include/triton-linalg/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) -add_subdirectory(Interfaces) diff --git a/include/triton-linalg/Conversion/LinalgCommon/Pattern.h b/include/triton-linalg/Conversion/LinalgCommon/Pattern.h index a7e293d..4c7291f 100644 --- a/include/triton-linalg/Conversion/LinalgCommon/Pattern.h +++ b/include/triton-linalg/Conversion/LinalgCommon/Pattern.h @@ -34,12 +34,12 @@ template class GenericOpPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // Remain unchanged if one of operands is scalar. if (!llvm::all_of(adaptor.getOperands(), - [&](Value v) { return v.getType().isa(); })) { + [&](Value v) { return isa(v.getType()); })) { return failure(); } // Apply only if all operands are not scalar. auto loc = op.getLoc(); - auto resType = op.getType().template cast(); + auto resType = cast(op.getType()); auto initDims = getDims(rewriter, loc, op->getOperand(0)); Value initTensor = rewriter.create( loc, initDims, resType.getElementType()); diff --git a/include/triton-linalg/Conversion/Passes.h b/include/triton-linalg/Conversion/Passes.h index f955bbc..0ab201a 100644 --- a/include/triton-linalg/Conversion/Passes.h +++ b/include/triton-linalg/Conversion/Passes.h @@ -12,7 +12,6 @@ #include "triton-linalg/Conversion/ArithToLinalg/ArithToLinalg.h" #include "triton-linalg/Conversion/MathToLinalg/MathToLinalg.h" #include "triton-linalg/Conversion/TritonToLinalg/TritonToLinalg.h" -#include "triton-linalg/Conversion/TritonToTensor/TritonToTensor.h" namespace mlir { class Pass; diff --git a/include/triton-linalg/Conversion/Passes.td b/include/triton-linalg/Conversion/Passes.td index 2973c3c..5cf1142 100644 --- a/include/triton-linalg/Conversion/Passes.td +++ b/include/triton-linalg/Conversion/Passes.td @@ -24,11 +24,4 @@ def MathToLinalgPass : Pass<"convert-math-to-linalg"> { ]; } -def TritonToTensorPass : Pass<"convert-triton-to-tensor", "ModuleOp"> { - let summary = "Convert the operations from the Triton dialect into the Tensor dialect"; - let constructor = "mlir::triton::createTritonToTensorPass()"; - let dependentDialects = [ - "triton::TritonDialect", "tensor::TensorDialect", - ]; -} #endif // TRITON_LINALG_CONVERSION_PASSES_TD diff --git a/include/triton-linalg/Conversion/TritonToLinalg/TritonPointerConversion.h b/include/triton-linalg/Conversion/TritonToLinalg/TritonPointerConversion.h index 0bbebfd..1399069 100644 --- a/include/triton-linalg/Conversion/TritonToLinalg/TritonPointerConversion.h +++ b/include/triton-linalg/Conversion/TritonToLinalg/TritonPointerConversion.h @@ -152,7 +152,9 @@ class DimInfo { int64_t getContigSize() const { return contigSize; } int64_t getBroadcastSize() const { return broadcastSize; } int64_t getDimSize() const { return dimSize; } - bool isBroadcastDim() const { return getContigSize() == 1; } + bool isBroadcastDim() const { + return getContigSize() == 1 && getDimSize() != 1; + } private: int64_t contigSize = -1; @@ -331,13 +333,46 @@ class TritonPtrScatterConversionBase class TritonTensorPtrLoadStoreOpConversionBase : public TritonPtrConversionBase { protected: - /// Get the actual size of each dim needed to be load, if boundaryCheck is - /// true, return min(tensorShape[dim], dimSize[dim] - offset[dim]). - SmallVector - getActualSizes(Location loc, std::optional> boundaryCheck, - ArrayRef tensorShape, - const TensorPointerMetaInfoTracker &tracker, - ConversionPatternRewriter &rewriter) const; + /// Actual offsets, padLeftSizes and sizes. + /// + /// For example, in a certain dimension, there are several quantities to + /// describe the actual data range. [0, `shape`) represents the valid data + /// range, `offset` represents the offset value, and `blockShape` represents + /// the size of the data block being retrieved. + /// + /// There are 3 cases for the position of `offset`. + /// + /// Case1: 0 <= offset < shape + /// offset = offset + /// padLeftSize = 0 + /// size = min(shape - offset, blockShape) + /// + /// Case2: offset < 0 + /// offset = 0 + /// padLeftSize = min(-offset, blockShape) + /// size = min(shape, blockShape - padLeftSize) + /// + /// Case3: offset >= shape + /// offset = offset + /// padLeftSize = shape + /// size = min(0, blockShape) = 0 + /// + /// These cases can be summarized by the following formula. + /// originOffset = offset + /// offset = max(offset, 0) + /// padLeftSize = min(offset - originOffset, blockShape) + /// size = min(max(shape - offset, 0), blockShape - padLeftSize) + struct PtrInfo { + SmallVector offsets; + SmallVector padLeftSizes; + SmallVector sizes; + }; + + /// Get the actual ptrinfo of each dim needed to be load. + PtrInfo getPtrInfo(Location loc, std::optional> boundaryCheck, + ArrayRef tensorShape, + const TensorPointerMetaInfoTracker &tracker, + ConversionPatternRewriter &rewriter) const; SmallVector getDimInfos(ArrayRef strides, ArrayRef tensorShape) const; diff --git a/include/triton-linalg/Conversion/TritonToTensor/TritonToTensor.h b/include/triton-linalg/Conversion/TritonToTensor/TritonToTensor.h deleted file mode 100644 index cced54a..0000000 --- a/include/triton-linalg/Conversion/TritonToTensor/TritonToTensor.h +++ /dev/null @@ -1,20 +0,0 @@ -//===- TritonToTensor.h - Triton to Tensor dialect convension ---*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_LINALG_CONVERSION_TRITONTOTENSOR_TRITONTOTENSOR_H -#define TRITON_LINALG_CONVERSION_TRITONTOTENSOR_TRITONTOTENSOR_H - -#include - -namespace mlir { -class Pass; -namespace triton { -/// Create a pass to convert a subset of Triton ops to Tensor ops. -std::unique_ptr createTritonToTensorPass(); -} // namespace triton -} // namespace mlir - -#endif // TRITON_LINALG_CONVERSION_TRITONTOTENSOR_TRITONTOTENSOR_H diff --git a/include/triton-linalg/Dialect/Arith/CMakeLists.txt b/include/triton-linalg/Dialect/Arith/CMakeLists.txt deleted file mode 100644 index e31af32..0000000 --- a/include/triton-linalg/Dialect/Arith/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Transforms) diff --git a/include/triton-linalg/Dialect/Arith/Transforms/CMakeLists.txt b/include/triton-linalg/Dialect/Arith/Transforms/CMakeLists.txt deleted file mode 100644 index 28d9b27..0000000 --- a/include/triton-linalg/Dialect/Arith/Transforms/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArithExt) -add_public_tablegen_target(ArithTransformsIncGen) diff --git a/include/triton-linalg/Dialect/Arith/Transforms/PassDetail.h b/include/triton-linalg/Dialect/Arith/Transforms/PassDetail.h deleted file mode 100644 index 68964e2..0000000 --- a/include/triton-linalg/Dialect/Arith/Transforms/PassDetail.h +++ /dev/null @@ -1,33 +0,0 @@ -//===- PassDetail.h - Details for arith transforms --------------*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSDETAIL_H -#define TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSDETAIL_H -// IWYU pragma: begin_keep -#include "mlir/IR/DialectRegistry.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { - -// Forward declaration from Dialect.h -template -void registerDialect(DialectRegistry ®istry); - -namespace arith { -class ArithDialect; -} // namespace arith - -namespace triton { -namespace arith_ext { -// IWYU pragma: end_keep -#define GEN_PASS_CLASSES -#include "triton-linalg/Dialect/Arith/Transforms/Passes.h.inc" - -} // namespace arith_ext -} // namespace triton -} // namespace mlir - -#endif // TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSDETAIL_H diff --git a/include/triton-linalg/Dialect/Arith/Transforms/Passes.h b/include/triton-linalg/Dialect/Arith/Transforms/Passes.h deleted file mode 100644 index 5622813..0000000 --- a/include/triton-linalg/Dialect/Arith/Transforms/Passes.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- Passes.h - Passes for arith ------------------------------*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSES_H -#define TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSES_H - -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "triton-linalg/Dialect/Arith/Transforms/PassDetail.h" - -namespace mlir { -namespace triton { -namespace arith_ext { - -std::unique_ptr createArithCanonicalizerPass(); - -#define GEN_PASS_REGISTRATION -#include "triton-linalg/Dialect/Arith/Transforms/Passes.h.inc" - -} // namespace arith_ext -} // namespace triton -} // namespace mlir - -#endif // TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSES_H diff --git a/include/triton-linalg/Dialect/Arith/Transforms/Passes.td b/include/triton-linalg/Dialect/Arith/Transforms/Passes.td deleted file mode 100644 index 7ecd2d8..0000000 --- a/include/triton-linalg/Dialect/Arith/Transforms/Passes.td +++ /dev/null @@ -1,18 +0,0 @@ -//===- Passes.td - Passes for arith ------------------------*- tablegen -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSES_TD -#define TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSES_TD - -include "mlir/Pass/PassBase.td" - -def ArithCanonicalizer: Pass<"arith-canonicalize"> { - let summary = "Register extra canonicalize patterns for arith ops."; - let constructor = "mlir::triton::arith_ext::createArithCanonicalizerPass()"; - let dependentDialects = ["arith::ArithDialect"]; -} - -#endif // TRITON_LINALG_DIALECT_ARITH_TRANSFORMS_PASSES_TD diff --git a/include/triton-linalg/Dialect/Auxiliary/IR/AuxiliaryOps.td b/include/triton-linalg/Dialect/Auxiliary/IR/AuxiliaryOps.td index 8cf908a..0a7bec4 100644 --- a/include/triton-linalg/Dialect/Auxiliary/IR/AuxiliaryOps.td +++ b/include/triton-linalg/Dialect/Auxiliary/IR/AuxiliaryOps.td @@ -47,8 +47,8 @@ def StoreResourceOp : Aux_Op<"store"> { }]; let arguments = (ins - AnyType:$to, - AnyType:$from + Arg:$to, + Arg:$from ); let results = (outs); @@ -61,14 +61,14 @@ def StoreResourceOp : Aux_Op<"store"> { let extraClassDeclaration = [{ bool isScalar(const Value& value) { - return !value.getType().isa(); + return !isa(value.getType()); } bool hasPureBufferSemantics() { return ::llvm::all_of(getOperands(), [&](const Value& opOperand) { return isScalar(opOperand) || - opOperand.getType().isa<::mlir::MemRefType>(); + isa<::mlir::MemRefType>(opOperand.getType()); }); } @@ -76,7 +76,7 @@ def StoreResourceOp : Aux_Op<"store"> { return ::llvm::all_of(getOperands(), [&](const Value& opOperand) { return isScalar(opOperand) || - opOperand.getType().isa<::mlir::TensorType>(); + isa<::mlir::TensorType>(opOperand.getType()); }); } @@ -190,7 +190,7 @@ def ViewOp : } // The result of the op is always a ranked memref. - MemRefType getType() { return getResult().getType().cast(); } + MemRefType getType() { return cast(getResult().getType()); } Value getViewSource() { return getPtr(); } Value getOffset() { return getOffsets().empty() ? nullptr : getOffsets()[0]; @@ -199,7 +199,7 @@ def ViewOp : /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrMaxRanks() { - unsigned resultRank = getResult().getType().cast().getRank(); + unsigned resultRank = cast(getResult().getType()).getRank(); return {1, resultRank, resultRank}; } @@ -265,12 +265,12 @@ def PrintOp : Aux_Op<"print", [DeclareOpInterfaceMethods(); + return isa<::mlir::MemRefType>(opOperand.getType()); }); } ShapedType getInitType() { - return getOperands()[0].getType().cast();; + return cast(getOperands()[0].getType());; } MutableOperandRange getDpsInitsMutable() { return getValuesMutable(); } diff --git a/include/triton-linalg/Dialect/CMakeLists.txt b/include/triton-linalg/Dialect/CMakeLists.txt index a0b6a77..6d11f43 100644 --- a/include/triton-linalg/Dialect/CMakeLists.txt +++ b/include/triton-linalg/Dialect/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(Arith) add_subdirectory(Auxiliary) add_subdirectory(LinalgExt) add_subdirectory(MathExt) diff --git a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtBase.td b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtBase.td index d6e2fbf..49dcdfe 100644 --- a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtBase.td +++ b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtBase.td @@ -9,7 +9,6 @@ include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtEnums.td" include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.td" -include "triton-linalg/Interfaces/InferResultTypeOpInterface.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" diff --git a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.td b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.td index c1f4504..7ec2064 100644 --- a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.td +++ b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.td @@ -24,67 +24,67 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> { //========================================================================// int64_t getNumDpsInputs() { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getNumDpsInputs(); } int64_t getNumDpsInits() { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getNumDpsInits(); } ::llvm::SmallVector<::mlir::OpOperand *> getDpsInputOperands() { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getDpsInputOperands(); } OpOperand *getDpsInputOperand(int64_t i) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getDpsInputOperand(i); } void setDpsInitOperand(int64_t i, Value value) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .setDpsInitOperand(i, value); } MutableOperandRange getDpsInitsMutable() { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getDpsInitsMutable(); } OpOperand *getDpsInitOperand(int64_t i) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getDpsInitOperand(i); } bool isDpsInput(OpOperand *opOperand) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .isDpsInput(opOperand); } bool isDpsInit(OpOperand *opOperand) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .isDpsInit(opOperand); } bool isScalar(OpOperand *opOperand) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .isScalar(opOperand); } OpResult getTiedOpResult(OpOperand *opOperand) { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .getTiedOpResult(opOperand); } bool hasPureBufferSemantics() { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .hasPureBufferSemantics(); } bool hasPureTensorSemantics() { - return cast(*this->getOperation()) + return mlir::cast(*this->getOperation()) .hasPureTensorSemantics(); } @@ -102,8 +102,8 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> { private: void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); + auto attr = mlir::cast( + (*this)->getAttr("operand_segment_sizes")); unsigned i = 0; auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); diff --git a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h index b18f4b7..48bcf8d 100644 --- a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h +++ b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h @@ -14,7 +14,6 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.h" -#include "triton-linalg/Interfaces/InferResultTypeOpInterface.h" // IWYU pragma: end_keep //===----------------------------------------------------------------------===// // LinalgExt Dialect diff --git a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.td b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.td index a6e3e77..f309152 100644 --- a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -107,17 +107,17 @@ def ScatterOp : LinalgExtBase_Op<"scatter", []> { return Value(); } ShapedType getUpdateType() { - return update().getType().cast(); + return mlir::cast(update().getType()); } ShapedType getIndiceType() { - return indice().getType().cast(); + return mlir::cast(indice().getType()); } ShapedType getInitType() { - return getInit().getType().cast(); + return mlir::cast(getInit().getType()); } ShapedType getMaskType() { if (mask()) { - return mask().getType().cast(); + return mlir::cast(mask().getType()); } return {}; } @@ -131,7 +131,7 @@ def ScatterOp : LinalgExtBase_Op<"scatter", []> { } LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } }]; @@ -222,14 +222,14 @@ def GatherOp : LinalgExtBase_Op<"gather", []> { return Value(); } ShapedType getInputType() { - return input().getType().cast(); + return mlir::cast(input().getType()); } ShapedType getIndiceType() { - return indice().getType().cast(); + return mlir::cast(indice().getType()); } ShapedType getMaskType() { if (mask()) { - return mask().getType().cast(); + return mlir::cast(mask().getType()); } return {}; } @@ -239,14 +239,14 @@ def GatherOp : LinalgExtBase_Op<"gather", []> { .back(); } ShapedType getInitType() { - return getInit().getType().cast(); + return mlir::cast(getInit().getType()); } int64_t getBatchDimNum() { return getIndiceType().getRank() - 1; } LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } }]; @@ -306,14 +306,14 @@ def AtomicRMWOp : LinalgExtBase_Op<"atomic_rmw", [AttrSizedOperandSegments, Same return getDpsInitOperand(1)->get(); } ShapedType getInputType() { - return input().getType().cast(); + return mlir::cast(input().getType()); } ShapedType getSrcType() { - return src().getType().cast(); + return mlir::cast(src().getType()); } LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); } }]; @@ -388,17 +388,17 @@ def GatherAtomicRMWOp : LinalgExtBase_Op<"gather_atomic_rmw", [AttrSizedOperandS return getDpsInitOperand(1)->get(); } ShapedType getInputType() { - return input().getType().cast(); + return mlir::cast(input().getType()); } ShapedType getSrcType() { - return src().getType().cast(); + return mlir::cast(src().getType()); } ShapedType getIndiceType() { - return indice().getType().cast(); + return mlir::cast(indice().getType()); } ShapedType getMaskType() { if (mask()) { - return mask().getType().cast(); + return mlir::cast(mask().getType()); } return {}; } @@ -407,7 +407,7 @@ def GatherAtomicRMWOp : LinalgExtBase_Op<"gather_atomic_rmw", [AttrSizedOperandS } LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); } }]; @@ -423,9 +423,14 @@ def AtomicCASOp : Op { let summary = [{ LinalgExt atomic CAS operation for continuous buffer. }]; let description = [{ - AtomicCASOp has three inputs: input, cmp and val. - Compares cmp with input. if cmp == input, store val to input, - else store the original value of input to init. + AtomicCASOp has three inputs: ``input``, ``cmp`` and ``val``. + + AtomicCASOp has one inits: ``init``. + + Stores the original value of ``input`` to ``init``. + + Compares ``cmp`` with ``input``. If ``input`` == ``cmp``, stores ``val`` to + ``input``, else keeps the value of ``input``. }]; let arguments = (ins @@ -457,24 +462,24 @@ def AtomicCASOp : Op(); + return mlir::cast(input().getType()); } ShapedType getCmpType() { - return cmp().getType().cast(); + return mlir::cast(cmp().getType()); } ShapedType getValType() { - return val().getType().cast(); + return mlir::cast(val().getType()); } ShapedType getInitType() { - return getInit().getType().cast();; + return mlir::cast(getInit().getType()); } LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } @@ -491,11 +496,17 @@ def GatherAtomicCASOp : Op { let summary = [{ LinalgExt atomic CAS operation for discrete buffer. }]; let description = [{ - AtomicCASOp has four inputs: input, cmp, val and indice. - Compares cmp with input. if cmp == input, store val to input, - else store the original value of input to init. - Note that the input must point to discrete data, so we add extra - indice like ``GatherOp`` to deal with this situation. + GatherAtomicCASOp has four inputs: ``input``, ``cmp``, ``val`` and ``indice``. + + GatherAtomicCASOp has one inits: ``init``. + + Stores the original value of ``input`` to ``init``. + + Compares ``cmp`` with ``input``. If ``input`` == ``cmp``, stores ``val`` + to ``input``, else keeps the value of ``input``. + + Note that the ``input`` must point to discrete data, so we add extra + ``indice`` like ``GatherOp`` to deal with this situation. }]; let arguments = (ins @@ -530,23 +541,23 @@ def GatherAtomicCASOp : Op(); + return mlir::cast(input().getType()); } ShapedType getCmpType() { - return cmp().getType().cast(); + return mlir::cast(cmp().getType()); } ShapedType getValType() { - return val().getType().cast(); + return mlir::cast(val().getType()); } ShapedType getIndiceType() { - return indice().getType().cast(); + return mlir::cast(indice().getType()); } ShapedType getInitType() { - return getInit().getType().cast();; + return mlir::cast(getInit().getType()); } int64_t getIndexDepth() { @@ -555,7 +566,7 @@ def GatherAtomicCASOp : Op(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } @@ -663,10 +674,10 @@ def PadOp : LinalgExtBase_Op<"pad", [AttrSizedOperandSegments]> { return getDpsInputOperand(0)->get(); } ShapedType getInputType() { - return input().getType().cast(); + return mlir::cast(input().getType()); } ShapedType getInitType() { - return getInit().getType().cast(); + return mlir::cast(getInit().getType()); } Type getPaddingValueType() { return getPvalue().getType(); @@ -723,7 +734,7 @@ def PadOp : LinalgExtBase_Op<"pad", [AttrSizedOperandSegments]> { //===----------------------------------------------------------------------===// LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } }]; @@ -733,12 +744,13 @@ def PadOp : LinalgExtBase_Op<"pad", [AttrSizedOperandSegments]> { // Op definition for AssertOp //===----------------------------------------------------------------------===// def AssertOp : Op, DestinationStyleOpInterface, LinalgExtInterface, ReifyRankedShapedTypeOpInterface]> { let summary = "lianlg assert operation"; let description = [{ - 'linalg_ext.assert' takes a condition tensor, a message string. + `linalg_ext.assert` takes a condition tensor, a message string. If the condition is false, the message is printed, and the program is aborted. }]; @@ -757,11 +769,11 @@ def AssertOp : Op();; + return mlir::cast(getCondition().getType()); } LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } // Method to implement for specifying output range for // DestinationStyleOpInterface @@ -776,6 +788,26 @@ def AssertOp : Op]> { + let summary = "linalg scalar assert operation"; + let description = [{ + `linalg_ext.scalar_assert` takes a condition scalar, a message string. + If the condition is false, the message is printed, and the program is aborted. + }]; + + let arguments = (ins + I1:$condition, + StrAttr:$msg + ); + + let assemblyFormat = [{ + attr-dict `ins` `(` $condition `:` type($condition) `)` + }]; +} + //===----------------------------------------------------------------------===// // Scan Op. //===----------------------------------------------------------------------===// @@ -801,7 +833,7 @@ def ScanOp : LinalgExtBase_Op<"scan", [ outs(%output, %init: tensor<16x32x64xf32>, tensor<16x64xf32>) dimension = [1] { - ^bb0(%in: f32, %out: f32, init: f32): + ^bb0(%in: f32, %out: f32, %init: f32): %0 = arith.addf %init, %in: f32 linalg_ext.yield %0, %0: f32, f32 } @@ -840,12 +872,12 @@ def ScanOp : LinalgExtBase_Op<"scan", [ Block::BlockArgListType getRegionInputArgs() { return getBlock()->getArguments().take_front( - cast(*this->getOperation()) + mlir::cast(*this->getOperation()) .getNumDpsInputs()); } Block::BlockArgListType getRegionOutputArgs() { return getBlock()->getArguments().take_back( - cast(*this->getOperation()) + mlir::cast(*this->getOperation()) .getNumDpsInits()); } llvm::SmallVector inputs() { @@ -870,7 +902,7 @@ def ScanOp : LinalgExtBase_Op<"scan", [ return outputs; } ShapedType getOperandType() { - return inputs()[0].getType().cast(); + return mlir::cast(inputs()[0].getType()); } int64_t getOperandRank() { return getOperandType().getRank(); @@ -878,12 +910,20 @@ def ScanOp : LinalgExtBase_Op<"scan", [ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } SmallVector getOpOperandsMatchingBBargs() { - // Latest llvm use getDpsInputs, change it when upgrade. - return getDpsInputOperands(); + // LinalgExtOps default implementation. + // This interface is different from + // LinalgExtStructuredOps getOpOperandsMatchingBBargs. + llvm::SmallVector result; + result.reserve(this->getNumOperands()); + llvm::transform( + this->getOperation()->getOpOperands(), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; } Block* getBlock() { return getBody(); } @@ -950,10 +990,10 @@ def LibdeviceCallOp : Opget(); })); } ShapedType getOperandType() { - return getInputs()[0].getType().cast(); + return mlir::cast(getInputs()[0].getType()); } ShapedType getInitType() { - return getInit().getType().cast();; + return mlir::cast(getInit().getType()); } int64_t getOperandRank() { return getOperandType().getRank(); @@ -961,7 +1001,7 @@ def LibdeviceCallOp : Op(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } @@ -1003,11 +1043,59 @@ def ScalarLibdeviceCallOp : LinalgExt_PureOp<"scalar_libdevice_call", [NoMemoryE `->` type($result) }]; let builders = [ - OpBuilder<(ins "mlir::Type":$resultType, "ValueRange":$inputs, + OpBuilder<(ins "mlir::Type":$resultType, "ValueRange":$inputs, "StringAttr":$symbol, CArg<"ArrayRef", "{}">:$attributes)>, ]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// Histogram op. +//===----------------------------------------------------------------------===// +def HistogramOp : LinalgExtBase_Op<"histogram", []> { + let summary = "return a histogram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + + If the input data contains integer values, the output will represent the + frequency of occurrences of each integer value in the input tensor. + For example, if the input tensor has values `[0, 0, 2, 1, 2, 1, 0, 1]`, and the + output tensor has 3 bins, the resulting histogram is [3, 3, 2], showing the count of + each value (0, 1, and 2). + + Example: + ``` + %hist = linalg_ext.histogram + ins(%input:tensor<8xi32>) + outs(%init:tensor<3xi32>) -> tensor<3xi32> + ``` + }]; + + let arguments = (ins + Variadic:$src, + TensorOrMemref:$init + ); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + attr-dict `ins` `(` $src `:` type($src) `)` + `outs` `(` $init `:` type($init) `)` + (`->` type($result)^)? + }]; + + let extraClassDeclaration = [{ + ShapedType getInitType() { + return mlir::cast(getInit().getType()); + } + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return mlir::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); + } + MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } + }]; +} #endif // TRITON_LINALG_DIALECT_LINALGEXT_IR_LINALGEXTOPS_TD diff --git a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtStructedOps.td b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtStructedOps.td index 8de9abb..30fd461 100644 --- a/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtStructedOps.td +++ b/include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtStructedOps.td @@ -166,7 +166,7 @@ def MakeRangeOp : LinalgStructuredBase_Op<"make_range", } MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } ShapedType getOutputOperandType() { - return getOutputs().front().getType().cast(); + return mlir::cast(getOutputs().front().getType()); } }]; } @@ -278,6 +278,131 @@ def Im2ColOp : LinalgStructuredBase_Op<"im2col", [AttrSizedOperandSegments]> { }]; } +//===----------------------------------------------------------------------===// +// ArgMaxMinBase op. +//===----------------------------------------------------------------------===// + +class ArgMaxMinBaseOp props> + : LinalgStructuredBase_Op, + SameVariadicOperandSize], props)> { + let summary = "Argmax/Argmin base operator"; + let description = [{ + Executes `combiner` on the `dimensions` of `inputs` and returns the + argmax/min result. The `dimensions` attribute needs to list the reduction + dimensions in increasing order. + }]; + + let arguments = (ins + // Input arg + Variadic:$inputs, + // Output arg + Variadic:$inits, + + ConfinedAttr]>:$dimensions + ); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$combiner); + + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits, + "ArrayRef":$dimensions, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)> + ]; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Declare functions necessary for LinalgStructuredInterface. + SmallVector getIteratorTypesArray() { + int64_t inputRank = mlir::cast(getInputs()[0].getType()).getRank(); + SmallVector iteratorTypes(inputRank, + mlir::utils::IteratorType::parallel); + for (int64_t reductionDim : getDimensions()) + iteratorTypes[reductionDim] = mlir::utils::IteratorType::reduction; + return iteratorTypes; + } + void getAsmBlockArgumentNames(Region ®ion, + OpAsmSetValueNameFn setNameFn) { + for (Value v : getRegionInputArgs()) + setNameFn(v, "in"); + for (Value v : getRegionOutputArgs()) + setNameFn(v, "init"); + } + ArrayAttr getIndexingMaps() { + int64_t inputRank = mlir::cast(getInputs()[0].getType()).getRank(); + SmallVector affineMaps( + getNumDpsInputs(), + AffineMap::getMultiDimIdentityMap(inputRank, getContext())); + AffineMap resultMap = + AffineMap::getMultiDimIdentityMap(inputRank, getContext()) + .dropResults(getDimensions()); + for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i) + affineMaps.push_back(resultMap); + return Builder(getContext()).getAffineMapArrayAttr(affineMaps); + } + std::string getLibraryCallName() { + return "op_has_no_registered_library_name"; + } + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + // Implement functions necessary for DestinationStyleOpInterface. + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); } + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def ArgMaxOp : ArgMaxMinBaseOp<"argmax", []> { + let description = [{ + Example: + ``` + %argmax:2 = linalg_ext.argmax + ins(%input, %index : tensor<16x32x64xf32>, tensor<16x32x64xi32>) + outs(%out1, %out2 : tensor<16x64xf32>, tensor<16x64xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + ``` + }]; +} + +def ArgMinOp : ArgMaxMinBaseOp<"argmin", []> { + let description = [{ + Example: + ``` + %argmin:2 = linalg_ext.argmin + ins(%input, %index : tensor<16x32x64xf32>, tensor<16x32x64xi32>) + outs(%out1, %out2 : tensor<16x64xf32>, tensor<16x64xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + ``` + }]; +} + //===----------------------------------------------------------------------===// // Named LinalgExt ops. //===----------------------------------------------------------------------===// diff --git a/include/triton-linalg/Dialect/LinalgExt/Utils/Utils.h b/include/triton-linalg/Dialect/LinalgExt/Utils/Utils.h index 3f8c96a..5eb66cb 100644 --- a/include/triton-linalg/Dialect/LinalgExt/Utils/Utils.h +++ b/include/triton-linalg/Dialect/LinalgExt/Utils/Utils.h @@ -6,10 +6,30 @@ #ifndef TRITON_LINALG_DIALECT_LINALGEXT_UTILS_UTILS_H #define TRITON_LINALG_DIALECT_LINALGEXT_UTILS_UTILS_H +#include +#include +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include namespace mlir { class Block; +class MLIRContext; +class RewriterBase; class Operation; + +namespace linalg { +class ReduceOp; +} // namespace linalg + namespace triton { namespace linalg_ext { /// Retrieve the operation from the body, if it is the only one (except @@ -18,6 +38,207 @@ namespace linalg_ext { /// operands of payload. Operation *findPayloadOp(Block *body, bool initFirst = false); } // namespace linalg_ext + +constexpr static int ANY_INDEX = -1; +template struct UpstreamMatcher; + +template struct UpstreamMatcher { + + /// Matches a line-shaped net. Multiple inputs are allowed, but multiple + /// outputs are not. + /// + /// example: + /// OP1 /|\ (where Values flow out) + /// | | + /// OP2 | + /// / | \ | + /// OP5 OP4 OP3 | (where Values flow in) + /// + /// Function output: + /// - `lineOut`: If the match is successful, this will contain the matched + /// operations (oplist). + /// + /// Inputs: + /// - `inputOp`: One of the operations in the graph to be matched. Since the + /// matching is performed from back to front, this is the last operation + /// in the graph. + /// - `inputIndex`: Specifies the operand indices, indicating which operand + /// of `OP2` corresponds to `OP1` and which operand of `OP3` corresponds + /// to `OP2`. + /// - `originLength`: The number of operations to be matched in this pass. + /// + /// Example code for match {OP1, OP2, OP3}: + /// SmallVector lineOut; + /// SmallVector inputIndex = {0, 2}; + /// Operation *result = UpstreamMatcher::matchLine( + /// lineOut, OP1Operation, inputIndex, inputIndex.size(), false); + /// if (result == nullptr) { + /// std::cout << "Check failed" << std::endl; + /// } + /// + /// Example code for match {OP1, OP2, OP4}: + /// SmallVector lineOut; + /// SmallVector inputIndex = {0, 1}; + /// Operation *result = UpstreamMatcher::matchLine( + /// lineOut, OP1Operation, inputIndex, inputIndex.size(), false); + /// if (result == nullptr) { + /// std::cout << "Check failed" << std::endl; + /// } + /// + /// Example code for match {OP1, OP2, OP5}: + /// SmallVector lineOut; + /// SmallVector inputIndex = {0, 0}; + /// Operation *result = UpstreamMatcher::matchLine( + /// lineOut, OP1Operation, inputIndex, inputIndex.size(), false); + /// if (result == nullptr) { + /// std::cout << "Check failed" << std::endl; + /// } + + static Operation *matchLine(SmallVector &lineOut, + Operation *inputOp, SmallVector &inputIndex, + int originLength, bool ifCheckLast = true) { + if (sizeof...(T2) != inputIndex.size()) { + return nullptr; + } + auto op = llvm::dyn_cast_or_null(inputOp); + if (op == nullptr) { + return nullptr; + } + lineOut.push_back(inputOp); + + // Check if the operation is the first one. + // Check if the operation has only one output. + // Check if the output of the operation has only one user. + if (ifCheckLast == true && inputIndex.size() == 0 && + originLength != (inputIndex.size() + 1) && + (inputOp->getNumResults() != 1 || !inputOp->getResult(0).hasOneUse())) { + return nullptr; + } + + // Get next op. + if (inputIndex.begin() == inputIndex.end()) + return inputOp; + int order = inputIndex.front(); + if (order >= (int)inputOp->getNumOperands()) { + return nullptr; + } + inputIndex.erase(inputIndex.begin()); + // Order equals ANY_INDEX means this op match is order irrelevant. All + // operands can possbily be matched. + if (order == ANY_INDEX) { + for (int operandIdx = 0; operandIdx < inputOp->getNumOperands(); + operandIdx++) { + Value opNextValue = inputOp->getOperand(operandIdx); + Operation *opNext = opNextValue.getDefiningOp(); + auto returnOp = + UpstreamMatcher::matchLine(lineOut, opNext, ifCheckLast); + if (returnOp != nullptr) { + return returnOp; + } + } + lineOut.pop_back(); + return nullptr; + } else { + Value opNextValue = inputOp->getOperand(order); + Operation *opNext = opNextValue.getDefiningOp(); + return UpstreamMatcher::matchLine(lineOut, opNext, inputIndex, + originLength, ifCheckLast); + } + } + + static Operation *matchLine(SmallVector &lineOut, + Operation *inputOp, bool ifCheckLast = true) { + auto op = llvm::dyn_cast_or_null(inputOp); + if (op == nullptr) { + return nullptr; + } + lineOut.push_back(inputOp); + + // Check if the operation is the first one. + // Check if the operation has only one output. + // Check if the output of the operation has only one user. + if (sizeof...(T2) == 0) { + if ((ifCheckLast == true) && ((inputOp->getNumResults() != 1) || + (!inputOp->getResult(0).hasOneUse()))) { + return nullptr; + } + return inputOp; + } + + // Get next op. + for (int operandIdx = 0; operandIdx < inputOp->getNumOperands(); + operandIdx++) { + Value opNextValue = inputOp->getOperand(operandIdx); + Operation *opNext = opNextValue.getDefiningOp(); + auto returnOp = + UpstreamMatcher::matchLine(lineOut, opNext, ifCheckLast); + if (returnOp != nullptr) { + return returnOp; + } + } + lineOut.pop_back(); + return nullptr; + } +}; + +template <> struct UpstreamMatcher<> { + static Operation *matchLine(SmallVector &lineOut, + Operation *inputOp, SmallVector &inputOrder, + int originLength, bool ifCheckLast = true) { + return nullptr; + } + static Operation *matchLine(SmallVector &lineOut, + Operation *inputOp, bool ifCheckLast = true) { + return nullptr; + } +}; + +Operation *upstreamMatcher(SmallVector> &lineOut, + Operation *inputOp, bool ifCheckLast = true); + +/// A enum class for representing reduction mode. +enum class ReductionMode { + SUM, + MAX, + UMAX, + NAN_MAX, + MIN, + UMIN, + NAN_MIN, + PROD, + AND, + OR, + XOR, + ARGMAX, + ARGMIN +}; + +/// Check whether the reduce op is supported and get the reduction mode +/// if supported. +std::optional getReductionMode(triton::ReduceOp op); + +/// Check whether the reduce op can convert to argmax/min operation. +std::optional matchArgMaxMinPattern(Region *region); + +/// Identify the pattern of the reduce operator. +std::optional reducePatternRecognition(triton::ReduceOp op); + +/// Check whether the reduce operation is constructed with a single +/// statement with type `OpTy`. And the statement has two arguments +/// from the block argument. And the operand of the yield operation +/// is the result of the single statement. +template +static bool isSingleStatementReduceOpWithType(ReduceTy op) { + // Block *block = op.getBlock(); + Block *block = &op.getRegion().front(); + Operation *initFirstPayloadOp = + triton::linalg_ext::findPayloadOp(block, true); + Operation *initBackPayloadOp = + triton::linalg_ext::findPayloadOp(block, false); + return (isa_and_nonnull(initFirstPayloadOp)) || + (isa_and_nonnull(initBackPayloadOp)); +} + } // namespace triton } // namespace mlir #endif // TRITON_LINALG_DIALECT_LINALGEXT_UTILS_UTILS_H diff --git a/include/triton-linalg/Dialect/MathExt/IR/CMakeLists.txt b/include/triton-linalg/Dialect/MathExt/IR/CMakeLists.txt index 6755808..7de8cf4 100644 --- a/include/triton-linalg/Dialect/MathExt/IR/CMakeLists.txt +++ b/include/triton-linalg/Dialect/MathExt/IR/CMakeLists.txt @@ -1,10 +1,10 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) -set(LLVM_TARGET_DEFINITIONS MathBase.td) +set(LLVM_TARGET_DEFINITIONS MathExtBase.td) mlir_tablegen(MathExtOpsDialect.h.inc -gen-dialect-decls -dialect=math_ext) mlir_tablegen(MathExtOpsDialect.cpp.inc -gen-dialect-defs -dialect=math_ext) -set(LLVM_TARGET_DEFINITIONS MathOps.td) +set(LLVM_TARGET_DEFINITIONS MathExtOps.td) mlir_tablegen(MathExtOps.h.inc -gen-op-decls) mlir_tablegen(MathExtOps.cpp.inc -gen-op-defs) diff --git a/include/triton-linalg/Dialect/MathExt/IR/Math.h b/include/triton-linalg/Dialect/MathExt/IR/MathExt.h similarity index 82% rename from include/triton-linalg/Dialect/MathExt/IR/Math.h rename to include/triton-linalg/Dialect/MathExt/IR/MathExt.h index d00fdc7..3292826 100644 --- a/include/triton-linalg/Dialect/MathExt/IR/Math.h +++ b/include/triton-linalg/Dialect/MathExt/IR/MathExt.h @@ -1,11 +1,11 @@ -//===- Math.h - Math dialect ------------------------------------*- C++ -*-===// +//===- MathExt.h - MathExt dialect ------------------------------*- C++ -*-===// // // Copyright (C) [2022-2025] by Cambricon. // //===----------------------------------------------------------------------===// -#ifndef TRITON_LINALG_DIALECT_MATHEXT_IR_MATH_H -#define TRITON_LINALG_DIALECT_MATHEXT_IR_MATH_H +#ifndef TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXT_H +#define TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXT_H #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -30,4 +30,4 @@ #define GET_OP_CLASSES #include "triton-linalg/Dialect/MathExt/IR/MathExtOps.h.inc" -#endif // TRITON_LINALG_DIALECT_MATHEXT_IR_MATH_H +#endif // TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXT_H diff --git a/include/triton-linalg/Dialect/MathExt/IR/MathBase.td b/include/triton-linalg/Dialect/MathExt/IR/MathExtBase.td similarity index 67% rename from include/triton-linalg/Dialect/MathExt/IR/MathBase.td rename to include/triton-linalg/Dialect/MathExt/IR/MathExtBase.td index aded814..98d4548 100644 --- a/include/triton-linalg/Dialect/MathExt/IR/MathBase.td +++ b/include/triton-linalg/Dialect/MathExt/IR/MathExtBase.td @@ -1,10 +1,10 @@ -//===- MathBase.td - Base definitions for math dialect -----*- tablegen -*-===// +//===- MathExtBase.td - Base definitions of MathExt dialect *- tablegen -*-===// // // Copyright (C) [2022-2025] by Cambricon. // //===----------------------------------------------------------------------===// -#ifndef TRITON_LINALG_DIALECT_MATHEXT_IR_MATHBASE_TD -#define TRITON_LINALG_DIALECT_MATHEXT_IR_MATHBASE_TD +#ifndef TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXTBASE_TD +#define TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXTBASE_TD include "mlir/IR/OpBase.td" @@ -20,4 +20,4 @@ def MathExt_Dialect : Dialect { "::mlir::arith::ArithDialect" ]; } -#endif // TRITON_LINALG_DIALECT_MATHEXT_IR_MATHBASE_TD +#endif // TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXTBASE_TD diff --git a/include/triton-linalg/Dialect/MathExt/IR/MathOps.td b/include/triton-linalg/Dialect/MathExt/IR/MathExtOps.td similarity index 84% rename from include/triton-linalg/Dialect/MathExt/IR/MathOps.td rename to include/triton-linalg/Dialect/MathExt/IR/MathExtOps.td index 9ebe524..f630685 100644 --- a/include/triton-linalg/Dialect/MathExt/IR/MathOps.td +++ b/include/triton-linalg/Dialect/MathExt/IR/MathExtOps.td @@ -1,15 +1,15 @@ -//===- MathOps.td - Math op definitions --------------------*- tablegen -*-===// +//===- MathExtOps.td - MathExt op definitions --------------*- tablegen -*-===// // // Copyright (C) [2022-2025] by Cambricon. // //===----------------------------------------------------------------------===// -#ifndef TRITON_LINALG_DIALECT_MATHEXT_IR_MATHOPS_TD -#define TRITON_LINALG_DIALECT_MATHEXT_IR_MATHOPS_TD +#ifndef TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXTOPS_TD +#define TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXTOPS_TD include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" -include "triton-linalg/Dialect/MathExt/IR/MathBase.td" +include "triton-linalg/Dialect/MathExt/IR/MathExtBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -42,4 +42,4 @@ def MathExt_MulhiUIOp : MathExt_IntegerBinaryOp<"mulhiui"> { }]; } -#endif // TRITON_LINALG_DIALECT_MATHEXT_IR_MATHOPS_TD +#endif // TRITON_LINALG_DIALECT_MATHEXT_IR_MATHEXTOPS_TD diff --git a/include/triton-linalg/Dialect/Triton/CMakeLists.txt b/include/triton-linalg/Dialect/Triton/CMakeLists.txt index e273bd0..b45d3c4 100644 --- a/include/triton-linalg/Dialect/Triton/CMakeLists.txt +++ b/include/triton-linalg/Dialect/Triton/CMakeLists.txt @@ -1,2 +1,2 @@ -add_subdirectory(Interfaces) add_subdirectory(Transforms) +add_subdirectory(Interfaces) diff --git a/include/triton-linalg/Dialect/Triton/Interfaces/InferAxisInfoInterface.h b/include/triton-linalg/Dialect/Triton/Interfaces/InferAxisInfoInterface.h index e2d3825..4874465 100644 --- a/include/triton-linalg/Dialect/Triton/Interfaces/InferAxisInfoInterface.h +++ b/include/triton-linalg/Dialect/Triton/Interfaces/InferAxisInfoInterface.h @@ -88,29 +88,53 @@ class AxisInfoExt { std::optional getConstantValue() const { return constantValue; } - bool isContiguousDim(ArrayRef shape, int dim) const { + bool isFullContiguousDim(ArrayRef shape, int dim) const { return getContiguity(dim) == shape[dim]; } - bool isStrideDim(ArrayRef shape, int dim) const { + bool isStridedContiguousDim(ArrayRef shape, int dim) const { + if (strideValue.size() < 1 || stride.size() < 1) + return false; + if (shape[dim] == 1) + return true; + return getStrideValue(dim) == 1 && getStride(dim) > 1 && + shape[dim] % getContiguity(dim) == 0; + } + + bool isFullConstantDim(ArrayRef shape, int dim) const { + return getConstancy(dim) == shape[dim]; + } + + bool isStridedConstantDim(ArrayRef shape, int dim) const { + if (strideValue.size() < 1 || stride.size() < 1) + return false; + if (shape[dim] == 1) + return true; + return getStrideValue(dim) == 0 && getStride(dim) > 1 && + shape[dim] % getConstancy(dim) == 0; + } + + bool isFullStrideDim(ArrayRef shape, int dim) const { return getStride(dim) == shape[dim]; } - bool isNonContiguousNonConstantStrideDim(ArrayRef shape, - int dim) const { - return getStride(dim) == shape[dim] && !isContiguousDim(shape, dim) && - !isConstantDim(shape, dim); + bool isNonConstantFullStrideDim(ArrayRef shape, int dim) const { + return isFullStrideDim(shape, dim) && !isFullConstantDim(shape, dim); } - bool isConstantDim(ArrayRef shape, int dim) const { - return getConstancy(dim) == shape[dim]; + bool isNonContiguousNonConstantFullStrideDim(ArrayRef shape, + int dim) const { + return isFullStrideDim(shape, dim) && !isFullContiguousDim(shape, dim) && + !isFullConstantDim(shape, dim); } - bool isConstantStrideDim(ArrayRef shape, int dim) const { + bool isNonStridedConstantStrideDim(ArrayRef shape, int dim) const { if (strideValue.size() < 1 || stride.size() < 1) return false; - return getStrideValue(dim) == 0 && shape[dim] % getConstancy(dim) == 0 && - shape[dim] / getConstancy(dim) >= 1; + if (shape[dim] == 1) + return true; + return getStrideValue(dim) != 0 && getStride(dim) > 1 && + shape[dim] % getContiguity(dim) == 0; } /// Comparison. diff --git a/include/triton-linalg/Dialect/Triton/Utils/PointerInfo.h b/include/triton-linalg/Dialect/Triton/Utils/PointerInfo.h new file mode 100644 index 0000000..0fe6591 --- /dev/null +++ b/include/triton-linalg/Dialect/Triton/Utils/PointerInfo.h @@ -0,0 +1,65 @@ +//===- PointerInfo.h - Triton pointer info ----------------------*- C++ -*-===// +// +// Copyright (C) [2022-2025] by Cambricon. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_LINALG_DIALECT_TRITON_UTILS_POINTERINFO_H +#define TRITON_LINALG_DIALECT_TRITON_UTILS_POINTERINFO_H + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include + +namespace mlir { +namespace triton { + +/// Structure info representation for pointer in triton. +class PtrInfo { +public: + PtrInfo() = delete; + PtrInfo(Value ptr, ArrayRef offsets) + : pointer(ptr), tensorPtrOffsets(offsets) {} + + PtrInfo(Value ptr, const SmallVector &sizes, + const SmallVector &strides, const SmallVector &offsets, + const ArrayRef &order) + : pointer(ptr), tensorPtrSizes(sizes), tensorPtrStrides(strides), + tensorPtrOffsets(offsets), tensorPtrOrder(order) {} + + PtrInfo(Value ptr, Value offset) : pointer(ptr) { + tensorPtrOffsets.push_back(offset); + isRawPtrInfo = true; + } + + Value ptr() const { return pointer; } + + ArrayRef offsets() const { return tensorPtrOffsets; } + Value offset(unsigned idx) const { return tensorPtrOffsets[idx]; } + Value offset() const { return tensorPtrOffsets[0]; } + void setOffsets(ValueRange vals) { + for (unsigned i = 0; i < vals.size(); i++) { + tensorPtrOffsets[i] = vals[i]; + } + } + unsigned offsetSize() { return tensorPtrOffsets.size(); } + + bool isBlockPtr() const { return !isRawPtrInfo; } + + ArrayRef sizes() const { return tensorPtrSizes; } + ArrayRef strides() const { return tensorPtrStrides; } + ArrayRef order() const { return tensorPtrOrder; } + +private: + bool isRawPtrInfo{false}; + Value pointer; + // Basic info for reconstruction of MakeTensorPtrOp. + SmallVector tensorPtrSizes; + SmallVector tensorPtrStrides; + SmallVector tensorPtrOffsets; + SmallVector tensorPtrOrder; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_LINALG_DIALECT_TRITON_UTILS_POINTERINFO_H diff --git a/include/triton-linalg/Dialect/Triton/Utils/PointerMetaInfoTracker.h b/include/triton-linalg/Dialect/Triton/Utils/PointerMetaInfoTracker.h index e9fe2a0..d883ce2 100644 --- a/include/triton-linalg/Dialect/Triton/Utils/PointerMetaInfoTracker.h +++ b/include/triton-linalg/Dialect/Triton/Utils/PointerMetaInfoTracker.h @@ -64,7 +64,6 @@ class PointerMetaInfoTracker { FailureOr parse(Value operand, Location loc, ConversionPatternRewriter &rewriter); - private: template LogicalResult parseOp(OpTy op, Location loc, diff --git a/include/triton-linalg/Dialect/Utils/ArithUtils.h b/include/triton-linalg/Dialect/Utils/ArithUtils.h index e9a7eac..294fd12 100644 --- a/include/triton-linalg/Dialect/Utils/ArithUtils.h +++ b/include/triton-linalg/Dialect/Utils/ArithUtils.h @@ -34,17 +34,6 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, /// Get splat value from the arith.constant op. FailureOr getSplatValue(OpBuilder &builder, arith::ConstantOp op); - -/// Derive the specific max/min semantics based on the type of compare and the -/// operand relationship between compare and select. -std::optional getCmpSelectResult(OpBuilder &builder, Location loc, - arith::CmpFOp op, - bool operandsSwapped); -std::optional getCmpSelectResult(OpBuilder &builder, Location loc, - arith::CmpIOp op, - bool operandsSwapped); -std::optional -getCmpSelectResult(OpBuilder &builder, Operation *cmpOp, arith::SelectOp op); } // namespace triton } // namespace mlir diff --git a/include/triton-linalg/Dialect/Utils/ShapeUtils.h b/include/triton-linalg/Dialect/Utils/ShapeUtils.h index 717fcce..4d1ee81 100644 --- a/include/triton-linalg/Dialect/Utils/ShapeUtils.h +++ b/include/triton-linalg/Dialect/Utils/ShapeUtils.h @@ -64,6 +64,18 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc, /// Return whether value is a scalar. bool isScalar(Value val); +/// Return "true" if the last N dimensions of the given type are contiguous. +/// +/// Examples: +/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when +/// considering both _all_ and _only_ the trailing 3 dims, +/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when +/// considering the trailing 3 dims. +/// +/// Note: This function is from LLVM. +/// Replace it with the latest version of LLVM after we update. +bool trailingNDimsContiguous(MemRefType type, int64_t n); + /// Add a unit dimension to the last dim. /// /// Example 1: @@ -165,11 +177,24 @@ Value dropUnitFirstDim(OpBuilder &b, Location loc, Value value); /// is collapsed to (sn-1' = sn-1xsn): /// ```mlir /// %0 = foo ... : tensor -/// %1 = tensor.collapse_shape %n [[0] [1] ... [sn-1, sn] +/// %1 = tensor.collapse_shape %n [[0] [1] ... [sn-1, sn]] /// : tensor to tensor /// ``` +/// +/// When exceptLastDim is true, the behavior of collapse becomes as follows: +/// +/// Example 4 (n = 2): +/// ```mlir +/// %0 = foo ... : tensor +/// ``` +/// is collapsed to (sn-2xsn-1xsn = sn-2'xsn): +/// ```mlir +/// %0 = foo ... : tensor +/// %1 = tensor.collapse_shape %n [[0] [1] ... [sn-2, sn-1] [sn]] +/// : tensor to tensor +/// ``` Value collapseLastNDimsToOneDim(OpBuilder &b, Location loc, Value value, - int64_t n); + int64_t n, bool exceptLastDim = false); } // namespace triton } // namespace mlir diff --git a/include/triton-linalg/Interfaces/CMakeLists.txt b/include/triton-linalg/Interfaces/CMakeLists.txt deleted file mode 100644 index 0b0fc72..0000000 --- a/include/triton-linalg/Interfaces/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -set(LLVM_TARGET_DEFINITIONS InferResultTypeOpInterface.td) -mlir_tablegen(InferResultTypeOpInterface.h.inc -gen-op-interface-decls) -mlir_tablegen(InferResultTypeOpInterface.cpp.inc -gen-op-interface-defs) - -add_public_tablegen_target(TritonLinalgInterfacesTableGen) diff --git a/include/triton-linalg/Interfaces/InferResultTypeOpInterface.h b/include/triton-linalg/Interfaces/InferResultTypeOpInterface.h deleted file mode 100644 index 4b3c904..0000000 --- a/include/triton-linalg/Interfaces/InferResultTypeOpInterface.h +++ /dev/null @@ -1,18 +0,0 @@ -//===- InferResultTypeOpInterface.h - Infer result type----------*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// -// -// This file implements the operation interface infers result type. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_LINALG_INTERFACES_INFERRESULTTYPEOPINTERFACE_H -#define TRITON_LINALG_INTERFACES_INFERRESULTTYPEOPINTERFACE_H -#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep - -/// Include the generated interface declarations. -#include "triton-linalg/Interfaces/InferResultTypeOpInterface.h.inc" // IWYU pragma: export - -#endif // TRITON_LINALG_INTERFACES_INFERRESULTTYPEOPINTERFACE_H diff --git a/include/triton-linalg/Interfaces/InferResultTypeOpInterface.td b/include/triton-linalg/Interfaces/InferResultTypeOpInterface.td deleted file mode 100644 index 73a3ce0..0000000 --- a/include/triton-linalg/Interfaces/InferResultTypeOpInterface.td +++ /dev/null @@ -1,54 +0,0 @@ -//===- InferResultTypeOpInterface.td - infer type interface-*- tablegen -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// -// -// This file contains a set of interfaces that can infer result types by input -// values and attributes. -//===----------------------------------------------------------------------===// - -#ifndef TRITON_LINALG_INTERFACES_INFERRESULTTYPEOPINTERFACE_TD -#define TRITON_LINALG_INTERFACES_INFERRESULTTYPEOPINTERFACE_TD - -include "mlir/IR/OpBase.td" - -def InferResultTypeOpInterface : OpInterface<"InferResultTypeOpInterface"> { - let description = [{ - Op with this interface can infer result type by input values and attributes. - }]; - let cppNamespace = "::mlir::triton"; - let methods = [ - InterfaceMethod< - /*desc=*/[{ - This method infers result type on operation, return the infered type. - }], - /*retType=*/"llvm::SmallVector<::mlir::Type>", - /*methodName=*/"inferResultTypes", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/"" - >, - InterfaceMethod< - /*desc=*/[{ - This method infers result type on operation, update result value - with inferred type inplace. - }], - /*retType=*/"void", - /*methodName=*/"updateResultType", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto &&resultTypes = cast( - $_op.getOperation()).inferResultTypes(); - for (const auto &en : llvm::enumerate($_op.getOperation()->getResults())) { - en.value().setType(resultTypes[en.index()]); - } - }] - > - ]; - let extraClassDeclaration = [{ - }]; -} - -#endif // TRITON_LINALG_INTERFACES_INFERRESULTTYPEOPINTERFACE_TD diff --git a/lib/Analysis/AxisInfoAnalysis.cpp b/lib/Analysis/AxisInfoAnalysis.cpp index e0e68fe..bd1f381 100644 --- a/lib/Analysis/AxisInfoAnalysis.cpp +++ b/lib/Analysis/AxisInfoAnalysis.cpp @@ -50,6 +50,22 @@ using namespace mlir::triton; // lib/Analysis/AxisInfo.cpp in the triton repo. //===--------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// AxisInfoLattice +//===----------------------------------------------------------------------===// + +ChangeResult AxisInfoLattice::join(const AxisInfoExt &rhs) { + if (!initialized) { + initialized = true; + auto &kepval = getValue(); + if (kepval == rhs) + return ChangeResult::NoChange; + kepval = rhs; + return ChangeResult::Change; + } + return mlir::dataflow::Lattice::join(rhs); +} + //===----------------------------------------------------------------------===// // AxisInfoAnalysisExt //===----------------------------------------------------------------------===// @@ -68,7 +84,7 @@ void AxisInfoAnalysisExt::visitOperation( } auto joinCallback = [op, results, this](Value v, const AxisInfoExt &info) { - auto result = v.dyn_cast(); + auto result = dyn_cast(v); if (!result) return; assert(llvm::is_contained(op->getResults(), result)); @@ -100,7 +116,7 @@ void AxisInfoAnalysisExt::visitNonControlFlowArguments( auto getRank = [](Type type) { auto rank = 1; - if (TensorType ty = type.dyn_cast()) + if (TensorType ty = dyn_cast(type)) rank = ty.getRank(); return rank; }; @@ -126,9 +142,8 @@ void AxisInfoAnalysisExt::visitNonControlFlowArguments( } auto lowerBoundVal = - lowerBound.getValue().cast().getValue().getZExtValue(); - auto stepVal = - step.getValue().cast().getValue().getZExtValue(); + cast(lowerBound.getValue()).getValue().getZExtValue(); + auto stepVal = cast(step.getValue()).getValue().getZExtValue(); auto divHint = AxisInfoExt::kInitValue; auto k = std::gcd(lowerBoundVal, stepVal); if (k != 0) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d6c197f..df966b5 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,6 +1,5 @@ add_subdirectory(Analysis) add_subdirectory(Conversion) add_subdirectory(Dialect) -add_subdirectory(Interfaces) add_subdirectory(Pipelines) add_subdirectory(Utils) diff --git a/lib/Conversion/ArithToLinalg/ArithToLinalg.cpp b/lib/Conversion/ArithToLinalg/ArithToLinalg.cpp index c06404b..a10b89a 100644 --- a/lib/Conversion/ArithToLinalg/ArithToLinalg.cpp +++ b/lib/Conversion/ArithToLinalg/ArithToLinalg.cpp @@ -58,7 +58,7 @@ class ArithConstantPattern : public OpRewritePattern { return failure(); auto loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); Value init = rewriter.create(loc, resType.getShape(), resType.getElementType()); @@ -78,8 +78,8 @@ class ArithSelectPattern : public OpRewritePattern { auto trueValue = op.getTrueValue(); auto falseValue = op.getFalseValue(); - if (!trueValue.getType().isa() || - !falseValue.getType().isa()) + if (!isa(trueValue.getType()) || + !isa(falseValue.getType())) return failure(); auto initDims = getDims(rewriter, loc, trueValue); @@ -136,7 +136,8 @@ void mlir::triton::populateArithToLinalgPatterns(RewritePatternSet &patterns) { GenericOpPattern, GenericOpPattern, GenericOpPattern, // Cast ops. - GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, @@ -151,14 +152,13 @@ void mlir::triton::populateArithToLinalgPatterns(RewritePatternSet &patterns) { namespace { struct ArithToLinalgPass : public ArithToLinalgPassBase { ArithToLinalgPass() = default; - void runOnOperation() override { MLIRContext &ctx = getContext(); ConversionTarget target(ctx); target.addLegalDialect(); target.addDynamicallyLegalDialect([&](Operation *op) { - return !op->getResultTypes().front().isa(); + return !isa(op->getResultTypes().front()); }); // Setup conversion patterns. RewritePatternSet patterns(&ctx); diff --git a/lib/Conversion/ArithToLinalg/CMakeLists.txt b/lib/Conversion/ArithToLinalg/CMakeLists.txt index ec9e3c2..c352a5b 100644 --- a/lib/Conversion/ArithToLinalg/CMakeLists.txt +++ b/lib/Conversion/ArithToLinalg/CMakeLists.txt @@ -6,4 +6,5 @@ add_triton_library(ArithToLinalg LINK_LIBS PUBLIC MLIRIR + MathExtDialect ) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index cdccab6..3563a6b 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,3 @@ add_subdirectory(ArithToLinalg) add_subdirectory(MathToLinalg) add_subdirectory(TritonToLinalg) -add_subdirectory(TritonToTensor) diff --git a/lib/Conversion/MathToLinalg/MathToLinalg.cpp b/lib/Conversion/MathToLinalg/MathToLinalg.cpp index a2acc2c..0cd665f 100644 --- a/lib/Conversion/MathToLinalg/MathToLinalg.cpp +++ b/lib/Conversion/MathToLinalg/MathToLinalg.cpp @@ -24,7 +24,7 @@ #include "triton-linalg/Conversion/MathToLinalg/MathToLinalg.h" #include "triton-linalg/Conversion/PassDetail.h" #include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h" // IWYU pragma: keep -#include "triton-linalg/Dialect/MathExt/IR/Math.h" // IWYU pragma: keep +#include "triton-linalg/Dialect/MathExt/IR/MathExt.h" // IWYU pragma: keep #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringRef.h" @@ -45,6 +45,7 @@ void mlir::triton::populateMathToLinalgPatterns(RewritePatternSet &patterns) { GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>(context); } @@ -60,7 +61,7 @@ struct MathToLinalgPass : public MathToLinalgPassBase { target.addDynamicallyLegalDialect( [&](Operation *op) { - return !op->getResultTypes().front().isa(); + return !isa(op->getResultTypes().front()); }); // Setup conversion patterns. RewritePatternSet patterns(&ctx); diff --git a/lib/Conversion/TritonToLinalg/AtomicCASConversion.cpp b/lib/Conversion/TritonToLinalg/AtomicCASConversion.cpp index 4cc9483..48827be 100644 --- a/lib/Conversion/TritonToLinalg/AtomicCASConversion.cpp +++ b/lib/Conversion/TritonToLinalg/AtomicCASConversion.cpp @@ -55,7 +55,7 @@ class TritonScalarAtomicCASOpConversion matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultTy = - op.getResult().getType().dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); if (resultTy) return failure(); @@ -67,7 +67,7 @@ class TritonScalarAtomicCASOpConversion auto zero = rewriter.create(loc, 0); RankedTensorType originTensorTy = - originTensor.getType().cast(); + mlir::cast(originTensor.getType()); auto cmpInit = rewriter.create( loc, originTensorTy.getShape(), op.getCmp().getType()); @@ -116,7 +116,7 @@ class TritonAtomicCASPattern ConversionPatternRewriter &rewriter) const override { // If atomic_cas on scalar type. RankedTensorType resultTy = - op.getResult().getType().dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); if (!resultTy) return failure(); @@ -169,7 +169,7 @@ class TritonGatherAtomicCASPattern ConversionPatternRewriter &rewriter) const override { // If atomic_cas on scalar type. RankedTensorType resultTy = - op.getResult().getType().dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); if (!resultTy) { return failure(); } diff --git a/lib/Conversion/TritonToLinalg/AtomicRmwConversion.cpp b/lib/Conversion/TritonToLinalg/AtomicRmwConversion.cpp index ccca326..8d0e722 100644 --- a/lib/Conversion/TritonToLinalg/AtomicRmwConversion.cpp +++ b/lib/Conversion/TritonToLinalg/AtomicRmwConversion.cpp @@ -99,7 +99,7 @@ class TritonScalarAtomicRMWOpConversion ConversionPatternRewriter &rewriter) const override { Type resultTy = op.getResult().getType(); // FIXME: lower to llvm.atom directly. - if (resultTy.isa()) + if (isa(resultTy)) return failure(); auto loc = op.getLoc(); @@ -110,7 +110,7 @@ class TritonScalarAtomicRMWOpConversion auto zero = rewriter.create(loc, 0); RankedTensorType originTensorTy = - originTensor.getType().cast(); + cast(originTensor.getType()); SmallVector shape = {1, 1}; auto val = @@ -162,7 +162,7 @@ class TritonScalarAtomicRMWOpConversion // Yield 0 if mask is false, align with GPU behaviour. rewriter.setInsertionPointToStart(ifOp.elseBlock()); Value zeroConst = rewriter.create( - loc, rewriter.getIntegerAttr(op.getResult().getType(), 0)); + loc, rewriter.getZeroAttr(op.getResult().getType())); rewriter.create(loc, ValueRange(zeroConst)); rewriter.replaceOp(op, ifOp.getResult(0)); return success(); @@ -187,7 +187,7 @@ class TritonContiguousAtomicRmwOpConversion matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultTy = - op.getResult().getType().dyn_cast(); + dyn_cast(op.getResult().getType()); if (!resultTy) return failure(); @@ -211,8 +211,6 @@ class TritonContiguousAtomicRmwOpConversion loc, ptrInfo->memref, true, true); // Create atomic_rmw here. - // Get linalg_ext.atomic_rmw input operands. - SmallVector atomicInputs{op.getVal()}; // Init atomic output. Type resultEltType = resultTy.getElementType(); Value atomicResultInit = rewriter.create( @@ -221,7 +219,19 @@ class TritonContiguousAtomicRmwOpConversion if (!maybeKind) return failure(); + Value input = op.getVal(); + auto rank = cast(atomicResultInit.getType()).getRank(); + atomicResultInit = rewriter.create( + loc, atomicResultInit, ptrInfo->offsets, ptrInfo->sizes, + SmallVector(rank, rewriter.getIndexAttr(1))); + + input = rewriter.create( + loc, input, ptrInfo->offsets, ptrInfo->sizes, + SmallVector(rank, rewriter.getIndexAttr(1))); + SmallVector atomicInits{originalTensor, atomicResultInit}; + // Get linalg_ext.atomic_rmw input operands. + SmallVector atomicInputs{input}; auto maybeMemoryOrder = getLinalgExtAtomicMemoryOrder(op.getSem()); if (failed(maybeMemoryOrder)) @@ -232,6 +242,13 @@ class TritonContiguousAtomicRmwOpConversion .create( loc, atomicInputs, atomicInits, *maybeKind, *maybeMemoryOrder) ->getResult(1); + // Pad value is set to 0 to align with GPU. + Value c0 = rewriter.create( + loc, resultEltType, rewriter.getZeroAttr(resultEltType)); + sliceTensor = getPadOrInsertOpWithOther( + loc, c0, + RankedTensorType::get(resultTy.getShape(), resultTy.getElementType()), + sliceTensor, ptrInfo->offsets, ptrInfo->sizes, rewriter); rewriter.replaceOp(op, sliceTensor); return success(); @@ -255,7 +272,7 @@ class TritonScatteredAtomicRMWOpConversion mlir::LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getResult().getType().dyn_cast(); + auto resultTy = dyn_cast(op.getResult().getType()); if (!resultTy) return failure(); @@ -280,18 +297,6 @@ class TritonScatteredAtomicRMWOpConversion if (op.getMask()) { auto atomicMask = triton::flattenValueToMatchGatherScatter( rewriter, op.getMask(), false); - Value maskInit = rewriter.create( - loc, atomicMask.getType().cast().getShape(), - rewriter.getI8Type()); - atomicMask = rewriter - .create( - loc, ValueRange{atomicMask}, maskInit, - [](OpBuilder &b, Location loc, ValueRange args) { - Value ret = b.create( - loc, b.getI8Type(), args[0]); - b.create(loc, ret); - }) - .getResult()[0]; atomicInputs.push_back(atomicMask); } diff --git a/lib/Conversion/TritonToLinalg/CMakeLists.txt b/lib/Conversion/TritonToLinalg/CMakeLists.txt index 77e5e11..f25e8cc 100644 --- a/lib/Conversion/TritonToLinalg/CMakeLists.txt +++ b/lib/Conversion/TritonToLinalg/CMakeLists.txt @@ -9,8 +9,10 @@ add_triton_library(TritonToLinalg DEPENDS TritonLinalgConverisonIncGen + TritonInterfacesExtendTableGen LINK_LIBS PUBLIC LinalgExtDialectUtils + TritonInterfaceExtend MLIRIR ) diff --git a/lib/Conversion/TritonToLinalg/LoadStoreConversion.cpp b/lib/Conversion/TritonToLinalg/LoadStoreConversion.cpp index dff48b2..05468fe 100644 --- a/lib/Conversion/TritonToLinalg/LoadStoreConversion.cpp +++ b/lib/Conversion/TritonToLinalg/LoadStoreConversion.cpp @@ -79,7 +79,7 @@ class TritonContiguousLoadOpConversion return failure(); RankedTensorType resultTy = - op.getResult().getType().dyn_cast(); + dyn_cast(op.getResult().getType()); if (!resultTy) return failure(); @@ -96,7 +96,7 @@ class TritonContiguousLoadOpConversion Value sliceTensor = rewriter.create( loc, ptrInfo->memref, true, true); - auto tensorType = sliceTensor.getType().cast(); + auto tensorType = cast(sliceTensor.getType()); Value emptyTensor = rewriter.create( loc, tensorType.getShape(), tensorType.getElementType(), getDynamicDimsValue(rewriter, loc, sliceTensor)); @@ -143,7 +143,7 @@ class TritonContiguousStoreOpConversion return failure(); RankedTensorType valueTy = - op.getValue().getType().dyn_cast(); + dyn_cast(op.getValue().getType()); if (!valueTy) return failure(); @@ -161,7 +161,7 @@ class TritonContiguousStoreOpConversion ptrInfo->dimInfos, defaultOffsets, rewriter); if (op.getMask()) { - auto rank = value.getType().cast().getRank(); + auto rank = cast(value.getType()).getRank(); value = rewriter.create( loc, value, permutateAndRemoveBroadcastDims( @@ -193,7 +193,7 @@ class TritonScalarLoadOpConversion : public OpConversionPattern, LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getResult().getType().dyn_cast()) + if (dyn_cast(op.getResult().getType())) return failure(); auto loc = op.getLoc(); @@ -242,7 +242,7 @@ class TritonScalarStoreOpConversion LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getValue().getType().dyn_cast()) + if (dyn_cast(op.getValue().getType())) return failure(); auto loc = op.getLoc(); @@ -273,7 +273,7 @@ class TritonScatteredLoadOpConversion ConversionPatternRewriter &rewriter) const override { if (triton::isTensorPointerType(op.getPtr().getType())) return failure(); - auto resultTy = op.getResult().getType().dyn_cast(); + auto resultTy = dyn_cast(op.getResult().getType()); if (!resultTy) return failure(); @@ -341,7 +341,7 @@ class TritonScatteredStoreOpConversion if (triton::isTensorPointerType(op.getPtr().getType())) return failure(); - auto valueTy = op.getValue().getType().dyn_cast(); + auto valueTy = dyn_cast(op.getValue().getType()); if (!valueTy) return failure(); @@ -435,7 +435,7 @@ class TritonTensorPtrLoadOpConversion if (!triton::isTensorPointerType(op.getPtr().getType())) return failure(); RankedTensorType resultTy = - op.getResult().getType().cast(); + cast(op.getResult().getType()); auto loc = op.getLoc(); if (op.getMask() || op.getOther()) return rewriter.notifyMatchFailure( @@ -447,17 +447,17 @@ class TritonTensorPtrLoadOpConversion SmallVector permutations = getPermutationFromOrder(tracker.getOrder()); auto dimInfos = getDimInfos(tracker.getStrides(), resultTy.getShape()); - auto sizes = getActualSizes(loc, op.getBoundaryCheck(), resultTy.getShape(), - tracker, rewriter); + auto ptrInfo = getPtrInfo(loc, op.getBoundaryCheck(), resultTy.getShape(), + tracker, rewriter); auto originalMemRef = - getMemRef(rewriter.getRemappedValue(tracker.getBase()), - tracker.getOffsets(), sizes, tracker.getStrides(), - permutations, dimInfos, resultTy.getElementType(), rewriter, + getMemRef(rewriter.getRemappedValue(tracker.getBase()), ptrInfo.offsets, + ptrInfo.sizes, tracker.getStrides(), permutations, dimInfos, + resultTy.getElementType(), rewriter, getCacheModeAttr(op.getContext(), op.getCache())); Value sliceTensor = rewriter.create( loc, originalMemRef, true, true); - auto tensorType = sliceTensor.getType().cast(); + auto tensorType = cast(sliceTensor.getType()); Value emptyTensor = rewriter.create( loc, tensorType.getShape(), tensorType.getElementType(), getDynamicDimsValue(rewriter, loc, sliceTensor)); @@ -466,7 +466,7 @@ class TritonTensorPtrLoadOpConversion if (op.getBoundaryCheck().empty()) { sliceTensor = transformResultWithTransposeAndDimInfo( - sliceTensor, permutations, dimInfos, sizes, rewriter); + sliceTensor, permutations, dimInfos, ptrInfo.sizes, rewriter); rewriter.replaceOp(op, sliceTensor); return success(); } @@ -478,25 +478,24 @@ class TritonTensorPtrLoadOpConversion // Set zero padding value. TypedAttr attr = elementType.isIntOrIndex() - ? rewriter.getIntegerAttr(elementType, 0).cast() - : rewriter.getFloatAttr(elementType, 0).cast(); + ? cast(rewriter.getIntegerAttr(elementType, 0)) + : cast(rewriter.getFloatAttr(elementType, 0)); // Float NaN padding case. if (op.getPadding().value() == triton::PaddingOption::PAD_NAN) { assert(!elementType.isIntOrIndex()); auto apNaN = llvm::APFloat::getNaN( - attr.cast().getValue().getSemantics()); + cast(attr).getValue().getSemantics()); attr = rewriter.getFloatAttr(elementType, apNaN); } other = rewriter.create(loc, attr); } auto value = transformResultWithTransposeAndDimInfo( - sliceTensor, permutations, dimInfos, sizes, rewriter); - value = getPadOrInsertOpWithOther( - loc, other, resultTy, value, - SmallVector(resultTy.getRank(), rewriter.getIndexAttr(0)), - sizes, rewriter); + sliceTensor, permutations, dimInfos, ptrInfo.sizes, rewriter); + value = getPadOrInsertOpWithOther(loc, other, resultTy, value, + ptrInfo.padLeftSizes, ptrInfo.sizes, + rewriter); rewriter.replaceOp(op, value); return success(); } @@ -513,7 +512,7 @@ class TritonTensorPtrStoreOpConversion ConversionPatternRewriter &rewriter) const override { if (!triton::isTensorPointerType(op.getPtr().getType())) return failure(); - RankedTensorType valueTy = op.getValue().getType().cast(); + RankedTensorType valueTy = cast(op.getValue().getType()); auto loc = op.getLoc(); if (op.getMask()) return rewriter.notifyMatchFailure( @@ -526,12 +525,12 @@ class TritonTensorPtrStoreOpConversion SmallVector permutations = getPermutationFromOrder(tracker.getOrder()); auto dimInfos = getDimInfos(tracker.getStrides(), valueTy.getShape()); - auto sizes = getActualSizes(loc, op.getBoundaryCheck(), valueTy.getShape(), - tracker, rewriter); + auto ptrInfo = getPtrInfo(loc, op.getBoundaryCheck(), valueTy.getShape(), + tracker, rewriter); auto originalMemRef = - getMemRef(rewriter.getRemappedValue(tracker.getBase()), - tracker.getOffsets(), sizes, tracker.getStrides(), - permutations, dimInfos, valueTy.getElementType(), rewriter, + getMemRef(rewriter.getRemappedValue(tracker.getBase()), ptrInfo.offsets, + ptrInfo.sizes, tracker.getStrides(), permutations, dimInfos, + valueTy.getElementType(), rewriter, getCacheModeAttr(op.getContext(), op.getCache())); auto value = op.getValue(); auto zeroAttr = rewriter.getIndexAttr(0); @@ -539,11 +538,13 @@ class TritonTensorPtrStoreOpConversion value = transformInputWithTransposeAndDimInfo(value, permutations, dimInfos, defaultOffsets, rewriter); if (!op.getBoundaryCheck().empty()) { - auto rank = value.getType().cast().getRank(); + auto rank = cast(value.getType()).getRank(); value = rewriter.create( - loc, value, SmallVector(rank, rewriter.getIndexAttr(0)), - permutateAndRemoveBroadcastDims(sizes, permutations, - dimInfos), + loc, value, + permutateAndRemoveBroadcastDims(ptrInfo.padLeftSizes, + permutations, dimInfos), + permutateAndRemoveBroadcastDims(ptrInfo.sizes, + permutations, dimInfos), SmallVector(rank, rewriter.getIndexAttr(1))); } auto materializeOp = diff --git a/lib/Conversion/TritonToLinalg/TritonPointerConversion.cpp b/lib/Conversion/TritonToLinalg/TritonPointerConversion.cpp index 7bf2e8c..ddd1c60 100644 --- a/lib/Conversion/TritonToLinalg/TritonPointerConversion.cpp +++ b/lib/Conversion/TritonToLinalg/TritonPointerConversion.cpp @@ -63,19 +63,19 @@ Value triton::selectByMask(Location loc, Value mask, Value trueVal, Value falseVal, ConversionPatternRewriter &rewriter) { assert(trueVal && "Get true value failed."); - auto trueType = trueVal.getType().dyn_cast(); + auto trueType = dyn_cast(trueVal.getType()); if (!mask || !falseVal || !trueType) return trueVal; auto falseType = falseVal.getType(); - if (!falseType.isa()) { + if (!isa(falseType)) { Value falseValInit = rewriter.create(loc, trueType.getShape(), falseType); falseVal = rewriter.create(loc, falseVal, ValueRange{falseValInit}) .getResult(0); } - auto resType = falseType.template cast(); + auto resType = cast(falseType); auto initDims = triton::getDims(rewriter, loc, falseVal); Value initTensor = rewriter.create(loc, initDims, resType.getElementType()); @@ -94,7 +94,7 @@ Value triton::flattenValueToMatchGatherScatter( if (!value) return value; - auto valueTy = value.getType().cast(); + auto valueTy = cast(value.getType()); auto loc = value.getLoc(); auto rank = valueTy.getRank(); @@ -132,10 +132,10 @@ Value triton::reshapeGatherScatterValueTo(Value value, RankedTensorType resultTy, ConversionPatternRewriter &rewriter) { assert(value); - auto valueTy = value.getType().cast(); + auto valueTy = cast(value.getType()); auto loc = value.getLoc(); auto dstRank = resultTy.getRank(); - auto srcRank = value.getType().cast().getRank(); + auto srcRank = cast(value.getType()).getRank(); if (dstRank == 0) { // Zero rank. @@ -225,7 +225,7 @@ Value TritonPtrConversionBase::transformResultWithTransposeAndDimInfo( Value value, ArrayRef permutations, ArrayRef dimInfos, ArrayRef actualSizes, ConversionPatternRewriter &rewriter) const { - auto valueTy = value.getType().cast(); + auto valueTy = cast(value.getType()); auto loc = value.getLoc(); // As the shape of value has been transposed before broadcasted by dimInfos, @@ -286,7 +286,7 @@ Value TritonPtrConversionBase::transformResultWithTransposeAndDimInfo( Value TritonPtrConversionBase::transformInputWithTransposeAndDimInfo( Value value, ArrayRef permutations, ArrayRef dimInfos, ArrayRef offsets, ConversionPatternRewriter &rewriter) const { - auto valueTy = value.getType().cast(); + auto valueTy = cast(value.getType()); assert((!ShapedType::isDynamicShape(valueTy.getShape())) && "value shape should be static"); auto loc = value.getLoc(); @@ -308,7 +308,7 @@ Value TritonPtrConversionBase::transformInputWithTransposeAndDimInfo( if (!isConsecutive(permutations)) { Value init = rewriter.create( loc, - getValuesByPerms(ret.getType().cast().getShape(), + getValuesByPerms(cast(ret.getType()).getShape(), permutations), valueTy.getElementType()); ret = rewriter.create(loc, ret, init, permutations) @@ -342,7 +342,7 @@ Value TritonPtrConversionBase::transformInputWithTransposeAndDimInfo( /// Deduce the type of the result to use for the canonicalized operation. RankedTensorType resultType = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - desiredResultRank, ret.getType().cast(), + desiredResultRank, cast(ret.getType()), transposedOffsets, newSizes, strides); ret = rewriter .create( @@ -381,9 +381,9 @@ SmallVector TritonPtrLoadStoreOpConversionBase::getDimInfos( SmallVector dimInfos; dimInfos.reserve(tensorShape.size()); for (const auto &dim : llvm::enumerate(tensorShape)) { - if (axisInfo->isConstantDim(tensorShape, dim.index())) { + if (axisInfo->isFullConstantDim(tensorShape, dim.index())) { dimInfos.push_back({1, dim.value(), DimInfo::Kind::BROADCAST}); - } else if (axisInfo->isStrideDim(tensorShape, dim.index())) { + } else if (axisInfo->isFullStrideDim(tensorShape, dim.index())) { dimInfos.push_back({dim.value(), 1, DimInfo::Kind::CONTIG}); } else { dimInfos.push_back({dim.value()}); @@ -401,8 +401,8 @@ SmallVector TritonPtrLoadStoreOpConversionBase::getPermutations( llvm::to_vector<2>(llvm::seq(0, rank)); for (int64_t i = rank - 2; i >= 0; i--) { - if (!axisInfo->isContiguousDim(tensorShape, rank - 1) && - axisInfo->isContiguousDim(tensorShape, i) && tensorShape[i] != 1) { + if (!axisInfo->isFullContiguousDim(tensorShape, rank - 1) && + axisInfo->isFullContiguousDim(tensorShape, i) && tensorShape[i] != 1) { std::swap(permutations[i], permutations[rank - 1]); break; } @@ -590,8 +590,8 @@ Value TritonPtrScatterConversionBase::getDynamicMemRef( //===----------------------------------------------------------------------===// // TritonTensorPtrLoadStoreOpConversionBase //===----------------------------------------------------------------------===// -SmallVector -TritonTensorPtrLoadStoreOpConversionBase::getActualSizes( +TritonTensorPtrLoadStoreOpConversionBase::PtrInfo +TritonTensorPtrLoadStoreOpConversionBase::getPtrInfo( Location loc, std::optional> boundaryCheck, ArrayRef tensorShape, const TensorPointerMetaInfoTracker &tracker, ConversionPatternRewriter &rewriter) const { @@ -599,14 +599,27 @@ TritonTensorPtrLoadStoreOpConversionBase::getActualSizes( llvm::map_range(tensorShape, [&rewriter](int64_t dim) -> OpFoldResult { return rewriter.getIndexAttr(dim); })); + SmallVector offsets(tracker.getOffsets().begin(), + tracker.getOffsets().end()); + SmallVector padLeftSizes(tensorShape.size(), + rewriter.getIndexAttr(0)); if (boundaryCheck) { for (auto i : boundaryCheck.value()) { - OpFoldResult remainSize = subOFRs(tracker.getSizes()[i], - tracker.getOffsets()[i], loc, rewriter); - blockSizes[i] = minOFRs(remainSize, blockSizes[i], loc, rewriter); + auto originOffset = tracker.getOffsets()[i]; + offsets[i] = + maxOFRs(originOffset, rewriter.getIndexAttr(0), loc, rewriter); + padLeftSizes[i] = + minOFRs(subOFRs(offsets[i], originOffset, loc, rewriter), + blockSizes[i], loc, rewriter); + OpFoldResult remainSize = + subOFRs(tracker.getSizes()[i], offsets[i], loc, rewriter); + remainSize = maxOFRs(remainSize, rewriter.getIndexAttr(0), loc, rewriter); + blockSizes[i] = minOFRs( + remainSize, subOFRs(blockSizes[i], padLeftSizes[i], loc, rewriter), + loc, rewriter); } } - return blockSizes; + return {offsets, padLeftSizes, blockSizes}; } SmallVector TritonTensorPtrLoadStoreOpConversionBase::getDimInfos( diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index 8038fb1..a1dd75b 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -67,10 +67,11 @@ #include "triton-linalg/Dialect/Auxiliary/IR/AuxiliaryDialect.h" #include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "triton-linalg/Dialect/LinalgExt/Utils/Utils.h" -#include "triton-linalg/Dialect/MathExt/IR/Math.h" // IWYU pragma: keep +#include "triton-linalg/Dialect/MathExt/IR/MathExt.h" // IWYU pragma: keep #include "triton-linalg/Dialect/Triton/Utils/MaskTracker.h" #include "triton-linalg/Dialect/Utils/ArithUtils.h" #include "triton-linalg/Dialect/Utils/Conventions.h" +#include "triton-linalg/Dialect/Utils/ShapeUtils.h" #include "triton-linalg/Utils/Utils.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -127,7 +128,7 @@ getBroadcastDimensions(ArrayRef srcShape, ArrayRef dstShape) { static Value sliceFirst(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t dim, bool reverse = false) { - ShapedType inputType = input.getType().cast(); + ShapedType inputType = cast(input.getType()); auto sizes = llvm::to_vector(llvm::map_range(inputType.getShape(), [&](int64_t t) { return OpFoldResult(rewriter.getI64IntegerAttr(t)); @@ -148,7 +149,7 @@ static Value sliceFirst(ConversionPatternRewriter &rewriter, Location loc, static Value sliceRemaining(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t dim, bool reverse = false) { - ShapedType inputType = input.getType().cast(); + ShapedType inputType = cast(input.getType()); auto sizes = llvm::to_vector(llvm::map_range(inputType.getShape(), [&](int64_t t) { return OpFoldResult(rewriter.getI64IntegerAttr(t)); @@ -167,12 +168,40 @@ static Value sliceRemaining(ConversionPatternRewriter &rewriter, Location loc, strides); } +static Value getInitConstValue(ReductionMode mode, ShapedType type, + Location &loc, PatternRewriter &rewriter) { + auto elementType = type.getElementType(); + switch (mode) { + case ReductionMode::ARGMAX: + if (isa(elementType)) { + return arith::getIdentityValue(arith::AtomicRMWKind::maximumf, + elementType, rewriter, loc); + } else if (elementType.isIntOrIndex()) { + return rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, -1)); + } + break; + case ReductionMode::ARGMIN: + if (isa(elementType)) { + return arith::getIdentityValue(arith::AtomicRMWKind::minimumf, + elementType, rewriter, loc); + } else if (elementType.isIntOrIndex()) { + return rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, -1)); + } + break; + default: + break; + } + return nullptr; +} + /// Create PrefixAttr for PrintOp. FailureOr createPrefixAttr(StringAttr prefixAttr, Value operand, bool hex, triton::PrintOp op, PatternRewriter &rewriter) { auto oriOperandType = getElementTypeOrSelf(operand.getType()); - if (oriOperandType.isa()) { + if (isa(oriOperandType)) { return rewriter.getStringAttr(prefixAttr.getValue() + Twine("%p")); } @@ -234,16 +263,16 @@ struct TritonBroadcastPattern if (!type) return failure(); - auto resultTy = type.cast(); + auto resultTy = cast(type); auto loc = op.getLoc(); // tt.broadcast with input of scalar has been converted to tt.splat, // no need to deal with scalar case here, just return. - if (!op.getSrc().getType().isa()) { + if (!isa(op.getSrc().getType())) { return failure(); } - ShapedType operandTy = op.getSrc().getType().cast(); + ShapedType operandTy = cast(op.getSrc().getType()); assert(operandTy.getRank() == resultTy.getRank() && "rank of source and destination should match"); @@ -282,7 +311,7 @@ struct TritonSplatPattern : public OpConversionPattern { auto type = typeConverter->convertType(op.getResult().getType()); if (!type) return failure(); - auto resultTy = type.cast(); + auto resultTy = cast(type); auto initOp = rewriter.create( op.getLoc(), resultTy.getShape(), resultTy.getElementType()); @@ -302,9 +331,9 @@ struct TritonExpandDimPattern auto type = typeConverter->convertType(op.getResult().getType()); if (!type) return failure(); - auto resultTy = type.cast(); + auto resultTy = cast(type); - ShapedType operandTy = op.getSrc().getType().cast(); + ShapedType operandTy = cast(op.getSrc().getType()); SmallVector reassociationMap; if (!createReassociationMaps(rewriter, resultTy.getShape(), @@ -327,9 +356,9 @@ struct TritonViewPattern : public OpConversionPattern { matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operand = adaptor.getOperands()[0]; - auto operandType = operand.getType().cast(); + auto operandType = cast(operand.getType()); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); // Special case where the result is a 0-d tensor. if (resultType.getRank() == 0) { @@ -459,7 +488,7 @@ struct TritonAddPtrPattern : public OpConversionPattern { }; Location loc = op.getLoc(); - auto resultTy = type.dyn_cast(); + auto resultTy = dyn_cast(type); if (!resultTy) { // Handle addptr for scalar. auto ret = createAdd(adaptor.getPtr(), adaptor.getOffset(), loc, type); @@ -493,7 +522,7 @@ struct TritonMakeRangePattern auto type = typeConverter->convertType(op.getResult().getType()); if (!type) return failure(); - auto resultTy = type.cast(); + auto resultTy = cast(type); auto initOp = rewriter.create(loc, resultTy.getShape(), resultTy.getElementType()); @@ -542,7 +571,7 @@ struct TritonBitcastPattern : public OpConversionPattern { auto type = typeConverter->convertType(op.getResult().getType()); if (!type) return failure(); - auto resultTy = type.dyn_cast(); + auto resultTy = dyn_cast(type); // Scalar case. if (!resultTy) { rewriter.replaceOpWithNewOp(op, type, adaptor.getSrc()); @@ -577,14 +606,14 @@ struct TritonReducePattern : public OpConversionPattern { // types. auto convertedInputTensorTypes = llvm::map_range(adaptor.getOperands().getTypes(), - [](Type t) { return t.cast(); }); + [](Type t) { return cast(t); }); assert(llvm::all_equal(llvm::map_range( convertedInputTensorTypes, [](TensorType t) { return t.getShape(); }))); static_cast(convertedInputTensorTypes); auto originalResultTensorTypes = llvm::map_range(op.getResultTypes(), [](Type t) -> TensorType { - if (auto tensorType = t.dyn_cast()) + if (auto tensorType = dyn_cast(t)) return tensorType; return RankedTensorType::get({}, t); }); @@ -599,31 +628,59 @@ struct TritonReducePattern : public OpConversionPattern { llvm::SmallVector initVals; // As we need to analysis the body of reduce op to get the init value, - // currently we only support single paylod op. Otherwise, We use a portion - // of the input as the initial value for the output. + // currently we only support single paylod op and argmax/min op. + // Otherwise, We use a portion of the input as the initial value for + // the output. + auto mode = reducePatternRecognition(op); do { - if (op.getNumResults() == 1) { - Operation *payloadOp = - triton::linalg_ext::findPayloadOp(&op.getCombineOp().front()); - if (!payloadOp) - break; - std::optional fillValAttr = - arith::getNeutralElement(payloadOp); - // When the requirements are not met, go to the later general - // implementation. - if (!fillValAttr.has_value()) - break; - Value fillVal = - rewriter.create(loc, fillValAttr.value()); - // Create empty vectors as init values. - for (TensorType t : convertedResultTensorTypes) { - auto initOp = rewriter.create(loc, t.getShape(), - t.getElementType()); - auto fillOp = - rewriter.create(loc, fillVal, initOp.getResult()); - initVals.push_back(fillOp.getResult(0)); + if (mode.has_value()) { + // Deal single payload. + if (op.getNumResults() == 1) { + Operation *payloadOp = + triton::linalg_ext::findPayloadOp(&op.getCombineOp().front()); + if (!payloadOp) + break; + std::optional fillValAttr = + arith::getNeutralElement(payloadOp); + // When the requirements are not met, go to the later general + // implementation. + if (!fillValAttr.has_value()) + break; + Value initVal = + rewriter.create(loc, fillValAttr.value()); + // Create empty vectors as init values. + for (TensorType t : convertedResultTensorTypes) { + auto initOp = rewriter.create(loc, t.getShape(), + t.getElementType()); + Value fillInitValue = + rewriter + .create(loc, initVal, initOp.getResult()) + .getResult(0); + initVals.push_back(fillInitValue); + } + } else if (mode == ReductionMode::ARGMAX || + mode == ReductionMode::ARGMIN) { + // Deal argmax/min op. + bool canGetInit = true; + for (auto initTy : convertedResultTensorTypes) { + auto initVal = getInitConstValue(*mode, initTy, loc, rewriter); + if (!initVal) { + canGetInit = false; + break; + } + auto initOp = rewriter.create( + loc, initTy.getShape(), initTy.getElementType()); + Value fillInitVal = + rewriter + .create(loc, initVal, initOp.getResult()) + .getResult(0); + initVals.push_back(fillInitVal); + } + if (!canGetInit) { + initVals.clear(); + break; + } } - // Create a linalg.reduce on the same input and move the combine region // there. (ReduceReturnOpConversion will take care of the terminator.) auto reduceOp = rewriter.create( @@ -642,14 +699,12 @@ struct TritonReducePattern : public OpConversionPattern { // Otherwise, the result has to be a scalar, so we need to extract the // scalar from the 0-ranked result tensor. SmallVector results; - Value scalar = rewriter.create( - loc, - SmallVector(convertedResultTensorTypes) - .begin() - ->dyn_cast() - .getElementType(), - reduceOp->getResults()[0], /*indices=*/ValueRange{}); - results.push_back(scalar); + for (auto [tensor, type] : + llvm::zip(reduceOp->getResults(), convertedResultTensorTypes)) { + Value scalar = rewriter.create( + loc, type.getElementType(), tensor, /*indices=*/ValueRange{}); + results.push_back(scalar); + } rewriter.replaceOp(op, results); return success(); @@ -667,7 +722,7 @@ struct TritonReducePattern : public OpConversionPattern { "tt.reduce requires the same input number and init number"); for (auto [inputVal, initTy] : llvm::zip(adaptor.getOperands(), convertedResultTensorTypes)) { - ShapedType inputTy = inputVal.getType().cast(); + ShapedType inputTy = cast(inputVal.getType()); ArrayRef inputShape = inputTy.getShape(); // If the size of reduce axis is 1, we will replace init operands by input @@ -690,7 +745,7 @@ struct TritonReducePattern : public OpConversionPattern { // operands' init value. { Value slice = sliceFirst(rewriter, loc, inputVal, op.getAxis()); - auto sliceShape = slice.getType().cast().getShape(); + auto sliceShape = cast(slice.getType()).getShape(); // Resize slice value's shape by init operand. SmallVector reassociationMap; @@ -767,7 +822,7 @@ struct TritonPureExternElementwisePattern ConversionPatternRewriter &rewriter) const override { assert(op.getPure()); Location loc = op.getLoc(); - if (auto resultTy = op.getType().dyn_cast()) { + if (auto resultTy = dyn_cast(op.getType())) { auto initOp = rewriter.create(loc, resultTy.getShape(), resultTy.getElementType()); rewriter.replaceOpWithNewOp( @@ -833,6 +888,78 @@ class PtrSelectOpPattern : public OpConversionPattern { } }; +class PtrExtractOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + PtrExtractOpPattern(triton::TritonLinalgTypeConverter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context) {} + + LogicalResult + matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getTensor(), + op.getIndices()); + return success(); + } +}; + +class PtrExtractSliceOpPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + PtrExtractSliceOpPattern(triton::TritonLinalgTypeConverter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context) {} + + LogicalResult + matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.getSource(), op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()); + return success(); + } +}; + +class PtrExpandShapeOpPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + PtrExpandShapeOpPattern(triton::TritonLinalgTypeConverter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context) {} + + LogicalResult + matchAndRewrite(tensor::ExpandShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getType(), adaptor.getSrc(), op.getReassociationIndices()); + return success(); + } +}; + +class PtrCollapseShapeOpPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + PtrCollapseShapeOpPattern(triton::TritonLinalgTypeConverter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context) {} + + LogicalResult + matchAndRewrite(tensor::CollapseShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.getSrc(), op.getReassociationIndices()); + return success(); + } +}; + struct GPUBarrierOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -851,19 +978,16 @@ struct TritonTransPattern : public OpConversionPattern { matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - RankedTensorType srcTy = - adaptor.getSrc().getType().cast(); - auto rank = srcTy.getRank(); + RankedTensorType resTy = cast(op.getResult().getType()); + auto rank = resTy.getRank(); if (rank <= 1) { rewriter.replaceOp(op, adaptor.getSrc()); return success(); } SmallVector permutation(op.getOrder()); - SmallVector retShape(srcTy.getShape().rbegin(), - srcTy.getShape().rend()); - auto initOp = - rewriter.create(loc, retShape, srcTy.getElementType()); + auto initOp = rewriter.create(loc, resTy.getShape(), + resTy.getElementType()); rewriter.replaceOpWithNewOp(op, adaptor.getSrc(), initOp, permutation); @@ -917,7 +1041,7 @@ class TritonFuncOpPattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - FunctionType type = op.getFunctionType().cast(); + FunctionType type = cast(op.getFunctionType()); auto *converter = getTypeConverter(); // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); @@ -957,24 +1081,32 @@ class ArithCmpIToFillPattern : public OpConversionPattern { return failure(); auto tensorType = - op.getResult().getType().dyn_cast_or_null(); + dyn_cast_or_null(op.getResult().getType()); if (!tensorType) return failure(); - Value init = rewriter.create(loc, tracker.getSizes(), - tensorType.getElementType()); - Value trueVal = - rewriter.create(loc, rewriter.getBoolAttr(true)); - Value one = - rewriter.create(loc, trueVal, init).getResult(0); - + auto hasZeroSize = llvm::any_of(tracker.getSizes(), [](const auto &size) { + return isConstantIntValue(size, 0); + }); + Value value; Value falseVal = rewriter.create(loc, rewriter.getBoolAttr(false)); - - SmallVector offsets = llvm::to_vector(tracker.getStarts()); - // Replace with pad op. - auto value = getPadOrInsertOpWithOther( - loc, falseVal, tensorType, one, offsets, tracker.getSizes(), rewriter); + if (!hasZeroSize) { + Value init = rewriter.create( + loc, tracker.getSizes(), tensorType.getElementType()); + Value trueVal = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value one = + rewriter.create(loc, trueVal, init).getResult(0); + SmallVector offsets = llvm::to_vector(tracker.getStarts()); + // Replace with pad op. + value = getPadOrInsertOpWithOther(loc, falseVal, tensorType, one, offsets, + tracker.getSizes(), rewriter); + } else { + Value init = rewriter.create( + loc, tensorType.getShape(), tensorType.getElementType()); + value = rewriter.create(loc, falseVal, init).getResult(0); + } rewriter.replaceOp(op, value); return success(); } @@ -994,7 +1126,7 @@ class ArithSelectConversionPattern auto cond = op.getCondition(); Value trueValue = op.getTrueValue(); Value falseValue = op.getFalseValue(); - if (!cond.getType().dyn_cast_or_null()) + if (!dyn_cast_or_null(cond.getType())) return failure(); triton::MaskTracker tracker; @@ -1017,7 +1149,7 @@ class ArithSelectConversionPattern tracker = operandTracker; } - auto srcType = trueValue.getType().dyn_cast_or_null(); + auto srcType = dyn_cast_or_null(trueValue.getType()); if (!srcType) return failure(); auto rank = srcType.getRank(); @@ -1091,7 +1223,7 @@ class TritonPrintPattern : public OpConversionPattern { if (!operandType) return failure(); - auto resultTy = operandType.dyn_cast(); + auto resultTy = dyn_cast(operandType); if (!resultTy) { rewriter.create(loc, operand, *prefixAttr); } else { @@ -1124,15 +1256,13 @@ class TritonAssertOpPattern : public OpConversionPattern { auto assertMessage = llvm::formatv("{0}:{1}: {2} Assertion `{3}` failed", op.getFile(), op.getLine(), op.getFunc(), op.getMessage()); - auto resultTy = valType.cast(); + auto rankType = cast(valType); // Only supports int type. - // follow: - // http://gitlab.software.cambricon.com/neuware/triton/-/blob/main-llvm-17/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp#L268 - assert(resultTy.getElementType().isa() && + assert(isa(rankType.getElementType()) && "Only support int tensor for assert"); - rewriter.create(op.getLoc(), resultTy, + rewriter.create(op.getLoc(), rankType, condVal, assertMessage.str()); rewriter.eraseOp(op); @@ -1150,9 +1280,7 @@ struct TritonScanPattern : public OpConversionPattern { // If the the size of scan axis is 1, we just replace op by // input operands. - if (adaptor.getOperands()[0] - .getType() - .cast() + if (cast(adaptor.getOperands()[0].getType()) .getShape()[op.getAxis()] <= 1) { rewriter.replaceOp(op, adaptor.getOperands()); return success(); @@ -1160,7 +1288,7 @@ struct TritonScanPattern : public OpConversionPattern { auto convertedInputTensorTypes = llvm::map_range(adaptor.getOperands().getTypes(), - [](Type t) { return t.cast(); }); + [](Type t) { return cast(t); }); assert(llvm::all_equal(llvm::map_range( convertedInputTensorTypes, [](TensorType t) { return t.getShape(); }))); static_cast(convertedInputTensorTypes); @@ -1172,7 +1300,7 @@ struct TritonScanPattern : public OpConversionPattern { assert(adaptor.getOperands().size() == 1 && "tt.scan only support single input now"); for (auto inputVal : adaptor.getOperands()) { - RankedTensorType inputTy = inputVal.getType().cast(); + RankedTensorType inputTy = cast(inputVal.getType()); int64_t rank = inputTy.getRank(); // 1. Slice the remaining elements of input operands. @@ -1180,7 +1308,7 @@ struct TritonScanPattern : public OpConversionPattern { Value slice = sliceRemaining(rewriter, loc, inputVal, op.getAxis(), op.getReverse()); // Create output tensor - auto sliceShape = slice.getType().cast().getShape(); + auto sliceShape = cast(slice.getType()).getShape(); Value empty = rewriter.create( loc, sliceShape, inputTy.getElementType()); inputVals.push_back(slice); @@ -1193,7 +1321,7 @@ struct TritonScanPattern : public OpConversionPattern { Value slice = sliceFirst(rewriter, loc, inputVal, op.getAxis(), op.getReverse()); SmallVector collapseDstShape; - ShapedType sliceTy = slice.getType().cast(); + ShapedType sliceTy = cast(slice.getType()); for (int64_t i = 0; i < rank; ++i) { if (i != op.getAxis()) { collapseDstShape.push_back(sliceTy.getShape()[i]); @@ -1212,7 +1340,7 @@ struct TritonScanPattern : public OpConversionPattern { // Create a linalg_ext.scan on the same input and move the combine region // there. (ScanReturnOpConversion will take care of the terminator.) auto resultTypes = llvm::map_range( - initVals, [](Value t) { return t.getType().cast(); }); + initVals, [](Value t) { return cast(t.getType()); }); auto scanOp = rewriter.create( loc, /*resultTypes=*/SmallVector(resultTypes), @@ -1227,7 +1355,7 @@ struct TritonScanPattern : public OpConversionPattern { // Insert linalg_ext.scan result into input operand. // Retrieve insert sizes of result tensor. - RankedTensorType initType = initVals[0].getType().cast(); + RankedTensorType initType = cast(initVals[0].getType()); ArrayRef initShape = initType.getShape(); int64_t rank = initType.getRank(); auto insertSizes = @@ -1269,6 +1397,72 @@ struct TritonScanReturnPattern } }; +/// Convert an `tt.cat` operation to `tensor.insert_slice` +/// operation. +/// +/// Concatenate two tensors along the highest dimension. +/// The two input tensors must have the same shape. +/// +/// ```mlir +/// %0 = tt.cat %arg0, %arg1 : tensor<32xf32> -> tensor<64xf32> +/// ``` +/// +/// converts to: +/// +/// ```mlir +/// %0 = tensor.empty() : tensor<64xf32> +/// %1 = tensor.insert_slice %arg0 into %0[0] [32] [1] : tensor<32xf32> into +/// tensor<64xf32> +/// %c32_0 = arith.constant 32 : index +/// %2 = tensor.insert_slice %arg1 into %inserted_slice[%c32_0] [32] [1] : +/// tensor<32xf32> into tensor<64xf32> +/// ``` +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = typeConverter->convertType(op.getResult().getType()); + if (!type) + return failure(); + + auto resultTy = cast(type); + + Location loc = op.getLoc(); + Value init = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); + + auto rank = resultTy.getRank(); + // Insert slice params. + auto zero = rewriter.getIndexAttr(0); + auto one = rewriter.getIndexAttr(1); + SmallVector offsets(rank, zero); + SmallVector strides(rank, one); + SmallVector sizes; + + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + // Consider 0-rank tensor as tensor with one element. + if (cast(lhs.getType()).getRank() == 0) { + sizes = {one}; + } else { + sizes = getDims(rewriter, loc, lhs); + } + + auto firstInsert = rewriter.createOrFold( + loc, lhs, init, offsets, sizes, strides); + // The tt.cat op always concatenate two tensors along the highest dimension. + offsets[0] = rewriter.createOrFold( + loc, materializeOpFoldResult(rewriter, loc, offsets[0]), + materializeOpFoldResult(rewriter, loc, sizes[0])); + auto secondInsert = rewriter.createOrFold( + loc, rhs, firstInsert, offsets, sizes, strides); + rewriter.replaceOp(op, secondInsert); + return success(); + } +}; + /// Convert an `tt.join` operation to `tensor.insert_slice` /// operation. /// @@ -1284,8 +1478,9 @@ struct TritonScanReturnPattern /// ```mlir /// %0 = tensor.empty() : tensor<2xf32> /// %1 = tensor.insert_slice %arg0 into %0[0] [1] [1] : tensor into -/// tensor<2xf32> %2 = tensor.insert_slice %arg1 into %1[1] [1] [1] : -/// tensor into tensor<2xf32> +/// tensor<2xf32> +/// %2 = tensor.insert_slice %arg1 into %1[1] [1] [1] : tensor into +/// tensor<2xf32> /// ``` struct TritonJoinPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1295,9 +1490,9 @@ struct TritonJoinPattern : public OpConversionPattern { Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); - auto lhsType = lhs.getType().cast(); + auto lhsType = cast(lhs.getType()); auto shape = lhsType.getShape(); - auto resultType = op.getResult().getType().cast(); + auto resultType = cast(op.getResult().getType()); Value emptyOp = rewriter.create( op.getLoc(), resultType.getShape(), lhsType.getElementType()); @@ -1343,9 +1538,9 @@ struct TritonSplitPattern : public OpConversionPattern { matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value inputs = op.getSrc(); - auto inputType = inputs.getType().cast(); + auto inputType = cast(inputs.getType()); auto outLhs = op.getOutLHS(); - auto outLhsType = outLhs.getType().cast(); + auto outLhsType = cast(outLhs.getType()); int64_t rank = inputType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); @@ -1427,54 +1622,6 @@ struct TritonMulhiuiPattern : public OpConversionPattern { } }; -/// Convert an `tt.histogram` operation to `arith.subi`, `arith.cmpi`, -/// `arith.andi`, `scf.for`, `scf.if` operation. -/// -/// Compute a histogram based on the input tensor with num_bins bins. -/// The process consists of the following steps: -/// 1. Set the minimum value (min_val) to 0 and the maximum value (max_val) to -/// num_bins - 1. -/// 2. Create a zero tensor of length num_bins to store the count for each bin. -/// 3. Compute the histogram: -/// 1) Iterate through each value in the input tensor. -/// 2) If the value is between min_val and max_val (inclusive), -/// calculate its corresponding bin index and increment the count for that -/// bin. -/// -/// ```mlir -/// %1 = tt.histogram %0 : tensor<8xi32> -> tensor<2xi32> -/// ``` -/// -/// converts to: -/// -/// ```mlir -/// %c0_i32 = arith.constant 0 : i32 -/// %c1_i32 = arith.constant 1 : i32 -/// %c2_i32 = arith.constant 2 : i32 -/// %0 = arith.subi %c2_i32, %c1_i32 : i32 -/// %1 = tensor.empty() : tensor<2xi32> -/// %c0_i32_0 = arith.constant 0 : i32 -/// %2 = linalg.fill ins(%c0_i32_0 : i32) -/// outs(%1 : tensor<2xi32>) -> tensor<2xi32> -/// %3 = scf.for ... { -/// %extracted = tensor.extract %arg0[%arg1] : tensor<8xi32> -/// %4 = arith.cmpi sle, %c0_i32, %extracted : i32 -/// %5 = arith.cmpi sge, %0, %extracted : i32 -/// %6 = arith.andi %4, %5 : i1 -/// %7 = scf.if %6 -> (tensor<2xi32>) { -/// %8 = arith.subi %extracted, %c0_i32 : i32 -/// %9 = arith.index_cast %8 : i32 to index -/// %extracted_1 = tensor.extract %arg2[%9] : tensor<2xi32> -/// %c1_i32_2 = arith.constant 1 : i32 -/// %10 = arith.addi %extracted_1, %c1_i32_2 : i32 -/// %inserted = tensor.insert %10 into %arg2[%9] : tensor<2xi32> -/// scf.yield %inserted : tensor<2xi32> -/// } else { -/// scf.yield %arg2 : tensor<2xi32> -/// } -/// scf.yield %7 : tensor<2xi32> -/// } -/// ``` struct TritonHistogramPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1484,109 +1631,31 @@ struct TritonHistogramPattern ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = adaptor.getSrc(); + Value result = op.getResult(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getResult().getType().dyn_cast(); - if (!resultTy || !inputTy) + auto resultTy = dyn_cast(result.getType()); + if (!resultTy) return failure(); - auto inputEleTy = inputTy.getElementType(); - assert(inputEleTy.isa() && "expected integer type"); - - // Get the number of bins from the first dimension size of the result - // tensor. assert(!resultTy.isDynamicDim(0) && "expected static dim"); - int numBins = resultTy.getDimSize(0); - - // Create a constant operation representing the minimum value (0). - Value minVal = rewriter.create( - loc, rewriter.getZeroAttr(inputEleTy)); - - // Compute the maximum value (numBins - 1). - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(inputEleTy, 1)); - Value numBinsConstant = rewriter.create( - loc, rewriter.getIntegerAttr(inputEleTy, numBins)); - Value maxVal = rewriter.create(loc, numBinsConstant, one); - - // Initialize the histogram tensor with zeros. - auto histoInit = rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType()); - auto zeroElem = rewriter.create( - loc, rewriter.getZeroAttr(resultTy.getElementType())); - Value histo = rewriter - .create(loc, ValueRange{zeroElem}, - ValueRange{histoInit}) - .result(); - - // Create a loop to iterate over each element in the input tensor. - auto inputSize = - rewriter.create(loc, inputTy.getShape()[0]); - auto zeroIndex = rewriter.create(loc, 0); - auto oneIndex = rewriter.create(loc, 1); - auto loop = - rewriter - .create( - loc, zeroIndex, inputSize, oneIndex, ValueRange{histo}, - [&](OpBuilder &b, Location nestedLoc, Value iv, - ValueRange iterArgs) { - // Extract the current value from the input tensor at the loop - // index. - Value currentIndexValue = - b.create(nestedLoc, input, iv); - - // Compare the current value with the min and max values. - Value cmpMin = b.create( - nestedLoc, arith::CmpIPredicate::sle, minVal, - currentIndexValue); - Value cmpMax = b.create( - nestedLoc, arith::CmpIPredicate::sge, maxVal, - currentIndexValue); - // Check if the current value is within the range [minVal, - // maxVal]. - Value cond = - b.create(nestedLoc, cmpMin, cmpMax); - - // Create an if-else block to update the histogram if the - // condition is met. - auto ifOp = rewriter.create( - loc, cond, - [&](OpBuilder &builder, Location ifLoc) { - // Calculate the histogram bin index for the current - // value. - Value idx = builder.create( - ifLoc, currentIndexValue, minVal); - idx = b.create( - ifLoc, b.getIndexType(), idx); - // Extract the current histogram value at the calculated - // index. - Value histoValue = builder.create( - ifLoc, iterArgs[0], idx); - // Increment the histogram value by 1. - Value one = rewriter.create( - ifLoc, rewriter.getIntegerAttr( - resultTy.getElementType(), 1)); - Value updateHistVal = builder.create( - ifLoc, histoValue, one); - // Insert the updated value back into the histogram - // tensor. - Value updatedHisto = builder.create( - ifLoc, updateHistVal, iterArgs[0], idx); - builder.create(ifLoc, updatedHisto); - }, - [&](OpBuilder &builder, Location elseLoc) { - builder.create(elseLoc, iterArgs[0]); - }); - b.create(nestedLoc, ifOp.getResults()); - }) - .getResult(0); - - rewriter.replaceOp(op, loop); + + auto initOp = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); + rewriter.replaceOpWithNewOp( + op, TypeRange{result.getType()}, ValueRange{input}, initOp); + return success(); } }; } // namespace +static void +populateTritonFuncToFuncPatterns(RewritePatternSet &patterns, + triton::TritonLinalgTypeConverter &converter) { + MLIRContext *context = patterns.getContext(); + patterns.add(converter, context); +} + static void populateTritonToLinalgPatterns(RewritePatternSet &patterns, triton::TritonLinalgTypeConverter &converter) { @@ -1597,9 +1666,9 @@ populateTritonToLinalgPatterns(RewritePatternSet &patterns, TritonBitcastPattern, TritonReducePattern, TritonReduceReturnPattern, TritonPureExternElementwisePattern, TritonPtrToIntPattern, TritonIntToPtrPattern, TritonTransPattern, TritonReturnOpConversion, - TritonCallOpPattern, TritonFuncOpPattern, TritonViewPattern, - TritonPrintPattern, TritonAssertOpPattern, TritonScanPattern, - TritonScanReturnPattern, TritonJoinPattern, TritonMulhiuiPattern, + TritonCallOpPattern, TritonViewPattern, TritonPrintPattern, + TritonAssertOpPattern, TritonScanPattern, TritonScanReturnPattern, + TritonCatPattern, TritonJoinPattern, TritonMulhiuiPattern, TritonSplitPattern, TritonClampFOpPattern, TritonPreciseSqrtOpPattern, TritonPreciseDivFOpPattern, TritonHistogramPattern>(converter, context); } @@ -1626,6 +1695,23 @@ void triton::TritonToLinalgPass::runOnOperation() { MLIRContext *context = &getContext(); triton::TritonLinalgTypeConverter converter; + // Step1: convert tt.func to func.func + // FIXME: Starting from LLVM19, during conversion, if the ParentOp of + // an Op is also in the same conversion pattern, accessing the ParentOp from + // within the Op may be an invalid behavior. Since `tt.func` internally + // nests other `tt` dialect ops, it is necessary to separate + // the conversion of `tt.func` from that of other ops. + // We need to find a way to merge the two conversions. + ConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + funcTarget.addIllegalOp(); + funcTarget.addLegalDialect(); + populateTritonFuncToFuncPatterns(funcPatterns, converter); + if (failed(applyPartialConversion(getOperation(), funcTarget, + std::move(funcPatterns)))) + return signalPassFailure(); + + // Step2: convert other ttir to linalgir ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalDialect(); @@ -1644,13 +1730,13 @@ void triton::TritonToLinalgPass::runOnOperation() { target.addDynamicallyLegalDialect( [&](Operation *op) { - return !op->getResultTypes().front().isa(); + return !isa(op->getResultTypes().front()); }); target.addDynamicallyLegalOp( [&](Operation *op) { return converter.isLegal(op); }); target.addDynamicallyLegalOp([&](Operation *op) { auto resType = op->getResultTypes().front(); - return !resType.isa() && converter.isLegal(op); + return !isa(resType) && converter.isLegal(op); }); target.addLegalOp(); @@ -1678,6 +1764,10 @@ void triton::TritonToLinalgPass::populatePatterns( patterns.add(converter, context); patterns.add(converter, context, 0); + patterns.add(converter, context, 0); + patterns.add(converter, context, 0); + patterns.add(converter, context, 0); + patterns.add(converter, context, 0); populateTritonToLinalgPatterns(patterns, converter); populateArithToLinalgPatterns(patterns); populateArithConversionPatterns(patterns); diff --git a/lib/Conversion/TritonToLinalg/Utils.cpp b/lib/Conversion/TritonToLinalg/Utils.cpp index 3ad1b40..28e82cb 100644 --- a/lib/Conversion/TritonToLinalg/Utils.cpp +++ b/lib/Conversion/TritonToLinalg/Utils.cpp @@ -40,7 +40,7 @@ Value mlir::triton::getPadOrInsertOpWithOther(Location loc, Value other, ArrayRef offsets, ArrayRef sizes, OpBuilder &rewriter) { - auto otherShapedType = otherType.cast(); + auto otherShapedType = cast(otherType); assert(otherShapedType.hasStaticShape() && "other val shape must be static."); Type elementType = otherShapedType.getElementType(); auto rank = otherShapedType.getRank(); diff --git a/lib/Conversion/TritonToTensor/CMakeLists.txt b/lib/Conversion/TritonToTensor/CMakeLists.txt deleted file mode 100644 index 41c1eff..0000000 --- a/lib/Conversion/TritonToTensor/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_triton_library(TritonToTensor - TritonToTensor.cpp - - DEPENDS - TritonLinalgConverisonIncGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/lib/Conversion/TritonToTensor/TritonToTensor.cpp b/lib/Conversion/TritonToTensor/TritonToTensor.cpp deleted file mode 100644 index c45664f..0000000 --- a/lib/Conversion/TritonToTensor/TritonToTensor.cpp +++ /dev/null @@ -1,105 +0,0 @@ -//===- TritonToTensor.cpp - Triton to Tensor dialect convension -*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinDialect.h" // IWYU pragma: keep -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton-linalg/Conversion/PassDetail.h" -#include "triton-linalg/Conversion/TritonToTensor/TritonToTensor.h" -#include "triton-linalg/Dialect/Utils/ShapeUtils.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/ADT/iterator.h" -#include "llvm/Support/Casting.h" - -namespace mlir { -class MLIRContext; -} // namespace mlir - -#define DEBUG_TYPE "triton-to-tensor" - -using namespace mlir; -using namespace triton; - -namespace { -struct ConvertCatToinsertSlicePattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getResult().getType().cast(); - - Location loc = op.getLoc(); - Value result = rewriter.create(loc, resultTy.getShape(), - resultTy.getElementType()); - - auto rank = resultTy.getRank(); - auto operands = adaptor.getOperands(); - // Insert slice params. - auto zero = rewriter.getIndexAttr(0); - auto one = rewriter.getIndexAttr(1); - SmallVector offsets(rank, zero); - SmallVector strides(rank, one); - SmallVector sizes; - - for (auto operand : operands) { - sizes = getDims(rewriter, loc, operand); - result = rewriter.createOrFold( - loc, operand, result, offsets, sizes, strides); - // Triton's cat op always concat the 1st axis. - offsets[0] = rewriter.createOrFold( - loc, materializeOpFoldResult(rewriter, loc, offsets[0]), - materializeOpFoldResult(rewriter, loc, sizes[0])); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct TritonToTensorPass : public TritonToTensorPassBase { - void runOnOperation() override; -}; -} // namespace - -void TritonToTensorPass::runOnOperation() { - MLIRContext *context = &getContext(); - ConversionTarget target(*context); - - target.markUnknownOpDynamicallyLegal( - [](Operation *op) { return !isa(op); }); - - // Rewrite patterns. - RewritePatternSet patterns(context); - patterns.add(context); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - return signalPassFailure(); -} - -std::unique_ptr mlir::triton::createTritonToTensorPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Arith/CMakeLists.txt b/lib/Dialect/Arith/CMakeLists.txt deleted file mode 100644 index e31af32..0000000 --- a/lib/Dialect/Arith/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Transforms) diff --git a/lib/Dialect/Arith/Transforms/ArithExtraCanonicalizer.cpp b/lib/Dialect/Arith/Transforms/ArithExtraCanonicalizer.cpp deleted file mode 100644 index ec3d08c..0000000 --- a/lib/Dialect/Arith/Transforms/ArithExtraCanonicalizer.cpp +++ /dev/null @@ -1,250 +0,0 @@ -//===- ArithExtraCanonicalizer.cpp -----------------------------*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// -// -// \Note: it's a supplement for original linalg canonicalization defined in -// mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp. -// -//===----------------------------------------------------------------------===// -#include -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton-linalg/Dialect/Arith/Transforms/Passes.h" -#include "triton-linalg/Dialect/Utils/ArithUtils.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Casting.h" -using namespace mlir; -using namespace mlir::triton; -namespace mlir { -class MLIRContext; -} // namespace mlir - -namespace { -/// -/// Convert "cst->div" to "1/cst->mul". -/// -/// Example: -/// ``` -/// %cst = arith.constant 4.0 : f32 -/// %0 = arith.divf %arg0, %cst : f32 -/// ``` -/// -/// transformed into: -/// -/// ``` -/// %cst = arith.constant 2.500000e-01 : f32 -/// %0 = arith.mulf %arg0, %cst : f32 -/// ``` -struct ScalarDivToMul final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::DivFOp op, - PatternRewriter &rewriter) const override { - auto divisor = op.getRhs(); - auto loc = op.getLoc(); - - // Case1: 'rhs' is a scalar. - FloatAttr divisorAttr; - if (matchPattern(divisor, m_Constant(&divisorAttr))) { - auto divisorVal = divisorAttr.getValue().convertToDouble(); - Value multiplier = rewriter.create( - loc, FloatAttr::get(divisor.getType(), 1.0 / divisorVal)); - rewriter.replaceOpWithNewOp(op, op.getLhs(), multiplier); - return success(); - } - - // Case2: 'rhs' is a const float tensor. - auto constDivisor = divisor.getDefiningOp(); - auto divisorType = divisor.getType().dyn_cast_or_null(); - if (!constDivisor || !divisorType || - !divisorType.getElementType().isa()) { - return failure(); - } - auto constAttr = constDivisor.getValue().dyn_cast(); - // Take the reciprocal element by element. - auto multiplierVal = llvm::to_vector(llvm::map_range( - constAttr.getValues(), [&](const APFloat &value) -> Attribute { - auto divisorVal = value.convertToDouble(); - return FloatAttr::get(divisorType.getElementType(), 1.0 / divisorVal); - })); - auto multiplierAttr = DenseElementsAttr::get(divisorType, multiplierVal); - auto multiplier = rewriter.create(loc, multiplierAttr); - - rewriter.replaceOpWithNewOp(op, op.getLhs(), multiplier); - return success(); - } -}; - -/// Canonicalize arith cmp and select to arith max/min pattern. -/// -/// Example: -/// ``` -/// %0 = arith.cmpf ogt, %arg0, %arg1 : f32 -/// %1 = arith.select %0, %arg0, %arg1 : f32 -/// ``` -/// -/// transformed into: -/// -/// ``` -/// %0 = arith.maximumf %arg0, %arg1 : f32 -/// ``` -/// FIXME: wrong conversion in float type -struct CanonicalizeCmpSelectToMinMax final - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::SelectOp op, - PatternRewriter &rewriter) const override { - // Check cmp op. - auto *cmpOp = op.getCondition().getDefiningOp(); - if (!cmpOp || !isa(cmpOp)) { - return failure(); - } - auto maxMinOp = getCmpSelectResult(rewriter, cmpOp, op); - if (!maxMinOp) { - return mlir::failure(); - } - - rewriter.replaceOp(op, maxMinOp.value()->getResults()); - return success(); - } -}; - -template -class CanonicalizeArithI1Pattern : public OpRewritePattern { -public: - explicit CanonicalizeArithI1Pattern(MLIRContext *ctx) - : OpRewritePattern(ctx) {} - - LogicalResult matchAndRewrite(Op op, - PatternRewriter &rewriter) const override { - Type eleType = getElementTypeOrSelf(op.getLhs().getType()); - if (!eleType.isInteger(1)) { - return failure(); - } - rewriter.replaceOpWithNewOp(op, op.getLhs(), op.getRhs()); - return success(); - } -}; - -/// Check whether the graph is constructed by nan statement -/// operations and the cmp-select pattern can be optimized to `OpTy`. -/// The nan structure region as follows: -/// bArg0 bArg1 -/// || \ / -/// || cmpf -/// || | -/// cmpf | -/// \ / -/// ori bArg0 bArg1 -/// | / / -/// select -template -class CanonicalizeNanStatement : public OpRewritePattern { -public: - explicit CanonicalizeNanStatement(MLIRContext *ctx) - : OpRewritePattern(ctx) {} - - LogicalResult matchAndRewrite(arith::SelectOp selectOp, - PatternRewriter &rewriter) const override { - using mlir::matchers::m_Any; - // Nan pattern match. - auto nanStatementPattern = m_Op( - m_Op(m_Op(m_Any(), m_Any()), - m_Op(m_Any(), m_Any())), - m_Any(), m_Any()); - if (!nanStatementPattern.match(selectOp)) { - return failure(); - } - auto trueVal = selectOp.getTrueValue(); - auto falseVal = selectOp.getFalseValue(); - // Collect block arg cmp users. - llvm::SmallSetVector cmpOps; - for (auto user = trueVal.getUsers().begin(); - user != trueVal.getUsers().end(); user++) { - if (isa(*user)) { - cmpOps.insert(*user); - } - } - for (auto user = falseVal.getUsers().begin(); - user != falseVal.getUsers().end(); user++) { - if (isa(*user)) { - cmpOps.insert(*user); - } - } - // Must be two cmp ops. - if (cmpOps.size() != 2) { - return failure(); - } - // One of cmp op must be une predicate and has same operands. - if (llvm::count_if(cmpOps, [](Operation *op) { - auto nanCmp = cast(op); - return (nanCmp.getPredicate() == arith::CmpFPredicate::UNE) && - (nanCmp.getLhs() == nanCmp.getRhs()); - }) != 1) { - return failure(); - } - // Find non-une cmp op. - auto *minMaxCmp = llvm::find_if(cmpOps, [](Operation *op) { - auto cmp = cast(op); - return cmp.getPredicate() != arith::CmpFPredicate::UNE; - }); - // Check if cmp and select can be optimized to min/max. - auto minMaxOp = getCmpSelectResult(rewriter, *minMaxCmp, selectOp); - if (!minMaxOp.has_value() || !isa_and_nonnull(minMaxOp.value())) { - return failure(); - } - rewriter.replaceOpWithNewOp(selectOp, trueVal, falseVal); - return success(); - } -}; - -struct ArithCanonicalizerPass - : public arith_ext::ArithCanonicalizerBase { - ArithCanonicalizerPass() = default; - ArithCanonicalizerPass(const ArithCanonicalizerPass &) = default; - - void runOnOperation() override { - Operation *op = getOperation(); - auto *ctx = op->getContext(); - RewritePatternSet patterns(ctx); - patterns.add, - CanonicalizeArithI1Pattern, - CanonicalizeNanStatement, - CanonicalizeNanStatement>( - patterns.getContext()); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr mlir::triton::arith_ext::createArithCanonicalizerPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Arith/Transforms/CMakeLists.txt b/lib/Dialect/Arith/Transforms/CMakeLists.txt deleted file mode 100644 index 893542d..0000000 --- a/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_triton_library(ArithTransforms - ArithExtraCanonicalizer.cpp - - DEPENDS - ArithTransformsIncGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/lib/Dialect/Auxiliary/IR/AuxiliaryDialect.cpp b/lib/Dialect/Auxiliary/IR/AuxiliaryDialect.cpp index bbb5060..7407d5f 100644 --- a/lib/Dialect/Auxiliary/IR/AuxiliaryDialect.cpp +++ b/lib/Dialect/Auxiliary/IR/AuxiliaryDialect.cpp @@ -131,7 +131,7 @@ LogicalResult StoreResourceOp::verify() { } auto from = getFrom(); auto to = getTo(); - if (isScalar(from) || from.getType().isa<::mlir::TensorType>()) { + if (isScalar(from) || isa<::mlir::TensorType>(from.getType())) { if (from.getType() != to.getType()) { return emitOpError() << "failed to verify that all of {from, to} have same type"; @@ -171,7 +171,7 @@ void ViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); if (!resultType) { resultType = MemRefType::get( staticSizes, elementType, diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index c8713f8..850d05f 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(Arith) add_subdirectory(Auxiliary) add_subdirectory(LinalgExt) add_subdirectory(MathExt) diff --git a/lib/Dialect/LinalgExt/IR/CMakeLists.txt b/lib/Dialect/LinalgExt/IR/CMakeLists.txt index b3a0814..3600acc 100644 --- a/lib/Dialect/LinalgExt/IR/CMakeLists.txt +++ b/lib/Dialect/LinalgExt/IR/CMakeLists.txt @@ -2,7 +2,7 @@ add_triton_library(LinalgExtInterface LinalgExtInterface.cpp DEPENDS - TritonLinalgInterfacesTableGen + LinalgExtTableGen LINK_LIBS PUBLIC MLIRIR @@ -14,7 +14,6 @@ add_triton_library(LinalgExtDialect DEPENDS LinalgExtTableGen - TritonLinalgInterfacesTableGen LINK_LIBS PUBLIC DialectUtils diff --git a/lib/Dialect/LinalgExt/IR/LinalgExtInterface.cpp b/lib/Dialect/LinalgExt/IR/LinalgExtInterface.cpp index cbf49c2..0dc3178 100644 --- a/lib/Dialect/LinalgExt/IR/LinalgExtInterface.cpp +++ b/lib/Dialect/LinalgExt/IR/LinalgExtInterface.cpp @@ -37,6 +37,8 @@ LogicalResult triton::detail::verifyLinalgExtOpInterface(Operation *op) { LogicalResult LinalgExtOp::reifyResultShapes( OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return llvm::cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); + if (auto linalgOp = dyn_cast(getOperation())) { + return linalgOp.reifyResultShapes(b, reifiedReturnShapes); + } + return failure(); } diff --git a/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp index b7a2fa0..d760046 100644 --- a/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -22,6 +22,7 @@ #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -111,6 +112,7 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc, b.create(loc, args[0]); }); } + // Helper function for getEffect impl from LinalgOps.cpp. static void getGenericEffectsImpl( SmallVectorImpl> @@ -120,15 +122,15 @@ static void getGenericEffectsImpl( for (auto operand : inputOperands) { if (!llvm::isa(operand.getType())) continue; - effects.emplace_back(MemoryEffects::Read::get(), operand, + effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } for (auto operand : outputOperands) { if (!llvm::isa(operand.getType())) continue; - effects.emplace_back(MemoryEffects::Read::get(), operand, - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), operand, + effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/1, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } } @@ -227,6 +229,139 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } +//===----------------------------------------------------------------------===// +// BEGIN copied from llvm-project mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +//===----------------------------------------------------------------------===// +static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, + const OperationName &payloadOpName, + const NamedAttrList &payloadOpAttrs, + ArrayRef operands, + bool initFirst = false) { + OpBuilder b(parser.getContext()); + Region *body = result.addRegion(); + Block &block = body->emplaceBlock(); + b.setInsertionPointToStart(&block); + SmallVector bbArgs; + for (auto &operand : operands) { + block.addArgument( + mlir::cast(operand.getType()).getElementType(), + b.getUnknownLoc()); + } + SmallVector payloadOpOperands; + // If initFirst flag is enabled, we consider init as the first position of + // payload operands. + if (initFirst) { + payloadOpOperands.push_back(block.getArguments().back()); + for (const auto &arg : block.getArguments().drop_back()) + payloadOpOperands.push_back(arg); + } else { + payloadOpOperands = {block.getArguments().begin(), + block.getArguments().end()}; + } + + Operation *payloadOp = b.create( + result.location, b.getStringAttr(payloadOpName.getStringRef()), + payloadOpOperands, + TypeRange{mlir::cast(result.operands.back().getType()) + .getElementType()}, + payloadOpAttrs); + b.create(result.location, payloadOp->getResults()); +} + +// Retrieve the operation from the body, if it is the only one (except +// yield) and if it gets the same amount of arguments as the body does. +// If initFirst flag is enabled, we check that init takes the first position in +// operands of payload. +static Operation *findPayloadOp(Block *body, bool initFirst = false) { + if (body->getOperations().size() != 2) + return nullptr; + Operation &payload = body->getOperations().front(); + assert(isa(body->getOperations().back())); + + if (payload.getNumOperands() == 0 || + payload.getNumOperands() != body->getNumArguments()) + return nullptr; + if (initFirst) { + // check init + if (payload.getOperands().back() != body->getArgument(0)) + return nullptr; + // check rest + for (const auto &[operand, bbArg] : + llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { + if (bbArg != operand) + return nullptr; + } + } else { + for (const auto &[operand, bbArg] : + llvm::zip(payload.getOperands(), body->getArguments())) { + if (bbArg != operand) + return nullptr; + } + } + return &payload; +} + +void printShortFormReduce(OpAsmPrinter &p, Operation *payloadOp) { + SmallVector elidedAttrs; + std::string attrToElide; + p << " { " << payloadOp->getName().getStringRef(); + for (const auto &attr : payloadOp->getAttrs()) { + auto fastAttr = + mlir::dyn_cast(attr.getValue()); + if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { + attrToElide = attr.getName().str(); + elidedAttrs.push_back(attrToElide); + break; + } + } + p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs); + p << " }"; +} + +static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, + NamedAttrList &attributes, + StringRef attributeName) { + if (parser.parseKeyword(attributeName) || parser.parseEqual()) + return failure(); + + attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{})); + return success(); +} + +static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, + ArrayRef attributeValue) { + p << ' ' << attributeName << " = [" << attributeValue << "] "; +} + +static ParseResult parseDstStyleOp( + OpAsmParser &parser, OperationState &result, + function_ref parseAttrsFn = + nullptr) { + // Parse `ins` and `outs`. + SmallVector inputTypes, outputTypes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, + /*addOperandSegmentSizes=*/false)) + return failure(); + + // Add result types. + for (Type outputType : outputTypes) { + if (llvm::isa(outputType)) + result.addTypes(outputType); + } + + // Parse required attributes. + if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes))) + return failure(); + + // Parse optional attributes. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} +//===----------------------------------------------------------------------===// +// END copied from llvm-project mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +//===----------------------------------------------------------------------===// + //===----------------------------------------------------------------------===// // Helper functions for named Linalg ops defined in ods-gen from LinalgOps.cpp. //===----------------------------------------------------------------------===// @@ -276,7 +411,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state, resultTensorTypes.value_or(TypeRange()); if (!resultTensorTypes) llvm::copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), - [](Type type) { return type.isa(); }); + [](Type type) { return mlir::isa(type); }); state.addOperands(inputs); state.addOperands(outputs); @@ -401,16 +536,15 @@ namespace { class RegionBuilderHelper { public: - RegionBuilderHelper(MLIRContext *context, Block &block) - : context(context), block(block) {} RegionBuilderHelper(OpBuilder &builder, Block &block) - : context(builder.getContext()), block(block) {} + : builder(builder), block(block) {} // Build the unary functions defined by OpDSL. Value buildUnaryFn(UnaryFn unaryFn, Value arg) { if (!isFloatingPoint(arg)) llvm_unreachable("unsupported non numeric type"); - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); switch (unaryFn) { case UnaryFn::exp: return builder.create(arg.getLoc(), arg); @@ -424,6 +558,24 @@ class RegionBuilderHelper { return builder.create(arg.getLoc(), arg); case UnaryFn::negf: return builder.create(arg.getLoc(), arg); + case UnaryFn::reciprocal: { + Attribute oneAttr = builder.getOneAttr(arg.getType()); + auto one = builder.create(arg.getLoc(), + ::cast(oneAttr)); + return builder.create(arg.getLoc(), one, arg); + } + case UnaryFn::round: + return builder.create(arg.getLoc(), arg); + case UnaryFn::sqrt: + return builder.create(arg.getLoc(), arg); + case UnaryFn::rsqrt: + return builder.create(arg.getLoc(), arg); + case UnaryFn::square: + return builder.create(arg.getLoc(), arg, arg); + case UnaryFn::tanh: + return builder.create(arg.getLoc(), arg); + case UnaryFn::erf: + return builder.create(arg.getLoc(), arg); } llvm_unreachable("unsupported unary function"); } @@ -437,7 +589,8 @@ class RegionBuilderHelper { arg1.getType().getIntOrFloatBitWidth() == 1; if (!allComplex && !allFloatingPoint && !allInteger) llvm_unreachable("unsupported non numeric type"); - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); switch (binaryFn) { case BinaryFn::add: if (allComplex) @@ -463,6 +616,18 @@ class RegionBuilderHelper { if (allBool) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::div: + if (allComplex) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + llvm_unreachable("unsupported operation: div with bools"); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::div_unsigned: + if (!allInteger || allBool) + llvm_unreachable("unsupported operation: unsigned div not on uint"); + return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); if (allFloatingPoint) @@ -483,13 +648,32 @@ class RegionBuilderHelper { if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); - case BinaryFn::div: - case BinaryFn::div_unsigned: - llvm_unreachable("unsupported binary function"); + case BinaryFn::powf: + assert(allFloatingPoint); + return builder.create(arg0.getLoc(), arg0, arg1); } llvm_unreachable("unsupported binary function"); } + // Build the ternary functions defined by OpDSL. + Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, + Value arg2) { + bool headBool = + isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1; + bool tailFloatingPoint = + isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2); + bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + switch (ternaryFn) { + case TernaryFn::select: + if (!headBool && !(tailFloatingPoint || tailInteger)) + llvm_unreachable("unsupported non numeric type"); + return builder.create(arg0.getLoc(), arg0, arg1, arg2); + } + llvm_unreachable("unsupported ternary function"); + } + // Build the type functions defined by OpDSL. Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { switch (typeFn) { @@ -502,31 +686,32 @@ class RegionBuilderHelper { } void yieldOutputs(ValueRange values) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); builder.create(loc, values); } Value constant(const std::string &value) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); - auto typedAttr = valueAttr.dyn_cast(); - return builder.create(loc, typedAttr.getType(), - typedAttr); + return builder.create(loc, ::cast(valueAttr)); } Value index(int64_t dim) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); return builder.create(builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { - return IntegerType::get(context, width); + return IntegerType::get(builder.getContext(), width); } - Type getFloat32Type() { return Float32Type::get(context); } - Type getFloat64Type() { return Float64Type::get(context); } + Type getFloat32Type() { return Float32Type::get(builder.getContext()); } + Type getFloat64Type() { return Float64Type::get(builder.getContext()); } private: // Generates operations to cast the given operand to a specified type. @@ -534,63 +719,23 @@ class RegionBuilderHelper { // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand, bool isUnsignedCast) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); auto loc = operand.getLoc(); - - if (operand.getType() == toType) - return operand; - if (auto toIntType = toType.dyn_cast()) { - // If operand is floating point, cast directly to the int type. - if (operand.getType().isa()) { - if (isUnsignedCast) - return builder.create(loc, toType, operand); - return builder.create(loc, toType, operand); - } - // Cast index operands directly to the int type. - if (operand.getType().isIndex()) - return builder.create(loc, toType, operand); - if (auto fromIntType = operand.getType().dyn_cast()) { - // Either extend or truncate. - if (toIntType.getWidth() > fromIntType.getWidth()) { - if (isUnsignedCast) - return builder.create(loc, toType, operand); - return builder.create(loc, toType, operand); - } - if (toIntType.getWidth() < fromIntType.getWidth()) - return builder.create(loc, toType, operand); - } - } else if (auto toFloatType = toType.dyn_cast()) { - // If operand is integer, cast directly to the float type. - // Note that it is unclear how to cast from BF16<->FP16. - if (operand.getType().isa()) { - if (isUnsignedCast) - return builder.create(loc, toFloatType, operand); - return builder.create(loc, toFloatType, operand); - } - if (auto fromFloatType = operand.getType().dyn_cast()) { - if (toFloatType.getWidth() > fromFloatType.getWidth()) - return builder.create(loc, toFloatType, operand); - if (toFloatType.getWidth() < fromFloatType.getWidth()) - return builder.create(loc, toFloatType, operand); - } - } - - emitWarning(operand.getLoc()) << "could not cast operand of type " - << operand.getType() << " to " << toType; - return operand; + return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } - bool isComplex(Value value) { return value.getType().isa(); } - bool isFloatingPoint(Value value) { return value.getType().isa(); } - bool isInteger(Value value) { return value.getType().isa(); } - - OpBuilder getBuilder() { - OpBuilder builder(context); - builder.setInsertionPointToEnd(&block); - return builder; + bool isComplex(Value value) { + return llvm::isa(value.getType()); + } + bool isFloatingPoint(Value value) { + return llvm::isa(value.getType()); + } + bool isInteger(Value value) { + return llvm::isa(value.getType()); } - MLIRContext *context; + OpBuilder &builder; Block █ }; @@ -610,7 +755,7 @@ void LibdeviceCallOp::build(::mlir::OpBuilder &builder, // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (mlir::isa(initType)) result.addTypes(initType); } @@ -627,11 +772,11 @@ LogicalResult LibdeviceCallOp::verify() { return success(); } LogicalResult ScalarLibdeviceCallOp::verify() { // The inputs of ScalarLibdeviceCallOp should be scalar type. for (auto v : getInputs()) { - if (v.getType().isa()) + if (mlir::isa(v.getType())) return emitOpError() << "expects all input types are scalar type."; } // The result type should be scalar type. - if (getResult().getType().isa()) + if (mlir::isa(getResult().getType())) return emitOpError() << "expects the result type is scalar type."; return success(); } @@ -690,31 +835,32 @@ ArrayAttr BatchConv2DNhwcFhwcOp::getIndexingMaps() { MLIRContext *context = getContext(); auto symbolBindings = getBatchConv2DSymbolBindings(*this); SmallVector maps; - maps.push_back(mlir::parseAttribute( - "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, " - "s3, s4, s5, s6, s7, s8, s9, s10, s11] -> (d0, d1, d2 * " - "s3 + d5 * s5, d3 * s7 + d6 * s9, d7)>", - context) - .cast() - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 8, 0)); maps.push_back( - mlir::parseAttribute( - "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, " - "s6, s7, s8, s9, s10, s11] -> (d0, d4, d5, d6, d7)>", - context) - .cast() + mlir::cast( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, " + "s3, s4, s5, s6, s7, s8, s9, s10, s11] -> (d0, d1, d2 * " + "s3 + d5 * s5, d3 * s7 + d6 * s9, d7)>", + context)) .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 8, 0)); - maps.push_back( - mlir::parseAttribute( - "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, " - "s6, s7, s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>", - context) - .cast() - .getValue()); + maps.push_back(mlir::cast( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, " + "s2, s3, s4, s5, " + "s6, s7, s8, s9, s10, s11] -> (d0, d4, d5, d6, d7)>", + context)) + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 8, 0)); + maps.push_back(mlir::cast( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, " + "s2, s3, s4, s5, " + "s6, s7, s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>", + context)) + .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 8, 0)); cached = Builder(context).getAffineMapArrayAttr(maps); @@ -749,7 +895,7 @@ void BatchConv2DNhwcFhwcOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { assert(3 > 0 && block.getNumArguments() == 3 && "BatchConv2DNhwcFhwcOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + RegionBuilderHelper helper(b, block); SmallVector yields; Value value1 = @@ -834,7 +980,7 @@ void MakeRangeOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { assert(block.getNumArguments() == 3 && "MakeRangeOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + RegionBuilderHelper helper(b, block); SmallVector yields; Value zero = helper.index(0); Value value0 = @@ -902,8 +1048,9 @@ LogicalResult MakeRangeOp::verify() { if (auto start = getStart().getDefiningOp()) { if (auto end = getEnd().getDefiningOp()) { - int64_t startConstantInt = start.getValue().cast().getInt(); - int64_t endConstantInt = end.getValue().cast().getInt(); + int64_t startConstantInt = + mlir::cast(start.getValue()).getInt(); + int64_t endConstantInt = mlir::cast(end.getValue()).getInt(); if (endConstantInt <= startConstantInt) return emitOpError() << "input argument end must greater than input arguments start " @@ -962,19 +1109,19 @@ ArrayAttr Im2ColOp::getIndexingMaps() { auto symbolBindings = getSymbolBindings(*this); SmallVector maps; maps.push_back( - mlir::parseAttribute( - "affine_map<(d0, d1, d2, d4, d5, d6)[s0, s1, s2, s3, s5, s6, s7, " - "s9] -> (d0, d1 * s2 + d4, d2 * s6 + d5, d6)>", - context) - .cast() + mlir::cast( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d4, d5, d6)[s0, s1, s2, s3, s5, s6, s7, " + "s9] -> (d0, d1 * s2 + d4, d2 * s6 + d5, d6)>", + context)) .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 6, 0)); - maps.push_back(mlir::parseAttribute( - "affine_map<(d0, d1, d2, d4, d5, d6)[s0, s1, s2, s3, " - "s5, s6, s7, s9] -> (d0, d1, d2, d4, d5, d6)>", - context) - .cast() + maps.push_back(mlir::cast( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d4, d5, d6)[s0, s1, s2, s3, " + "s5, s6, s7, s9] -> (d0, d1, d2, d4, d5, d6)>", + context)) .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 6, 0)); @@ -1009,7 +1156,7 @@ void Im2ColOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { assert(2 > 0 && block.getNumArguments() == 2 && "Im2ColOp regionBuilder expects 2 (>=0) args"); - RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + RegionBuilderHelper helper(b, block); SmallVector yields; Value value1 = @@ -1085,7 +1232,7 @@ void ScatterOp::build( // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (mlir::isa(initType)) result.addTypes(initType); if (bodyBuild) { @@ -1215,8 +1362,27 @@ LogicalResult ScatterOp::fold(FoldAdaptor, SmallVectorImpl &) { void ScatterOp::getEffects( SmallVectorImpl> &effects) { - getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), - getDpsInits()); + if (!hasPureBufferSemantics()) + return; + + if (mask()) { + effects.emplace_back(MemoryEffects::Read::get(), mask(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), update(), /*stage=*/0, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + } else { + effects.emplace_back(MemoryEffects::Read::get(), update(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), indice(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), getInit(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// @@ -1244,7 +1410,7 @@ void ScanOp::build( // Add output types for `RankedTensorType` output arguments. for (Value init : inits) { Type initType = init.getType(); - if (initType.isa()) + if (mlir::isa(initType)) result.addTypes(initType); } @@ -1258,6 +1424,14 @@ void ScanOp::getEffects( &effects) { getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), getDpsInits()); + + for (auto operand : getDpsInits()) { + if (!llvm::isa(operand.getType())) + continue; + effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } } LogicalResult ScanOp::fold(FoldAdaptor, SmallVectorImpl &) { @@ -1273,8 +1447,8 @@ LogicalResult ScanOp::verify() { << "num inputs is " << getNumDpsInputs() << "."; } - if (getInits()[0].getType().cast().getShape() != - getInputs()[0].getType().cast().getShape()) { + if (mlir::cast(getInits()[0].getType()).getShape() != + mlir::cast(getInputs()[0].getType()).getShape()) { return emitOpError() << "expects inputs and outputs have the same shapes. " "Shape at input-index 0 is not equal to" " the shape at output-index 0."; @@ -1285,7 +1459,7 @@ LogicalResult ScanOp::verify() { } int64_t dimension = getDimensions()[0]; - auto inputType = getInputs()[0].getType().cast(); + auto inputType = mlir::cast(getInputs()[0].getType()); if (dimension < 0 || dimension >= inputType.getRank()) { return emitOpError() << "dimension for scan should be in the range [0, " << inputType.getRank() - 1 << "]."; @@ -1298,8 +1472,8 @@ LogicalResult ScanOp::verify() { } for (int64_t i = 1; i < numInputs; ++i) { - if (getInputs()[i].getType().cast().getShape() != - getInputs()[0].getType().cast().getShape()) { + if (mlir::cast(getInputs()[i].getType()).getShape() != + mlir::cast(getInputs()[0].getType()).getShape()) { return emitOpError() << "expects all inputs have the same shapes. " "Shape at input-index " << i @@ -1309,15 +1483,16 @@ LogicalResult ScanOp::verify() { auto numOutputs = getNumDpsInits() / 2; for (int64_t i = 1; i < numOutputs; ++i) { - if (getInits()[i].getType().cast().getShape() != - getInits()[0].getType().cast().getShape()) { + if (mlir::cast(getInits()[i].getType()).getShape() != + mlir::cast(getInits()[0].getType()).getShape()) { return emitOpError() << "expects all outputs have the same shapes. " "Shape at output-index " << i << " is not equal to the shape at output-index 0."; } - if (getInits()[i + numOutputs].getType().cast().getShape() != - getInits()[numOutputs].getType().cast().getShape()) { + if (mlir::cast(getInits()[i + numOutputs].getType()) + .getShape() != + mlir::cast(getInits()[numOutputs].getType()).getShape()) { return emitOpError() << "expects all inits have the same shapes. " "Shape at init-index " << i + numOutputs @@ -1327,7 +1502,7 @@ LogicalResult ScanOp::verify() { } if (expectedInitShape != - getInits()[numInputs].getType().cast().getShape()) { + mlir::cast(getInits()[numOutputs].getType()).getShape()) { return emitOpError() << "inits shape is not equal to the expected shape " << expectedInitShape << "."; } @@ -1339,7 +1514,8 @@ LogicalResult ScanOp::verify() { // Check that the first block arguments match the element type of the inputs. for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) { - Type inputElementType = input.getType().cast().getElementType(); + Type inputElementType = + mlir::cast(input.getType()).getElementType(); if (inputElementType != bbArg.getType()) return emitOpError() << "input element type " << inputElementType @@ -1351,7 +1527,7 @@ LogicalResult ScanOp::verify() { for (auto [output, bbArg] : llvm::zip( getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) { Type outputElementType = - output.getType().cast().getElementType(); + mlir::cast(output.getType()).getElementType(); if (outputElementType != bbArg.getType()) return emitOpError() << "output element type " << outputElementType @@ -1376,7 +1552,7 @@ void GatherOp::build( // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (mlir::isa(initType)) result.addTypes(initType); if (bodyBuild) { @@ -1505,8 +1681,27 @@ LogicalResult GatherOp::fold(FoldAdaptor, SmallVectorImpl &) { void GatherOp::getEffects( SmallVectorImpl> &effects) { - getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), - getDpsInits()); + if (!hasPureBufferSemantics()) + return; + + effects.emplace_back(MemoryEffects::Read::get(), input(), /*stage=*/0, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), indice(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + if (mask()) { + effects.emplace_back(MemoryEffects::Read::get(), mask(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), getInit(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + } else { + effects.emplace_back(MemoryEffects::Write::get(), getInit(), /*stage=*/1, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } } //===----------------------------------------------------------------------===// @@ -1540,20 +1735,28 @@ LogicalResult AtomicRMWOp::fold(FoldAdaptor, SmallVectorImpl &) { void AtomicRMWOp::getEffects( SmallVectorImpl> &effects) { - for (auto *operand : getDpsInputOperands()) { - if (!operand->get().getType().isa()) - continue; - effects.emplace_back(MemoryEffects::Read::get(), operand->get(), + // FIXME: When atomic ops support memref input, we should remove the effects + // of tensor. + if (!hasPureBufferSemantics()) { + effects.emplace_back(MemoryEffects::Read::get(), src(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), src(), SideEffects::DefaultResource::get()); + return; } - effects.emplace_back(MemoryEffects::Read::get(), src(), + + effects.emplace_back(MemoryEffects::Read::get(), input(), /*stage=*/0, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), src(), + effects.emplace_back(MemoryEffects::Read::get(), src(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), src(), /*stage=*/1, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), dst(), /*stage=*/1, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); - if (dst().getType().isa()) { - effects.emplace_back(MemoryEffects::Write::get(), dst(), - SideEffects::DefaultResource::get()); - } } //===----------------------------------------------------------------------===// @@ -1595,18 +1798,47 @@ LogicalResult GatherAtomicRMWOp::fold(FoldAdaptor, void GatherAtomicRMWOp::getEffects( SmallVectorImpl> &effects) { - for (auto *operand : getDpsInputOperands()) { - if (!operand->get().getType().isa()) - continue; - effects.emplace_back(MemoryEffects::Read::get(), operand->get(), + // FIXME: When atomic ops support memref input, we should remove the effects + // of tensor. + if (!hasPureBufferSemantics()) { + effects.emplace_back(MemoryEffects::Read::get(), src(), SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), src(), + SideEffects::DefaultResource::get()); + return; } - effects.emplace_back(MemoryEffects::Read::get(), src(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), src(), + + effects.emplace_back(MemoryEffects::Read::get(), indice(), /*stage=*/0, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); - if (window().getType().isa()) { - effects.emplace_back(MemoryEffects::Write::get(), window(), + if (mask()) { + effects.emplace_back(MemoryEffects::Read::get(), mask(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), input(), /*stage=*/0, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), src(), /*stage=*/0, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), src(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), window(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + } else { + effects.emplace_back(MemoryEffects::Read::get(), input(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), src(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), src(), /*stage=*/1, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), window(), /*stage=*/1, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } } @@ -1621,14 +1853,31 @@ LogicalResult AtomicCASOp::fold(FoldAdaptor, SmallVectorImpl &) { void AtomicCASOp::getEffects( SmallVectorImpl> &effects) { + // FIXME: When atomic ops support memref input, we should remove the effects + // of tensor. if (!hasPureBufferSemantics()) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), input(), SideEffects::DefaultResource::get()); + return; } - getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), - getDpsInits()); + + effects.emplace_back(MemoryEffects::Read::get(), input(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), input(), /*stage=*/1, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), cmp(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), val(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), getInit(), /*stage=*/1, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// @@ -1642,16 +1891,343 @@ LogicalResult GatherAtomicCASOp::fold(FoldAdaptor, void GatherAtomicCASOp::getEffects( SmallVectorImpl> &effects) { + // FIXME: When atomic ops support memref input, we should remove the effects + // of tensor. if (!hasPureBufferSemantics()) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), input(), SideEffects::DefaultResource::get()); + return; } - getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), - getDpsInits()); + + effects.emplace_back(MemoryEffects::Read::get(), input(), /*stage=*/0, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), input(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), cmp(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), val(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), indice(), /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), getInit(), /*stage=*/1, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); } +//===----------------------------------------------------------------------===// +// BEGIN refers to linalg.reduce +//===----------------------------------------------------------------------===// + +template static LogicalResult verifyArgMaxMinOp(T op) { + ArrayRef dimensionsRef = op.getDimensions(); + + for (int64_t i = 1; i < op.getNumDpsInputs(); ++i) { + if (llvm::cast(op.getInputs()[i].getType()).getShape() != + llvm::cast(op.getInputs()[0].getType()).getShape()) { + return op->emitOpError() + << "expects all inputs to have the same shapes. " + "Shape at input-index " + << i << " is not equal to the shape at input-index 0."; + } + } + for (int64_t i = 1; i < op.getNumDpsInits(); ++i) { + if (llvm::cast(op.getInits()[i].getType()).getShape() != + llvm::cast(op.getInits()[0].getType()).getShape()) { + return op->emitOpError() + << "expects all outputs to have the same shapes. " + "Shape at output-index " + << i << " is not equal to the shape at output-index 0."; + } + } + auto inputType = llvm::cast(op.getInputs()[0].getType()); + auto initType = llvm::cast(op.getInits()[0].getType()); + + DenseSet dimensionsToReduce; + for (int64_t dimension : dimensionsRef) { + if (dimension < 0 || dimension >= inputType.getRank()) { + return op->emitOpError() + << "dimensions for reduction should be in the range [0, " + << inputType.getRank() - 1 << "]."; + } + dimensionsToReduce.insert(dimension); + } + + auto inputDims = inputType.getShape(); + auto initDims = initType.getShape(); + + // Input dimensions that will be left after the reduction. + SmallVector reducedInputDims; + for (const auto &en : llvm::enumerate(inputDims)) { + if (!dimensionsToReduce.count(en.index())) + reducedInputDims.push_back(en.value()); + } + + if (reducedInputDims.size() != static_cast(initType.getRank())) { + return op->emitOpError() + << "number of dimensions after reduction " << reducedInputDims.size() + << " doesn't match the init rank " << initType.getRank(); + } + + if (reducedInputDims != initDims) + return op->emitOpError() + << "init dimensions [" << initDims + << "] doesn't match input dimensions after reduction [" + << reducedInputDims << "]"; + + Block *block = op.getBody(); + if (block->getNumArguments() != op->getNumOperands()) + return op->emitOpError() + << "mismatching number of operands and block arguments"; + + // Check that the first block arguments match the element type of the inputs. + for (auto [input, bbArg] : llvm::zip(op.getInputs(), block->getArguments())) { + Type inputElementType = + llvm::cast(input.getType()).getElementType(); + if (inputElementType != bbArg.getType()) + return op->emitOpError() + << "input element type " << inputElementType + << " does not match corresponding block argument type " + << bbArg.getType(); + } + + // Check that the last block arguments match the element type of the outputs. + for (auto [output, bbArg] : + llvm::zip(op.getDpsInits(), + block->getArguments().take_back(op.getNumDpsInits()))) { + auto outputElementType = + llvm::cast(output.getType()).getElementType(); + if (outputElementType != bbArg.getType()) + return op->emitOpError() + << "output element type " << outputElementType + << " does not match corresponding block argument type " + << bbArg.getType(); + } + return success(); +} + +static ParseResult parseArgMaxMin(OpAsmParser &parser, OperationState &result) { + std::optional payloadOpName; + NamedAttrList payloadOpAttrs; + if (succeeded(parser.parseOptionalLBrace())) { + FailureOr operationName = parser.parseCustomOperationName(); + if (failed(operationName)) + return failure(); + if (parser.parseOptionalAttrDict(payloadOpAttrs)) + return failure(); + payloadOpName = operationName.value(); + if (parser.parseRBrace()) + return failure(); + } + + if (parseDstStyleOp( + parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); + })) + return failure(); + + if (payloadOpName.has_value()) { + addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, + ArrayRef(result.operands), /*initFirst=*/true); + } else { + SmallVector regionArgs; + if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) { + return failure(); + } + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); + } + + return success(); +} + +template static void printArgMaxMin(OpAsmPrinter &p, T op) { + Block *mapper = op.getBody(); + Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); + if (payloadOp) { + printShortFormReduce(p, payloadOp); + } + + printCommonStructuredOpParts(p, op.getDpsInputs(), op.getDpsInits()); + printDenseI64ArrayAttr(p, op.getDimensionsAttrName(), op.getDimensions()); + p.printOptionalAttrDict(op->getAttrs(), {op.getDimensionsAttrName()}); + if (!payloadOp) { + // Print region if the payload op was not detected. + p.increaseIndent(); + p.printNewline(); + p << "("; + llvm::interleaveComma(mapper->getArguments(), p, + [&](auto arg) { p.printRegionArgument(arg); }); + p << ") "; + + p.printRegion(op.getCombiner(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); + } +} + +//===----------------------------------------------------------------------===// +// ArgMaxOp +//===----------------------------------------------------------------------===// + +void ArgMaxOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange inits, ArrayRef dimensions, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, TypeRange{}, inputs, inits, dimensions); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + for (Value init : inits) { + Type initType = init.getType(); + if (mlir::isa(initType)) + result.addTypes(initType); + } + + if (bodyBuild) + buildGenericRegion(builder, result.location, *result.regions.front(), + inputs, inits, bodyBuild); +} + +void ArgMaxOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + RegionBuilderHelper helper(b, block); + assert(block.getNumArguments() == 4 && + "ArgMaxOp regionBuilder expects 4 (>=0) args"); + SmallVector yields; + auto loc = block.getArgument(0).getLoc(); + Value cmpfResOeq = + b.create(loc, arith::CmpFPredicate::OEQ, + block.getArgument(0), block.getArgument(2)); + Value cmpiRes = + b.create(loc, arith::CmpIPredicate::slt, + block.getArgument(1), block.getArgument(3)); + Value andiRes = b.create(loc, cmpfResOeq, cmpiRes); + Value cmpfResOgt = + b.create(loc, arith::CmpFPredicate::OGT, + block.getArgument(0), block.getArgument(2)); + Value oriRes = b.create(loc, cmpfResOgt, andiRes); + Value selectRes1 = b.create( + loc, oriRes, block.getArgument(0), block.getArgument(2)); + Value selectRes2 = b.create( + loc, oriRes, block.getArgument(1), block.getArgument(3)); + yields.push_back(selectRes1); + yields.push_back(selectRes2); + helper.yieldOutputs(yields); +} + +void ArgMaxOp::getEffects( + SmallVectorImpl> + &effects) { + getGenericEffectsImpl(effects, cast(getOperation())); +} + +ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseArgMaxMin(parser, result); +} + +void ArgMaxOp::getAsmResultNames( + function_ref setNameFn) { + if (!getResults().empty()) + setNameFn(getResults().front(), "argmax"); +} + +void ArgMaxOp::print(OpAsmPrinter &p) { printArgMaxMin(p, *this); } + +LogicalResult ArgMaxOp::verify() { return verifyArgMaxMinOp(*this); } + +LogicalResult ArgMaxOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// ArgMinOp +//===----------------------------------------------------------------------===// + +void ArgMinOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange inits, ArrayRef dimensions, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, TypeRange{}, inputs, inits, dimensions); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + for (Value init : inits) { + Type initType = init.getType(); + if (mlir::isa(initType)) + result.addTypes(initType); + } + + if (bodyBuild) + buildGenericRegion(builder, result.location, *result.regions.front(), + inputs, inits, bodyBuild); +} + +void ArgMinOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + RegionBuilderHelper helper(b, block); + assert(block.getNumArguments() == 4 && + "ArgMinOp regionBuilder expects 4 (>=0) args"); + SmallVector yields; + auto loc = block.getArgument(0).getLoc(); + Value cmpfResOeq = + b.create(loc, arith::CmpFPredicate::OEQ, + block.getArgument(0), block.getArgument(2)); + Value cmpiRes = + b.create(loc, arith::CmpIPredicate::slt, + block.getArgument(1), block.getArgument(3)); + Value andiRes = b.create(loc, cmpfResOeq, cmpiRes); + Value cmpfResOgt = + b.create(loc, arith::CmpFPredicate::OLT, + block.getArgument(0), block.getArgument(2)); + Value oriRes = b.create(loc, cmpfResOgt, andiRes); + Value selectRes1 = b.create( + loc, oriRes, block.getArgument(0), block.getArgument(2)); + Value selectRes2 = b.create( + loc, oriRes, block.getArgument(1), block.getArgument(3)); + yields.push_back(selectRes1); + yields.push_back(selectRes2); + helper.yieldOutputs(yields); +} + +void ArgMinOp::getEffects( + SmallVectorImpl> + &effects) { + getGenericEffectsImpl(effects, cast(getOperation())); +} + +ParseResult ArgMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseArgMaxMin(parser, result); +} + +void ArgMinOp::getAsmResultNames( + function_ref setNameFn) { + if (!getResults().empty()) + setNameFn(getResults().front(), "argmin"); +} + +void ArgMinOp::print(OpAsmPrinter &p) { printArgMaxMin(p, *this); } + +LogicalResult ArgMinOp::verify() { return verifyArgMaxMinOp(*this); } + +LogicalResult ArgMinOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// END refers to linalg.reduce +//===----------------------------------------------------------------------===// + //===----------------------------------------------------------------------===// // Implementation of PadOp //===----------------------------------------------------------------------===// @@ -1807,7 +2383,7 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { auto newSliceOp = rewriter.create( loc, outerSliceOp.getSource(), newOffsets, newSizes, innerSliceOp.getMixedStrides()); - auto resultTy = padOp->getResultTypes().front().cast(); + auto resultTy = mlir::cast(padOp->getResultTypes().front()); llvm::ArrayRef staticShapes = resultTy.getShape(); SmallVector dynamicShapes; SmallVector dynShape; @@ -1819,8 +2395,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { rewriter.create(loc, padOp.input(), cstIndex); auto lowV = low[i]; Value lowValue; - if (auto attr = lowV.dyn_cast()) { - if (auto intAttr = attr.dyn_cast_or_null()) { + if (auto attr = mlir::dyn_cast(lowV)) { + if (auto intAttr = mlir::dyn_cast_or_null(attr)) { lowValue = rewriter.create( loc, intAttr.getValue().getSExtValue()); } @@ -1829,8 +2405,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { } auto highV = newHighPad[i]; Value highValue; - if (auto attr = highV.dyn_cast()) { - if (auto intAttr = attr.dyn_cast_or_null()) { + if (auto attr = mlir::dyn_cast(highV)) { + if (auto intAttr = mlir::dyn_cast_or_null(attr)) { highValue = rewriter.create( loc, intAttr.getValue().getSExtValue()); } @@ -1993,7 +2569,7 @@ void PadOp::build(OpBuilder &builder, OperationState &result, Value input, dispatchIndexOpFoldResults(lows, dynamicLows, staticLows); dispatchIndexOpFoldResults(highs, dynamicHighs, staticHighs); auto resultType = init.getType(); - assert(resultType.isa()); + assert(mlir::isa(resultType)); result.addOperands(input); result.addOperands(init); result.addOperands(pvalue); @@ -2074,10 +2650,12 @@ LogicalResult PadOp::verify() { "expected same type of padding value and input elements"); } RankedTensorType sourceType, resultType; - if (inputType.isa() && initType.isa()) { - sourceType = inputType.cast(); - resultType = initType.cast(); - } else if (inputType.isa() && initType.isa()) { + if (mlir::isa(inputType) && + mlir::isa(initType)) { + sourceType = mlir::cast(inputType); + resultType = mlir::cast(initType); + } else if (mlir::isa(initType) && + mlir::isa(inputType)) { ArrayRef inputShape = inputType.getShape(); sourceType = RankedTensorType::get(inputShape, inputType.getElementType()); ArrayRef initShape = initType.getShape(); @@ -2099,6 +2677,74 @@ LogicalResult PadOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Implementation of AssertOp +//===----------------------------------------------------------------------===// +void AssertOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + + for (auto operand : getDpsInputs()) { + if (!llvm::isa(operand.getType())) + continue; + effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); + } +} + +//===----------------------------------------------------------------------===// +// Implementation of HistogramOp +//===----------------------------------------------------------------------===// + +LogicalResult HistogramOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void HistogramOp::getEffects( + SmallVectorImpl> + &effects) { + getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), + getDpsInits()); +} + +LogicalResult HistogramOp::verify() { + if (getNumDpsInputs() != 1) { + return emitOpError("only supports 1 input operand!"); + } + + auto inputType = mlir::cast(getSrc()[0].getType()); + if (inputType.getRank() != 1) { + return emitOpError("only supports 1D input!"); + } + if (!mlir::isa((inputType.getElementType()))) { + return emitOpError("only supports integer input!"); + } + + if (getInitType().getRank() != 1) { + return emitOpError("only supports 1D output!"); + } + if (!mlir::isa(getInitType().getElementType())) { + return emitOpError("only supports integer output!"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ScalarAssertOp +//===----------------------------------------------------------------------===// +void ScalarAssertOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), /*stage=*/1, + /*effectOnFullRegion=*/false, + SideEffects::DefaultResource::get()); +} + /////// Operations corresponding to library calls defined with Tablegen //////// #include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtNamedStructuredOps.yamlgen.cpp.inc" diff --git a/lib/Dialect/LinalgExt/Transforms/LinalgExtOpTilingInterface.cpp b/lib/Dialect/LinalgExt/Transforms/LinalgExtOpTilingInterface.cpp index 8e17e15..ef2c111 100644 --- a/lib/Dialect/LinalgExt/Transforms/LinalgExtOpTilingInterface.cpp +++ b/lib/Dialect/LinalgExt/Transforms/LinalgExtOpTilingInterface.cpp @@ -77,7 +77,7 @@ static Value getSimpliedSlice(OpBuilder &b, Location loc, Value source, static Value addBroadcast(OpBuilder &builder, Location loc, Value input, ArrayRef shapeOperands, ArrayRef broadcastDim) { - ShapedType shapeTy = input.getType().cast(); + ShapedType shapeTy = cast(input.getType()); Value init = builder.create(loc, shapeOperands, shapeTy.getElementType()); return builder.create(loc, input, init, broadcastDim) @@ -89,7 +89,7 @@ using RegionFn = function_ref; static Value addMap(OpBuilder &builder, Location loc, Value lhs, Value rhs, RegionFn regionFn) { auto shapeOperands = getDims(builder, loc, lhs); - ShapedType shapeTy = lhs.getType().cast(); + ShapedType shapeTy = cast(lhs.getType()); Value init = builder.create(loc, shapeOperands, shapeTy.getElementType()); return builder @@ -100,7 +100,7 @@ static Value addMap(OpBuilder &builder, Location loc, Value lhs, Value rhs, /// Padded tiled indice to the paddedLength. static Value padTiledIndice(OpBuilder &builder, Location loc, Value input, int64_t paddedLength) { - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type eleTy = inputTy.getElementType(); Value zero = builder.create( loc, eleTy, builder.getIntegerAttr(eleTy, 0)); @@ -183,13 +183,13 @@ tileByWindowSlice(OpBuilder &b, Location loc, Value data, Value window, auto oneAttr = b.getI64IntegerAttr(1); // Slice of init. - auto windowRank = window.getType().cast().getRank(); + auto windowRank = cast(window.getType()).getRank(); SmallVector windowStrides(windowRank, oneAttr); Value tiledWindow = getSimpliedSlice(b, loc, window, windowOffsets, windowSizes, windowStrides); assert(tiledWindow && "failed to get slice of window"); // Slice of indices. - auto indicesTy = indices.getType().cast(); + auto indicesTy = cast(indices.getType()); auto indicesRank = indicesTy.getRank(); SmallVector indicesOffsets(indicesRank, zeroAttr); SmallVector indicesSizes(indicesRank); @@ -217,7 +217,7 @@ tileByWindowSlice(OpBuilder &b, Location loc, Value data, Value window, assert(tiledMask && "failed to get slice of mask"); } // Update indice value using update tiling offset. - auto dataRank = data.getType().cast().getRank(); + auto dataRank = cast(data.getType()).getRank(); ArrayRef curOffsetArray(windowOffsets.begin() + batchNum, windowOffsets.end()); bool hasNonZeroVal = llvm::any_of(curOffsetArray, [](OpFoldResult val) { @@ -722,10 +722,10 @@ struct LinalgExtOpTilingInterface Value currVal = genericOp.getCurrentValue(); Value cmp; - if (cmpVal.getType().isa()) { + if (isa(cmpVal.getType())) { cmp = bodyBuilder.create(loc, arith::CmpIPredicate::eq, currVal, cmpVal); - } else if (cmpVal.getType().isa()) { + } else if (isa(cmpVal.getType())) { cmp = bodyBuilder.create( loc, arith::CmpFPredicate::OEQ, currVal, cmpVal); } else { @@ -853,10 +853,10 @@ struct LinalgExtOpTilingInterface Value currVal = genericOp.getCurrentValue(); Value cmp; - if (cmpVal.getType().isa()) { + if (isa(cmpVal.getType())) { cmp = bodyBuilder.create(loc, arith::CmpIPredicate::eq, currVal, cmpVal); - } else if (cmpVal.getType().isa()) { + } else if (isa(cmpVal.getType())) { cmp = bodyBuilder.create( loc, arith::CmpFPredicate::OEQ, currVal, cmpVal); } else { @@ -926,7 +926,7 @@ struct LinalgExtOpTilingInterface tiledInits.push_back(atomicRMWOp.src()); // Slice input and dst. Value input = atomicRMWOp.input(); - auto inputRank = input.getType().cast().getRank(); + auto inputRank = cast(input.getType()).getRank(); SmallVector inputStrides(inputRank, oneAttr); Value tiledInput = getSimpliedSlice(b, loc, input, offsets, sizes, inputStrides); @@ -957,8 +957,21 @@ struct LinalgExtOpTilingInterface triton::linalg_ext::AtomicRMWOp atomicRMWOp = cast(op); // Result 0 is src, we keep it unchanged. + if (resultNumber == 0) { - return failure(); + auto zeroAttr = b.getI64IntegerAttr(0); + auto initRank = atomicRMWOp.getSrcType().getRank(); + auto initShape = atomicRMWOp.getSrcType().getShape(); + for (unsigned r = 0; r < initRank; ++r) { + if (!isNoTile(sizes[r], offsets[r], initShape, r)) { + return failure(); + } + } + resultOffsets.clear(); + resultOffsets.append(offsets.begin(), offsets.end()); + resultSizes.clear(); + resultSizes.append(sizes.begin(), sizes.end()); + return success(); } resultOffsets.assign(offsets.begin(), offsets.end()); resultSizes.assign(sizes.begin(), sizes.end()); @@ -1065,7 +1078,7 @@ struct LinalgExtOpTilingInterface tiledInits.push_back(atomicRMWOp.src()); // Slice input and window batch. Value input = atomicRMWOp.input(); - auto inputRank = input.getType().cast().getRank(); + auto inputRank = cast(input.getType()).getRank(); SmallVector inputStrides(inputRank, oneAttr); Value tiledInput = getSimpliedSlice(b, loc, input, offsets, sizes, inputStrides); @@ -1078,7 +1091,7 @@ struct LinalgExtOpTilingInterface tiledInits.push_back(tiledWindow); // Slice indice. auto indice = atomicRMWOp.indice(); - auto indiceRank = indice.getType().cast().getRank(); + auto indiceRank = cast(indice.getType()).getRank(); SmallVector indiceOffsets(indiceRank, zeroAttr); SmallVector indiceSizes(indiceRank); SmallVector indiceStrides(indiceRank, oneAttr); @@ -1398,6 +1411,19 @@ struct LinalgExtOpTilingInterface highs.push_back(getAsOpFoldResult(originHighIndex)); continue; } + + // If the tiled dimension has no pad, the 'low' and 'high' should be + // constant zero, the 'offset' and the 'size' of src should be the same + // with the dst. + if (matchPattern(originLowIndex, m_Zero()) && + matchPattern(originHighIndex, m_Zero())) { + inputOffsets.push_back(initOffsets[r]); + inputSizes.push_back(initSizes[r]); + lows.push_back(b.getIndexAttr(0)); + highs.push_back(b.getIndexAttr(0)); + continue; + } + Value inputDimSize = getDimValue(b, loc, input, r); Value srcStart = originLowIndex; Value srcEnd = b.create(loc, srcStart, inputDimSize); @@ -1479,27 +1505,42 @@ struct LinalgExtOpTilingInterface curSrcSize = b.create(loc, curSrcSize, tileSize); inputSizes.push_back(getAsOpFoldResult(curSrcSize)); - Value low1 = b.create(loc, srcStart, curDstStart); - // This is case 2, 3: (curDstStart < srcStart) && (curDstEnd >= srcStart). - Value cond = b.create(loc, curDstStartBeforeSrc, - curDstEndGESrcStart); - // curLow = case2-3? srcStart - curDstStart : 0. - curLow = b.create(loc, cond, low1, zeroIndex); - lows.push_back(getAsOpFoldResult(curLow)); - - Value high1 = b.create(loc, curDstEnd, srcEnd); - // curHigh = case1 ? size : 0. - curHigh = - b.create(loc, curDstBeforeSrc, tileSize, zeroIndex); - // curHigh = case2 ? curDstEnd - srcEnd : curHigh. - curHigh = b.create(loc, srcInCurDst, high1, curHigh); - // curHigh = case4 ? size : curHigh. - curHigh = b.create(loc, curDstStartAfterSrc, tileSize, - curHigh); - // curHigh = case6 ? curDstEnd - srcEnd : curHigh. - curHigh = b.create( - loc, curDstStartInSrcAndcurDstEndAfterSrc, high1, curHigh); - highs.push_back(getAsOpFoldResult(curHigh)); + if (matchPattern(originLowIndex, m_Zero())) { + lows.push_back(b.getIndexAttr(0)); + } else { + Value low1 = b.create(loc, srcStart, curDstStart); + // This is case 2, 3: (curDstStart < srcStart) && (curDstEnd >= + // srcStart). + Value cond = b.create(loc, curDstStartBeforeSrc, + curDstEndGESrcStart); + // curLow = case2-3? srcStart - curDstStart : 0. + curLow = b.create(loc, cond, low1, zeroIndex); + // curLow = min (curLow, originLowIndex). + curLow = b.create(loc, curLow, originLowIndex); + lows.push_back(getAsOpFoldResult(curLow)); + } + + if (matchPattern(originHighIndex, m_Zero())) { + highs.push_back(b.getIndexAttr(0)); + } else { + Value high1 = b.create(loc, curDstEnd, srcEnd); + // curHigh = case1 ? size : 0. + curHigh = b.create(loc, curDstBeforeSrc, tileSize, + zeroIndex); + // curHigh = case2 ? curDstEnd - srcEnd : curHigh. + curHigh = b.create(loc, srcInCurDst, high1, curHigh); + // curHigh = case4 ? size : curHigh. + curHigh = b.create(loc, curDstStartAfterSrc, tileSize, + curHigh); + // curHigh = case6 ? curDstEnd - srcEnd : curHigh. + curHigh = b.create( + loc, curDstStartInSrcAndcurDstEndAfterSrc, high1, curHigh); + Value maxPad = + b.create(loc, originHighIndex, originLowIndex); + // curHigh = min (curHigh, maxPad). + curHigh = b.create(loc, curHigh, maxPad); + highs.push_back(getAsOpFoldResult(curHigh)); + } } Value tiledInput = getSlice(b, loc, input, inputOffsets, inputSizes, inputStrides); @@ -1699,7 +1740,7 @@ struct LinalgExtOpTilingInterface SmallVector strides(rank, oneAttr); SmallVector inputsSlice; for (auto input : libdeviceCallOp.inputs()) { - if (input.getType().isa()) { + if (isa(input.getType())) { auto inputSlice = b.create( loc, input, inputOffsets, inputSizes, strides); inputsSlice.push_back(inputSlice); @@ -1901,7 +1942,6 @@ struct LinalgExtOpTilingInterface if (i != scanDim) accIndices.push_back(indices[i]); } - scanBlkArgs.push_back( b.create(loc, concreteOp.inputs()[0], indices)); scanBlkArgs.push_back( @@ -1926,6 +1966,18 @@ struct LinalgExtOpTilingInterface } }; +template <> +struct LinalgExtOpTilingInterface + : public TilingInterface::ExternalModel< + LinalgExtOpTilingInterface, + triton::linalg_ext::HistogramOp> { + + SmallVector getDestinationOperands(Operation *op, + OpBuilder &builder) const { + return llvm::cast(op).getDpsInits(); + } +}; + } // namespace template static void registerOne(MLIRContext *ctx) { @@ -1946,5 +1998,6 @@ void mlir::triton::linalg_ext::registerExtOpTilingInterfaceExternalModels( registerOne(ctx); registerOne(ctx); registerOne(ctx); + registerOne(ctx); }); } diff --git a/lib/Dialect/LinalgExt/Utils/CMakeLists.txt b/lib/Dialect/LinalgExt/Utils/CMakeLists.txt index f760616..384babc 100644 --- a/lib/Dialect/LinalgExt/Utils/CMakeLists.txt +++ b/lib/Dialect/LinalgExt/Utils/CMakeLists.txt @@ -1,6 +1,10 @@ add_triton_library(LinalgExtDialectUtils Utils.cpp + DEPENDS + TritonTableGen + LINK_LIBS PUBLIC MLIRIR + TritonIR ) diff --git a/lib/Dialect/LinalgExt/Utils/Utils.cpp b/lib/Dialect/LinalgExt/Utils/Utils.cpp index dd66c76..c8a7f48 100644 --- a/lib/Dialect/LinalgExt/Utils/Utils.cpp +++ b/lib/Dialect/LinalgExt/Utils/Utils.cpp @@ -11,15 +11,20 @@ #include #include +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "triton-linalg/Dialect/LinalgExt/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -27,6 +32,12 @@ #include "llvm/ADT/iterator.h" using namespace mlir; +using namespace triton; + +namespace mlir { +class Block; +class Operation; +} // namespace mlir Operation *triton::linalg_ext::findPayloadOp(Block *body, bool initFirst) { if (body->getOperations().size() != 2) @@ -56,3 +67,184 @@ Operation *triton::linalg_ext::findPayloadOp(Block *body, bool initFirst) { } return &payload; } + +/// Check whether the reduce op is supported and get the reduction mode +/// if supported. +std::optional triton::getReductionMode(triton::ReduceOp op) { + if (isSingleStatementReduceOpWithType(op) || + isSingleStatementReduceOpWithType(op)) + return ReductionMode::SUM; + + if (isSingleStatementReduceOpWithType( + op) || + isSingleStatementReduceOpWithType( + op) || + isSingleStatementReduceOpWithType(op)) + return ReductionMode::MAX; + + if (isSingleStatementReduceOpWithType(op)) + return ReductionMode::UMAX; + + if (isSingleStatementReduceOpWithType( + op) || + isSingleStatementReduceOpWithType( + op) || + isSingleStatementReduceOpWithType(op)) + return ReductionMode::MIN; + + if (isSingleStatementReduceOpWithType(op)) + return ReductionMode::UMIN; + + if (isSingleStatementReduceOpWithType(op)) + return ReductionMode::PROD; + + if (isSingleStatementReduceOpWithType(op)) + return ReductionMode::AND; + + if (isSingleStatementReduceOpWithType(op)) + return ReductionMode::OR; + + if (isSingleStatementReduceOpWithType(op)) + return ReductionMode::XOR; + // Unsupport reduce op mode. + return std::nullopt; +} + +/// Check whether the reduce op can convert to argmax/min operation. +std::optional triton::matchArgMaxMinPattern(Region *region) { + // We're looking for an op that looks like this: + // + // %9:2 = "tt.reduce"(%8, %3) <{axis = 0 : i32}> ({ + // ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + // ------------------------------------------------- + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 + // %13 = arith.andi %11, %12 : i1 + // ------------------------------------------------- + // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 + // ------------------------------------------------- + // %15 = arith.ori %14, %13 : i1 + // ------------------------------------------------- + // %16 = arith.select %15, %arg9, %arg11 : f32 + // %17 = arith.select %15, %arg10, %arg12 : i32 + // tt.reduce.return %16, %17 : f32, i32 + // }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) + if (region->getNumArguments() != 4) { + return std::nullopt; + } + + Block &block = region->front(); + // There are 8 fixed operations within the argmaxmin region. + if (block.getOperations().size() != 8) { + return std::nullopt; + } + Operation *terminatorOp = block.getTerminator(); + + // %15 = arith.ori %14, %13 : i1 + // %16 = arith.select %15, %arg9, %arg11 : f32 + // linalg.yield %16, %17 : f32, i32 + SmallVector lineOut0; + SmallVector inputIndex0 = {0, 0}; + Operation *result0 = + UpstreamMatcher::matchLine( + lineOut0, terminatorOp, inputIndex0, inputIndex0.size(), false); + if (result0 == nullptr || + cast(lineOut0[1]).getTrueValue() != + block.getArgument(0) || + cast(lineOut0[1]).getFalseValue() != + block.getArgument(2)) { + return std::nullopt; + } + + // %15 = arith.ori %14, %13 : i1 + // %17 = arith.select %15, %arg10, %arg12 : i32 + // linalg.yield %16, %17 : f32, i32 + SmallVector lineOut1; + SmallVector inputIndex1 = {1, 0}; + Operation *result1 = + UpstreamMatcher::matchLine( + lineOut1, terminatorOp, inputIndex1, inputIndex1.size(), false); + if (result1 == nullptr || lineOut1[2] != lineOut0[2] || + cast(lineOut1[1]).getTrueValue() != + block.getArgument(1) || + cast(lineOut1[1]).getFalseValue() != + block.getArgument(3)) { + return std::nullopt; + } + + auto oriOp = lineOut0[2]; + // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 + // %15 = arith.ori %14, %13 : i1 + SmallVector lineOut2; + SmallVector inputIndex2 = {0}; + Operation *result2 = UpstreamMatcher::matchLine( + lineOut2, oriOp, inputIndex2, inputIndex2.size(), false); + if (result2 == nullptr || + cast(lineOut2[1]).getLhs() != block.getArgument(0) || + cast(lineOut2[1]).getRhs() != block.getArgument(2)) { + return std::nullopt; + } + + // %13 = arith.andi %11, %12 : i1 + // %15 = arith.ori %14, %13 : i1 + SmallVector lineOut3; + SmallVector inputIndex3 = {1}; + Operation *result3 = UpstreamMatcher::matchLine( + lineOut3, oriOp, inputIndex3, inputIndex3.size(), false); + if (result3 == nullptr) { + return std::nullopt; + } + + auto andiOp = lineOut3[1]; + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + // %13 = arith.andi %11, %12 : i1 + SmallVector lineOut4; + SmallVector inputIndex4 = {0}; + Operation *result4 = UpstreamMatcher::matchLine( + lineOut4, andiOp, inputIndex4, inputIndex4.size(), false); + if (result4 == nullptr || + cast(lineOut4[1]).getPredicate() != + arith::CmpFPredicate::OEQ || + cast(lineOut4[1]).getLhs() != block.getArgument(0) || + cast(lineOut4[1]).getRhs() != block.getArgument(2)) { + return std::nullopt; + } + + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 + // %13 = arith.andi %11, %12 : i1 + SmallVector lineOut5; + SmallVector inputIndex5 = {1}; + Operation *result5 = UpstreamMatcher::matchLine( + lineOut5, andiOp, inputIndex5, inputIndex5.size(), false); + if (result5 == nullptr || + cast(lineOut5[1]).getPredicate() != + arith::CmpIPredicate::slt || + cast(lineOut5[1]).getLhs() != block.getArgument(1) || + cast(lineOut5[1]).getRhs() != block.getArgument(3)) { + return std::nullopt; + } + + auto cmpfOp = cast(lineOut2[1]); + if (cmpfOp.getPredicate() == arith::CmpFPredicate::OGT) { + return ReductionMode::ARGMAX; + } else if (cmpfOp.getPredicate() == arith::CmpFPredicate::OLT) { + return ReductionMode::ARGMIN; + } + + return std::nullopt; +} + +/// Identify the pattern of the reduce operator. +std::optional +triton::reducePatternRecognition(triton::ReduceOp op) { + auto mode = getReductionMode(op); + if (mode.has_value()) { + return mode; + } + mode = matchArgMaxMinPattern(&op.getRegion()); + if (mode.has_value()) { + return mode; + } + + return std::nullopt; +} diff --git a/lib/Dialect/MathExt/IR/CMakeLists.txt b/lib/Dialect/MathExt/IR/CMakeLists.txt index 7157037..9bd4779 100644 --- a/lib/Dialect/MathExt/IR/CMakeLists.txt +++ b/lib/Dialect/MathExt/IR/CMakeLists.txt @@ -1,6 +1,6 @@ add_triton_library(MathExtDialect - MathOps.cpp - MathDialect.cpp + MathExtOps.cpp + MathExtDialect.cpp DEPENDS MathExtTableGen diff --git a/lib/Dialect/MathExt/IR/MathDialect.cpp b/lib/Dialect/MathExt/IR/MathExtDialect.cpp similarity index 88% rename from lib/Dialect/MathExt/IR/MathDialect.cpp rename to lib/Dialect/MathExt/IR/MathExtDialect.cpp index 850603d..d3a0b11 100644 --- a/lib/Dialect/MathExt/IR/MathDialect.cpp +++ b/lib/Dialect/MathExt/IR/MathExtDialect.cpp @@ -1,4 +1,4 @@ -//===- MathDialect.cpp - MLIR dialect for Math implementation ---*- C++ -*-===// +//===- MathExtDialect.cpp - Dialect for MathExt implementation --*- C++ -*-===// // // Copyright (C) [2022-2025] by Cambricon. // @@ -6,7 +6,7 @@ #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Transforms/InliningUtils.h" -#include "triton-linalg/Dialect/MathExt/IR/Math.h" +#include "triton-linalg/Dialect/MathExt/IR/MathExt.h" using namespace mlir; using namespace mlir::math_ext; diff --git a/lib/Dialect/MathExt/IR/MathOps.cpp b/lib/Dialect/MathExt/IR/MathExtOps.cpp similarity index 92% rename from lib/Dialect/MathExt/IR/MathOps.cpp rename to lib/Dialect/MathExt/IR/MathExtOps.cpp index 64ac1d3..a79d2e8 100644 --- a/lib/Dialect/MathExt/IR/MathOps.cpp +++ b/lib/Dialect/MathExt/IR/MathExtOps.cpp @@ -1,4 +1,4 @@ -//===- MathOps.cpp - MLIR operations for math implementation ----*- C++ -*-===// +//===- MathExtOps.cpp - Operations for MathExt implementation ---*- C++ -*-===// // // Copyright (C) [2022-2025] by Cambricon. // @@ -8,7 +8,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" -#include "triton-linalg/Dialect/MathExt/IR/Math.h" +#include "triton-linalg/Dialect/MathExt/IR/MathExt.h" #include using namespace mlir; diff --git a/lib/Dialect/Triton/CMakeLists.txt b/lib/Dialect/Triton/CMakeLists.txt index 5e95f61..7ef4547 100644 --- a/lib/Dialect/Triton/CMakeLists.txt +++ b/lib/Dialect/Triton/CMakeLists.txt @@ -1,3 +1,3 @@ -add_subdirectory(Interfaces) add_subdirectory(Transforms) +add_subdirectory(Interfaces) add_subdirectory(Utils) diff --git a/lib/Dialect/Triton/Interfaces/InferAxisInfoInterface.cpp b/lib/Dialect/Triton/Interfaces/InferAxisInfoInterface.cpp index 0a24828..2098a78 100644 --- a/lib/Dialect/Triton/Interfaces/InferAxisInfoInterface.cpp +++ b/lib/Dialect/Triton/Interfaces/InferAxisInfoInterface.cpp @@ -32,6 +32,8 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include "triton/Dialect/Triton/IR/Types.h" + #include "triton-linalg/Dialect/Triton/Interfaces/InferAxisInfoInterface.cpp.inc" using namespace mlir; @@ -57,12 +59,16 @@ AxisInfoExt AxisInfoExt::overrideByHint(Operation *op) const { AxisInfoExt AxisInfoExt::getPessimisticValueState(Value value) { auto rank = 1; - if (TensorType ty = value.getType().dyn_cast()) + if (TensorType ty = dyn_cast(value.getType())) rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + AxisInfoExt ret(DimVectorT(rank, kInitValue), DimVectorT(rank, kInitValue), DimVectorT(rank, kStrideValueInitValue)); - BlockArgument blockArg = value.dyn_cast(); + BlockArgument blockArg = dyn_cast(value); if (!blockArg || !blockArg.getOwner()->isEntryBlock()) { return ret; } @@ -82,9 +88,9 @@ AxisInfoExt AxisInfoExt::getPessimisticValueState(Value value) { // Initialize attributes one by one. for (auto [vec, attrName] : retVecs) { Attribute attr = func.getArgAttr(blockArg.getArgNumber(), attrName); - if (auto intAttr = attr.dyn_cast_or_null()) + if (auto intAttr = dyn_cast_or_null(attr)) *vec = AxisInfoExt::DimVectorT(rank, intAttr.getValue().getZExtValue()); - if (auto denseAttr = attr.dyn_cast_or_null()) { + if (auto denseAttr = dyn_cast_or_null(attr)) { auto vals = denseAttr.getValues(); *vec = AxisInfoExt::DimVectorT(vals.begin(), vals.end()); } @@ -121,9 +127,8 @@ AxisInfoExt AxisInfoExt::join(const AxisInfoExt &lhs, const AxisInfoExt &rhs) { DimVectorT stride(lhsRank, kInitValue); DimVectorT strideValue(lhsRank, kStrideValueInitValue); for (auto d = 0; d < lhsRank; ++d) { - divisibility[d] = - leastCommonMultiple(lhs.getDivisibility(d), rhs.getDivisibility(d)); - stride[d] = leastCommonMultiple(lhs.getStride(d), rhs.getStride(d)); + divisibility[d] = std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); + stride[d] = std::gcd(lhs.getStride(d), rhs.getStride(d)); if (lhs.strideValue[d] != kStrideValueInitValue && rhs.strideValue[d] != kStrideValueInitValue && lhs.strideValue[d] == rhs.strideValue[d]) { @@ -163,18 +168,18 @@ triton::overrideAxisInfoByHint(Operation *op, AxisInfoExt::DimVectorT divisibility = knownDivisibility, stride = knownStride, strideValue = knownStrideValue; if (Attribute attr = op->getAttr("tt.divisibility")) { - auto vals = attr.cast().getValues(); + auto vals = cast(attr).getValues(); divisibility = AxisInfoExt::DimVectorT(vals.begin(), vals.end()); } if (Attribute attr = op->getAttr("tt.contiguity")) { - auto vals = attr.cast().getValues(); + auto vals = cast(attr).getValues(); stride = AxisInfoExt::DimVectorT(vals.begin(), vals.end()); strideValue = AxisInfoExt::DimVectorT(vals.size(), 1); } if (Attribute attr = op->getAttr("tt.constancy")) { assert(!op->getAttr("tt.contiguity") && "Get tt.constancy and tt.contiguity attribute at the same op"); - auto vals = attr.cast().getValues(); + auto vals = cast(attr).getValues(); stride = AxisInfoExt::DimVectorT(vals.begin(), vals.end()); strideValue = AxisInfoExt::DimVectorT(vals.size(), 0); } diff --git a/lib/Dialect/Triton/Transforms/CanonicalizeTriton.cpp b/lib/Dialect/Triton/Transforms/CanonicalizeTriton.cpp index 90fdb12..2805f45 100644 --- a/lib/Dialect/Triton/Transforms/CanonicalizeTriton.cpp +++ b/lib/Dialect/Triton/Transforms/CanonicalizeTriton.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "triton-linalg/Dialect/Triton/Transforms/PassDetail.h" // IWYU pragma: keep #include "triton-linalg/Dialect/Triton/Transforms/Passes.h" #include "triton-linalg/Dialect/Triton/Utils/MaskTracker.h" @@ -50,6 +52,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace triton; @@ -58,51 +61,168 @@ namespace mlir { class MLIRContext; } // namespace mlir -bool isAllOneSizeType(ShapedType inputType) { +static bool isAllOneSizeType(ShapedType inputType) { return inputType.getRank() == 0 || llvm::all_of(inputType.getShape(), [](int64_t size) { return size == int64_t(1); }); } -/// Get base mask value before broadcasting. -static std::optional getBaseMaskVal(Value maskVal) { - Value baseMaskVal = nullptr; - if (maskVal) { - // If the mask value is a 0-rank tensor or a tensor with all dimensions of - // one-size, return it. - // For example: - // - %mask : tensor - // - %mask = tensor<1x1xi1> - if (auto maskTy = maskVal.getType().dyn_cast()) { - if (isAllOneSizeType(maskTy)) - return maskVal; +// Extract cond value from mask. +static Value extractCondFromShapeMask(Value mask, PatternRewriter &rewriter) { + if (!mask) + return mask; + auto loc = mask.getLoc(); + if (auto maskValType = dyn_cast(mask.getType())) { + assert(isAllOneSizeType(maskValType) || + isa_and_nonnull(mask.getDefiningOp())); + Value zero = rewriter.create(loc, 0); + SmallVector indices; + for (int64_t i = 0; i < maskValType.getRank(); ++i) + indices.push_back(zero); + return rewriter.create(loc, mask, ValueRange(indices)); + } + return mask; +} + +/// Get cond value from mask. +static Value getCondVal(Value mask, PatternRewriter &rewriter) { + if (!mask) + return nullptr; + + // If the mask value is a 0-rank tensor or a tensor with all dimensions of + // one-size, return it. + // For example: + // - %mask : tensor + // - %mask = tensor<1x1xi1> + if (auto maskTy = dyn_cast(mask.getType())) { + if (isAllOneSizeType(maskTy)) + return extractCondFromShapeMask(mask, rewriter); + } + + // The origin mask must be obtained by broadcasting an i1. + auto defineOp = mask.getDefiningOp(); + Value scalarMask = nullptr; + if (defineOp) { + llvm::TypeSwitch(defineOp) + // %mask = tt.splat i1 -> + .Case( + [&](triton::SplatOp splatOp) { scalarMask = splatOp.getSrc(); }) + // %mask = tt.broadcast tensor -> + // %mask = tt.broadcast tensor<1x1xi1> -> + .Case([&](triton::BroadcastOp broadcastOp) { + Value src = broadcastOp.getSrc(); + if (isAllOneSizeType(cast(src.getType()))) { + scalarMask = src; + } else { + scalarMask = getCondVal(src, rewriter); + } + }) + // %mask = arith.constant dense : tensor + .Case([&](arith::ConstantOp constantOp) { + auto value = dyn_cast(constantOp.getValue()); + if (value && value.isSplat()) + scalarMask = constantOp.getResult(); + }); + } + return extractCondFromShapeMask(scalarMask, rewriter); +} + +/// Donate the MaskInfo extract from mask. +struct MaskInfo { + enum LogicalType { AND, OR, NONE }; + Value cond; ///< Mask which is splat or broadcast by i1 scalar. + Value mask; ///< Tensor or i1, other mask. + LogicalType logicalType; ///< The logical type of cond and mask. +}; + +/// Get mask info mask which separate mask from splat scalar i1 or broadcast by +/// tensor<1xi1> with others. +/// +/// Example: +/// +/// ```mlir +/// %mask0 = ... +/// %mask1 = tt.splat %scalar +/// %mask2 = tt.broadcast %tensor0d +/// %mask01 = arith.andi %mask0, %mask1 +/// %mask = arith.andi %mask01, %mask2 +/// ``` +/// +/// Given %mask, returns %splatMask and %otherMask. +/// +/// ```mlir +/// %extract = tensor.extract %tensor0d +/// %cond = arith.andi %scalar, %extract +/// %mask = %mask0 +/// ``` +MaskInfo getMaskInfo(Value mask, PatternRewriter &rewriter) { + MaskInfo maskInfo = {nullptr, nullptr, MaskInfo::LogicalType::NONE}; + if (!mask) + return maskInfo; + + auto loc = mask.getLoc(); + + using LogicalType = MaskInfo::LogicalType; + auto createMasks = [&rewriter, loc](ArrayRef masks, + LogicalType type) -> Value { + if (masks.empty()) + return nullptr; + auto size = masks.size(); + Value ret = masks[0]; + for (size_t i = 1; i < size; i++) { + if (type == LogicalType::AND) { + ret = rewriter.create(loc, ret, masks[i]); + } else { + assert(type == LogicalType::OR); + ret = rewriter.create(loc, ret, masks[i]); + } + } + + return ret; + }; + + SmallVector condMasks, otherMasks; + std::queue worklist; + worklist.push(mask); + LogicalType logicalType = LogicalType::NONE; + while (!worklist.empty()) { + auto item = worklist.front(); + worklist.pop(); + auto defOp = item.getDefiningOp(); + if (llvm::isa_and_nonnull(defOp)) { + worklist.push(defOp->getOperand(0)); + worklist.push(defOp->getOperand(1)); + LogicalType currentLogicalType = + llvm::isa(defOp) ? LogicalType::AND : LogicalType::OR; + if (logicalType == LogicalType::NONE) { + logicalType = currentLogicalType; + } else if (logicalType != currentLogicalType) { + // TODO: add support for mixed logical type. + return maskInfo; + } + continue; } - // The origin mask must be obtained by broadcasting an i1. - auto defineOp = maskVal.getDefiningOp(); - if (defineOp) { - llvm::TypeSwitch(defineOp) - // %mask = tt.splat i1 -> - .Case( - [&](triton::SplatOp splatOp) { baseMaskVal = splatOp.getSrc(); }) - // %mask = tt.broadcast tensor -> - // %mask = tt.broadcast tensor<1x1xi1> -> - .Case([&](triton::BroadcastOp broadcastOp) { - Value src = broadcastOp.getSrc(); - if (isAllOneSizeType(src.getType().cast())) - baseMaskVal = src; - }) - // %mask = arith.constant dense : tensor - .Case([&](arith::ConstantOp constantOp) { - auto value = constantOp.getValue().dyn_cast(); - if (value && value.isSplat()) - baseMaskVal = constantOp.getResult(); - }); + Value cond = getCondVal(item, rewriter); + if (cond) { + condMasks.push_back(cond); + continue; } + + otherMasks.push_back(item); } - if (baseMaskVal) - return baseMaskVal; - return std::nullopt; + if (condMasks.empty()) + return maskInfo; + + // Only has cond, set logical type to AND. + if (otherMasks.empty()) + logicalType = LogicalType::AND; + + maskInfo.logicalType = logicalType; + maskInfo.cond = createMasks(condMasks, logicalType); + maskInfo.mask = createMasks(otherMasks, logicalType); + + return maskInfo; } namespace { @@ -252,8 +372,8 @@ class CanonicalizeTtBroadCastPattern return success(); } - ShapedType inputShape = op.getSrc().getType().cast(); - ShapedType outputShape = op.getResult().getType().cast(); + ShapedType inputShape = cast(op.getSrc().getType()); + ShapedType outputShape = cast(op.getResult().getType()); // Meet case 3 described above, so do nothing but return. if (inputShape.getRank() == outputShape.getRank()) { return failure(); @@ -268,7 +388,7 @@ class CanonicalizeTtBroadCastPattern expanded = rewriter.create(op.getLoc(), expanded, 0); } - if (llvm::equal(expanded.getType().cast().getShape(), + if (llvm::equal(cast(expanded.getType()).getShape(), outputShapeArr)) { rewriter.replaceOp(op, expanded); } else { @@ -280,8 +400,8 @@ class CanonicalizeTtBroadCastPattern private: bool isScalarOr0dTensor(const Value &val) const { - return !val.getType().isa() || - val.getType().cast().getRank() == 0; + return !isa(val.getType()) || + cast(val.getType()).getRank() == 0; } /// Returns how many axes should be expanded to input @@ -395,24 +515,44 @@ class CanonicalizeTtLoadPattern : public OpRewritePattern { }; /// This pattern applies to memory access operations like `tt.load` and -/// `tt.store`. If their mask is created through broadcasting an `i1`, it -/// results in subsequent conversion to scalar memory access operations. This -/// pattern converts memory access operations with this specific mask into -/// `scf.if` + the memory access operation(without mask). +/// `tt.store`. If their mask is created through a scalar splat of `i1` or used +/// as an input for `arith.andi` or `arith.ori`, it results in subsequent +/// conversion to discrete memory access operations. This pattern converts +/// memory access operations with this specific mask into `scf.if` + the memory +/// access operation(without splat i1 mask). +/// +/// Case1: If the mask is created by `arith.andi`, it convert to: +/// ``` +/// if (%cond) { +/// yield tt.load +/// } else { +/// yield other operand from tt.load op +/// } +/// ``` +/// +/// Case2: If the mask is created by `arith.ori`, it convert to: +/// ``` +/// if (%cond) { +/// yield other operand from tt.load op +/// } else { +/// yield tt.load +/// } +/// ``` /// /// Example 1: /// /// ``` mlir -/// %other = arith.constant dense<0.000000e+00> : tensor<1x1024xf32> -/// %mask = tt.splat %bool : i1 -> tensor<1x1024xi1> -/// %res = tt.load %ptr, %mask, %other : tensor<1x1024x!tt.ptr> +/// %mask0 = tt.splat %bool : i1 -> tensor<1x1024xi1> +/// %mask1 = ... +/// %mask = arith.andi %mask0, %mask1 +/// %res = tt.load %ptr, %mask : tensor<1x1024x!tt.ptr> /// tt.return %res : tensor<1x1024xf32> /// ``` /// is converted to: /// ``` mlir /// %constant = arith.constant dense<0.000000e+00> : tensor<1x1024xf32> /// %res = scf.if (%bool) -> tensor<1x1024xf32> { -/// %load = tt.load %ptr : tensor<1x1024x!tt.ptr> +/// %load = tt.load %ptr, %mask1 : tensor<1x1024x!tt.ptr> /// scf.yield %load : tensor<1x1024xf32> /// } else { /// scf.yield %constant : tensor<1x1024xf32> @@ -423,6 +563,28 @@ class CanonicalizeTtLoadPattern : public OpRewritePattern { /// Example 2: /// /// ``` mlir +/// %other = arith.constant dense<1.000000e+00> : tensor<1x1024xf32> +/// %mask0 = tt.splat %bool : i1 -> tensor<1x1024xi1> +/// %mask1 = ... +/// %mask = arith.andi %mask0, %mask1 +/// %res = tt.load %ptr, %mask, %other : tensor<1x1024x!tt.ptr> +/// tt.return %res : tensor<1x1024xf32> +/// ``` +/// is converted to: +/// ``` mlir +/// %other = arith.constant dense<1.000000e+00> : tensor<1x1024xf32> +/// %res = scf.if (%bool) -> tensor<1x1024xf32> { +/// %load = tt.load %ptr, %mask1, %other : tensor<1x1024x!tt.ptr> +/// scf.yield %load : tensor<1x1024xf32> +/// } else { +/// scf.yield %other : tensor<1x1024xf32> +/// } +/// tt.return %res : tensor<1x1024xf32> +/// ``` +/// +/// Example 3: +/// +/// ``` mlir /// %mask = tt.splat %bool : i1 -> tensor<1x1024xi1> /// tt.store %ptr, %val, %mask : tensor<1x1024x!tt.ptr> /// ``` @@ -433,7 +595,7 @@ class CanonicalizeTtLoadPattern : public OpRewritePattern { /// } /// ``` /// -/// Example 3: +/// Example 4: /// /// ``` mlir /// %mask = tt.splat %bool : i1 -> tensor<64x64xi1> @@ -464,46 +626,61 @@ class CanonicalizeTtMaskAccessPattern : public OpRewritePattern { llvm::isa(op))) return failure(); - std::optional baseMaskVal = getBaseMaskVal(op.getMask()); - if (!baseMaskVal.has_value()) + MaskInfo maskInfo = getMaskInfo(op.getMask(), rewriter); + + if (!maskInfo.cond) return failure(); + assert(maskInfo.logicalType != MaskInfo::LogicalType::NONE); auto loc = op->getLoc(); auto resTypes = op->getResultTypes(); - // Create scf.if op. - Value condVal = baseMaskVal.value(); - if (auto condValType = condVal.getType().dyn_cast()) { - Value zero = rewriter.create(loc, 0); - SmallVector indices; - for (int64_t i = 0; i < condValType.getRank(); ++i) - indices.push_back(zero); - condVal = - rewriter.create(loc, condVal, ValueRange(indices)); - } scf::IfOp ifOp = rewriter.create(loc, /*resultTypes*/ resTypes, - /*cond*/ condVal, + /*cond*/ maskInfo.cond, /*addElseBlock*/ true); - // Create new op in then region. - // Then region: new op(without mask) + yield(if new op has result) - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - op.getMaskMutable().clear(); + // Create new op in then region if logical type is AND. + // Then region: new op(new_mask) + yield(if new op has result) + if (maskInfo.logicalType == MaskInfo::LogicalType::AND) { + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } else { + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + } + if (maskInfo.mask) { + op.getMaskMutable().assign(maskInfo.mask); + } else { + op.getMaskMutable().clear(); + } auto newOp = rewriter.clone(*op); - if (auto loadOp = llvm::dyn_cast(*newOp)) - loadOp.getOtherMutable().clear(); + if (auto loadOp = llvm::dyn_cast(*newOp)) { + // If load op has no mask, delete other operand. + if (op.getMaskMutable().empty()) + loadOp.getOtherMutable().clear(); + } if (resTypes.empty()) { rewriter.eraseOp(op); } else { rewriter.create(loc, newOp->getResults()); // If else region is empty, it will be fold in canonicalize. - // Else region: constant(0) + yield. - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + // Else region: constant(0 or other value) + yield if logical type is AND. + if (maskInfo.logicalType == MaskInfo::LogicalType::AND) { + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + } else { + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } // If the processed op has results, it will has only one result. - Value zeroVal = rewriter.create( - loc, resTypes.front(), rewriter.getZeroAttr(resTypes.front())); - rewriter.create(loc, zeroVal); + Value yieldVal = nullptr; + // If load has other arg set, use other value + auto loadOp = llvm::dyn_cast(op.getOperation()); + if (loadOp && loadOp.getOther()) { + yieldVal = loadOp.getOther(); + } else { + yieldVal = rewriter.create( + loc, resTypes.front(), rewriter.getZeroAttr(resTypes.front())); + } + + rewriter.create(loc, yieldVal); rewriter.replaceOp(op, ifOp); } @@ -511,9 +688,58 @@ class CanonicalizeTtMaskAccessPattern : public OpRewritePattern { } }; +class CanonicalizeTtAssertPattern : public OpRewritePattern { +public: + explicit CanonicalizeTtAssertPattern(MLIRContext *ctx) + : OpRewritePattern(ctx) {} + + LogicalResult matchAndRewrite(triton::AssertOp op, + PatternRewriter &rewriter) const override { + auto condVal = op.getCondition(); + auto valType = condVal.getType(); + auto rank = valType.getRank(); + auto assertMessage = + llvm::formatv("{0}:{1}: {2} Assertion `{3}` failed", op.getFile(), + op.getLine(), op.getFunc(), op.getMessage()); + assert(isa(valType.getElementType()) && + "Only support int tensor for assert"); + // If the AssertOp input shape dimension is 1 or 0 dimension and the 0th + // dimension is 1, it is converted to ScalarAssertOp. + if ((rank != 1 && rank != 0) || (rank > 0 && valType.getShape()[0] != 1)) { + return failure(); + } + auto rankType = cast(valType); + auto elemType = rankType.getElementType(); + Value zeroIndex = rewriter.create(op.getLoc(), 0); + Value cond; + if (rank == 0) { + cond = + rewriter.create(op.getLoc(), condVal).getResult(); + } else { + cond = rewriter.create(op.getLoc(), condVal, zeroIndex) + .getResult(); + } + // If the input data type is not i1, cast it to i1. + if (!elemType.isInteger(1)) { + Value zero = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(elemType, 0)); + cond = rewriter.create( + op.getLoc(), arith::CmpIPredicate::ne, cond, zero); + } + rewriter.create(op.getLoc(), cond, + assertMessage.str()); + rewriter.eraseOp(op); + return success(); + } +}; + struct CanonicalizeTritonPass : public CanonicalizeTritonBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } void runOnOperation() override { MLIRContext &ctx = getContext(); // Canonicalize mask-related ops and its connection. @@ -530,7 +756,8 @@ struct CanonicalizeTritonPass patterns.insert, CanonicalizeTtMaskAccessPattern, - CanonicalizeTtMaskAccessPattern>(&ctx); + CanonicalizeTtMaskAccessPattern, + CanonicalizeTtAssertPattern>(&ctx); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { diff --git a/lib/Dialect/Triton/Transforms/ExtractMoveBackward.cpp b/lib/Dialect/Triton/Transforms/ExtractMoveBackward.cpp index de0a22e..10c77e9 100644 --- a/lib/Dialect/Triton/Transforms/ExtractMoveBackward.cpp +++ b/lib/Dialect/Triton/Transforms/ExtractMoveBackward.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" @@ -105,7 +106,7 @@ static inline bool isOutputRankReduced(tensor::ExtractSliceOp op) { /// 2. ofr is a value and defined by tensor.dim, with the index `dim` and /// source `value`. static bool hasSameSizeWithDim(OpFoldResult ofr, Value value, int64_t dim) { - auto type = value.getType().dyn_cast(); + auto type = mlir::dyn_cast(value.getType()); assert(type && dim >= 0 && dim < type.getRank() && "Expected value with ShapedType and dim is a valid axis index."); @@ -116,7 +117,7 @@ static bool hasSameSizeWithDim(OpFoldResult ofr, Value value, int64_t dim) { // Check whether ofr is defined by an tensor.dim, with the index `dim` and // source `value`. - auto ofrValue = ofr.dyn_cast(); + auto ofrValue = mlir::dyn_cast(ofr); auto dimOp = ofrValue ? ofrValue.getDefiningOp() : nullptr; if (!dimOp || dimOp.getSource() != value) return false; @@ -124,7 +125,7 @@ static bool hasSameSizeWithDim(OpFoldResult ofr, Value value, int64_t dim) { Value index = dimOp.getIndex(); auto constantOp = index.getDefiningOp(); return constantOp && - constantOp.getValue().cast().getInt() == dim; + mlir::cast(constantOp.getValue()).getInt() == dim; } /// Reshape input to resultType by adding unit dims. @@ -139,7 +140,7 @@ static bool hasSameSizeWithDim(OpFoldResult ofr, Value value, int64_t dim) { static Value expandShapeToResultTypeByAddUnitDims(OpBuilder &b, Location loc, ShapedType resultType, Value value) { - auto sourceType = value.getType().template cast(); + auto sourceType = mlir::cast(value.getType()); int64_t dstRank = resultType.getRank(); int64_t srcRank = sourceType.getRank(); int64_t rankDiff = dstRank - srcRank; @@ -194,7 +195,7 @@ static Value expandShapeToResultTypeByAddUnitDims(OpBuilder &b, Location loc, static Value reshapeToResultTypeByDropUnitDims(OpBuilder &b, Location loc, ShapedType resultType, Value value) { - auto sourceType = value.getType().template cast(); + auto sourceType = mlir::cast(value.getType()); int64_t dstRank = resultType.getRank(); int64_t srcRank = sourceType.getRank(); int64_t rankDiff = srcRank - dstRank; @@ -233,7 +234,8 @@ static Value reshapeToResultTypeByDropUnitDims(OpBuilder &b, Location loc, reassociation.front().insert(reassociation.front().begin(), leadingInconsistentDims.begin(), leadingInconsistentDims.end()); - + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfterValue(value); return b.create(loc, value, reassociation); } @@ -355,10 +357,73 @@ ExtractState::ExtractState(tensor::ExtractSliceOp op) } // anonymous namespace +/// Get the insertion point by analysing values that will be used. +static OpBuilder::InsertPoint getInsertionPoint(ArrayRef values, + ExtractState &state, + PatternRewriter &rewriter) { + // Collect all related values. + SmallVector opFoldResults; + opFoldResults.append(state.offsets); + opFoldResults.append(state.sizes); + opFoldResults.append(state.strides); + std::pair, SmallVector> attrOrVals = + decomposeMixedValues(opFoldResults); + SmallVector dependentVals = attrOrVals.second; + dependentVals.append(SmallVector(values)); + + // Create current func op dominance info. + auto funcOp = dependentVals.front() + .getParentRegion() + ->getParentOfType(); + DominanceInfo domInfo(funcOp); + // Divide values by different blocks. + DenseMap> blockVals; + for (auto val : dependentVals) { + Block *block = val.getParentBlock(); + if (blockVals.count(block)) { + blockVals[block].emplace_back(val); + } else { + blockVals.insert({block, SmallVector{val}}); + } + } + // Find the innermost block according to dominance info. Note: here all + // related values will be used to create operations, so their defining block + // is locating at the same branch of dominance tree. Global variable is not + // considered. + Block *innerBlock = blockVals.begin()->first; + for (auto &[key, value] : blockVals) { + if (domInfo.dominates(innerBlock, key)) { + innerBlock = key; + } + } + // Find the last value in the innermost block. + SmallVector innerVals = blockVals[innerBlock]; + SmallVector ops; + SmallVector args; + for (auto val : innerVals) { + auto op = val.getDefiningOp(); + if (op) { + ops.emplace_back(op); + } else { + args.emplace_back(val); + } + } + // If op exists, find the last one as the insertion point. + if (!ops.empty()) { + llvm::sort( + ops, [](Operation *a, Operation *b) { return a->isBeforeInBlock(b); }); + return OpBuilder::InsertPoint(innerBlock, ++Block::iterator(ops.back())); + } + // Otherwise, return one argument as the insertion point. + return OpBuilder::InsertPoint( + innerBlock, mlir::cast(args.front()).getOwner()->begin()); +} + static void getExtractedValueFrom(Value value, ExtractState &state, Location loc, PatternRewriter &rewriter) { if (state.extractedVal) return; + rewriter.restoreInsertionPoint(getInsertionPoint({value}, state, rewriter)); // Get value by tensor.extract. if (state.type == ExtractType::EXTRACT) { state.extractedVal = rewriter.create( @@ -481,6 +546,10 @@ void ExtractAnalysis::visitOperandFromOp(linalg::MapOp op, ExtractState &state, // Retrieve extracted operands. SmallVector foldOperands(llvm::map_range( operandStates, [](const ExtractState &s) { return s.extractedVal; })); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint(foldOperands, state, rewriter)); if (state.type == ExtractType::EXTRACT) { // Map operands to extracted operands. IRMapping bvm; @@ -493,7 +562,7 @@ void ExtractAnalysis::visitOperandFromOp(linalg::MapOp op, ExtractState &state, // Clone map body to generate extracted map result. for (auto &payload : body->getOperations()) { - if (!isa(payload)) + if (!mlir::isa(payload)) rewriter.clone(payload, bvm); } auto &yieldOp = body->getOperations().back(); @@ -516,13 +585,16 @@ template <> void ExtractAnalysis::visitOperandFromOp(linalg::FillOp op, ExtractState &state, Location loc, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op.getResult(0)}, state, rewriter)); if (state.type == ExtractType::EXTRACT) { state.extractedVal = op.getOperand(0); return; } Value output = rewriter.create( loc, state.sizes, - op.getResult(0).getType().template cast().getElementType()); + mlir::cast(op.getResult(0).getType()).getElementType()); state.extractedVal = rewriter.create(loc, op.getOperand(0), output) .getResult(0); @@ -534,6 +606,9 @@ void ExtractAnalysis::visitOperandFromOp(linalg_ext::MakeRangeOp op, PatternRewriter &rewriter) { assert(state.offsets.size() == 1 && "Offset size must be 1 for make_range op."); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op.getResult(0)}, state, rewriter)); if (state.type == ExtractType::EXTRACT) { Value offset = rewriter.create( loc, op.getStart().getType(), @@ -575,6 +650,9 @@ void ExtractAnalysis::visitOperandFromOp(linalg::BroadcastOp op, visitOperand(op.getInput(), operandState, op, loc, rewriter); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({operandState.extractedVal}, state, rewriter)); if (state.type == ExtractType::EXTRACT) state.extractedVal = operandState.extractedVal; else { @@ -591,6 +669,9 @@ void ExtractAnalysis::visitOperandFromOp(linalg::BroadcastOp op, static void extractFromCollapseShapeOp(tensor::CollapseShapeOp op, ExtractState &state, Location loc, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op->getResult(0)}, state, rewriter)); Value collapseSrc = op.getSrc(); auto reassociationIndices = op.getReassociationIndices(); int64_t resRank = op.getResultType().getRank(); @@ -618,7 +699,7 @@ static void extractFromCollapseShapeOp(tensor::CollapseShapeOp op, if (resRank == 0) { Value c0 = rewriter.createOrFold(loc, rewriter.getIndexAttr(0)); - auto srcRank = collapseSrc.getType().cast().getRank(); + auto srcRank = mlir::cast(collapseSrc.getType()).getRank(); operandState.offsets.append(srcRank, c0); } @@ -634,6 +715,9 @@ static void extractFromCollapseShapeOp(tensor::CollapseShapeOp op, static void extractSliceFromCollapseShapeOp(tensor::CollapseShapeOp op, ExtractState &state, Location loc, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op->getResult(0)}, state, rewriter)); Value collapseSrc = op.getSrc(); Value collapseDst = op.getResult(); auto reassociationIndices = op.getReassociationIndices(); @@ -687,10 +771,11 @@ static void extractSliceFromCollapseShapeOp(tensor::CollapseShapeOp op, operandState.strides = SmallVector(srcRank, indexAttrOne); } - auto srcTy = op.getResultType().cast(); + auto srcTy = mlir::cast(op.getResultType()); auto resultTy = tensor::ExtractSliceOp::inferResultType( srcTy, state.offsets, state.sizes, state.strides); ExtractAnalysis::visitOperand(collapseSrc, operandState, op, loc, rewriter); + rewriter.setInsertionPointAfterValue(operandState.extractedVal); state.extractedVal = rewriter.create( loc, resultTy, operandState.extractedVal, reassociationIndices); } @@ -708,6 +793,9 @@ void ExtractAnalysis::visitOperandFromOp(tensor::CollapseShapeOp op, static void extractFromExpandShapeOp(tensor::ExpandShapeOp op, ExtractState &state, Location loc, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op->getResult(0)}, state, rewriter)); Value expandSrc = op.getSrc(); Value expandDst = op.getResult(); int64_t srcRank = op.getSrcType().getRank(); @@ -735,7 +823,6 @@ static void extractFromExpandShapeOp(tensor::ExpandShapeOp op, } dstDimIdx += reassociationIndices[srcDimIdx].size(); } - ExtractAnalysis::visitOperand(expandSrc, operandState, op, loc, rewriter); state.extractedVal = operandState.extractedVal; } @@ -748,13 +835,16 @@ static void extractFromExpandShapeOp(tensor::ExpandShapeOp op, static void extractSliceFromExpandShapeOp(tensor::ExpandShapeOp op, ExtractState &state, Location loc, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op->getResult(0)}, state, rewriter)); Value expandSrc = op.getSrc(); Value expandDst = op.getResult(); int64_t srcRank = op.getSrcType().getRank(); auto reassociationIndices = op.getReassociationIndices(); ExtractState operandState; operandState.type = state.type; - auto dstType = expandDst.getType().cast(); + auto dstType = mlir::cast(expandDst.getType()); auto indexAttrZero = rewriter.getIndexAttr(0); auto indexAttrOne = rewriter.getIndexAttr(1); for (auto srcDimIdx : llvm::seq(0, srcRank)) { @@ -805,11 +895,11 @@ static void extractSliceFromExpandShapeOp(tensor::ExpandShapeOp op, getExtractedValueFrom(expandDst, state, loc, rewriter); return; } - - auto srcTy = op.getResultType().cast(); + auto srcTy = mlir::cast(op.getResultType()); auto resultTy = tensor::ExtractSliceOp::inferResultType( srcTy, state.offsets, state.sizes, state.strides); ExtractAnalysis::visitOperand(expandSrc, operandState, op, loc, rewriter); + rewriter.setInsertionPointAfterValue(operandState.extractedVal); state.extractedVal = rewriter.create( loc, resultTy, operandState.extractedVal, reassociationIndices); } @@ -830,9 +920,9 @@ void ExtractAnalysis::visitOperandFromOp(Operation *op, ExtractState &state, PatternRewriter &rewriter) { auto *dialect = op->getDialect(); (void)dialect; - assert( - (isa(dialect) || isa(dialect)) && - "unregister operations in extact analysis for now."); + assert((mlir::isa(dialect) || + mlir::isa(dialect)) && + "unregister operations in extact analysis for now."); SmallVector operandStates; for (Value v : op->getOperands()) { @@ -841,16 +931,20 @@ void ExtractAnalysis::visitOperandFromOp(Operation *op, ExtractState &state, operandStates.push_back(operandState); } + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + getInsertionPoint({op->getResult(0)}, state, rewriter)); + SmallVector foldOperands(llvm::map_range( operandStates, [](const ExtractState &s) { return s.extractedVal; })); // Since `arith.constant` uses attribute to represent value, we can not // use operation identifier to update it directly. Here, we utilize the // fold methods of extract-like operations to eliminate constant operations. - if (isa(op)) + if (mlir::isa(op)) return getExtractedValueFrom(op->getResult(0), state, loc, rewriter); - auto resultTy = op->getResult(0).getType().template cast(); + auto resultTy = mlir::cast(op->getResult(0).getType()); Type newResTy = (state.type == ExtractType::EXTRACT) ? resultTy.getElementType() @@ -872,8 +966,8 @@ void ExtractAnalysis::visitOperand(Value operand, ExtractState &state, return getExtractedValueFrom(operand, state, loc, rewriter); auto *dialect = opInst->getDialect(); - if (isa(dialect) || - isa(dialect)) { + if (mlir::isa(dialect) || + mlir::isa(dialect)) { return visitOperandFromOp(opInst, state, loc, rewriter); } @@ -907,10 +1001,6 @@ LogicalResult ExtractAnalysis::rewriteExtractLikeOp(OpTy op, if (!isOutputRankReduced(op)) return failure(); - OpBuilder::InsertionGuard guard(rewriter); - // Any inserted instruction should be before this extract operation. - rewriter.setInsertionPoint(op); - // Set as the flag for the original tensor.extract operation. ExtractState state{op}; state.extractedVal = op.getResult(); @@ -943,7 +1033,7 @@ static void eliminateDeadExpressionsFrom(Value value, auto val = candidates.front(); candidates.pop(); - if (val.isa()) + if (mlir::isa(val)) continue; auto *defOp = val.getDefiningOp(); @@ -965,6 +1055,7 @@ static SetVector getCandidateExtractLikeOps(scf::ForOp forOp, Block *loopBody = &forOp.getRegion().front(); auto arg = loopBody->getArgument(forOp.getNumInductionVars() + iterIndex); auto loopLikeOpInterface = cast(forOp.getOperation()); + moveLoopInvariantCode(loopLikeOpInterface); SetVector candidates; for (auto *op : arg.getUsers()) { @@ -1005,7 +1096,8 @@ static LogicalResult extractIterArgPrecondition(scf::ForOp forOp, unsigned iterIndex, PatternRewriter &rewriter) { // Move loop invariant code ahead. - auto loopLikeOpInterface = cast(forOp.getOperation()); + auto loopLikeOpInterface = + mlir::cast(forOp.getOperation()); moveLoopInvariantCode(loopLikeOpInterface); Block *loopBody = &forOp.getRegion().front(); @@ -1043,7 +1135,7 @@ static LogicalResult extractIterArgPrecondition(scf::ForOp forOp, SetVector forwardSlice; ForwardSliceOptions forwardSliceOptions; forwardSliceOptions.filter = [&forOp](Operation *op) { - return !isa(op) && !isa(op) && + return !mlir::isa(op) && !mlir::isa(op) && op->getParentOp() == forOp.getOperation() && forOp->isProperAncestor(op); }; @@ -1070,7 +1162,7 @@ static LogicalResult extractIterArgPrecondition(scf::ForOp forOp, SmallVector states; states.reserve(extractNum); for (size_t index = 0; index < extractNum; ++index) { - auto extractLikeOp = cast(extractCandidates[index]); + auto extractLikeOp = mlir::cast(extractCandidates[index]); auto operand = yieldOp->getOperand(iterIndex); ExtractState state{extractLikeOp}; state.extractedVal = nullptr; @@ -1091,15 +1183,16 @@ static LogicalResult extractIterArgPrecondition(scf::ForOp forOp, if (loopLikeOpInterface.isDefinedOutsideOfLoop(currVal)) continue; - if (currVal.isa() && - currVal.dyn_cast().getOwner()->getParentOp() == forOp) + if (mlir::isa(currVal) && + mlir::dyn_cast(currVal).getOwner()->getParentOp() == + forOp) return failure(); auto *op = currVal.getDefiningOp(); - if (isa(op)) { + if (mlir::isa(op)) { if (op->getOperand(0) != loopBody->getArgument( iterIndex + forOp.getNumInductionVars()) || - !states[index].isSameExceptVal(cast(op))) { + !states[index].isSameExceptVal(mlir::cast(op))) { // Remove new inserted operations. eliminateDeadExpressionsFrom(states[index].extractedVal, rewriter); return failure(); @@ -1142,7 +1235,7 @@ static LogicalResult tryExtractIterArgPrecondition(scf::ForOp op, rewriter.setInsertionPoint(op); auto exeOp = rewriter.create(op->getLoc(), TypeRange{}); rewriter.setInsertionPointToStart(&exeOp.getRegion().emplaceBlock()); - auto forOp = cast(rewriter.clone(*op)); + auto forOp = mlir::cast(rewriter.clone(*op)); auto ret = extractIterArgPrecondition(forOp, iterIndex, rewriter); rewriter.eraseOp(exeOp); return ret; @@ -1323,7 +1416,7 @@ struct SCFRearrangementPattern : public OpRewritePattern { // Clone operations from old loop body to the new one. llvm::MapVector oldAndNewOpMap; for (auto &op : *oldLoopBody) { - if (extractLikeOpsToMove.contains(&op) || isa(&op)) + if (extractLikeOpsToMove.contains(&op) || mlir::isa(&op)) continue; auto *newOp = rewriter.clone(op, bvm); oldAndNewOpMap[&op] = newOp; @@ -1364,12 +1457,12 @@ struct SCFRearrangementPattern : public OpRewritePattern { loopBody ->getArgument(iterOperandNumber + newForOp.getNumInductionVars()) .getUsers()) { - if (!isa(op)) + if (!mlir::isa(op)) continue; Value substitute; - ExtractState targetState{cast(op)}; + ExtractState targetState{mlir::cast(op)}; for (const auto &en : llvm::enumerate(extractLikeOpsToMove)) { - if (!targetState.isSameExceptVal(cast(en.value()))) + if (!targetState.isSameExceptVal(mlir::cast(en.value()))) continue; substitute = loopBody->getArgument(en.index() + loopBody->getNumArguments() - @@ -1377,13 +1470,16 @@ struct SCFRearrangementPattern : public OpRewritePattern { break; } - assert(substitute && "Fail to find new iter argument."); + if (!substitute) + return rewriter.notifyMatchFailure(forOp->getLoc(), + "Fail to find new iter argument."); replacePairs.push_back({op, substitute}); } for (auto p : replacePairs) { rewriter.setInsertionPointToStart(loopBody); - if (auto extractSliceOp = dyn_cast(p.first)) { + if (auto extractSliceOp = + mlir::dyn_cast(p.first)) { auto resultType = extractSliceOp.getResultType(); p.second = expandShapeToResultTypeByAddUnitDims( rewriter, forOp.getLoc(), resultType, p.second); @@ -1414,14 +1510,35 @@ struct ExtractLikeMoveBackwardPass void runOnOperation() override { auto *context = &getContext(); - RewritePatternSet patterns(context); - patterns.add, - ExtractRearrangementPattern, - SCFRearrangementPattern, - SCFRearrangementPattern>(context); - func::FuncOp func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); + GreedyRewriteConfig config; + bool changed = false; + + // FIXME: Starting from LLVM19, during conversion, if the ParentOp of + // an Op is also in the same conversion pattern, accessing the ParentOp from + // within the Op may be an invalid behavior. + do { + RewritePatternSet extractPatterns(context); + extractPatterns.add, + ExtractRearrangementPattern>( + context); + + RewritePatternSet scfPatterns(context); + scfPatterns.add, + SCFRearrangementPattern>(context); + + bool extractChanged = false; + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(extractPatterns), + config, &extractChanged))) + return signalPassFailure(); + + bool scfChanged = false; + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(scfPatterns), config, &scfChanged))) + return signalPassFailure(); + + changed = extractChanged && scfChanged; + } while (changed); } }; } // anonymous namespace diff --git a/lib/Dialect/Triton/Transforms/InferAxisInfoInterfaceImpl.cpp b/lib/Dialect/Triton/Transforms/InferAxisInfoInterfaceImpl.cpp index a89db5c..69c2449 100644 --- a/lib/Dialect/Triton/Transforms/InferAxisInfoInterfaceImpl.cpp +++ b/lib/Dialect/Triton/Transforms/InferAxisInfoInterfaceImpl.cpp @@ -5,11 +5,11 @@ //===----------------------------------------------------------------------===// #include #include -#include #include #include #include #include +#include #include #include "mlir/Dialect/Arith/IR/Arith.h" @@ -60,6 +60,14 @@ template static inline T highestPowOf2Divisor(T n) { return (n & (~(n - 1))); } +/// If lhs * rhs overflows, return max value possible value for the type. +static int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + static constexpr int log2Int(int64_t num) { return (num > 1) ? 1 + log2Int(num / 2) : 0; } @@ -85,7 +93,7 @@ struct ConstantOpInferAxisInfoOpInterface void inferAxisInfos(Operation *op, ArrayRef argInfos, SetAxisInfoFn setResultAxisInfo) const { arith::ConstantOp constantOp = cast(op); - auto intAttr = constantOp.getValue().dyn_cast(); + auto intAttr = dyn_cast(constantOp.getValue()); if (intAttr) { int64_t value = intAttr.getValue().getZExtValue(); @@ -96,10 +104,10 @@ struct ConstantOpInferAxisInfoOpInterface } // TODO: generalize to dense attr. - auto splatAttr = constantOp.getValue().dyn_cast(); + auto splatAttr = dyn_cast(constantOp.getValue()); if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { int64_t value = splatAttr.getSplatValue().getZExtValue(); - TensorType ty = splatAttr.getType().cast(); + TensorType ty = cast(splatAttr.getType()); return setResultAxisInfo( constantOp.getResult(), AxisInfoExt(AxisInfoExt::DimVectorT(ty.getRank(), @@ -263,7 +271,7 @@ struct MulOpInferAxisInfoOpInterface // lhs = k * d_lhs // rhs = p * d_rhs // lhs * rhs = k * d_lhs * p * d_rhs = k * p * d_lhs * d_rhs - return lhs.getDivisibility(dim) * rhs.getDivisibility(dim); + return multiplyDivisor(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); } std::optional getConstantValue(arith::MulIOp op, @@ -285,8 +293,7 @@ struct DivOpInferAxisInfoOpInterface OpTy divOp = cast(op); assert(argInfos.size() == 2 && "Expected two operands"); - auto resTy = - divOp.getResult().getType().template dyn_cast(); + auto resTy = dyn_cast(divOp.getResult().getType()); if (!resTy) return setResultAxisInfo(divOp.getResult(), AxisInfoExt{}); auto shape = resTy.getShape(); @@ -314,8 +321,9 @@ struct DivOpInferAxisInfoOpInterface constantValue = {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; divisibility.push_back(highestPowOf2Divisor(constantValue.value())); - } else if (!lhs.isConstantDim(shape, d) && lhs.isStrideDim(shape, d) && - rhs.isConstantDim(shape, d) && + } else if (!lhs.isFullConstantDim(shape, d) && + lhs.isFullStrideDim(shape, d) && + rhs.isFullConstantDim(shape, d) && rhs.getConstantValue().has_value() && llvm::isPowerOf2_64(lhs.getStrideValue(d))) { // Case 3: lhs stride(stride_val is power of 2), rhs constant. @@ -328,21 +336,23 @@ struct DivOpInferAxisInfoOpInterface // minStride = max(gcd(d_lhs, d_rhs) / strideVal, 1). // Since minStride maybe > len(lhs), // we need to use another gcd to get the actual constancy. - assert(lhs.getStrideValue(d) != 0 && "Stride value should not be zero"); - stride.push_back( - std::gcd(lhs.getStride(d), - std::max(std::gcd(lhs.getDivisibility(d), - rhs.getDivisibility(d)) / - lhs.getStrideValue(d), - 1))); + int64_t divisibilityGCD = + std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); + bool isFullStrided = + lhs.getStrideValue(d) % rhs.getConstantValue().value() == 0; + int64_t newStride = + isFullStrided + ? lhs.getStride(d) + : std::max(divisibilityGCD / lhs.getStrideValue(d), 1); + stride.push_back(std::gcd(lhs.getStride(d), newStride)); strideValue.push_back(lhs.getStrideValue(d) / rhs.getConstantValue().value()); divisibility.push_back(std::max( lhs.getDivisibility(d) / rhs.getConstantValue().value(), 1)); - } else if (lhs.isConstantStrideDim(shape, d) && - rhs.isConstantStrideDim(shape, d)) { - divisibility.push_back( - std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + } else if (lhs.isStridedConstantDim(shape, d) && + rhs.getConstantValue().has_value()) { + divisibility.push_back(std::max( + lhs.getDivisibility(d) / rhs.getConstantValue().value(), 1)); stride.push_back(std::gcd(lhs.getStride(d), rhs.getStride(d))); strideValue.push_back(0); } else { @@ -367,8 +377,7 @@ struct RemOpInferAxisInfoOpInterface OpTy remOp = cast(op); assert(argInfos.size() == 2 && "Expected two operands"); - auto resTy = - remOp.getResult().getType().template dyn_cast(); + auto resTy = dyn_cast(remOp.getResult().getType()); if (!resTy) return setResultAxisInfo(remOp.getResult(), AxisInfoExt{}); auto shape = resTy.getShape(); @@ -394,7 +403,8 @@ struct RemOpInferAxisInfoOpInterface constantValue = {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; divisibility.push_back(highestPowOf2Divisor(constantValue.value())); - } else if (lhs.isContiguousDim(shape, d) && rhs.isConstantDim(shape, d)) { + } else if (lhs.isFullContiguousDim(shape, d) && + rhs.isFullConstantDim(shape, d)) { // Case3: lhs contiguous, rhs constant. // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' @@ -415,6 +425,23 @@ struct RemOpInferAxisInfoOpInterface std::gcd(lhs.getContiguity(d), std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)))); strideValue.push_back(1); + } else if (lhs.isStridedContiguousDim(shape, d) && + rhs.getConstantValue().has_value()) { + // Case4: lhs strided contiguous, rhs constant value. + divisibility.push_back( + std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + stride.push_back( + std::gcd(lhs.getContiguity(d), + std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)))); + strideValue.push_back(lhs.getStrideValue(d) % + rhs.getConstantValue().value()); + } else if (lhs.isStridedConstantDim(shape, d) && + rhs.getConstantValue().has_value()) { + // Case5: lhs strided constant, rhs constant value. + divisibility.push_back( + std::gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + stride.push_back(lhs.getConstancy(d)); + strideValue.push_back(0); } else { divisibility.push_back(AxisInfoExt::kInitValue); stride.push_back(AxisInfoExt::kInitValue); @@ -436,7 +463,7 @@ struct CmpOpInferAxisInfoOpInterface arith::CmpIOp cmpOp = cast(op); assert(argInfos.size() == 2 && "Expected two operands"); - auto resTy = cmpOp.getResult().getType().dyn_cast(); + auto resTy = dyn_cast(cmpOp.getResult().getType()); if (!resTy) return setResultAxisInfo(cmpOp.getResult(), AxisInfoExt{}); auto shape = resTy.getShape(); @@ -466,8 +493,8 @@ struct CmpOpInferAxisInfoOpInterface auto commonDivisor = std::gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)); - if (lhsInfo.isConstantDim(shape, d) && - rhsInfo.isContiguousDim(shape, d)) { + if (lhsInfo.isFullConstantDim(shape, d) && + rhsInfo.isFullContiguousDim(shape, d)) { // Case 2: lhs all constant, rhs all contiguous // NOTE: // lhs: k0 * d, k0 * d, ... @@ -478,9 +505,8 @@ struct CmpOpInferAxisInfoOpInterface // lhs gt rhs: 1, 1, 1, 1 (minimal len: d if k0 > k1) constancyHint = std::max( constancyHint, std::gcd(rhsInfo.getContiguity(d), commonDivisor)); - } else if (lhsInfo.isContiguousDim(shape, d) && - rhsInfo.isConstantDim(shape, d)) { - + } else if (lhsInfo.isFullContiguousDim(shape, d) && + rhsInfo.isFullConstantDim(shape, d)) { // Case 3: lhs all contiguous, rhs all constant // NOTE // lhs: k0 * d, k0 * d + 1, ... @@ -491,6 +517,10 @@ struct CmpOpInferAxisInfoOpInterface // lhs lt rhs: 1, 1, 1, 1 (minimal len: d if k0 < k1) constancyHint = std::max( constancyHint, std::gcd(lhsInfo.getContiguity(d), commonDivisor)); + } else if (lhsInfo.isFullConstantDim(shape, d) && + rhsInfo.isFullConstantDim(shape, d)) { + // Case 4: lhs all constant, rhs all constant + strideValueHint = 0; } } @@ -599,8 +629,7 @@ struct SelectOpInferAxisInfoOpInterface arith::SelectOp selectOp = cast(op); assert(argInfos.size() == 3 && "Expected three operands"); - auto resTy = - selectOp.getResult().getType().template dyn_cast(); + auto resTy = dyn_cast(selectOp.getResult().getType()); if (!resTy) return setResultAxisInfo(selectOp.getResult(), AxisInfoExt()); auto shape = resTy.getShape(); @@ -623,10 +652,16 @@ struct SelectOpInferAxisInfoOpInterface constantValue = lhsInfo.getConstantValue(); } } else { + bool i1Cond = isa(op->getOperand(0).getType()); for (auto d = 0; d < rank; ++d) { - stride.push_back(std::gcd( - std::gcd(lhsInfo.getStride(d), argInfos[0].getConstancy(d)), - std::gcd(rhsInfo.getStride(d), argInfos[0].getConstancy(d)))); + if (i1Cond) { + stride.push_back( + std::gcd(lhsInfo.getStride(d), rhsInfo.getStride(d))); + } else { + stride.push_back(std::gcd( + std::gcd(lhsInfo.getStride(d), argInfos[0].getConstancy(d)), + std::gcd(rhsInfo.getStride(d), argInfos[0].getConstancy(d)))); + } strideValue.push_back(AxisInfoExt::kStrideValueInitValue); divisibility.push_back( std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); @@ -756,10 +791,10 @@ struct BroadcastOpInferAxisInfoOpInterface SetAxisInfoFn setResultAxisInfo) const { triton::BroadcastOp broadcastOp = cast(op); assert(argInfos.size() == 1 && "Expected one operand"); - TensorType retTy = broadcastOp.getResult().getType().cast(); + TensorType retTy = cast(broadcastOp.getResult().getType()); ArrayRef retShape = retTy.getShape(); TensorType opTy = - broadcastOp.getOperand().getType().dyn_cast_or_null(); + dyn_cast_or_null(broadcastOp.getOperand().getType()); ArrayRef opShape; SmallVector scalarAsShapeOne(retTy.getRank(), 1); if (opTy) { @@ -797,7 +832,7 @@ struct SplatOpInferAxisInfoOpInterface SetAxisInfoFn setResultAxisInfo) const { triton::SplatOp splatOp = cast(op); assert(argInfos.size() == 1 && "Expected one operand"); - TensorType retTy = splatOp.getResult().getType().cast(); + TensorType retTy = cast(splatOp.getResult().getType()); AxisInfoExt::DimVectorT divisibility, stride, strideValue; for (int d = 0; d < retTy.getRank(); ++d) { divisibility.push_back(argInfos[0].getDivisibility(0)); @@ -823,23 +858,251 @@ struct ExpandDimsOpInferAxisInfoOpInterface AxisInfoExt::DimVectorT divisibility = opInfo.getDivisibility(); AxisInfoExt::DimVectorT stride = opInfo.getStride(); AxisInfoExt::DimVectorT strideValue = opInfo.getStrideValue(); + ArrayRef srcShape = expandDimsOp.getSrc().getType().getShape(); + int64_t expandedDim = + std::max(static_cast(expandDimsOp.getAxis()) - 1, 0); + int64_t expandedDivisibility = + opInfo.isFullConstantDim(srcShape, expandedDim) + ? divisibility[expandedDim] + : AxisInfoExt::kInitValue; divisibility.insert(divisibility.begin() + expandDimsOp.getAxis(), - AxisInfoExt::kInitValue); + expandedDivisibility); stride.insert(stride.begin() + expandDimsOp.getAxis(), AxisInfoExt::kInitValue); - strideValue.insert(strideValue.begin() + expandDimsOp.getAxis(), 0); + strideValue.insert(strideValue.begin() + expandDimsOp.getAxis(), + AxisInfoExt::kStrideValueInitValue); return setResultAxisInfo(expandDimsOp->getResult(0), AxisInfoExt(divisibility, stride, strideValue, opInfo.getConstantValue())); } }; +struct ExpandShapeOpInferAxisInfoOpInterface + : public InferAxisInfoInterface::ExternalModel< + ExpandShapeOpInferAxisInfoOpInterface, tensor::ExpandShapeOp> { + + void inferAxisInfos(Operation *op, ArrayRef argInfos, + SetAxisInfoFn setResultAxisInfo) const { + tensor::ExpandShapeOp expandShapeOp = cast(op); + assert(argInfos.size() == 1 && "Expected one operand"); + + AxisInfoExt opInfo = argInfos[0]; + ArrayRef srcShape = expandShapeOp.getSrcType().getShape(); + ArrayRef resShape = expandShapeOp.getResultType().getShape(); + AxisInfoExt::DimVectorT divisibility, stride, strideValue; + for (auto [srcDim, indice] : + llvm::enumerate(expandShapeOp.getReassociationIndices())) { + // Init expanded axisinfo by source axisinfo. + int64_t srcDivisibility = opInfo.getDivisibility()[srcDim]; + int64_t srcStride = opInfo.getStride()[srcDim]; + int64_t srcStrideValue = opInfo.getStrideValue()[srcDim]; + AxisInfoExt::DimVectorT initStride(indice.size(), srcStride); + AxisInfoExt::DimVectorT initStrideValue(indice.size(), srcStrideValue); + AxisInfoExt::DimVectorT initDivisibility(indice.size(), srcDivisibility); + stride.insert(stride.end(), initStride.begin(), initStride.end()); + strideValue.insert(strideValue.end(), initStrideValue.begin(), + initStrideValue.end()); + divisibility.insert(divisibility.end(), initDivisibility.begin(), + initDivisibility.end()); + if (indice.size() == 1) + continue; + + // Calculate axisinfo of expanded dimension. + int64_t nextStride = srcStride; + int64_t nextStrideValue = srcStrideValue; + int64_t nextDivisibility = srcDivisibility; + for (auto resDim : llvm::reverse(indice)) { + if (nextStride >= resShape[resDim] && + nextStride % resShape[resDim] == 0) { + strideValue[resDim] = nextStrideValue; + nextStrideValue = nextStride == resShape[resDim] + ? AxisInfoExt::kStrideValueInitValue + : nextStrideValue * resShape[resDim]; + stride[resDim] = resShape[resDim]; + nextStride /= resShape[resDim]; + if (opInfo.isFullConstantDim(srcShape, srcDim)) { + divisibility[resDim] = nextDivisibility; + } else if (opInfo.isNonConstantFullStrideDim(srcShape, srcDim)) { + divisibility[resDim] = std::gcd( + nextDivisibility, resShape[resDim] * strideValue[resDim]); + nextDivisibility = srcStrideValue == 1 + ? AxisInfoExt::kInitValue + : highestPowOf2Divisor(strideValue[resDim]); + } else { + divisibility[resDim] = AxisInfoExt::kInitValue; + nextDivisibility = AxisInfoExt::kInitValue; + } + } else if (resShape[resDim] > nextStride && + resShape[resDim] % nextStride == 0) { + strideValue[resDim] = nextStrideValue; + nextStrideValue = AxisInfoExt::kStrideValueInitValue; + stride[resDim] = nextStride; + nextStride = AxisInfoExt::kInitValue; + divisibility[resDim] = AxisInfoExt::kInitValue; + nextDivisibility = AxisInfoExt::kInitValue; + } else { + strideValue[resDim] = AxisInfoExt::kStrideValueInitValue; + nextStrideValue = AxisInfoExt::kStrideValueInitValue; + stride[resDim] = AxisInfoExt::kInitValue; + nextStride = AxisInfoExt::kInitValue; + divisibility[resDim] = AxisInfoExt::kInitValue; + nextDivisibility = AxisInfoExt::kInitValue; + } + } + } + + return setResultAxisInfo(expandShapeOp->getResult(0), + AxisInfoExt(divisibility, stride, strideValue, + opInfo.getConstantValue())); + } +}; + +struct CollapseShapeOpInferAxisInfoOpInterface + : public InferAxisInfoInterface::ExternalModel< + CollapseShapeOpInferAxisInfoOpInterface, tensor::CollapseShapeOp> { + + void inferAxisInfos(Operation *op, ArrayRef argInfos, + SetAxisInfoFn setResultAxisInfo) const { + tensor::CollapseShapeOp collapseShapeOp = cast(op); + assert(argInfos.size() == 1 && "Expected one operand"); + + AxisInfoExt opInfo = argInfos[0]; + ArrayRef srcShape = collapseShapeOp.getSrcType().getShape(); + AxisInfoExt::DimVectorT divisibility, stride, strideValue; + for (const auto &indices : + llvm::enumerate(collapseShapeOp.getReassociationIndices())) { + int64_t resDim = indices.value().back(); + int64_t resDivisibility = opInfo.getDivisibility()[resDim]; + int64_t resStride = opInfo.getStride()[resDim]; + int64_t resStrideValue = opInfo.getStrideValue()[resDim]; + for (const auto &indice : + llvm::enumerate(llvm::reverse(indices.value()))) { + if (indices.value().size() == 1 || indice.index() == 0) + continue; + int64_t srcDim = indice.value(); + int64_t srcStride = opInfo.getStride()[srcDim]; + int64_t srcStrideValue = opInfo.getStrideValue()[srcDim]; + int64_t srcDivisibility = opInfo.getDivisibility()[srcDim]; + bool isLastDimFullStrided = + opInfo.getStride()[srcDim + 1] == srcShape[srcDim + 1]; + if (resStride == 1) { + resStride = srcStride; + resStrideValue = srcStrideValue; + } else if (srcStrideValue == resStride * resStrideValue && + isLastDimFullStrided) { + resStride *= srcStride; + } + resDivisibility = std::max(resDivisibility, srcDivisibility); + } + divisibility.push_back(resDivisibility); + stride.push_back(resStride); + strideValue.push_back(resStrideValue); + } + + return setResultAxisInfo(collapseShapeOp->getResult(0), + AxisInfoExt(divisibility, stride, strideValue, + opInfo.getConstantValue())); + } +}; + +struct ExtractSliceOpInferAxisInfoOpInterface + : public InferAxisInfoInterface::ExternalModel< + ExtractSliceOpInferAxisInfoOpInterface, tensor::ExtractSliceOp> { + + void inferAxisInfos(Operation *op, ArrayRef argInfos, + SetAxisInfoFn setResultAxisInfo) const { + tensor::ExtractSliceOp extractSliceOp = cast(op); + assert(argInfos.size() == 1 && "Expected one operand"); + + AxisInfoExt opInfo = argInfos[0]; + ArrayRef extractSliceOffsets = extractSliceOp.getStaticOffsets(); + ArrayRef extractSliceSizes = extractSliceOp.getStaticSizes(); + ArrayRef extractSliceStrides = extractSliceOp.getStaticStrides(); + ArrayRef srcShape = extractSliceOp.getSourceType().getShape(); + AxisInfoExt::DimVectorT divisibility, stride, strideValue; + int64_t extractSliceRank = extractSliceOp.getResultType().getRank(); + for (int64_t d = 0; d < extractSliceRank; ++d) { + bool isSliceInsideFirstStride = + (extractSliceOffsets[d] * extractSliceStrides[d] + + (extractSliceSizes[d] - 1) * extractSliceStrides[d]) <= + opInfo.getStride()[d]; + bool isSliceInsideOtherStride = + extractSliceOffsets[d] * extractSliceStrides[d] >= + opInfo.getStride()[d] && + (extractSliceOffsets[d] * extractSliceStrides[d] % + opInfo.getStride()[d] + + (extractSliceSizes[d] - 1) * extractSliceStrides[d]) <= + opInfo.getStride()[d]; + bool isSliceMultipleStride = + extractSliceStrides[d] == 1 && + extractSliceOffsets[d] % opInfo.getStride()[d] == 0 && + extractSliceSizes[d] % opInfo.getStride()[d] == 0 && + extractSliceSizes[d] / opInfo.getStride()[d] >= 1; + if (opInfo.isStridedConstantDim(srcShape, d)) { + if (opInfo.isFullStrideDim(srcShape, d) || isSliceInsideFirstStride) { + stride.push_back(extractSliceSizes[d]); + strideValue.push_back(opInfo.getStrideValue()[d]); + divisibility.push_back(opInfo.getDivisibility()[d]); + } else if (isSliceInsideOtherStride) { + stride.push_back(extractSliceSizes[d]); + strideValue.push_back(opInfo.getStrideValue()[d]); + divisibility.push_back(AxisInfoExt::kInitValue); + } else if (isSliceMultipleStride) { + stride.push_back(opInfo.getStride()[d]); + strideValue.push_back(opInfo.getStrideValue()[d]); + divisibility.push_back(extractSliceOffsets[d] == 0 + ? opInfo.getDivisibility()[d] + : AxisInfoExt::kInitValue); + } else { + stride.push_back(AxisInfoExt::kInitValue); + strideValue.push_back(AxisInfoExt::kStrideValueInitValue); + divisibility.push_back(extractSliceOffsets[d] == 0 + ? opInfo.getDivisibility()[d] + : AxisInfoExt::kInitValue); + } + } else if (opInfo.isNonStridedConstantStrideDim(srcShape, d)) { + if (opInfo.isFullStrideDim(srcShape, d) || isSliceInsideFirstStride) { + stride.push_back(extractSliceSizes[d]); + strideValue.push_back(opInfo.getStrideValue()[d] * + extractSliceStrides[d]); + divisibility.push_back(extractSliceOffsets[d] == 0 + ? opInfo.getDivisibility()[d] + : AxisInfoExt::kInitValue); + } else if (isSliceInsideOtherStride) { + stride.push_back(extractSliceSizes[d]); + strideValue.push_back(opInfo.getStrideValue()[d] * + extractSliceStrides[d]); + divisibility.push_back(AxisInfoExt::kInitValue); + } else if (isSliceMultipleStride) { + stride.push_back(opInfo.getStride()[d]); + strideValue.push_back(opInfo.getStrideValue()[d]); + divisibility.push_back(extractSliceOffsets[d] == 0 + ? opInfo.getDivisibility()[d] + : AxisInfoExt::kInitValue); + } else { + stride.push_back(AxisInfoExt::kInitValue); + strideValue.push_back(AxisInfoExt::kStrideValueInitValue); + divisibility.push_back(AxisInfoExt::kInitValue); + } + } else { + stride.push_back(AxisInfoExt::kInitValue); + strideValue.push_back(AxisInfoExt::kStrideValueInitValue); + divisibility.push_back(AxisInfoExt::kInitValue); + } + } + return setResultAxisInfo(extractSliceOp.getResult(), + AxisInfoExt(divisibility, stride, strideValue, + opInfo.getConstantValue())); + } +}; + } // anonymous namespace void mlir::triton::registerInferAxisInfoInterfaceExternalModels( DialectRegistry ®istry) { // Must ensure that any dependent dialects are registered. - registry.insert(); + registry.insert(); registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { arith::ConstantOp::attachInterface( @@ -885,4 +1148,13 @@ void mlir::triton::registerInferAxisInfoInterfaceExternalModels( triton::ExpandDimsOp::attachInterface( *ctx); }); + + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + tensor::ExtractSliceOp::attachInterface< + ExtractSliceOpInferAxisInfoOpInterface>(*ctx); + tensor::ExpandShapeOp::attachInterface< + ExpandShapeOpInferAxisInfoOpInterface>(*ctx); + tensor::CollapseShapeOp::attachInterface< + CollapseShapeOpInferAxisInfoOpInterface>(*ctx); + }); } diff --git a/lib/Dialect/Triton/Transforms/PointerStrengthReduction.cpp b/lib/Dialect/Triton/Transforms/PointerStrengthReduction.cpp index 5c731b8..365b279 100644 --- a/lib/Dialect/Triton/Transforms/PointerStrengthReduction.cpp +++ b/lib/Dialect/Triton/Transforms/PointerStrengthReduction.cpp @@ -35,6 +35,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton-linalg/Dialect/Triton/Transforms/PassDetail.h" // IWYU pragma: keep #include "triton-linalg/Dialect/Triton/Transforms/Passes.h" +#include "triton-linalg/Dialect/Triton/Utils/PointerInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "llvm/ADT/ArrayRef.h" @@ -61,51 +62,6 @@ class MLIRContext; namespace { -/// Structure info representation for pointer in triton. -class PtrInfo { -public: - PtrInfo() = delete; - PtrInfo(Value ptr, ArrayRef offsets) - : pointer(ptr), tensorPtrOffsets(offsets) {} - - PtrInfo(Value ptr, ArrayRef sizes, ArrayRef strides, - ArrayRef offsets, ArrayRef order) - : pointer(ptr), tensorPtrSizes(sizes), tensorPtrStrides(strides), - tensorPtrOffsets(offsets), tensorPtrOrder(order) {} - - PtrInfo(Value ptr, Value offset) : pointer(ptr) { - tensorPtrOffsets.push_back(offset); - isRawPtrInfo = true; - } - - Value ptr() const { return pointer; } - - ArrayRef offsets() const { return tensorPtrOffsets; } - Value offset(unsigned idx) const { return tensorPtrOffsets[idx]; } - Value offset() const { return tensorPtrOffsets[0]; } - void setOffsets(ValueRange vals) { - for (unsigned i = 0; i < vals.size(); i++) { - tensorPtrOffsets[i] = vals[i]; - } - } - unsigned offsetSize() { return tensorPtrOffsets.size(); } - - bool isBlockPtr() const { return !isRawPtrInfo; } - - ArrayRef sizes() const { return tensorPtrSizes; } - ArrayRef strides() const { return tensorPtrStrides; } - ArrayRef order() const { return tensorPtrOrder; } - -private: - bool isRawPtrInfo{false}; - Value pointer; - // Basic info for reconstruction of MakeTensorPtrOp. - SmallVector tensorPtrSizes; - SmallVector tensorPtrStrides; - SmallVector tensorPtrOffsets; - SmallVector tensorPtrOrder; -}; - /// Check if there are repeat arguments in block inputs and terminatorOp. static bool verifyArgsMatchTerminatorInputsInBlock(Block *block, unsigned blockArgStart, @@ -334,17 +290,9 @@ class PtrStrengthReductionPattern : public OpRewritePattern { return failure(); } IRRewriter rewriter(ptr.getContext()); - SmallVector sizes; - SmallVector strides; - SmallVector offsets; - SmallVector orders; - for (int i = 0; i < makeTensorPtrOp.getOffsets().size(); ++i) { - sizes.push_back(makeTensorPtrOp.getShape()[i]); - strides.push_back(makeTensorPtrOp.getStrides()[i]); - offsets.push_back(makeTensorPtrOp.getOffsets()[i]); - orders.push_back(makeTensorPtrOp.getOrder()[i]); - } - return PtrInfo(makeTensorPtrOp.getBase(), sizes, strides, offsets, orders); + return PtrInfo(makeTensorPtrOp.getBase(), makeTensorPtrOp.getShape(), + makeTensorPtrOp.getStrides(), makeTensorPtrOp.getOffsets(), + makeTensorPtrOp.getOrder()); } template @@ -444,7 +392,7 @@ LogicalResult PtrStrengthReductionPattern::matchAndRewrite( auto newOffset = rewriter.create( loc, RankedTensorType::get( - op.getResult().getType().cast().getShape(), + cast(op.getResult().getType()).getShape(), getElementTypeOrSelf(info->offset().getType())), info->offset()); auto newPtr = rewriter.create( @@ -465,7 +413,7 @@ LogicalResult PtrStrengthReductionPattern::matchAndRewrite( auto newOffset = rewriter.create( loc, RankedTensorType::get( - op.getResult().getType().cast().getShape(), + cast(op.getResult().getType()).getShape(), getElementTypeOrSelf(info->offset().getType())), info->offset(), false); auto newPtr = rewriter.create( @@ -486,7 +434,7 @@ LogicalResult PtrStrengthReductionPattern::matchAndRewrite( auto newOffset = rewriter.create( loc, RankedTensorType::get( - op.getResult().getType().cast().getShape(), + cast(op.getResult().getType()).getShape(), getElementTypeOrSelf(info->offset().getType())), info->offset(), op.getOrder()); auto newPtr = rewriter.create(loc, op.getResult().getType(), @@ -578,11 +526,11 @@ class PtrWithCFGStrengthReductionPattern private: bool isTritonPtrWithTensor(Type type) const { - if (auto ptrType = type.dyn_cast()) { - return triton::getPointeeType(type).isa(); + if (auto ptrType = dyn_cast(type)) { + return isa(triton::getPointeeType(type)); } - if (auto tensorTy = type.dyn_cast()) { - return tensorTy.getElementType().isa(); + if (auto tensorTy = dyn_cast(type)) { + return isa(tensorTy.getElementType()); } return false; } @@ -600,7 +548,7 @@ class PtrWithCFGStrengthReductionPattern rewriter.setInsertionPointAfter(splatOp); Type offsetType = rewriter.getIntegerType(32); assert(isTritonPtrWithTensor(splatOp.getResult().getType())); - if (auto type = splatOp.getResult().getType().dyn_cast()) { + if (auto type = dyn_cast(splatOp.getResult().getType())) { offsetType = RankedTensorType::get(type.getShape(), rewriter.getIntegerType(32)); } @@ -610,18 +558,9 @@ class PtrWithCFGStrengthReductionPattern } // Get previous ptr of tensor and infos. if (auto makeTensorPtrOp = ptr.getDefiningOp()) { - SmallVector sizes; - SmallVector strides; - SmallVector offsets; - SmallVector orders; - for (int i = 0; i < makeTensorPtrOp.getOffsets().size(); ++i) { - sizes.push_back(makeTensorPtrOp.getShape()[i]); - strides.push_back(makeTensorPtrOp.getStrides()[i]); - offsets.push_back(makeTensorPtrOp.getOffsets()[i]); - orders.push_back(makeTensorPtrOp.getOrder()[i]); - } - return PtrInfo(makeTensorPtrOp.getBase(), sizes, strides, offsets, - orders); + return PtrInfo(makeTensorPtrOp.getBase(), makeTensorPtrOp.getShape(), + makeTensorPtrOp.getStrides(), makeTensorPtrOp.getOffsets(), + makeTensorPtrOp.getOrder()); } return failure(); } @@ -678,7 +617,7 @@ class PtrWithCFGStrengthReductionPattern auto offset = info.offset(); if (isIntType(offset, 32)) { Type targetOffsetType = rewriter.getIntegerType(64); - if (auto type = offset.getType().dyn_cast()) { + if (auto type = dyn_cast(offset.getType())) { targetOffsetType = RankedTensorType::get(type.getShape(), rewriter.getIntegerType(64)); } diff --git a/lib/Dialect/Triton/Transforms/WrapFuncBodyWithSingleBlock.cpp b/lib/Dialect/Triton/Transforms/WrapFuncBodyWithSingleBlock.cpp index dfc0778..059a437 100644 --- a/lib/Dialect/Triton/Transforms/WrapFuncBodyWithSingleBlock.cpp +++ b/lib/Dialect/Triton/Transforms/WrapFuncBodyWithSingleBlock.cpp @@ -58,7 +58,7 @@ static void encapsulateMultiBlock(FunctionOpInterface funcOp) { // Add scf.execute_region to the entry block. builder.setInsertionPointToStart(newBlock); - FunctionType funcType = funcOp.getFunctionType().cast(); + FunctionType funcType = cast(funcOp.getFunctionType()); auto containerOp = builder.create(loc, funcType.getResults()); auto &containerRegion = containerOp.getRegion(); diff --git a/lib/Dialect/Triton/Utils/MaskTracker.cpp b/lib/Dialect/Triton/Utils/MaskTracker.cpp index 981abb0..89c9758 100644 --- a/lib/Dialect/Triton/Utils/MaskTracker.cpp +++ b/lib/Dialect/Triton/Utils/MaskTracker.cpp @@ -136,6 +136,10 @@ inline raw_ostream &operator<<(raw_ostream &os, const Mask &s) { } using Result = std::variant; +inline raw_ostream &operator<<(raw_ostream &os, const Result &s) { + std::visit([&os](const auto &val) { os << val; }, s); + return os; +} /// A visitor(std::visit functor) used to wrapper calculations between different /// of results. @@ -235,15 +239,44 @@ struct CmpVisitor : public VisitorBase { : VisitorBase(loc, rewriter), cmpTy(cmpTy) {} template FailureOr operator()(const T1 &lhs, const T2 &rhs) { + // Compare range and scalar. if constexpr (std::is_same_v && std::is_same_v) { return compareSimpleRange(lhs, rhs, cmpTy); } + // Compare scalar and range. + if constexpr (std::is_same_v && + std::is_same_v) { + // The function compareSimpleRange handles comparisons between + // a range and a scalar. When the input is in the form + // of (scalar, range), it is necessary to swap the order of the + // arguments and map the comparison operation to its equivalent operation. + return compareSimpleRange(rhs, lhs, reversePredicate(cmpTy)); + } return rewriter.notifyMatchFailure(loc, "Unsupported cmpi scenario"); } private: arith::CmpIPredicate cmpTy; + + arith::CmpIPredicate reversePredicate(mlir::arith::CmpIPredicate cmpTy) { + /* + * (range < scalar) <=> (scalar > range) + * (range > scalar) <=> (scalar < range) + * (range <= scale) <=> (scalar >= range) + * (range >= scalar) <=> (scalar <= range) + */ + static const llvm::DenseMap + map = { + {arith::CmpIPredicate::slt, arith::CmpIPredicate::sgt}, + {arith::CmpIPredicate::sgt, arith::CmpIPredicate::slt}, + {arith::CmpIPredicate::sle, arith::CmpIPredicate::sge}, + {arith::CmpIPredicate::sge, arith::CmpIPredicate::sle}, + }; + auto it = map.find(cmpTy); + assert(it != map.end()); + return it->second; + } FailureOr compareSimpleRange(const SimpleRange &lhs, const Scalar &rhs, mlir::arith::CmpIPredicate cmpTy) { @@ -263,12 +296,18 @@ struct CmpVisitor : public VisitorBase { case arith::CmpIPredicate::sle: { auto openedUpperBound = addOFRs(rhs.scalar, rewriter.getIndexAttr(1), loc, rewriter); + // The value of `rhs.scalar` might be exactly `INT64_MAX`. + // We need to prevent overflow after adding one. + openedUpperBound = + maxOFRs(openedUpperBound, rhs.scalar, loc, rewriter); newDim = cmpSlt(ret, lhs, openedUpperBound, i); break; } case arith::CmpIPredicate::sgt: { auto closedLowerBound = addOFRs(rhs.scalar, rewriter.getIndexAttr(1), loc, rewriter); + closedLowerBound = + maxOFRs(closedLowerBound, rhs.scalar, loc, rewriter); newDim = cmpSgt(ret, lhs, closedLowerBound, i); break; } @@ -520,9 +559,9 @@ class MaskParser { /// Get the value of the constant and assign it to scalar. FailureOr parseOp(arith::ConstantOp constOp) { // Scalar constant will be processed in func parseIntScalar. - auto attr = constOp.getValue().cast(); + auto attr = cast(constOp.getValue()); - if (!attr.isSplat() || !attr.getElementType().isa()) { + if (!attr.isSplat() || !isa(attr.getElementType())) { return rewriter.notifyMatchFailure( loc, "All elements must share a single integer constant value"); } @@ -619,7 +658,7 @@ class MaskParser { /// Operand is the result of make_range. /// Set start and end accordingly; step size must be 1. FailureOr parseOp(triton::MakeRangeOp rangeOp) { - auto shape = rangeOp.getType().cast().getShape(); + auto shape = cast(rangeOp.getType()).getShape(); auto start = rangeOp.getStart(); auto end = rangeOp.getEnd(); assert(((end - start + shape[0] - 1) / shape[0] == 1) && @@ -642,8 +681,8 @@ class MaskParser { auto dst = broadcastOp.getResult(); // We canonicalize tt.broadcast in triton canonicalization pass, // so no scalar case here. - auto dstShape = dst.getType().cast().getShape(); - auto srcShape = src.getType().cast().getShape(); + auto dstShape = cast(dst.getType()).getShape(); + auto srcShape = cast(src.getType()).getShape(); assert(srcShape.size() == dstShape.size() && "rank of source and destination should match"); @@ -664,7 +703,7 @@ class MaskParser { FailureOr parseOp(triton::SplatOp splatOp) { auto src = splatOp.getSrc(); auto dst = splatOp.getResult(); - auto dstShape = dst.getType().cast().getShape(); + auto dstShape = cast(dst.getType()).getShape(); auto ret = parse(src); if (failed(ret)) @@ -685,11 +724,10 @@ class MaskParser { return failure(); auto axis = expandDimsOp.getAxis(); - assert(expandDimsOp.getResult() - .getType() - .cast() - .getShape()[axis] == 1 && - "expect changed dimension to be 1 in expand_dims"); + assert( + cast(expandDimsOp.getResult().getType()).getShape()[axis] == + 1 && + "expect changed dimension to be 1 in expand_dims"); if (failed(std::visit(ExpandDimVisitor(loc, rewriter, axis), *ret))) return failure(); @@ -717,12 +755,12 @@ class MaskParser { } FailureOr parseUnknownValue(Value operand) { - auto type = operand.getType().dyn_cast(); + auto type = dyn_cast(operand.getType()); if (!type) return rewriter.notifyMatchFailure( loc, "only support track shaped type value"); - assert((type.getElementType().isa() && + assert((isa(type.getElementType()) && "unsupport unknown value type")); Result ret; @@ -736,7 +774,7 @@ class MaskParser { } FailureOr parse(Value operand) { - if (operand.getType().isa()) { + if (isa(operand.getType())) { return parseIntScalar(operand); } @@ -761,7 +799,7 @@ class MaskParser { } // namespace void MaskTracker::parse(Value operand, Location loc, RewriterBase &rewriter) { - auto shapeTy = operand.getType().dyn_cast(); + auto shapeTy = dyn_cast(operand.getType()); if (!shapeTy) return; int64_t rank = shapeTy.getRank(); diff --git a/lib/Dialect/Triton/Utils/PointerMetaInfoTracker.cpp b/lib/Dialect/Triton/Utils/PointerMetaInfoTracker.cpp index fd94792..93b533f 100644 --- a/lib/Dialect/Triton/Utils/PointerMetaInfoTracker.cpp +++ b/lib/Dialect/Triton/Utils/PointerMetaInfoTracker.cpp @@ -146,7 +146,7 @@ LogicalResult PointerMetaInfoTracker::parseOp( this->offset = rewriter.create( loc, RankedTensorType::get( - op.getResult().getType().cast().getShape(), + mlir::cast(op.getResult().getType()).getShape(), this->offset.getType()), this->offset); return success(); @@ -171,12 +171,67 @@ LogicalResult PointerMetaInfoTracker::parseOp( this->offset = rewriter.create( loc, RankedTensorType::get( - op.getResult().getType().cast().getShape(), + mlir::cast(op.getResult().getType()).getShape(), getElementTypeOrSelf(this->offset.getType())), this->offset); return success(); } +template <> +LogicalResult PointerMetaInfoTracker::parseOp( + tensor::ExtractOp op, Location loc, ConversionPatternRewriter &rewriter) { + if (failed(parse(op.getTensor(), loc, rewriter))) + return failure(); + this->offset = + rewriter.create(loc, this->offset, op.getIndices()); + return success(); +} + +template <> +LogicalResult PointerMetaInfoTracker::parseOp( + tensor::ExtractSliceOp op, Location loc, + ConversionPatternRewriter &rewriter) { + if (failed(parse(op.getSource(), loc, rewriter))) + return failure(); + SmallVector offsets = op.getMixedOffsets(); + SmallVector sizes = op.getMixedSizes(); + SmallVector strides = op.getMixedStrides(); + this->offset = rewriter.create( + loc, this->offset, offsets, sizes, strides); + return success(); +} + +template <> +LogicalResult PointerMetaInfoTracker::parseOp( + tensor::ExpandShapeOp op, Location loc, + ConversionPatternRewriter &rewriter) { + if (failed(parse(op.getSrc(), loc, rewriter))) + return failure(); + this->offset = rewriter.create( + loc, + RankedTensorType::get( + mlir::cast(op.getResult().getType()).getShape(), + getElementTypeOrSelf(this->offset.getType())), + this->offset, op.getReassociation(), op.getOutputShape(), + op.getStaticOutputShape()); + return success(); +} + +template <> +LogicalResult PointerMetaInfoTracker::parseOp( + tensor::CollapseShapeOp op, Location loc, + ConversionPatternRewriter &rewriter) { + if (failed(parse(op.getOperand(), loc, rewriter))) + return failure(); + this->offset = rewriter.create( + loc, + RankedTensorType::get( + mlir::cast(op.getResult().getType()).getShape(), + getElementTypeOrSelf(this->offset.getType())), + this->offset, op.getReassociation()); + return success(); +} + FailureOr PointerMetaInfoTracker::parse(Value operand, Location loc, ConversionPatternRewriter &rewriter) { @@ -203,7 +258,9 @@ PointerMetaInfoTracker::parse(Value operand, Location loc, auto res = llvm::TypeSwitch(defOp) .Case([&](auto op) { + triton::BroadcastOp, triton::ExpandDimsOp, + tensor::ExpandShapeOp, tensor::CollapseShapeOp, + tensor::ExtractOp, tensor::ExtractSliceOp>([&](auto op) { auto ret = parseOp(op, loc, rewriter); isProcessedSuccessfully = ret.succeeded(); return ret; @@ -214,7 +271,7 @@ PointerMetaInfoTracker::parse(Value operand, Location loc, if (res.failed() && !isProcessedSuccessfully) return failure(); // res } - if (!operand.getType().isa()) + if (!mlir::isa(operand.getType())) return rewriter.notifyMatchFailure( loc, "only support base ptr of triton scalar pointer"); this->base = operand; diff --git a/lib/Dialect/Utils/ArithUtils.cpp b/lib/Dialect/Utils/ArithUtils.cpp index 20dbbb5..ec27b01 100644 --- a/lib/Dialect/Utils/ArithUtils.cpp +++ b/lib/Dialect/Utils/ArithUtils.cpp @@ -66,6 +66,7 @@ Value mlir::triton::createScalarOrSplatConstant(OpBuilder &builder, //===----------------------------------------------------------------------===// // END copied from mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp //===----------------------------------------------------------------------===// + FailureOr mlir::triton::getSplatValue(OpBuilder &builder, arith::ConstantOp op) { auto loc = op.getLoc(); @@ -74,10 +75,10 @@ FailureOr mlir::triton::getSplatValue(OpBuilder &builder, return op.getResult(); } Type retType = op.getType(); - auto tensorType = retType.dyn_cast_or_null(); + auto tensorType = dyn_cast_or_null(retType); if (!tensorType) return failure(); - auto value = op.getValue().dyn_cast(); + auto value = dyn_cast(op.getValue()); if (!value || !value.isSplat()) return failure(); @@ -103,88 +104,3 @@ FailureOr mlir::triton::getSplatValue(OpBuilder &builder, return failure(); return fillVal; } - -std::optional -mlir::triton::getCmpSelectResult(OpBuilder &builder, Location loc, - arith::CmpFOp op, bool operandsSwapped) { - auto predicate = op.getPredicate(); - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - switch (predicate) { - case arith::CmpFPredicate::OGT: - case arith::CmpFPredicate::UGT: - case arith::CmpFPredicate::OGE: - case arith::CmpFPredicate::UGE: - return operandsSwapped ? builder.create(loc, lhs, rhs) - : builder.create(loc, lhs, rhs); - case arith::CmpFPredicate::OLT: - case arith::CmpFPredicate::ULT: - case arith::CmpFPredicate::OLE: - case arith::CmpFPredicate::ULE: - return operandsSwapped ? builder.create(loc, lhs, rhs) - : builder.create(loc, lhs, rhs); - default: - return std::nullopt; - } -} - -std::optional -mlir::triton::getCmpSelectResult(OpBuilder &builder, Location loc, - arith::CmpIOp op, bool operandsSwapped) { - auto predicate = op.getPredicate(); - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - switch (predicate) { - case arith::CmpIPredicate::sgt: - case arith::CmpIPredicate::sge: - return operandsSwapped ? builder.create(loc, lhs, rhs) - : builder.create(loc, lhs, rhs); - case arith::CmpIPredicate::ugt: - case arith::CmpIPredicate::uge: - return operandsSwapped ? builder.create(loc, lhs, rhs) - : builder.create(loc, lhs, rhs); - case arith::CmpIPredicate::slt: - case arith::CmpIPredicate::sle: - return operandsSwapped ? builder.create(loc, lhs, rhs) - : builder.create(loc, lhs, rhs); - case arith::CmpIPredicate::ult: - case arith::CmpIPredicate::ule: - return operandsSwapped ? builder.create(loc, lhs, rhs) - : builder.create(loc, lhs, rhs); - default: - return std::nullopt; - } -} - -std::optional -mlir::triton::getCmpSelectResult(OpBuilder &builder, Operation *cmpOp, - arith::SelectOp op) { - // Get cmp op mode. - std::optional cmpFOp; - std::optional cmpIOp; - if (isa(cmpOp)) { - cmpFOp = cast(cmpOp); - } else if (isa(cmpOp)) { - cmpIOp = cast(cmpOp); - } else { - return std::nullopt; - } - // Get specific max/min semantics. - auto loc = op.getLoc(); - if (op->getOperand(1) == cmpOp->getOperand(0) && - op->getOperand(2) == cmpOp->getOperand(1)) { - if (cmpFOp) { - return getCmpSelectResult(builder, loc, *cmpFOp, false); - } else if (cmpIOp) { - return getCmpSelectResult(builder, loc, *cmpIOp, false); - } - } else if (op->getOperand(1) == cmpOp->getOperand(1) && - op->getOperand(2) == cmpOp->getOperand(0)) { - if (cmpFOp) { - return getCmpSelectResult(builder, loc, *cmpFOp, true); - } else if (cmpIOp) { - return getCmpSelectResult(builder, loc, *cmpIOp, true); - } - } - return std::nullopt; -} diff --git a/lib/Dialect/Utils/Conventions.cpp b/lib/Dialect/Utils/Conventions.cpp index 61f44b4..116966b 100644 --- a/lib/Dialect/Utils/Conventions.cpp +++ b/lib/Dialect/Utils/Conventions.cpp @@ -17,8 +17,8 @@ using namespace mlir::triton; bool mlir::triton::isLinearMemory(::mlir::ModuleOp op) { if (auto attr = op->getAttr(getIsLinearMemoryAttrKey())) { - assert(attr.dyn_cast() && "Invalid linear attribute type"); - return attr.cast().getValue(); + assert(dyn_cast(attr) && "Invalid linear attribute type"); + return cast(attr).getValue(); } // The default value for missing linear attribute is false. diff --git a/lib/Dialect/Utils/ShapeUtils.cpp b/lib/Dialect/Utils/ShapeUtils.cpp index 6844e75..550ea3e 100644 --- a/lib/Dialect/Utils/ShapeUtils.cpp +++ b/lib/Dialect/Utils/ShapeUtils.cpp @@ -43,6 +43,35 @@ bool mlir::triton::isConsecutive(llvm::ArrayRef array) { }); } +bool mlir::triton::trailingNDimsContiguous(MemRefType type, int64_t n) { + if (canonicalizeStridedLayout(type).getLayout().isIdentity()) + return true; + + auto memrefShape = type.getShape().take_back(n); + if (ShapedType::isDynamicShape(memrefShape.drop_front())) + return false; + + int64_t offset; + SmallVector stridesFull; + if (!succeeded(getStridesAndOffset(type, stridesFull, offset))) + return false; + auto strides = ArrayRef(stridesFull).take_back(n); + + if (strides.empty()) + return true; + + // Check whether strides match "flattened" dims. + SmallVector flattenedDims; + auto dimProduct = 1; + for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { + dimProduct *= dim; + flattenedDims.push_back(dimProduct); + } + + strides = strides.drop_back(1); + return llvm::equal(strides, llvm::reverse(flattenedDims)); +} + /// Returns a memref.subview or a tensor.extract_slice based on the type of the /// `source`. Value mlir::triton::getSlice(OpBuilder &b, Location loc, Value source, @@ -87,7 +116,7 @@ mlir::triton::canonicalizeOpFoldResult(ArrayRef in) { Value mlir::triton::getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { - ShapedType type = v.getType().cast(); + ShapedType type = cast(v.getType()); if (!type.isDynamicDim(dim)) { return builder.create(loc, type.getDimSize(dim)); } @@ -102,7 +131,7 @@ Value mlir::triton::getDimValue(OpBuilder &builder, Location loc, Value v, OpFoldResult mlir::triton::getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) { - auto t = v.getType().cast(); + auto t = cast(v.getType()); if (t.isDynamicDim(dim)) { return getDimValue(builder, loc, v, dim); } @@ -113,7 +142,7 @@ SmallVector mlir::triton::getDims(OpBuilder &builder, Location loc, Value shapedTypeValue) { SmallVector ret; for (auto i : llvm::seq( - 0, shapedTypeValue.getType().cast().getRank())) { + 0, cast(shapedTypeValue.getType()).getRank())) { ret.push_back(getDim(builder, loc, shapedTypeValue, i)); } return ret; @@ -123,7 +152,7 @@ SmallVector mlir::triton::getDimsValue(OpBuilder &builder, Location loc, Value shapedTypeValue) { SmallVector ret; for (auto i : llvm::seq( - 0, shapedTypeValue.getType().cast().getRank())) { + 0, cast(shapedTypeValue.getType()).getRank())) { ret.push_back(getDimValue(builder, loc, shapedTypeValue, i)); } return ret; @@ -132,7 +161,7 @@ SmallVector mlir::triton::getDimsValue(OpBuilder &builder, Location loc, SmallVector mlir::triton::getDynamicDimsValue(OpBuilder &builder, Location loc, Value val) { SmallVector dynamicDims; - auto type = val.getType().cast(); + auto type = cast(val.getType()); for (auto dimIdx : llvm::seq(0, type.getRank())) { if (type.isDynamicDim(dimIdx)) { dynamicDims.push_back(getDimValue(builder, loc, val, dimIdx)); @@ -143,15 +172,15 @@ SmallVector mlir::triton::getDynamicDimsValue(OpBuilder &builder, Value mlir::triton::materializeOpFoldResult(OpBuilder &builder, Location loc, OpFoldResult opFoldResult) { - if (auto value = opFoldResult.dyn_cast()) + if (auto value = dyn_cast(opFoldResult)) return value; - auto attr = opFoldResult.get().cast(); + auto attr = cast(opFoldResult.get()); return builder.create(loc, attr.getValue().getSExtValue()); } Value mlir::triton::prependUnitDim(OpBuilder &b, Location loc, Value value) { - auto valTy = value.getType().cast(); + auto valTy = cast(value.getType()); int64_t rank = valTy.getRank(); SmallVector shape(valTy.getShape()); shape.insert(shape.begin(), 1); @@ -177,7 +206,7 @@ Value mlir::triton::prependUnitDim(OpBuilder &b, Location loc, Value value) { } Value mlir::triton::dropUnitFirstDim(OpBuilder &b, Location loc, Value value) { - auto valTy = value.getType().cast(); + auto valTy = cast(value.getType()); int64_t rank = valTy.getRank(); assert(rank > 0 && valTy.getShape().front() == 1); @@ -198,7 +227,7 @@ Value mlir::triton::dropUnitFirstDim(OpBuilder &b, Location loc, Value value) { } Value mlir::triton::appendUnitDim(OpBuilder &b, Location loc, Value value) { - auto valTy = value.getType().cast(); + auto valTy = cast(value.getType()); int64_t rank = valTy.getRank(); SmallVector shape(valTy.getShape()); shape.push_back(1); @@ -223,17 +252,35 @@ Value mlir::triton::appendUnitDim(OpBuilder &b, Location loc, Value value) { .Default([](auto) -> Value { llvm_unreachable("unsupport value type"); }); } +static bool DetermineLastNDContiguous(MemRefType type, int64_t n, + bool exceptLastDim) { + int64_t idx = type.getRank(); + for (; idx > 0; idx--) { + if (mlir::triton::trailingNDimsContiguous(type, idx)) + break; + } + return idx >= n + static_cast(exceptLastDim); +} + Value mlir::triton::collapseLastNDimsToOneDim(OpBuilder &b, Location loc, - Value value, int64_t n) { + Value value, int64_t n, + bool exceptLastDim) { if (!value || n == 1) return value; - auto valueTy = value.getType().cast(); + if (isa(value.getType())) { + assert(DetermineLastNDContiguous(cast(value.getType()), n, + exceptLastDim) && + "The dimensions that require collapse need to be continuous."); + } + auto valueTy = cast(value.getType()); auto rank = valueTy.getRank(); + if (exceptLastDim) + rank -= 1; assert(rank >= n && "Dim number to collapse is larger than rank."); - // Add a unit dim to the last. - if (n == 0) + // When exceptLastDim is false and n == 0, add a unit dim to the last. + if (n == 0 && !exceptLastDim) return appendUnitDim(b, loc, value); // Collapse the last n(n > 1) dims to one dim. @@ -241,6 +288,8 @@ Value mlir::triton::collapseLastNDimsToOneDim(OpBuilder &b, Location loc, for (int64_t i = 0; i < rank - n; ++i) reassociation.push_back({i}); reassociation.push_back(llvm::to_vector(llvm::seq(rank - n, rank))); + if (exceptLastDim) + reassociation.push_back({rank}); return TypeSwitch(valueTy) .Case([&](auto) { return b.create(loc, value, reassociation); @@ -251,5 +300,5 @@ Value mlir::triton::collapseLastNDimsToOneDim(OpBuilder &b, Location loc, } bool mlir::triton::isScalar(const Value val) { - return !val.getType().isa(); + return !isa(val.getType()); } diff --git a/lib/Interfaces/CMakeLists.txt b/lib/Interfaces/CMakeLists.txt deleted file mode 100644 index cfd5f7b..0000000 --- a/lib/Interfaces/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_triton_library(TritonLinalgInterface - InferResultTypeOpInterface.cpp - - DEPENDS - TritonLinalgInterfacesTableGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/lib/Interfaces/InferResultTypeOpInterface.cpp b/lib/Interfaces/InferResultTypeOpInterface.cpp deleted file mode 100644 index 6d79454..0000000 --- a/lib/Interfaces/InferResultTypeOpInterface.cpp +++ /dev/null @@ -1,13 +0,0 @@ -//===- InferResultTypeOpInterface.cpp - Infer result type -------*- C++ -*-===// -// -// Copyright (C) [2022-2025] by Cambricon. -// -//===----------------------------------------------------------------------===// -// -// This file implements the operation interface infers result type. -// -//===----------------------------------------------------------------------===// -#include "triton-linalg/Interfaces/InferResultTypeOpInterface.h" -#include "llvm/ADT/STLExtras.h" // IWYU pragma: keep - -#include "triton-linalg/Interfaces/InferResultTypeOpInterface.cpp.inc" diff --git a/lib/Pipelines/CMakeLists.txt b/lib/Pipelines/CMakeLists.txt index 0ead1d6..5c1ac28 100644 --- a/lib/Pipelines/CMakeLists.txt +++ b/lib/Pipelines/CMakeLists.txt @@ -6,6 +6,5 @@ add_triton_library(TritonLinalgPipelines MathToLinalg MLIRIR TritonToLinalg - TritonToTensor TritonTransformsExtend ) diff --git a/lib/Pipelines/Pipelines.cpp b/lib/Pipelines/Pipelines.cpp index c9749e1..12eb39a 100644 --- a/lib/Pipelines/Pipelines.cpp +++ b/lib/Pipelines/Pipelines.cpp @@ -13,7 +13,6 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" #include "triton-linalg/Conversion/Passes.h" -#include "triton-linalg/Dialect/Arith/Transforms/Passes.h" #include "triton-linalg/Dialect/Triton/Transforms/Passes.h" #include "llvm/ADT/StringRef.h" #include @@ -24,11 +23,9 @@ void buildTritonToLinalgPipeline(mlir::OpPassManager &pm) { pm.addPass(mlir::createInlinerPass({}, nullptr)); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::triton::createCanonicalizeTritonPass()); - pm.addPass(mlir::triton::arith_ext::createArithCanonicalizerPass()); pm.addPass(mlir::triton::createPointerStrengthReductionPass()); // Since canonicalizer pass may convert single block function to multi-blocks, // we rerun this pass here. - pm.addPass(mlir::triton::createTritonToTensorPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::triton::createTritonToLinalgPass()); pm.addNestedPass( @@ -39,7 +36,6 @@ void buildTritonToLinalgPipeline(mlir::OpPassManager &pm) { pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createLoopInvariantCodeMotionPass()); pm.addPass(mlir::triton::createWrapFuncBodyWithSingleBlockPass()); - pm.addPass(mlir::triton::arith_ext::createArithCanonicalizerPass()); } } // namespace diff --git a/lib/Utils/Utils.cpp b/lib/Utils/Utils.cpp index 3323134..c8e453f 100644 --- a/lib/Utils/Utils.cpp +++ b/lib/Utils/Utils.cpp @@ -80,12 +80,12 @@ bool mlir::triton::createReassociationMaps( } Value triton::castToIndexType(OpBuilder &b, Location loc, OpFoldResult ofr) { - if (auto value = ofr.dyn_cast()) { - if (!value.getType().isa()) + if (auto value = dyn_cast(ofr)) { + if (!isa(value.getType())) return b.createOrFold(loc, b.getIndexType(), value); return value; } - auto attr = ofr.dyn_cast().dyn_cast(); + auto attr = dyn_cast(dyn_cast(ofr)); assert(attr && "expect the op fold result casts to an integer attribute"); return b.create(loc, attr.getValue().getSExtValue()) .getResult(); diff --git a/test/Conversion/arith-to-linalg.mlir b/test/Conversion/arith-to-linalg.mlir index 474f521..a04e257 100644 --- a/test/Conversion/arith-to-linalg.mlir +++ b/test/Conversion/arith-to-linalg.mlir @@ -1,5 +1,7 @@ // RUN: triton-linalg-opt -convert-arith-to-linalg -split-input-file %s | FileCheck %s +// ----- + func.func @const_valid_float(%arg0: tensor<1x16x128x128xf32>) -> tensor<1x16x128x128xf32> { // CHECK: %cst = arith.constant 0.000000e+00 : f32 // CHECK: tensor.empty diff --git a/test/Conversion/math-to-linalg.mlir b/test/Conversion/math-to-linalg.mlir index 6e1bc99..305f4b4 100644 --- a/test/Conversion/math-to-linalg.mlir +++ b/test/Conversion/math-to-linalg.mlir @@ -1,5 +1,6 @@ // RUN: triton-linalg-opt -convert-math-to-linalg -split-input-file %s | FileCheck %s +// ----- func.func @math_log(%arg0: tensor<128xf32>) { // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xf32> // CHECK: %[[MAPPED:.*]] = linalg.map { math.log } ins(%arg0 : tensor<128xf32>) outs(%[[INIT]] : tensor<128xf32>) @@ -322,3 +323,44 @@ func.func @math_fma_tensor_staic(%arg0: tensor<100x10xf32>, %arg1: tensor<100x10 return %0 : tensor<100x10xf32> } +// ----- +func.func @math_rsqrt_tensor_staic(%arg0: tensor<64x128xf16>) -> tensor<64x128xf16> { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<64x128xf16> + // CHECK: %mapped = linalg.map { math.rsqrt } ins(%arg0 : tensor<64x128xf16>) outs(%[[INIT]] : tensor<64x128xf16>) + %0 = math.rsqrt %arg0: tensor<64x128xf16> + return %0 : tensor<64x128xf16> +} + +// ----- +tt.func @tt_mulhiui_vector_i32(%arg0: tensor<16x16xi32>, %arg1: tensor<16x16xi32>) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<16x16xi32> + // CHECK: %mapped = linalg.map { math_ext.mulhiui } ins(%arg0, %arg1 : tensor<16x16xi32>, tensor<16x16xi32>) outs(%[[INIT]] : tensor<16x16xi32>) + %0 = math_ext.mulhiui %arg0, %arg1 : tensor<16x16xi32> + tt.return +} + +// ----- +func.func @math_tanh_scalar(%arg0: f32) { + // CHECK: math.tanh %arg0 : f32 + // CHECK-NOT: linalg.map + %0 = math.tanh %arg0 : f32 + return +} + +// ----- +func.func @math_tanh_tensor_staic(%arg0: tensor<128xf32>) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xf32> + // CHECK: %[[MAPPED:.*]] = linalg.map { math.tanh } ins(%arg0 : tensor<128xf32>) outs(%[[INIT]] : tensor<128xf32>) + %0 = math.tanh %arg0 : tensor<128xf32> + return +} + +// ----- +func.func @math_tanh_tensor_partial_static(%arg0: tensor<128x?xf32>) { + // CHECK: %[[CST:.*]] = arith.constant 1 : index + // CHECK: %[[DYNAMIC_DIM:.*]] = tensor.dim %arg0, %[[CST]] : tensor<128x?xf32> + // CHECK: %[[INIT:.*]] = tensor.empty(%[[DYNAMIC_DIM]]) : tensor<128x?xf32> + // CHECK: %[[MAPPED:.*]] = linalg.map { math.tanh } ins(%arg0 : tensor<128x?xf32>) outs(%[[INIT]] : tensor<128x?xf32>) + %0 = math.tanh %arg0 : tensor<128x?xf32> + return +} diff --git a/test/Conversion/triton-to-linalg.mlir b/test/Conversion/triton-to-linalg.mlir index dc4b1eb..e07d6bc 100644 --- a/test/Conversion/triton-to-linalg.mlir +++ b/test/Conversion/triton-to-linalg.mlir @@ -156,7 +156,7 @@ func.func @view_1d_0d_ptr(%arg0: tensor<1x!tt.ptr>) { // CHECK-LABEL: @add_ptr // CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: tensor<256xi32> tt.func @add_ptr(%arg0: !tt.ptr, %arg1: tensor<256xi32>) { - // CHECK-NEXT: %[[INIT:.*]] = tensor.empty + // CHECK: %[[INIT:.*]] = tensor.empty // CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ARG0]] : i64) outs(%[[INIT]] : tensor<256xi64>) // CHECK-NEXT: %[[INIT2:.*]] = tensor.empty // CHECK-NEXT: %[[OUT:.*]] = linalg.map ins(%[[FILL]], %[[ARG1]] : tensor<256xi64>, tensor<256xi32>) outs(%[[INIT2]] : tensor<256xi64>) @@ -175,7 +175,7 @@ tt.func @add_ptr(%arg0: !tt.ptr, %arg1: tensor<256xi32>) { // CHECK-LABEL: @add_ptr_for_scalar // CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i32 tt.func @add_ptr_for_scalar(%arg0: !tt.ptr, %arg1: i32) { - // CHECK-NEXT: %[[EXT:.*]] = arith.extsi %[[ARG1]] : i32 to i64 + // CHECK: %[[EXT:.*]] = arith.extsi %[[ARG1]] : i32 to i64 // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[EXT]], %[[C4]] // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[MUL]] @@ -188,7 +188,7 @@ tt.func @add_ptr_for_scalar(%arg0: !tt.ptr, %arg1: i32) { // CHECK-LABEL: @add_ptr_for_scalar_i64 // CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64 tt.func @add_ptr_for_scalar_i64(%arg0: !tt.ptr, %arg1: i64) { - // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 + // CHECK: %[[C4:.*]] = arith.constant 4 // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG1]], %[[C4]] // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[MUL]] %0 = tt.addptr %arg0, %arg1: !tt.ptr, i64 @@ -657,37 +657,35 @@ tt.func @reduce_multi_statement_argmin_f32(%arg0: tensor<1x256xf32>, %arg1: tens // CHECK-LABEL: func.func @reduce_multi_statement_argmin_f32( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256xi32>) { - // CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [1, 1] [1, 1] : tensor<1x256xf32> to tensor<1x1xf32> - // CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0, 1]] : tensor<1x1xf32> into tensor<1xf32> - // CHECK: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_0]][0, 1] [1, 255] [1, 1] : tensor<1x256xf32> to tensor<1x255xf32> - // CHECK: %[[VAL_5:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [1, 1] [1, 1] : tensor<1x256xi32> to tensor<1x1xi32> - // CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : tensor<1x1xi32> into tensor<1xi32> - // CHECK: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_1]][0, 1] [1, 255] [1, 1] : tensor<1x256xi32> to tensor<1x255xi32> - // CHECK: %[[VAL_8:.*]]:2 = linalg.reduce ins(%[[VAL_4]], %[[VAL_7]] : tensor<1x255xf32>, tensor<1x255xi32>) outs(%[[VAL_3]], %[[VAL_6]] : tensor<1xf32>, tensor<1xi32>) dimensions = [1] + // CHECK: %[[VAL_2:.*]] = arith.constant 0x7F800000 : f32 + // CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1xf32> + // CHECK: %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_3]] : tensor<1xf32>) -> tensor<1xf32> + // CHECK: %[[VAL_5:.*]] = arith.constant -1 : i32 + // CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<1xi32> + // CHECK: %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_5]] : i32) outs(%[[VAL_6]] : tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[VAL_8:.*]]:2 = linalg.reduce ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x256xf32>, tensor<1x256xi32>) outs(%[[VAL_4]], %[[VAL_7]] : tensor<1xf32>, tensor<1xi32>) dimensions = [1] // CHECK: (%[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: i32) { - // CHECK: %[[VAL_13:.*]] = arith.cmpf olt, %[[VAL_9]], %[[VAL_11]] : f32 - // CHECK: %[[VAL_14:.*]] = arith.cmpf ogt, %[[VAL_9]], %[[VAL_11]] : f32 - // CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_10]], %[[VAL_12]] : i32 - // CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_10]], %[[VAL_12]] : i32 - // CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_14]], %[[VAL_12]], %[[VAL_16]] : i32 - // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_13]], %[[VAL_10]], %[[VAL_17]] : i32 - // CHECK: %[[VAL_19:.*]] = arith.cmpf olt, %[[VAL_9]], %[[VAL_11]] : f32 - // CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_9]], %[[VAL_11]] : f32 - // CHECK: linalg.yield %[[VAL_20]], %[[VAL_18]] : f32, i32 + // CHECK: %[[VAL_13:.*]] = arith.cmpf oeq, %[[VAL_9]], %[[VAL_11]] : f32 + // CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_10]], %[[VAL_12]] : i32 + // CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 + // CHECK: %[[VAL_16:.*]] = arith.cmpf olt, %[[VAL_9]], %[[VAL_11]] : f32 + // CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_16]], %[[VAL_15]] : i1 + // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_11]] : f32 + // CHECK: %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_10]], %[[VAL_12]] : i32 + // CHECK: linalg.yield %[[VAL_18]], %[[VAL_19]] : f32, i32 // CHECK: } // CHECK: return // CHECK: } %9:2 = "tt.reduce"(%arg0, %arg1) ({ - ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): - %14 = arith.cmpf olt, %arg4, %arg6 : f32 - %15 = arith.cmpf ogt, %arg4, %arg6 : f32 - %16 = arith.cmpi slt, %arg5, %arg7 : i32 - %17 = arith.select %16, %arg5, %arg7 : i32 - %18 = arith.select %15, %arg7, %17 : i32 - %19 = arith.select %14, %arg5, %18 : i32 - %20 = arith.cmpf olt, %arg4, %arg6 : f32 - %21 = arith.select %20, %arg4, %arg6 : f32 - tt.reduce.return %21, %19 : f32, i32 + ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + %12 = arith.cmpi slt, %arg10, %arg12 : i32 + %13 = arith.andi %11, %12 : i1 + %14 = arith.cmpf olt, %arg9, %arg11 : f32 + %15 = arith.ori %14, %13 : i1 + %16 = arith.select %15, %arg9, %arg11 : f32 + %17 = arith.select %15, %arg10, %arg12 : i32 + tt.reduce.return %16, %17 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) tt.return } @@ -697,27 +695,35 @@ tt.func @reduce_multi_statement_argmax_f32(%arg0: tensor<2x2x256xf32>, %arg1: te // CHECK-LABEL: func.func @reduce_multi_statement_argmax_f32( // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x256xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2x256xi32>) { - // CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_0]][0, 0, 0] [2, 2, 1] [1, 1, 1] : tensor<2x2x256xf32> to tensor<2x2x1xf32> - // CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : tensor<2x2x1xf32> into tensor<2x2xf32> - // CHECK: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_0]][0, 0, 1] [2, 2, 255] [1, 1, 1] : tensor<2x2x256xf32> to tensor<2x2x255xf32> - // CHECK: %[[VAL_5:.*]] = tensor.extract_slice %[[VAL_1]][0, 0, 0] [2, 2, 1] [1, 1, 1] : tensor<2x2x256xi32> to tensor<2x2x1xi32> - // CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_5]] {{\[\[}}0], [1, 2]] : tensor<2x2x1xi32> into tensor<2x2xi32> - // CHECK: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_1]][0, 0, 1] [2, 2, 255] [1, 1, 1] : tensor<2x2x256xi32> to tensor<2x2x255xi32> - // CHECK: %[[VAL_8:.*]]:2 = linalg.reduce ins(%[[VAL_4]], %[[VAL_7]] : tensor<2x2x255xf32>, tensor<2x2x255xi32>) outs(%[[VAL_3]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>) dimensions = [2] + // CHECK: %[[VAL_2:.*]] = arith.constant 0xFF800000 : f32 + // CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<2x2xf32> + // CHECK: %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_3]] : tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: %[[VAL_5:.*]] = arith.constant -1 : i32 + // CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<2x2xi32> + // CHECK: %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_5]] : i32) outs(%[[VAL_6]] : tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: %[[VAL_8:.*]]:2 = linalg.reduce ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x2x256xf32>, tensor<2x2x256xi32>) outs(%[[VAL_4]], %[[VAL_7]] : tensor<2x2xf32>, tensor<2x2xi32>) dimensions = [2] // CHECK: (%[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: i32) { - // CHECK: %[[VAL_13:.*]] = arith.cmpf ogt, %[[VAL_9]], %[[VAL_11]] : f32 - // CHECK: %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_9]], %[[VAL_11]] : f32 - // CHECK: %[[VAL_15:.*]] = arith.select %[[VAL_13]], %[[VAL_10]], %[[VAL_12]] : i32 - // CHECK: linalg.yield %[[VAL_14]], %[[VAL_15]] : f32, i32 + // CHECK: %[[VAL_13:.*]] = arith.cmpf oeq, %[[VAL_9]], %[[VAL_11]] : f32 + // CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_10]], %[[VAL_12]] : i32 + // CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 + // CHECK: %[[VAL_16:.*]] = arith.cmpf ogt, %[[VAL_9]], %[[VAL_11]] : f32 + // CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_16]], %[[VAL_15]] : i1 + // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_11]] : f32 + // CHECK: %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_10]], %[[VAL_12]] : i32 + // CHECK: linalg.yield %[[VAL_18]], %[[VAL_19]] : f32, i32 // CHECK: } // CHECK: return // CHECK: } %9:2 = "tt.reduce"(%arg0, %arg1) ({ - ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): - %1 = arith.cmpf ogt, %arg4, %arg6 : f32 - %2 = arith.select %1, %arg4, %arg6 : f32 - %3 = arith.select %1, %arg5, %arg7 : i32 - tt.reduce.return %2, %3 : f32, i32 + ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + %12 = arith.cmpi slt, %arg10, %arg12 : i32 + %13 = arith.andi %11, %12 : i1 + %14 = arith.cmpf ogt, %arg9, %arg11 : f32 + %15 = arith.ori %14, %13 : i1 + %16 = arith.select %15, %arg9, %arg11 : f32 + %17 = arith.select %15, %arg10, %arg12 : i32 + tt.reduce.return %16, %17 : f32, i32 }) {axis = 2 : i32} : (tensor<2x2x256xf32>, tensor<2x2x256xi32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) tt.return } @@ -765,6 +771,7 @@ tt.func @ext_elemwise_2(%arg0: tensor<16x16xi32>) { // ----- // CHECK-LABEL: @cast_ptr_and_int_scalar tt.func @cast_ptr_and_int_scalar(%arg0: !tt.ptr) { + // CHECK-NEXT: builtin.unrealized_conversion_cast // CHECK-NEXT: return %0 = tt.ptr_to_int %arg0 : !tt.ptr -> i64 %1 = tt.int_to_ptr %0 : i64 -> !tt.ptr @@ -774,6 +781,7 @@ tt.func @cast_ptr_and_int_scalar(%arg0: !tt.ptr) { // ----- // CHECK-LABEL: @cast_ptr_and_int_1D tt.func @cast_ptr_and_int_1D(%arg0: tensor<16x!tt.ptr>) { + // CHECK-NEXT: builtin.unrealized_conversion_cast // CHECK-NEXT: return %0 = tt.ptr_to_int %arg0 : tensor<16x!tt.ptr> -> tensor<16xi64> %1 = tt.int_to_ptr %0 : tensor<16xi64> -> tensor<16x!tt.ptr> @@ -783,6 +791,7 @@ tt.func @cast_ptr_and_int_1D(%arg0: tensor<16x!tt.ptr>) { // ----- // CHECK-LABEL: @cast_ptr_and_int_2D tt.func @cast_ptr_and_int_2D(%arg0: tensor<2x16x!tt.ptr>) { + // CHECK-NEXT: builtin.unrealized_conversion_cast // CHECK-NEXT: return %0 = tt.ptr_to_int %arg0 : tensor<2x16x!tt.ptr> -> tensor<2x16xi64> %1 = tt.int_to_ptr %0 : tensor<2x16xi64> -> tensor<2x16x!tt.ptr> @@ -810,6 +819,16 @@ tt.func @optimize_barrier(%arg0: !tt.ptr) { tt.return } +// ----- +// CHECK-LABEL: @trans_3d +// CHECK-SAME: %[[ARG0:.*]]: tensor<16x256x4xf32> +tt.func @trans_3d(%arg0: tensor<16x256x4xf32>) { + // CHECK: %[[INIT_OUT:.*]] = tensor.empty() : tensor<16x4x256xf32> + // CHECK: %[[TRANS_OUT:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<16x256x4xf32>) outs(%[[INIT_OUT]] : tensor<16x4x256xf32>) permutation = [0, 2, 1] + %out = tt.trans %arg0 {order = array} : tensor<16x256x4xf32> -> tensor<16x4x256xf32> + tt.return +} + // ----- // CHECK-LABEL: @trans_2d // CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf32> @@ -918,9 +937,9 @@ func.func @cmpi_to_fill(%arg0: i32) { // CHECK: %[[ARG1:.*]] = arith.maxsi %[[ARG]], %[[C0]] : index // CHECK: %[[C128:.*]] = arith.constant 128 : index // CHECK: %[[SIZE:.*]] = arith.minsi %[[C128]], %[[ARG1]] + // CHECK: %[[FALSE:.*]] = arith.constant false // CHECK: %[[INIT1:.*]] = tensor.empty(%[[SIZE]]) // CHECK: %[[TRUE:.*]] = linalg.fill ins(%true : i1) outs(%[[INIT1]] : tensor) - // CHECK: %[[FALSE:.*]] = arith.constant false // CHECK: %[[PAD_INIT:.*]] = tensor.empty() : tensor<128xi1> // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[PAD_INIT]], %[[C0_0]] : tensor<128xi1> @@ -941,16 +960,17 @@ func.func @cmpi_to_fill(%arg0: i32) { // CHECK: %[[ARG:.*]] = arith.index_cast %[[ARG_I32]] : i32 to index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[ARG1:.*]] = arith.addi %[[ARG]], %[[C1]] : index +// CHECK: %[[ARG2:.*]] = arith.maxsi %[[ARG1]], %[[ARG]] : index // CHECK: %[[C128:.*]] = arith.constant 128 : index -// CHECK: %[[LB0:.*]] = arith.minsi %[[C128]], %[[ARG1]] +// CHECK: %[[LB0:.*]] = arith.minsi %[[C128]], %[[ARG2]] // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LB1:.*]] = arith.maxsi %[[C0]], %[[LB0]] : index // CHECK: %[[C128_0:.*]] = arith.constant 128 : index // CHECK: %[[SIZE:.*]] = arith.subi %[[C128_0]], %[[LB1]] : index +// CHECK: %[[FALSE:.*]] = arith.constant false // CHECK: %[[INIT1:.*]] = tensor.empty(%[[SIZE]]) : tensor // CHECK: %[[C_TRUE:.*]] = arith.constant true // CHECK: %[[TRUE:.*]] = linalg.fill ins(%[[C_TRUE]] : i1) outs(%[[INIT1]] : tensor) -// CHECK: %[[FALSE:.*]] = arith.constant false // CHECK: %[[PAD_INIT:.*]] = tensor.empty() : tensor<128xi1> // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[PAD_INIT]], %[[C0_0]] : tensor<128xi1> @@ -965,7 +985,29 @@ func.func @cmp2fill_ub(%arg: i32) -> tensor<128xi1> { return %2 : tensor<128xi1> } - +// ----- +// CHECK-LABEL: @cmpi_to_fill_false +func.func @cmpi_to_fill_false() { + // CHECK: %[[C_MAX:.*]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xi64> + // CHECK: %[[C_MAX_TENSOR:.*]] = linalg.fill ins(%[[C_MAX]] : i64) outs(%[[INIT]] : tensor<128xi64>) -> tensor<128xi64> + // CHECK: %[[RANGE_INIT:.*]] = tensor.empty() : tensor<128xi32> + // CHECK: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK: %[[C128:.*]] = arith.constant 128 : i32 + // CHECK: %[[RANGE:.*]] = linalg_ext.make_range {operandSegmentSizes = array} ins(%[[C0]], %[[C128]] : i32, i32) outs(%[[RANGE_INIT]] : tensor<128xi32>) -> tensor<128xi32> + // CHECK: %[[EXT_INIT:.*]] = tensor.empty() : tensor<128xi64> + // CHECK: %[[EXT_RANGE:.*]] = linalg.map { arith.extsi } ins(%[[RANGE]] : tensor<128xi32>) outs(%[[EXT_INIT]] : tensor<128xi64>) + // CHECK: %[[C_MAX_INDEX:.*]] = arith.constant 9223372036854775807 : index + // CHECK: %[[FALSE:.*]] = arith.constant false + // CHECK: %[[CMP_RES_INIT:.*]] = tensor.empty() : tensor<128xi1> + // CHECK: %[[CMP_RES:.*]] = linalg.fill ins(%[[FALSE]] : i1) outs(%[[CMP_RES_INIT]] : tensor<128xi1>) -> tensor<128xi1> + // CHECK-NOT: arith.cmpi + %c_max = arith.constant dense<9223372036854775807> : tensor<128xi64> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = arith.extsi %0 : tensor<128xi32> to tensor<128xi64> + %mask = arith.cmpi sgt, %1, %c_max : tensor<128xi64> + return +} // ----- // CHECK-LABEL: @select_conversion @@ -1068,7 +1110,8 @@ func.func @select_pad_conversion(%arg0: tensor<64x64xf32>, %arg1: i32) -> tensor // CHECK: linalg_ext.pad // CHECK: %[[C32_INDEX:.*]] = arith.index_cast %{{.*}} : i32 to index // CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C33:.*]] = arith.addi %[[C32_INDEX]], %[[C1]] : index +// CHECK: %[[C33_0:.*]] = arith.addi %[[C32_INDEX]], %[[C1]] : index +// CHECK: %[[C33:.*]] = arith.maxsi %[[C33_0]], %[[C32_INDEX]] : index // CHECK: %[[C128:.*]] = arith.constant 128 : index // CHECK: %[[C33_1:.*]] = arith.minsi %[[C128]], %[[C33]] : index // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -1757,7 +1800,7 @@ tt.func @scan_add_1d_size1_f32_reverse(%arg0: tensor<1xf32>) -> tensor<1xf32> { // ----- tt.func @tt_mulhiui_scalar_i32(%arg0: i32, %arg1: i32) { - // CHECK: math_ext.mulhiui + // CHECK: math_ext.mulhiui %0 = tt.mulhiui %arg0, %arg1 : i32 tt.return } @@ -1769,6 +1812,39 @@ tt.func @tt_mulhiui_vector_i32(%arg0: tensor<16x16xi32>, %arg1: tensor<16x16xi32 tt.return } +// ----- +// CHECK-LABEL: @cat_tensor +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32> +tt.func public @cat_tensor(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<64xf32> + // CHECK: %[[INSET1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0] [32] [1] : tensor<32xf32> into tensor<64xf32> + // CHECK: %[[INSET2:.*]] = tensor.insert_slice %[[ARG1]] into %[[INSET1]][%c32_0] [32] [1] : tensor<32xf32> into tensor<64xf32> + %0 = tt.cat %arg0, %arg1 : tensor<32xf32> -> tensor<64xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @cat_0rank +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor +tt.func public @cat_0rank(%arg0: tensor, %arg1: tensor) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2xf32> + // CHECK: %[[INSET1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0] [1] [1] : tensor into tensor<2xf32> + // CHECK: %[[INSET2:.*]] = tensor.insert_slice %[[ARG1]] into %[[INSET1]][%c1_0] [1] [1] : tensor into tensor<2xf32> + %0 = tt.cat %arg0, %arg1 : tensor -> tensor<2xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @cat_3rank +// CHECK-SAME: %[[ARG0:.*]]: tensor<32x16x8xf32>, %[[ARG1:.*]]: tensor<32x16x8xf32> +tt.func public @cat_3rank(%arg0: tensor<32x16x8xf32>, %arg1: tensor<32x16x8xf32>) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<64x16x8xf32> + // CHECK: %[[INSET1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0, 0] [32, 16, 8] [1, 1, 1] : tensor<32x16x8xf32> into tensor<64x16x8xf32> + // CHECK: %[[INSET2:.*]] = tensor.insert_slice %[[ARG1]] into %[[INSET1]][%c32_0, 0, 0] [32, 16, 8] [1, 1, 1] : tensor<32x16x8xf32> into tensor<64x16x8xf32> + %0 = tt.cat %arg0, %arg1 : tensor<32x16x8xf32> -> tensor<64x16x8xf32> + tt.return +} + // ----- // CHECK-LABEL: @join_int8 // CHECK-SAME: %[[ARG0:.*]]: tensor<2x8xi8>, %[[ARG1:.*]]: tensor<2x8xi8> @@ -1870,7 +1946,7 @@ tt.func @tt_precise_divf_vector_f32(%arg0: tensor<128xf32>, %arg1: tensor<128xf3 // ----- // CHECK-LABEL: @clampf_propagateNan_all_f32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32>, %[[ARG2:.*]]: tensor<32xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32>, %[[ARG2:.*]]: tensor<32xf32> tt.func @clampf_propagateNan_all_f32(%x: tensor<32xf32>, %min: tensor<32xf32>, %max: tensor<32xf32>) -> tensor<32xf32> { // CHECK: %[[MAPPED:.*]] = linalg.map { arith.maximumf } ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>) // CHECK: linalg.map { arith.minimumf } ins(%[[MAPPED]], %[[ARG2]] : tensor<32xf32>, tensor<32xf32>) @@ -1911,35 +1987,8 @@ tt.func @clampf_propagateNan_none_f16(%x: tensor<32xf16>, %min: tensor<32xf16>, // ----- // CHECK-LABEL: @histogram_i32 // CHECK-SAME: %[[ARG0:.*]]: tensor<8xi32>) -// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 -// CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32 -// CHECK: %[[C2_I32:.*]] = arith.constant 2 : i32 -// CHECK: %[[ARG1:.*]] = arith.subi %[[C2_I32]], %[[C1_I32]] : i32 -// CHECK: %[[ARG2:.*]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32 -// CHECK: %[[ARG3:.*]] = linalg.fill ins(%[[C0_I32_0]] : i32) outs(%[[ARG2]] : tensor<2xi32>) -> tensor<2xi32> -// CHECK: %[[C8:.*]] = arith.constant 8 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[ARG4:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG3]]) -> (tensor<2xi32>) { -// CHECK: %[[ARG7:.*]] = tensor.extract %[[ARG0]][%[[ARG5]]] : tensor<8xi32> -// CHECK: %[[ARG8:.*]] = arith.cmpi sle, %[[C0_I32]], %[[ARG7]] : i32 -// CHECK: %[[ARG9:.*]] = arith.cmpi sge, %[[ARG1]], %[[ARG7]] : i32 -// CHECK: %[[ARG10:.*]] = arith.andi %[[ARG8]], %[[ARG9]] : i1 -// CHECK: %[[ARG11:.*]] = scf.if %[[ARG10]] -> (tensor<2xi32>) { -// CHECK: %[[ARG12:.*]] = arith.subi %[[ARG7]], %[[C0_I32]] : i32 -// CHECK: %[[ARG13:.*]] = arith.index_cast %[[ARG12]] : i32 to index -// CHECK: %[[ARG14:.*]] = tensor.extract %[[ARG6]][%[[ARG13]]] : tensor<2xi32> -// CHECK: %[[C1_I32_2:.*]] = arith.constant 1 : i32 -// CHECK: %[[ARG15:.*]] = arith.addi %[[ARG14]], %[[C1_I32_2]] : i32 -// CHECK: %[[ARG16:.*]] = tensor.insert %[[ARG15]] into %[[ARG6]][%[[ARG13]]] : tensor<2xi32> -// CHECK: scf.yield %[[ARG16]] : tensor<2xi32> -// CHECK: } else { -// CHECK: scf.yield %[[ARG6]] : tensor<2xi32> -// CHECK: } -// CHECK: scf.yield %[[ARG11]] : tensor<2xi32> -// CHECK: } -// CHECK: return +// CHECK: %[[ARG1:.*]] = tensor.empty() : tensor<2xi32> +// CHECK: %[[ARG2:.*]] = linalg_ext.histogram ins(%[[ARG0]] : tensor<8xi32>) outs(%[[ARG1]] : tensor<2xi32>) -> tensor<2xi32> tt.func @histogram_i32(%0: tensor<8xi32>) { %1 = tt.histogram %0 : tensor<8xi32> -> tensor<2xi32> tt.return @@ -1948,36 +1997,44 @@ tt.func @histogram_i32(%0: tensor<8xi32>) { // ----- // CHECK-LABEL: @histogram_i64 // CHECK-SAME: %[[ARG0:.*]]: tensor<128xi64>) -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 -// CHECK: %[[ARG1:.*]] = arith.subi %[[C32_I64]], %[[C1_I64]] : i64 -// CHECK: %[[ARG2:.*]] = tensor.empty() : tensor<32xi64> -// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[ARG3:.*]] = linalg.fill ins(%[[C0_I64_0]] : i64) outs(%[[ARG2]] : tensor<32xi64>) -> tensor<32xi64> -// CHECK: %[[C128:.*]] = arith.constant 128 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[ARG4:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG3]]) -> (tensor<32xi64>) { -// CHECK: %[[ARG7:.*]] = tensor.extract %[[ARG0]][%[[ARG5]]] : tensor<128xi64> -// CHECK: %[[ARG8:.*]] = arith.cmpi sle, %[[C0_I64]], %[[ARG7]] : i64 -// CHECK: %[[ARG9:.*]] = arith.cmpi sge, %[[ARG1]], %[[ARG7]] : i64 -// CHECK: %[[ARG10:.*]] = arith.andi %[[ARG8]], %[[ARG9]] : i1 -// CHECK: %[[ARG11:.*]] = scf.if %[[ARG10]] -> (tensor<32xi64>) { -// CHECK: %[[ARG12:.*]] = arith.subi %[[ARG7]], %[[C0_I64]] : i64 -// CHECK: %[[ARG13:.*]] = arith.index_cast %[[ARG12]] : i64 to index -// CHECK: %[[ARG14:.*]] = tensor.extract %[[ARG6]][%[[ARG13]]] : tensor<32xi64> -// CHECK: %[[C1_I64_2:.*]] = arith.constant 1 : i64 -// CHECK: %[[ARG15:.*]] = arith.addi %[[ARG14]], %[[C1_I64_2]] : i64 -// CHECK: %[[ARG16:.*]] = tensor.insert %[[ARG15]] into %[[ARG6]][%[[ARG13]]] : tensor<32xi64> -// CHECK: scf.yield %[[ARG16]] : tensor<32xi64> -// CHECK: } else { -// CHECK: scf.yield %[[ARG6]] : tensor<32xi64> -// CHECK: } -// CHECK: scf.yield %[[ARG11]] : tensor<32xi64> -// CHECK: } -// CHECK: return +// CHECK: %[[ARG1:.*]] = tensor.empty() : tensor<32xi64> +// CHECK: %[[ARG2:.*]] = linalg_ext.histogram ins(%[[ARG0]] : tensor<128xi64>) outs(%[[ARG1]] : tensor<32xi64>) -> tensor<32xi64> tt.func @histogram_i64(%0: tensor<128xi64>) { %1 = tt.histogram %0 : tensor<128xi64> -> tensor<32xi64> tt.return } + +// ----- +// CHECK-LABEL: @arith_select_scalar_cond +// CHECK: linalg.map { arith.select } +func.func @arith_select_scalar_cond(%arg0: i1, %arg1: tensor<128x128xf32>) { + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %0 = arith.select %arg0, %cst_1, %arg1 : tensor<128x128xf32> + return +} + +// ----- +// CHECK-LABEL: @arith_mul_scalar +// CHECK: arith.muli +func.func @arith_mul_scalar(%arg0: i32) { + %cst_0 = arith.constant 0 : i32 + %cst_128 = arith.constant 128 : i32 + %0 = arith.muli %arg0, %cst_0 : i32 + %1 = arith.muli %0, %cst_128 : i32 + return +} + +// ----- +// CHECK-LABEL: @automic_rmw_zero_dtype +func.func @automic_rmw_zero_dtype(%arg0: !tt.ptr, %arg1: i32, %arg2: f16) { + %true = arith.constant true + %0 = tt.addptr %arg0, %arg1 : !tt.ptr, i32 + %1 = arith.extf %arg2 : f16 to f32 + // CHECK-NOT: expected integer or index type + // CHECK: scf.if + // CHECK: else + // CHECK-NEXT: %[[ARG:.*]] = arith.constant 0.000000e+00 + // CHECK: scf.yield %[[ARG]] + %2 = tt.atomic_rmw max, acq_rel, gpu, %0, %1, %true : (!tt.ptr, f32, i1) -> f32 + return +} diff --git a/test/Conversion/triton-to-tensor.mlir b/test/Conversion/triton-to-tensor.mlir deleted file mode 100644 index d5f8415..0000000 --- a/test/Conversion/triton-to-tensor.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: triton-linalg-opt --convert-triton-to-tensor %s -split-input-file | FileCheck %s - -// CHECK: tensor.insert_slice %arg0 -// CHECK: tensor.insert_slice %arg1 -tt.func public @cat(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) { - %0 = tt.cat %arg0, %arg1 : tensor<32xf32> -> tensor<64xf32> - tt.return -} diff --git a/test/Dialect/LinalgExt/invalid.mlir b/test/Dialect/LinalgExt/invalid.mlir index 3945044..9b22222 100644 --- a/test/Dialect/LinalgExt/invalid.mlir +++ b/test/Dialect/LinalgExt/invalid.mlir @@ -137,6 +137,7 @@ func.func @batch_conv_2d_nhwc_fhwc_invalid_dtype_in_strides(%input: tensor) -> tensor<2x64xi32> { %c0 = arith.constant 0 : i32 %c128 = arith.constant 128 : i32 @@ -146,6 +147,7 @@ func.func @make_range_output_rank_invalid(%arg0: tensor<2x64xi32>) -> tensor<2x6 } // ----- +// CHECK: linalg_ext.make_range func.func @make_range_start_end_invalid(%arg0: tensor<128xi32>) -> tensor<128xi32> { %c0 = arith.constant 0 : i32 %c128 = arith.constant 128 : i32 @@ -155,6 +157,7 @@ func.func @make_range_start_end_invalid(%arg0: tensor<128xi32>) -> tensor<128xi3 } // ----- +// CHECK: linalg_ext.make_range func.func @make_range_output_shape_mismatch(%arg0: tensor<129xi32>) -> tensor<129xi32> { %c0 = arith.constant 0 : i32 %c128 = arith.constant 128 : i32 @@ -164,6 +167,7 @@ func.func @make_range_output_shape_mismatch(%arg0: tensor<129xi32>) -> tensor<12 } // ----- +// CHECK: linalg_ext.make_range func.func @make_range_result_type_invalid(%arg0: tensor<128xf32>) -> tensor<128xf32> { %c2 = arith.constant 2 : i32 %c130 = arith.constant 130 : i32 @@ -960,6 +964,131 @@ func.func @pad_pvalue_type_mismatch(%input : tensor<4x4xf32>, %init : tensor<6x8 return %pad : tensor<6x8xf32> } +// ----- +func.func @argmax_unsorted_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<16xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmax' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}} + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<16xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmax_unmatched_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmax' op init dimensions [128] doesn't match input dimensions after reduction [16]}} + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmax_out_of_range(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmax' op dimensions for reduction should be in the range [0, 1].}} + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [4] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmin_unsorted_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<16xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmin' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}} + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<16xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmin_unmatched_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmin' op init dimensions [128] doesn't match input dimensions after reduction [16]}} + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmin_out_of_range(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmin' op dimensions for reduction should be in the range [0, 1].}} + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [4] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + // ----- func.func @scan_unmatched_output_and_init_num(%input: tensor<16x32x64xf32>, %output0: tensor<16x32x64xf32>, @@ -1200,3 +1329,178 @@ func.func @scalar_libdevice_call_result_invalid(%arg0: f32) -> tensor { symbol = "__cn_scalar_add_f32" -> tensor return %libdevicecall : tensor } + +// ----- +func.func @histogram_wrong_num(%input1: tensor<128xi32>, + %input2: tensor<64xi32>, %init: tensor<16xi32>) -> tensor<16xi32> { + // expected-error @+1 {{'linalg_ext.histogram' op only supports 1 input operand!}} + %1 = linalg_ext.histogram + ins(%input1, %input2 : tensor<128xi32>, tensor<64xi32>) + outs(%init:tensor<16xi32>) -> tensor<16xi32> + func.return %1 : tensor<16xi32> +} + +// ----- +func.func @histogram_wrong_input_rank(%input: tensor<128x64xi32>, + %init: tensor<16x64xi32>) -> tensor<16x64xi32> { + // expected-error @+1 {{'linalg_ext.histogram' op only supports 1D input!}} + %1 = linalg_ext.histogram + ins(%input : tensor<128x64xi32>) + outs(%init:tensor<16x64xi32>) -> tensor<16x64xi32> + func.return %1 : tensor<16x64xi32> +} + +// ----- +func.func @histogram_wrong_input_type(%input: tensor<64xf32>, + %init: tensor<8xf32>) -> tensor<8xf32> { + // expected-error @+1 {{'linalg_ext.histogram' op only supports integer input!}} + %1 = linalg_ext.histogram + ins(%input : tensor<64xf32>) + outs(%init:tensor<8xf32>) -> tensor<8xf32> + func.return %1 : tensor<8xf32> +} + +// ----- +func.func @histogram_wrong_output_rank(%input: tensor<128xi32>, + %init: tensor<16x64xi32>) -> tensor<16x64xi32> { + // expected-error @+1 {{'linalg_ext.histogram' op only supports 1D output!}} + %1 = linalg_ext.histogram + ins(%input : tensor<128xi32>) + outs(%init:tensor<16x64xi32>) -> tensor<16x64xi32> + func.return %1 : tensor<16x64xi32> +} + +// ----- +func.func @histogram_wrong_output_type(%input: tensor<64xi32>, + %init: tensor<8xf32>) -> tensor<8xf32> { + // expected-error @+1 {{'linalg_ext.histogram' op only supports integer output!}} + %1 = linalg_ext.histogram + ins(%input : tensor<64xi32>) + outs(%init:tensor<8xf32>) -> tensor<8xf32> + func.return %1 : tensor<8xf32> +} + +// ----- +func.func @argmax_unsorted_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<16xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmax' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}} + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<16xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmax_unmatched_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmax' op init dimensions [128] doesn't match input dimensions after reduction [16]}} + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmax_out_of_range(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmax' op dimensions for reduction should be in the range [0, 1].}} + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [4] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmin_unsorted_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<16xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmin' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}} + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<16xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmin_unmatched_dim(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmin' op init dimensions [128] doesn't match input dimensions after reduction [16]}} + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [1] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- + +func.func @argmin_out_of_range(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + // expected-error @+1 {{'linalg_ext.argmin' op dimensions for reduction should be in the range [0, 1].}} + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [4] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} diff --git a/test/Dialect/LinalgExt/ops.mlir b/test/Dialect/LinalgExt/ops.mlir index a8a5eb0..4293a88 100644 --- a/test/Dialect/LinalgExt/ops.mlir +++ b/test/Dialect/LinalgExt/ops.mlir @@ -577,3 +577,63 @@ func.func @scan_memref(%input: memref<16x32x64xf32>, } -> memref<16x32x64xf32>, memref<16x64xf32> func.return } + +// ----- +// CHECK: histogram_tensor +func.func @histogram_tensor(%input: tensor<128xi32>, + %init: tensor<16xi32>) -> tensor<16xi32> { + %1 = linalg_ext.histogram + ins(%input:tensor<128xi32>) + outs(%init:tensor<16xi32>) -> tensor<16xi32> + func.return %1 : tensor<16xi32> +} + +// ----- +// CHECK: histogram_memref +func.func @histogram_memref(%input: memref<128xi32>, + %init: memref<16xi32>) { + linalg_ext.histogram + ins(%input:memref<128xi32>) + outs(%init:memref<16xi32>) + func.return +} + +// ----- +// CHECK: linalg_ext.argmax +func.func @ext_argmax(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + %argmax:2 = linalg_ext.argmax + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [0] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "ogt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} + +// ----- +// CHECK: linalg_ext.argmin +func.func @ext_argmin(%input_value: tensor<16x128xf32>, %input_index: tensor<16x128xi32>, %output_value: tensor<128xf32>, %output_index: tensor<128xi32>) { + %argmin:2 = linalg_ext.argmin + ins(%input_value, %input_index : tensor<16x128xf32>, tensor<16x128xi32>) + outs(%output_value, %output_index : tensor<128xf32>, tensor<128xi32>) + dimensions = [0] + (%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32) { + %0 = arith.cmpf "oeq", %arg2, %arg4 : f32 + %1 = arith.cmpi "slt", %arg3, %arg5 : i32 + %2 = arith.andi %0, %1 : i1 + %3 = arith.cmpf "olt", %arg2, %arg4 : f32 + %4 = arith.ori %3, %2 : i1 + %5 = arith.select %4, %arg2, %arg4 : f32 + %6 = arith.select %4, %arg3, %arg5 : i32 + linalg.yield %5, %6 : f32, i32 + } + func.return +} diff --git a/test/Dialect/Triton/extract-move-backward.mlir b/test/Dialect/Triton/extract-move-backward.mlir index 508146f..e301a5f 100644 --- a/test/Dialect/Triton/extract-move-backward.mlir +++ b/test/Dialect/Triton/extract-move-backward.mlir @@ -32,9 +32,9 @@ func.func @extract_element_from_make_range(%arg0: i64) -> i32 { // CHECK-LABEL: @extract_element_from_arith_addi func.func @extract_element_from_arith_addi(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> i32 { // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xi32> - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xi32> - // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : i32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xi32> + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xi32> + // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : i32 // CHECK-NOT: linalg.* %c0 = arith.constant 0 : index %0 = arith.addi %arg0, %arg1 : tensor<128xi32> @@ -59,9 +59,9 @@ func.func @extract_element_from_math_erf(%arg0: tensor<128xf32>) -> f32 { // CHECK-LABEL: @extract_element_from_arith_addf func.func @extract_element_from_arith_addf(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> f32 { // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xf32> - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xf32> - // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : f32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xf32> + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xf32> + // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : f32 // CHECK-NOT: linalg.* %c0 = arith.constant 0 : index %0 = arith.addf %arg0, %arg1 : tensor<128xf32> @@ -73,9 +73,9 @@ func.func @extract_element_from_arith_addf(%arg0: tensor<128xf32>, %arg1: tensor // CHECK-LABEL: @extract_element_from_arith_cmpi func.func @extract_element_from_arith_cmpi(%arg0: tensor<128xindex>, %arg1: tensor<128xindex>) -> i1 { // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xindex> - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xindex> - // CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi slt, %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : index + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xindex> + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xindex> + // CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi slt, %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : index // CHECK-NOT: linalg.* %c0 = arith.constant 0 : index %0 = arith.cmpi slt, %arg0, %arg1 : tensor<128xindex> @@ -87,9 +87,9 @@ func.func @extract_element_from_arith_cmpi(%arg0: tensor<128xindex>, %arg1: tens // CHECK-LABEL: @extract_element_from_arith_cmpf func.func @extract_element_from_arith_cmpf(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> i1 { // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xf32> - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xf32> - // CHECK-NEXT: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : f32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xf32> + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xf32> + // CHECK-NEXT: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : f32 // CHECK-NOT: linalg.* %c0 = arith.constant 0 : index %0 = arith.cmpf olt, %arg0, %arg1 : tensor<128xf32> @@ -153,9 +153,9 @@ func.func @extract_element_from_arith_truncf(%arg0: tensor<128xf32>) -> bf16 { // CHECK-LABEL: @extract_element_from_arith_and func.func @extract_element_from_arith_and(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> i32 { // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xi32> - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xi32> - // CHECK-NEXT: %[[AND:.*]] = arith.andi %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : i32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<128xi32> + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<128xi32> + // CHECK-NEXT: %[[AND:.*]] = arith.andi %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : i32 // CHECK-NOT: linalg.* %c0 = arith.constant 0 : index %0 = arith.andi %arg0, %arg1 : tensor<128xi32> @@ -168,13 +168,13 @@ func.func @extract_element_from_arith_and(%arg0: tensor<128xi32>, %arg1: tensor< func.func @extract_element_from_map_arith_add_with_two_indices( %arg0: tensor<128x16xi32>, %arg1: tensor<128x16xi32>) -> i32 { // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]], %[[C0_INDEX]]] : tensor<128x16xi32> - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]], %[[C0_INDEX]]] : tensor<128x16xi32> - // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : i32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]], %[[C0_INDEX]]] : tensor<128x16xi32> + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]], %[[C0_INDEX]]] : tensor<128x16xi32> + // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : i32 // CHECK-NOT: linalg.* %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<128x16xi32> - %1 = linalg.map { arith.addi } ins(%arg0, %arg1 : tensor<128x16xi32>, tensor<128x16xi32>) outs(%0 : tensor<128x16xi32>) + %1 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128x16xi32>, tensor<128x16xi32>) outs(%0 : tensor<128x16xi32>) %2 = tensor.extract %1[%c0, %c0] : tensor<128x16xi32> return %2 : i32 } @@ -182,14 +182,14 @@ func.func @extract_element_from_map_arith_add_with_two_indices( // ----- // CHECK-LABEL: @extract_element_from_map_with_two_payloads func.func @extract_element_from_map_with_two_payloads(%arg0: tensor<32xi64>, %arg1: tensor<32xi32>) -> i64 { - // CHECK-DAG: %[[C0_INDEX:.*]] = arith.constant 0 : index + // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 + // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 %c1 = arith.constant 1 : i64 - // CHECK: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<32xi64> - // CHECK: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<32xi32> - // CHECK: %[[EXT:.*]] = arith.extsi %[[ARG_EXTRA1]] : i32 to i64 - // CHECK: %[[ADD_0:.*]] = arith.addi %[[ARG_EXTRA0]], %[[EXT]] : i64 + // CHECK: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<32xi32> + // CHECK: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<32xi64> + // CHECK: %[[EXT:.*]] = arith.extsi %[[ARG_EXTRA0]] : i32 to i64 + // CHECK: %[[ADD_0:.*]] = arith.addi %[[ARG_EXTRA1]], %[[EXT]] : i64 // CHECK: %[[ADD_1:.*]] = arith.addi %[[ADD_0]], %[[C1_I64]] : i64 %0 = tensor.empty() : tensor<32xi64> %1 = linalg.map ins(%arg0, %arg1 : tensor<32xi64>, tensor<32xi32>) outs(%0 : tensor<32xi64>) @@ -260,7 +260,7 @@ func.func @extract_element_from_expand_shape_op_with_src_dimension_size_1(%arg0: %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index - %0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<2x12xi32> into tensor<2x3x4xi32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xi32> into tensor<2x3x4xi32> %1 = tensor.extract %0[%c0, %c2, %c3] : tensor<2x3x4xi32> return %1 : i32 } @@ -276,7 +276,7 @@ func.func @extract_element_from_expand_shape_op_with_static_shape(%arg0: tensor< %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c10 = arith.constant 10 : index - %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<128x64xi32> into tensor<4x32x4x16xi32> + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 32, 4, 16] : tensor<128x64xi32> into tensor<4x32x4x16xi32> %1 = tensor.extract %0[%c2, %arg1, %c3, %c10] : tensor<4x32x4x16xi32> return %1 : i32 } @@ -294,9 +294,17 @@ func.func @extract_element_from_expand_shape_op_with_dynamic_shape(%arg0: tensor // CHECK-DAG: %[[VAL1:.*]] = arith.addi %[[VAL0]], %[[C3_INDEX]] : index // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C22_INDEX]], %[[VAL1]]] : tensor // CHECK-NOT: tensor.collapse_shape + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index - %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %c6 = arith.constant 6 : index + %c10 = arith.constant 10 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %new_dim0 = arith.divui %dim0, %c10 : index + %new_dim1 = arith.divui %dim1, %c6 : index + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%new_dim0, 10, 6, %new_dim1] : tensor into tensor %1 = tensor.extract %0[%c2, %c2, %c3, %c3] : tensor return %1 : i32 } @@ -339,26 +347,69 @@ func.func @extract_from_for_iter_args(%arg0: i64, %arg1: tensor<64x64xf32>, %arg %4 = tensor.extract_slice %3[%2, %2] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> %mapped = linalg.map { math.absf } ins(%4 : tensor<64x64xf32>) outs(%arg5 : tensor<64x64xf32>) %5 = tensor.empty() : tensor<64x64xi64> - %mapped_0 = linalg.map { arith.addi } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) + %mapped_0 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) scf.yield %mapped, %mapped_0 : tensor<64x64xf32>, tensor<64x64xi64> } return %0#0 : tensor<64x64xf32> } // CHECK-LABEL: @extract_from_for_iter_args -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[PTR:.*]] = llvm.inttoptr %[[ARG0:.*]] : i64 to !llvm.ptr<1> +// CHECK-DAG: %[[EXTRACT0:.*]] = tensor.extract %[[ARG3:.*]][%[[C0]], %[[C0]]] : tensor<64x64xi64> +// CHECK-DAG: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2:.*]][%[[C0]], %[[C0]]] : tensor<64x64xi64> +// CHECK: %[[VAL1:.*]]:2 = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1:.*]], %[[ARG6:.*]] = %[[EXTRACT1]]) -> (tensor<64x64xf32>, i64) { +// CHECK: %[[ADDI:.*]] = arith.addi %[[ARG6]], %[[EXTRACT0]] : i64 +// CHECK: %[[VAL2:.*]] = arith.index_cast %[[ARG6]] : i64 to index +// CHECK: %[[MEMREF:.*]] = aux.view %[[PTR]] to offset: [0], sizes: [64, %[[VAL2]]] +// CHECK: %[[VAL3:.*]] = bufferization.to_tensor %[[MEMREF]] +// CHECK: %[[VAL4:.*]] = tensor.extract_slice %[[VAL3]][%[[VAL2]], %[[VAL2]]] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> +// CHECK: %[[VAL_MAP:.*]] = linalg.map { math.absf } ins(%[[VAL4]] : tensor<64x64xf32>) outs(%[[ARG5]] : tensor<64x64xf32>) +// CHECK: scf.yield %[[VAL_MAP]], %[[ADDI]] : tensor<64x64xf32>, i64 +// CHECK: } +// CHECK: return %[[VAL1]]#0 : tensor<64x64xf32> + +// ----- +func.func @extract_from_for_iter_args_with_few_compute(%arg0: i64, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xi64>, %arg3: tensor<64x64xi64>, %arg4: index, %arg5: index) -> tensor<64x64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = scf.for %arg6 = %c0 to %c2 step %c1 iter_args(%arg7 = %arg1, %arg8 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xi64>) { + %1 = arith.addi %arg4, %arg5 : index + %2 = arith.muli %arg4, %1 : index + %3 = arith.subi %2, %arg5 : index + %extracted = tensor.extract %arg8[%1, %3] : tensor<64x64xi64> + %4 = arith.index_cast %extracted : i64 to index + %5 = llvm.inttoptr %arg0 : i64 to !llvm.ptr<1> + %view_memref = aux.view %5 to offset: [0], sizes: [64, %4], strides: [1, 1] : <1> to memref<64x?xf32, 1> + %6 = bufferization.to_tensor %view_memref : memref<64x?xf32, 1> + %extracted_slice = tensor.extract_slice %6[%4, %4] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> + %mapped = linalg.map { math.absf } ins(%extracted_slice : tensor<64x64xf32>) outs(%arg7 : tensor<64x64xf32>) + %7 = tensor.empty() : tensor<64x64xi64> + %mapped_0 = linalg.map { arith.addi } ins(%arg8, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%7 : tensor<64x64xi64>) + scf.yield %mapped, %mapped_0 : tensor<64x64xf32>, tensor<64x64xi64> + } + return %0#0 : tensor<64x64xf32> +} +// CHECK-LABEL: @extract_from_for_iter_args_with_few_compute // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL0:.*]] = tensor.extract %[[ARG2:.*]][%[[C0]], %[[C0]]] : tensor<64x64xi64> -// CHECK: %[[VAL1:.*]]:2 = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1:.*]], %[[ARG6:.*]] = %[[VAL0]]) -> (tensor<64x64xf32>, i64) { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[T0:.*]] = arith.addi %arg4, %arg5 : index +// CHECK-DAG: %[[T1:.*]] = arith.muli %arg4, %0 : index +// CHECK-DAG: %[[T2:.*]] = arith.subi %1, %arg5 : index +// CHECK-DAG: %[[PTR:.*]] = llvm.inttoptr %[[ARG0:.*]] : i64 to !llvm.ptr<1> +// CHECK-DAG: %[[EXTRACT0:.*]] = tensor.extract %[[ARG3:.*]][%[[T0]], %[[T2]]] : tensor<64x64xi64> +// CHECK-DAG: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2:.*]][%[[T0]], %[[T2]]] : tensor<64x64xi64> +// CHECK: %[[VAL1:.*]]:2 = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1:.*]], %[[ARG6:.*]] = %[[EXTRACT1]]) -> (tensor<64x64xf32>, i64) { +// CHECK: %[[ADDI:.*]] = arith.addi %[[ARG6]], %[[EXTRACT0]] : i64 // CHECK: %[[VAL2:.*]] = arith.index_cast %[[ARG6]] : i64 to index -// CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ARG0:.*]] : i64 to !llvm.ptr<1> // CHECK: %[[MEMREF:.*]] = aux.view %[[PTR]] to offset: [0], sizes: [64, %[[VAL2]]] // CHECK: %[[VAL3:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: %[[VAL4:.*]] = tensor.extract_slice %[[VAL3]][%[[VAL2]], %[[VAL2]]] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> // CHECK: %[[VAL_MAP:.*]] = linalg.map { math.absf } ins(%[[VAL4]] : tensor<64x64xf32>) outs(%[[ARG5]] : tensor<64x64xf32>) -// CHECK: %[[VAL5:.*]] = tensor.extract %[[ARG3:.*]][%[[C0]], %[[C0]]] : tensor<64x64xi64> -// CHECK: %[[VAL6:.*]] = arith.addi %[[ARG6]], %[[VAL5]] : i64 -// CHECK: scf.yield %[[VAL_MAP]], %[[VAL6]] : tensor<64x64xf32>, i64 +// CHECK: scf.yield %[[VAL_MAP]], %[[ADDI]] : tensor<64x64xf32>, i64 // CHECK: } // CHECK: return %[[VAL1]]#0 : tensor<64x64xf32> @@ -378,7 +429,7 @@ func.func @extract_from_for_iter_args_fail_case0(%arg0: i64, %arg1: tensor<64x64 %4 = tensor.extract_slice %3[%2, %2] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> %mapped = linalg.map { math.absf } ins(%4 : tensor<64x64xf32>) outs(%arg5 : tensor<64x64xf32>) %5 = tensor.empty() : tensor<64x64xi64> - %mapped_0 = linalg.map { arith.addi } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) + %mapped_0 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) %6 = arith.addi %arg7, %c1 : index scf.yield %mapped, %mapped_0, %6 : tensor<64x64xf32>, tensor<64x64xi64>, index } @@ -400,7 +451,7 @@ func.func @extract_from_for_iter_args_fail_case1(%arg0: i64, %arg1: tensor<64x64 %3 = bufferization.to_tensor %memref : memref<64x?xf32, 1> %4 = tensor.extract_slice %3[%2, %2] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> %5 = tensor.empty() : tensor<64x64xi64> - %mapped_0 = linalg.map { arith.addi } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) + %mapped_0 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) %6 = arith.sitofp %arg6 : tensor<64x64xi64> to tensor<64x64xf32> %mapped = linalg.map { arith.addf } ins(%4, %6 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%arg5 : tensor<64x64xf32>) scf.yield %mapped, %mapped_0 : tensor<64x64xf32>, tensor<64x64xi64> @@ -424,10 +475,10 @@ func.func @extract_from_for_iter_args_fail_case2(%arg0: i64, %arg1: tensor<64x64 %4 = tensor.extract_slice %3[%2, %2] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> %mapped = linalg.map { math.absf } ins(%4 : tensor<64x64xf32>) outs(%arg5 : tensor<64x64xf32>) %5 = tensor.empty() : tensor<64x64xi64> - %mapped_0 = linalg.map { arith.addi } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) + %mapped_0 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) %6 = arith.fptosi %mapped : tensor<64x64xf32> to tensor<64x64xi64> %7 = tensor.empty() : tensor<64x64xi64> - %mapped_1 = linalg.map { arith.addi } ins(%mapped_0, %6 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%7 : tensor<64x64xi64>) + %mapped_1 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%mapped_0, %6 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%7 : tensor<64x64xi64>) scf.yield %mapped, %mapped_1 : tensor<64x64xf32>, tensor<64x64xi64> } return %0#0 : tensor<64x64xf32> @@ -449,15 +500,15 @@ func.func @extract_element_from_collapse_shape_op_with_single_element(%arg0: ten // COM: %0 %1 %2 %3 will be bufferized to the same memref. So, we can not move // tensor.extract before %2. // CHECK-LABEL: func.func @test_destination_style_op_for_result_chain( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, f32) { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, f32) { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_3:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_2]] : tensor<64xf32>) -// CHECK: %[[VAL_4:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_3]] : tensor<64xf32>) -// CHECK: %[[VAL_5:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_1]]] : tensor<64xf32> -// CHECK: %[[VAL_6:.*]] = math.absf %[[VAL_5]] : f32 -// CHECK: %[[VAL_7:.*]] = math.exp %[[VAL_6]] : f32 -// CHECK: return %[[VAL_4]], %[[VAL_7]] : tensor<64xf32>, f32 +// CHECK: %[[VAL_2:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_1]]] : tensor<64xf32> +// CHECK: %[[VAL_3:.*]] = math.absf %[[VAL_2]] : f32 +// CHECK: %[[VAL_4:.*]] = math.exp %[[VAL_3]] : f32 +// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[VAL_6:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_5]] : tensor<64xf32>) +// CHECK: %[[VAL_7:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_6]] : tensor<64xf32>) +// CHECK: return %[[VAL_7]], %[[VAL_4]] : tensor<64xf32>, f32 // CHECK: } func.func @test_destination_style_op_for_result_chain(%arg0: tensor<64xf32>) -> (tensor<64xf32>, f32) { %c0 = arith.constant 0 : index @@ -473,13 +524,13 @@ func.func @test_destination_style_op_for_result_chain(%arg0: tensor<64xf32>) -> // CHECK-LABEL: func.func @test_destination_style_op_cross_block( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, f32) { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_3:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_2]] : tensor<64xf32>) -// CHECK: %[[VAL_4:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_3]] : tensor<64xf32>) -// CHECK: %[[VAL_5:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_1]]] : tensor<64xf32> -// CHECK: %[[VAL_6:.*]] = math.absf %[[VAL_5]] : f32 -// CHECK: %[[VAL_7:.*]] = math.exp %[[VAL_6]] : f32 -// CHECK: return %[[VAL_4]], %[[VAL_7]] : tensor<64xf32>, f32 +// CHECK: %[[VAL_2:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_1]]] : tensor<64xf32> +// CHECK: %[[VAL_3:.*]] = math.absf %[[VAL_2]] : f32 +// CHECK: %[[VAL_4:.*]] = math.exp %[[VAL_3]] : f32 +// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[VAL_6:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_5]] : tensor<64xf32>) +// CHECK: %[[VAL_7:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_6]] : tensor<64xf32>) +// CHECK: return %[[VAL_7]], %[[VAL_4]] : tensor<64xf32>, f32 // CHECK: } func.func @test_destination_style_op_cross_block(%arg0: tensor<64xf32>) -> (tensor<64xf32>, f32) { %c0 = arith.constant 0 : index @@ -497,12 +548,12 @@ func.func @test_destination_style_op_cross_block(%arg0: tensor<64xf32>) -> (tens // CHECK-LABEL: func.func @test_destination_style_op_for_init_chain( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, f32) { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_3:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_2]] : tensor<64xf32>) -// CHECK: %[[VAL_4:.*]] = linalg.map { math.exp } ins(%[[VAL_3]] : tensor<64xf32>) outs(%[[VAL_3]] : tensor<64xf32>) -// CHECK: %[[VAL_5:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_1]]] : tensor<64xf32> -// CHECK: %[[VAL_6:.*]] = math.atan %[[VAL_5]] : f32 -// CHECK: return %[[VAL_4]], %[[VAL_6]] : tensor<64xf32>, f32 +// CHECK: %[[VAL_2:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_1]]] : tensor<64xf32> +// CHECK: %[[VAL_3:.*]] = math.atan %[[VAL_2]] : f32 +// CHECK: %[[VAL_4:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[VAL_5:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_4]] : tensor<64xf32>) +// CHECK: %[[VAL_6:.*]] = linalg.map { math.exp } ins(%[[VAL_5]] : tensor<64xf32>) outs(%[[VAL_5]] : tensor<64xf32>) +// CHECK: return %[[VAL_6]], %[[VAL_3]] : tensor<64xf32>, f32 // CHECK: } func.func @test_destination_style_op_for_init_chain(%arg0: tensor<64xf32>) -> (tensor<64xf32>, f32) { %c0 = arith.constant 0 : index @@ -530,9 +581,9 @@ func.func @fill_0d(%arg0: i64) -> i64 { // ----- // CHECK-LABEL: @arith_addi_0d func.func @arith_addi_0d(%arg0: tensor, %arg1: tensor) -> i32 { - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %[[ARG0:.*]][] : tensor - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %[[ARG1:.*]][] : tensor - // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : i32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %[[ARG1:.*]][] : tensor + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %[[ARG0:.*]][] : tensor + // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : i32 // CHECK-NOT: linalg.* %0 = arith.addi %arg0, %arg1 : tensor %1 = tensor.extract %0[] : tensor @@ -553,12 +604,12 @@ func.func @math_erf_0d(%arg0: tensor) -> f32 { // ----- // CHECK-LABEL: @map_0d func.func @map_0d(%arg0: tensor, %arg1: tensor) -> i32 { - // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[] : tensor - // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[] : tensor - // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA0]], %[[ARG_EXTRA1]] : i32 + // CHECK-NEXT: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[] : tensor + // CHECK-NEXT: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[] : tensor + // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG_EXTRA1]], %[[ARG_EXTRA0]] : i32 // CHECK-NOT: linalg.* %0 = tensor.empty() : tensor - %1 = linalg.map { arith.addi } ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) + %1 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) %2 = tensor.extract %1[] : tensor return %2 : i32 } @@ -568,10 +619,10 @@ func.func @map_0d(%arg0: tensor, %arg1: tensor) -> i32 { func.func @map_0d_with_two_payloads(%arg0: tensor, %arg1: tensor) -> i64 { // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 %c1 = arith.constant 1 : i64 - // CHECK: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[] : tensor - // CHECK: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[] : tensor - // CHECK: %[[EXT:.*]] = arith.extsi %[[ARG_EXTRA1]] : i32 to i64 - // CHECK: %[[ADD_0:.*]] = arith.addi %[[ARG_EXTRA0]], %[[EXT]] : i64 + // CHECK: %[[ARG_EXTRA0:.*]] = tensor.extract %arg1[] : tensor + // CHECK: %[[ARG_EXTRA1:.*]] = tensor.extract %arg0[] : tensor + // CHECK: %[[EXT:.*]] = arith.extsi %[[ARG_EXTRA0]] : i32 to i64 + // CHECK: %[[ADD_0:.*]] = arith.addi %[[ARG_EXTRA1]], %[[EXT]] : i64 // CHECK: %[[ADD_1:.*]] = arith.addi %[[ADD_0]], %[[C1_I64]] : i64 %0 = tensor.empty() : tensor %1 = linalg.map ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) @@ -614,7 +665,7 @@ func.func @expand_shape_op_0d(%arg0: tensor) -> i32 { // CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG0:.*]][] : tensor // CHECK-NOT: tensor.collapse_shape %c0 = arith.constant 0 : index - %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1x1xi32> + %0 = tensor.expand_shape %arg0 [] output_shape [1, 1] : tensor into tensor<1x1xi32> %1 = tensor.extract %0[%c0, %c0] : tensor<1x1xi32> return %1 : i32 } @@ -633,26 +684,26 @@ func.func @for_op_with_0d_iter_args(%arg0: i64, %arg1: tensor, %arg2: tenso %4 = tensor.extract_slice %3[%2] [1] [1] : tensor to tensor %mapped = linalg.map { math.absf } ins(%4 : tensor) outs(%arg5 : tensor) %5 = tensor.empty() : tensor - %mapped_0 = linalg.map { arith.addi } ins(%arg6, %arg3 : tensor, tensor) outs(%5 : tensor) + %mapped_0 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg6, %arg3 : tensor, tensor) outs(%5 : tensor) scf.yield %mapped, %mapped_0 : tensor, tensor } return %0#0 : tensor } // CHECK-LABEL: @for_op_with_0d_iter_args -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL0:.*]] = tensor.extract %[[ARG2:.*]][] : tensor -// CHECK: %[[VAL1:.*]]:2 = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1:.*]], %[[ARG6:.*]] = %[[VAL0]]) -> (tensor, i64) { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[PTR:.*]] = llvm.inttoptr %[[ARG0:.*]] : i64 to !llvm.ptr<1> +// CHECK-DAG: %[[EXTRACT0:.*]] = tensor.extract %[[ARG3:.*]][] : tensor +// CHECK-DAG: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2:.*]][] : tensor +// CHECK: %[[VAL1:.*]]:2 = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1:.*]], %[[ARG6:.*]] = %[[EXTRACT1]]) -> (tensor, i64) { +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG6]], %[[EXTRACT0]] : i64 // CHECK: %[[VAL2:.*]] = arith.index_cast %[[ARG6]] : i64 to index -// CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ARG0:.*]] : i64 to !llvm.ptr<1> // CHECK: %[[MEMREF:.*]] = aux.view %[[PTR]] to offset: [0], sizes: [%[[VAL2]]] // CHECK: %[[VAL3:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: %[[VAL4:.*]] = tensor.extract_slice %[[VAL3]][%[[VAL2]]] [1] [1] : tensor to tensor // CHECK: %[[VAL_MAP:.*]] = linalg.map { math.absf } ins(%[[VAL4]] : tensor) outs(%[[ARG5]] : tensor) -// CHECK: %[[VAL5:.*]] = tensor.extract %[[ARG3:.*]][] : tensor -// CHECK: %[[VAL6:.*]] = arith.addi %[[ARG6]], %[[VAL5]] : i64 -// CHECK: scf.yield %[[VAL_MAP]], %[[VAL6]] : tensor, i64 +// CHECK: scf.yield %[[VAL_MAP]], %[[ADD]] : tensor, i64 // CHECK: } // ----- @@ -669,7 +720,7 @@ func.func @extract_from_for_iter_args_failed(%arg0: i64, %arg1: tensor<64x64xf32 %4 = tensor.extract_slice %3[%2, %2] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> %mapped = linalg.map { math.absf } ins(%4 : tensor<64x64xf32>) outs(%arg5 : tensor<64x64xf32>) %5 = tensor.empty() : tensor<64x64xi64> - %mapped_0 = linalg.map { arith.addi } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) + %mapped_0 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg6, %arg3 : tensor<64x64xi64>, tensor<64x64xi64>) outs(%5 : tensor<64x64xi64>) %6 = tensor.empty() : tensor<64x64xi64> %transpose = linalg.transpose ins(%mapped_0 : tensor<64x64xi64>) outs(%6 : tensor<64x64xi64>) permutation = [1, 0] scf.yield %mapped, %transpose : tensor<64x64xf32>, tensor<64x64xi64> @@ -693,7 +744,7 @@ func.func @extract_from_for_iter_args_failed(%arg0: i64, %arg1: tensor<64x64xf32 // CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<64x64xi64> // CHECK: %[[VAL_9:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[ARG6]], %[[ARG3]] : tensor<64x64xi64>, tensor<64x64xi64>) outs(%[[VAL_8]] : tensor<64x64xi64>) // CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<64x64xi64> -// CHECK: %[[VAL_11:.*]] = linalg.transpose ins(%[[VAL_9]] : tensor<64x64xi64>) outs(%[[VAL_10]] : tensor<64x64xi64>) permutation = [1, 0] +// CHECK: %[[VAL_11:.*]] = linalg.transpose ins(%[[VAL_9]] : tensor<64x64xi64>) outs(%[[VAL_10]] : tensor<64x64xi64>) permutation = [1, 0] // CHECK: scf.yield %[[VAL_7]], %[[VAL_11]] : tensor<64x64xf32>, tensor<64x64xi64> // CHECK: } // CHECK: return %[[VAL_0]]#0 : tensor<64x64xf32> diff --git a/test/Dialect/Triton/extractslice-move-backward.mlir b/test/Dialect/Triton/extractslice-move-backward.mlir index 90f4daf..600c7c1 100644 --- a/test/Dialect/Triton/extractslice-move-backward.mlir +++ b/test/Dialect/Triton/extractslice-move-backward.mlir @@ -39,7 +39,7 @@ func.func @extract_slice_from_broadcast_op(%arg0: tensor<128x4xi32>) -> tensor<1 // CHECK: return %[[VAL_2]] : tensor<3x2xi32> // CHECK: } func.func @extract_slice_from_expand_shape_op1(%arg0: tensor<5x6x7xi32>) -> tensor<3x2xi32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<5x6x7xi32> into tensor<5x6x1x7xi32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [5, 6, 1, 7] : tensor<5x6x7xi32> into tensor<5x6x1x7xi32> %1 = tensor.extract_slice %0[1, 0, 0, 2] [3, 1, 1, 2] [1, 1, 1, 1] : tensor<5x6x1x7xi32> to tensor<3x2xi32> return %1 : tensor<3x2xi32> } @@ -48,12 +48,12 @@ func.func @extract_slice_from_expand_shape_op1(%arg0: tensor<5x6x7xi32>) -> tens // CHECK-LABEL: func.func @extract_slice_from_expand_shape_op2( // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x6x7xi32>) -> tensor<3x2x3xi32> { // CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 2] [3, 6, 1] [1, 1, 1] : tensor<5x6x7xi32> to tensor<3x6x1xi32> -// CHECK: %[[VAL_2:.*]] = tensor.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3]] : tensor<3x6x1xi32> into tensor<3x2x3x1xi32> +// CHECK: %[[VAL_2:.*]] = tensor.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3]] output_shape [3, 2, 3, 1] : tensor<3x6x1xi32> into tensor<3x2x3x1xi32> // CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<3x2x3x1xi32> into tensor<3x2x3xi32> // CHECK: return %[[VAL_3]] : tensor<3x2x3xi32> // CHECK: } func.func @extract_slice_from_expand_shape_op2(%arg0: tensor<5x6x7xi32>) -> tensor<3x2x3xi32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<5x6x7xi32> into tensor<5x2x3x7xi32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [5, 2, 3, 7] : tensor<5x6x7xi32> into tensor<5x2x3x7xi32> %1 = tensor.extract_slice %0[1, 0, 0, 2] [3, 2, 3, 1] [1, 1, 1, 1] : tensor<5x2x3x7xi32> to tensor<3x2x3xi32> return %1 : tensor<3x2x3xi32> } @@ -61,12 +61,12 @@ func.func @extract_slice_from_expand_shape_op2(%arg0: tensor<5x6x7xi32>) -> tens // ----- // CHECK-LABEL: func.func @extract_slice_from_expand_shape_op3( // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x6x7xi32>) -> tensor<3x2x2xi32> { -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0], [1, 2], [3]] : tensor<5x6x7xi32> into tensor<5x2x3x7xi32> +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0], [1, 2], [3]] output_shape [5, 2, 3, 7] : tensor<5x6x7xi32> into tensor<5x2x3x7xi32> // CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_1]][1, 0, 0, 2] [3, 1, 2, 2] [1, 1, 1, 1] : tensor<5x2x3x7xi32> to tensor<3x2x2xi32> // CHECK: return %[[VAL_2]] : tensor<3x2x2xi32> // CHECK: } func.func @extract_slice_from_expand_shape_op3(%arg0: tensor<5x6x7xi32>) -> tensor<3x2x2xi32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<5x6x7xi32> into tensor<5x2x3x7xi32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [5, 2, 3, 7] : tensor<5x6x7xi32> into tensor<5x2x3x7xi32> %1 = tensor.extract_slice %0[1, 0, 0, 2] [3, 1, 2, 2] [1, 1, 1, 1] : tensor<5x2x3x7xi32> to tensor<3x2x2xi32> return %1 : tensor<3x2x2xi32> } @@ -149,17 +149,17 @@ func.func @extract_slice_from_collapse_shape_op_with_0_rank(%arg0 : tensor<1x1xf // CHECK-LABEL: func.func @extract_slice_from_map_arith_add( // CHECK-SAME: %[[VAL_0:.*]]: tensor<128x16xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<128x16xi32>) -> tensor<16xi32> { -// CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16x1xi32> -// CHECK: %[[VAL_3:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16x1xi32> +// CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16x1xi32> +// CHECK: %[[VAL_3:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16x1xi32> // CHECK: %[[VAL_4:.*]] = tensor.empty() : tensor<16x1xi32> -// CHECK: %[[VAL_5:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[VAL_2]], %[[VAL_3]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_4]] : tensor<16x1xi32>) +// CHECK: %[[VAL_5:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[VAL_3]], %[[VAL_2]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_4]] : tensor<16x1xi32>) // CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : tensor<16x1xi32> into tensor<16xi32> // CHECK: return %[[VAL_6]] : tensor<16xi32> // CHECK: } func.func @extract_slice_from_map_arith_add( %arg0: tensor<128x16xi32>, %arg1: tensor<128x16xi32>) -> tensor<16xi32> { %0 = tensor.empty() : tensor<128x16xi32> - %1 = linalg.map { arith.addi } ins(%arg0, %arg1 : tensor<128x16xi32>, tensor<128x16xi32>) outs(%0 : tensor<128x16xi32>) + %1 = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128x16xi32>, tensor<128x16xi32>) outs(%0 : tensor<128x16xi32>) %2 = tensor.extract_slice %1[0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16xi32> return %2 : tensor<16xi32> } @@ -169,10 +169,10 @@ func.func @extract_slice_from_map_arith_add( // CHECK-SAME: %[[VAL_0:.*]]: tensor<128x16xi1>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<128x16xf32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<128x16xf32>) -> tensor<100xf32> { -// CHECK: %[[VAL_3:.*]] = tensor.extract_slice %[[VAL_0]][10, 5] [100, 1] [1, 2] : tensor<128x16xi1> to tensor<100x1xi1> +// CHECK: %[[VAL_3:.*]] = tensor.extract_slice %[[VAL_2]][10, 5] [100, 1] [1, 2] : tensor<128x16xf32> to tensor<100x1xf32> // CHECK: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_1]][10, 5] [100, 1] [1, 2] : tensor<128x16xf32> to tensor<100x1xf32> -// CHECK: %[[VAL_5:.*]] = tensor.extract_slice %[[VAL_2]][10, 5] [100, 1] [1, 2] : tensor<128x16xf32> to tensor<100x1xf32> -// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : tensor<100x1xi1>, tensor<100x1xf32> +// CHECK: %[[VAL_5:.*]] = tensor.extract_slice %[[VAL_0]][10, 5] [100, 1] [1, 2] : tensor<128x16xi1> to tensor<100x1xi1> +// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_4]], %[[VAL_3]] : tensor<100x1xi1>, tensor<100x1xf32> // CHECK: %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_6]] {{\[\[}}0, 1]] : tensor<100x1xf32> into tensor<100xf32> // CHECK: return %[[VAL_7]] : tensor<100xf32> // CHECK: } @@ -250,16 +250,16 @@ func.func @extract_slice_from_broadcast_op_with_0d_input(%arg0: tensor) -> // ----- // CHECK-LABEL: func.func @test_destination_style_op_for_result_chain( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, tensor) { -// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_2:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_1]] : tensor<64xf32>) -// CHECK: %[[VAL_3:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_2]] : tensor<64xf32>) -// CHECK: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_0]][0] [1] [1] : tensor<64xf32> to tensor<1xf32> -// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<1xf32> -// CHECK: %[[VAL_6:.*]] = linalg.map { math.absf } ins(%[[VAL_4]] : tensor<1xf32>) outs(%[[VAL_5]] : tensor<1xf32>) -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1xf32> -// CHECK: %[[VAL_8:.*]] = linalg.map { math.exp } ins(%[[VAL_6]] : tensor<1xf32>) outs(%[[VAL_7]] : tensor<1xf32>) -// CHECK: %[[VAL_9:.*]] = tensor.collapse_shape %[[VAL_8]] [] : tensor<1xf32> into tensor -// CHECK: return %[[VAL_3]], %[[VAL_9]] : tensor<64xf32>, tensor +// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][0] [1] [1] : tensor<64xf32> to tensor<1xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_3:.*]] = linalg.map { math.absf } ins(%[[VAL_1]] : tensor<1xf32>) outs(%[[VAL_2]] : tensor<1xf32>) +// CHECK: %[[VAL_4:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = linalg.map { math.exp } ins(%[[VAL_3]] : tensor<1xf32>) outs(%[[VAL_4]] : tensor<1xf32>) +// CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_5]] [] : tensor<1xf32> into tensor +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[VAL_8:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<64xf32>) +// CHECK: %[[VAL_9:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_8]] : tensor<64xf32>) +// CHECK: return %[[VAL_9]], %[[VAL_6]] : tensor<64xf32>, tensor // CHECK: } func.func @test_destination_style_op_for_result_chain(%arg0: tensor<64xf32>) -> (tensor<64xf32>, tensor) { %0 = tensor.empty() : tensor<64xf32> @@ -273,14 +273,14 @@ func.func @test_destination_style_op_for_result_chain(%arg0: tensor<64xf32>) -> // ----- // CHECK-LABEL: func.func @test_destination_style_op_for_init_chain( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, tensor) { -// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_2:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_1]] : tensor<64xf32>) -// CHECK: %[[VAL_3:.*]] = linalg.map { math.exp } ins(%[[VAL_2]] : tensor<64xf32>) outs(%[[VAL_2]] : tensor<64xf32>) -// CHECK: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_0]][0] [1] [1] : tensor<64xf32> to tensor<1xf32> -// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<1xf32> -// CHECK: %[[VAL_6:.*]] = linalg.map { math.atan } ins(%[[VAL_4]] : tensor<1xf32>) outs(%[[VAL_5]] : tensor<1xf32>) -// CHECK: %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_6]] [] : tensor<1xf32> into tensor -// CHECK: return %[[VAL_3]], %[[VAL_7]] : tensor<64xf32>, tensor +// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][0] [1] [1] : tensor<64xf32> to tensor<1xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_3:.*]] = linalg.map { math.atan } ins(%[[VAL_1]] : tensor<1xf32>) outs(%[[VAL_2]] : tensor<1xf32>) +// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_3]] [] : tensor<1xf32> into tensor +// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[VAL_6:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_5]] : tensor<64xf32>) +// CHECK: %[[VAL_7:.*]] = linalg.map { math.exp } ins(%[[VAL_6]] : tensor<64xf32>) outs(%[[VAL_6]] : tensor<64xf32>) +// CHECK: return %[[VAL_7]], %[[VAL_4]] : tensor<64xf32>, tensor // CHECK: } func.func @test_destination_style_op_for_init_chain(%arg0: tensor<64xf32>) -> (tensor<64xf32>, tensor) { %0 = tensor.empty() : tensor<64xf32> @@ -294,16 +294,16 @@ func.func @test_destination_style_op_for_init_chain(%arg0: tensor<64xf32>) -> (t // ----- // CHECK-LABEL: func.func @test_destination_style_op_cross_block( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>) -> (tensor<64xf32>, tensor) { -// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_2:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_1]] : tensor<64xf32>) -// CHECK: %[[VAL_3:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_2]] : tensor<64xf32>) -// CHECK: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_0]][0] [1] [1] : tensor<64xf32> to tensor<1xf32> -// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<1xf32> -// CHECK: %[[VAL_6:.*]] = linalg.map { math.absf } ins(%[[VAL_4]] : tensor<1xf32>) outs(%[[VAL_5]] : tensor<1xf32>) -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1xf32> -// CHECK: %[[VAL_8:.*]] = linalg.map { math.exp } ins(%[[VAL_6]] : tensor<1xf32>) outs(%[[VAL_7]] : tensor<1xf32>) -// CHECK: %[[VAL_9:.*]] = tensor.collapse_shape %[[VAL_8]] [] : tensor<1xf32> into tensor -// CHECK: return %[[VAL_3]], %[[VAL_9]] : tensor<64xf32>, tensor +// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][0] [1] [1] : tensor<64xf32> to tensor<1xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_3:.*]] = linalg.map { math.absf } ins(%[[VAL_1]] : tensor<1xf32>) outs(%[[VAL_2]] : tensor<1xf32>) +// CHECK: %[[VAL_4:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = linalg.map { math.exp } ins(%[[VAL_3]] : tensor<1xf32>) outs(%[[VAL_4]] : tensor<1xf32>) +// CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_5]] [] : tensor<1xf32> into tensor +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[VAL_8:.*]] = linalg.map { math.absf } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<64xf32>) +// CHECK: %[[VAL_9:.*]] = linalg.map { math.atan } ins(%[[VAL_0]] : tensor<64xf32>) outs(%[[VAL_8]] : tensor<64xf32>) +// CHECK: return %[[VAL_9]], %[[VAL_6]] : tensor<64xf32>, tensor // CHECK: } func.func @test_destination_style_op_cross_block(%arg0: tensor<64xf32>) -> (tensor<64xf32>, tensor) { %0 = tensor.empty() : tensor<64xf32> @@ -385,17 +385,17 @@ func.func @extractslice_outside_failed(%arg0: i64, %arg1: tensor<64x64xf32>, %ar // CHECK-LABEL: func.func @extractslice_cross_iter_args( // CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<128x64xi32>) { -// CHECK: %[[VAL_2:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 // CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_5:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [128, 1] [1, 1] : tensor<128x64xi32> to tensor<128xi32> -// CHECK: %[[VAL_6:.*]] = scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_8:.*]] = %[[VAL_5]]) -> (tensor<128xi32>) : i32 { -// CHECK: %[[VAL_9:.*]] = tensor.expand_shape %[[VAL_8]] {{\[\[}}0, 1]] : tensor<128xi32> into tensor<128x1xi32> -// CHECK: "test.foo"(%[[VAL_8]]) : (tensor<128xi32>) -> () -// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [128, 1] [1, 1] : tensor<128x64xi32> to tensor<128x1xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_5:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [128, 1] [1, 1] : tensor<128x64xi32> to tensor<128x1xi32> +// CHECK: %[[VAL_6:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [128, 1] [1, 1] : tensor<128x64xi32> to tensor<128xi32> +// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (tensor<128xi32>) : i32 { +// CHECK: %[[VAL_10:.*]] = tensor.expand_shape %[[VAL_9]] {{\[\[}}0, 1]] output_shape [128, 1] : tensor<128xi32> into tensor<128x1xi32> // CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<128x1xi32> -// CHECK: %[[VAL_12:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[VAL_9]], %[[VAL_10]] : tensor<128x1xi32>, tensor<128x1xi32>) outs(%[[VAL_11]] : tensor<128x1xi32>) +// CHECK: %[[VAL_12:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[VAL_10]], %[[VAL_5]] : tensor<128x1xi32>, tensor<128x1xi32>) outs(%[[VAL_11]] : tensor<128x1xi32>) // CHECK: %[[VAL_13:.*]] = tensor.collapse_shape %[[VAL_12]] {{\[\[}}0, 1]] : tensor<128x1xi32> into tensor<128xi32> +// CHECK: "test.foo"(%[[VAL_9]]) : (tensor<128xi32>) -> () // CHECK: scf.yield %[[VAL_13]] : tensor<128xi32> // CHECK: } // CHECK: return @@ -408,8 +408,45 @@ func.func @extractslice_cross_iter_args(%arg0: tensor<128x64xi32>, %arg1: tensor %extracted_slice = tensor.extract_slice %arg3[0, 0] [128, 1] [1, 1] : tensor<128x64xi32> to tensor<128xi32> "test.foo"(%extracted_slice) : (tensor<128xi32>) -> () %1 = tensor.empty() : tensor<128x64xi32> - %mapped = linalg.map { arith.addi } ins(%arg3, %arg0 : tensor<128x64xi32>, tensor<128x64xi32>) outs(%1 : tensor<128x64xi32>) + %mapped = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg3, %arg0 : tensor<128x64xi32>, tensor<128x64xi32>) outs(%1 : tensor<128x64xi32>) scf.yield %mapped : tensor<128x64xi32> } return } + +// ----- +func.func @extract_slice_from_out_of_if_body(%arg0: i1, %arg1: tensor<12800x1xf32>) -> tensor<2800xf16> { + %0 = tensor.empty() : tensor<12800x1xf32> + %mapped = linalg.map { math.absf } ins(%arg1 : tensor<12800x1xf32>) outs(%0 : tensor<12800x1xf32>) + %1 = tensor.empty() : tensor<12800x1xf16> + %2 = tensor.empty() : tensor<2800xf16> + %mapped_0 = linalg.map { arith.truncf } ins(%mapped : tensor<12800x1xf32>) outs(%1 : tensor<12800x1xf16>) + %3 = scf.if %arg0 -> (tensor<2800xf16>) { + %extracted_slice = tensor.extract_slice %mapped_0[0, 0] [2800, 1] [1, 1] : tensor<12800x1xf16> to tensor<2800xf16> + %4 = tensor.empty() : tensor<2800xf16> + %mapped_1 = linalg.map { math.absf } ins(%extracted_slice : tensor<2800xf16>) outs(%4 : tensor<2800xf16>) + scf.yield %mapped_1 : tensor<2800xf16> + } else { + scf.yield %2 : tensor<2800xf16> + } + return %3 : tensor<2800xf16> +} +// CHECK-LABEL: func.func @extract_slice_from_out_of_if_body( +// CHECK-SAME: %[[VAL_0:.*]]: i1, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<12800x1xf32>) -> tensor<2800xf16> { +// CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [2800, 1] [1, 1] : tensor<12800x1xf32> to tensor<2800x1xf32> +// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<2800x1xf32> +// CHECK: %[[VAL_4:.*]] = linalg.map { math.absf } ins(%[[VAL_2]] : tensor<2800x1xf32>) outs(%[[VAL_3]] : tensor<2800x1xf32>) +// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<2800x1xf16> +// CHECK: %[[VAL_6:.*]] = linalg.map { arith.truncf } ins(%[[VAL_4]] : tensor<2800x1xf32>) outs(%[[VAL_5]] : tensor<2800x1xf16>) +// CHECK: %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_6]] {{\[\[}}0, 1]] : tensor<2800x1xf16> into tensor<2800xf16> +// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<2800xf16> +// CHECK: %[[VAL_9:.*]] = scf.if %[[VAL_0]] -> (tensor<2800xf16>) { +// CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<2800xf16> +// CHECK: %[[VAL_11:.*]] = linalg.map { math.absf } ins(%[[VAL_7]] : tensor<2800xf16>) outs(%[[VAL_10]] : tensor<2800xf16>) +// CHECK: scf.yield %[[VAL_11]] : tensor<2800xf16> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_8]] : tensor<2800xf16> +// CHECK: } +// CHECK: return %[[VAL_9]] : tensor<2800xf16> +// CHECK: } diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 11b0425..ab7d741 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -44,8 +44,7 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.triton_linalg_obj_root, 'test') -config.triton_tools_dir = os.path.join(config.triton_linalg_obj_root, - 'bin') +config.triton_tools_dir = os.path.join(config.triton_linalg_obj_root, 'bin') config.filecheck_dir = os.path.join(config.triton_obj_root, 'bin', 'FileCheck') tool_dirs = [ diff --git a/tools/ci/daily/triton-linalg_daliy.pipeline b/tools/ci/daily/triton-linalg_daliy.pipeline index 29affda..d2c891a 100644 --- a/tools/ci/daily/triton-linalg_daliy.pipeline +++ b/tools/ci/daily/triton-linalg_daliy.pipeline @@ -8,7 +8,7 @@ cnpipe { } container { networkPolicy "cncl-no-internnet-access" - image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/genesis:0.7.13-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" } resReq { @@ -25,8 +25,8 @@ cnpipe { cd triton-linalg git fetch origin pull/${pr_id}/head:local_test git config --global url."http://gitmirror.cambricon.com/git_repos/".insteadOf https:// - git submodule update --init --recursive git checkout local_test + git submodule update --init --recursive git log -1 cd .. ''' @@ -41,7 +41,7 @@ cnpipe { } container { networkPolicy "cncl-no-internnet-access" - image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/genesis:0.7.13-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" } resReq { @@ -64,7 +64,7 @@ cnpipe { stash 'logs', 'task_logs' archiveLog 'logs/', false } - task('build') { + task('build_and_unittest') { stage 'build' node { labelSelector "cambricon.com/mm-daily":true @@ -72,7 +72,7 @@ cnpipe { } container { networkPolicy "cncl-no-internnet-access" - image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/genesis:0.7.13-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" } resReq { @@ -80,50 +80,29 @@ cnpipe { lmtMlus 1 reqCpu 30 lmtCpu 30 - reqMemory '40Gi' - lmtMemory '40Gi' + reqMemory '60Gi' + lmtMemory '60Gi' } unstash 'triton-linalg-pr' script ''' mkdir logs set -e export TRITON_PLUGIN_DIRS=${CI_WORK_DIR}/triton-linalg + export TRITON_BUILD_PROTON=OFF cd triton-linalg/triton + sed -i '435,482d' python/setup.py + sed -i 's/https:\\/\\/oaitriton.blob.core.windows.net\\/public\\/llvm-builds/http:\\/\\/daily.software.cambricon.com\\/download\\/genesis/g' python/setup.py + sed -i '/packages += \\["triton\\/profiler"\\]/d' python/setup.py set -o pipefail - TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true pip3 install -e python --no-build-isolation -vvv | tee ${CI_WORK_DIR}/logs/build_log || exit 1 - ''' - stash 'triton-linalg', 'triton-linalg-build' - stash 'logs', 'task_logs' - archiveLog 'logs/', false - } - task('test') { - stage 'test' - node { - labelSelector "cambricon.com/mm-daily":true - cardType 'MLU370' - } - container { - networkPolicy "cncl-no-internnet-access" - image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' - runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" - } - resReq { - reqMlus 1 - lmtMlus 1 - reqCpu 30 - lmtCpu 30 - reqMemory '40Gi' - lmtMemory '40Gi' - } - unstash 'triton-linalg-build' - script ''' + pip install wheel + export MAX_JOBS=32 + pip3 install -e python --no-build-isolation -vvv | tee ${CI_WORK_DIR}/logs/build_log || exit 1 + mkdir logs - set -e - cd triton-linalg - set -o pipefail + cd .. bash tools/scripts/test_triton-linalg.sh test_linalg_unittest | tee ${CI_WORK_DIR}/logs/test_log || exit 1 ''' - stash 'triton-linalg', 'triton-linalg-test' + stash 'triton-linalg', 'triton-linalg-build' stash 'logs', 'task_logs' archiveLog 'logs/', false } diff --git a/triton b/triton index 6083043..757b6a6 160000 --- a/triton +++ b/triton @@ -1 +1 @@ -Subproject commit 6083043eb7a0722db6bd2ad8efc453300da74819 +Subproject commit 757b6a61e7df814ba806f498f8bb3160f84b120c