Skip to content

Commit

Permalink
feat(triton-linalg): update triton-linalg and update to triton3.0.x(7…
Browse files Browse the repository at this point in the history
…57b6a6)
  • Loading branch information
hexi authored and sethbrin committed Oct 12, 2024
1 parent d090db4 commit e601be5
Show file tree
Hide file tree
Showing 93 changed files with 4,106 additions and 1,954 deletions.
18 changes: 13 additions & 5 deletions .github/ci_script/file_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import argparse


def file_guard(guard_status_file, guard_log_file):
# where stores the last position that pointer pointed to.
where= 0
where = 0
while True:
file = open(guard_log_file, "r")
file.seek(where)
Expand All @@ -28,11 +29,18 @@ def file_guard(guard_status_file, guard_log_file):
exit(-1)
# sleep for a while
time.sleep(2)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Monitor a log file and echo lines, check status to stop.")
parser.add_argument('guard_status_file', type=str, help='The path to the status file.')
parser.add_argument('guard_log_file', type=str, help='The path to the log file.')
parser = argparse.ArgumentParser(
description="Monitor a log file and echo lines, check status to stop.")
parser.add_argument('guard_status_file',
type=str,
help='The path to the status file.')
parser.add_argument('guard_log_file',
type=str,
help='The path to the log file.')

args = parser.parse_args()

file_guard(args.guard_status_file, args.guard_log_file)
1 change: 0 additions & 1 deletion bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ add_llvm_executable(triton-linalg-opt triton-linalg-opt.cpp PARTIAL_SOURCES_INTE

llvm_update_compile_flags(triton-linalg-opt)
target_link_libraries(triton-linalg-opt PRIVATE
ArithTransforms
AuxiliaryTransforms
LinalgExtTransforms
TritonLinalgAnalysis
Expand Down
4 changes: 2 additions & 2 deletions bin/RegisterTritonLinalgDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "triton-linalg/Dialect/Auxiliary/Transforms/AuxOpTilingInterface.h"
#include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "triton-linalg/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.h"
#include "triton-linalg/Dialect/MathExt/IR/MathExt.h"
#include "triton-linalg/Dialect/Triton/Transforms/InferAxisInfoInterfaceImpl.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "triton-linalg/Conversion/Passes.h"
#include "triton-linalg/Dialect/Arith/Transforms/Passes.h"
#include "triton-linalg/Dialect/Triton/Transforms/Passes.h"

inline void registerTritonLinalgDialects(mlir::DialectRegistry &registry) {
Expand All @@ -17,6 +17,7 @@ inline void registerTritonLinalgDialects(mlir::DialectRegistry &registry) {
// TritonLinalg.
registry.insert<mlir::triton::aux::AuxiliaryDialect>();
registry.insert<mlir::triton::linalg_ext::LinalgExtDialect>();
registry.insert<mlir::math_ext::MathExtDialect>();

mlir::triton::aux::registerTilingInterfaceExternalModels(registry);
mlir::triton::linalg_ext::registerTilingInterfaceExternalModels(registry);
Expand All @@ -26,7 +27,6 @@ inline void registerTritonLinalgDialects(mlir::DialectRegistry &registry) {
}

inline void registerTritonLinalgPasses() {
::mlir::triton::arith_ext::registerArithExtPasses();
::mlir::triton::registerTritonLinalgConversionPasses();
::mlir::triton::registerTritonTransformsExtendPasses();
}
6 changes: 6 additions & 0 deletions include/triton-linalg/Analysis/AxisInfoAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ namespace triton {
class AxisInfoLattice : public mlir::dataflow::Lattice<AxisInfoExt> {
public:
using Lattice::Lattice;
ChangeResult join(const AxisInfoExt &rhs);
bool isInitialized() { return initialized; }

private:
bool initialized = false;
using mlir::dataflow::Lattice<AxisInfoExt>::join;
};

//===--------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion include/triton-linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Interfaces)
4 changes: 2 additions & 2 deletions include/triton-linalg/Conversion/LinalgCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
ConversionPatternRewriter &rewriter) const override {
// Remain unchanged if one of operands is scalar.
if (!llvm::all_of(adaptor.getOperands(),
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
[&](Value v) { return isa<ShapedType>(v.getType()); })) {
return failure();
}
// Apply only if all operands are not scalar.
auto loc = op.getLoc();
auto resType = op.getType().template cast<ShapedType>();
auto resType = cast<ShapedType>(op.getType());
auto initDims = getDims(rewriter, loc, op->getOperand(0));
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, initDims, resType.getElementType());
Expand Down
1 change: 0 additions & 1 deletion include/triton-linalg/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "triton-linalg/Conversion/ArithToLinalg/ArithToLinalg.h"
#include "triton-linalg/Conversion/MathToLinalg/MathToLinalg.h"
#include "triton-linalg/Conversion/TritonToLinalg/TritonToLinalg.h"
#include "triton-linalg/Conversion/TritonToTensor/TritonToTensor.h"

namespace mlir {
class Pass;
Expand Down
7 changes: 0 additions & 7 deletions include/triton-linalg/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,4 @@ def MathToLinalgPass : Pass<"convert-math-to-linalg"> {
];
}

def TritonToTensorPass : Pass<"convert-triton-to-tensor", "ModuleOp"> {
let summary = "Convert the operations from the Triton dialect into the Tensor dialect";
let constructor = "mlir::triton::createTritonToTensorPass()";
let dependentDialects = [
"triton::TritonDialect", "tensor::TensorDialect",
];
}
#endif // TRITON_LINALG_CONVERSION_PASSES_TD
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ class DimInfo {
int64_t getContigSize() const { return contigSize; }
int64_t getBroadcastSize() const { return broadcastSize; }
int64_t getDimSize() const { return dimSize; }
bool isBroadcastDim() const { return getContigSize() == 1; }
bool isBroadcastDim() const {
return getContigSize() == 1 && getDimSize() != 1;
}

private:
int64_t contigSize = -1;
Expand Down Expand Up @@ -331,13 +333,46 @@ class TritonPtrScatterConversionBase
class TritonTensorPtrLoadStoreOpConversionBase
: public TritonPtrConversionBase {
protected:
/// Get the actual size of each dim needed to be load, if boundaryCheck is
/// true, return min(tensorShape[dim], dimSize[dim] - offset[dim]).
SmallVector<OpFoldResult>
getActualSizes(Location loc, std::optional<ArrayRef<int>> boundaryCheck,
ArrayRef<int64_t> tensorShape,
const TensorPointerMetaInfoTracker &tracker,
ConversionPatternRewriter &rewriter) const;
/// Actual offsets, padLeftSizes and sizes.
///
/// For example, in a certain dimension, there are several quantities to
/// describe the actual data range. [0, `shape`) represents the valid data
/// range, `offset` represents the offset value, and `blockShape` represents
/// the size of the data block being retrieved.
///
/// There are 3 cases for the position of `offset`.
///
/// Case1: 0 <= offset < shape
/// offset = offset
/// padLeftSize = 0
/// size = min(shape - offset, blockShape)
///
/// Case2: offset < 0
/// offset = 0
/// padLeftSize = min(-offset, blockShape)
/// size = min(shape, blockShape - padLeftSize)
///
/// Case3: offset >= shape
/// offset = offset
/// padLeftSize = shape
/// size = min(0, blockShape) = 0
///
/// These cases can be summarized by the following formula.
/// originOffset = offset
/// offset = max(offset, 0)
/// padLeftSize = min(offset - originOffset, blockShape)
/// size = min(max(shape - offset, 0), blockShape - padLeftSize)
struct PtrInfo {
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> padLeftSizes;
SmallVector<OpFoldResult> sizes;
};

/// Get the actual ptrinfo of each dim needed to be load.
PtrInfo getPtrInfo(Location loc, std::optional<ArrayRef<int>> boundaryCheck,
ArrayRef<int64_t> tensorShape,
const TensorPointerMetaInfoTracker &tracker,
ConversionPatternRewriter &rewriter) const;

SmallVector<DimInfo> getDimInfos(ArrayRef<OpFoldResult> strides,
ArrayRef<int64_t> tensorShape) const;
Expand Down
20 changes: 0 additions & 20 deletions include/triton-linalg/Conversion/TritonToTensor/TritonToTensor.h

This file was deleted.

1 change: 0 additions & 1 deletion include/triton-linalg/Dialect/Arith/CMakeLists.txt

This file was deleted.

5 changes: 0 additions & 5 deletions include/triton-linalg/Dialect/Arith/Transforms/CMakeLists.txt

This file was deleted.

33 changes: 0 additions & 33 deletions include/triton-linalg/Dialect/Arith/Transforms/PassDetail.h

This file was deleted.

27 changes: 0 additions & 27 deletions include/triton-linalg/Dialect/Arith/Transforms/Passes.h

This file was deleted.

18 changes: 0 additions & 18 deletions include/triton-linalg/Dialect/Arith/Transforms/Passes.td

This file was deleted.

18 changes: 9 additions & 9 deletions include/triton-linalg/Dialect/Auxiliary/IR/AuxiliaryOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def StoreResourceOp : Aux_Op<"store"> {
}];

let arguments = (ins
AnyType:$to,
AnyType:$from
Arg<AnyType, "", [MemWrite]>:$to,
Arg<AnyType, "", [MemRead]>:$from
);

let results = (outs);
Expand All @@ -61,22 +61,22 @@ def StoreResourceOp : Aux_Op<"store"> {

let extraClassDeclaration = [{
bool isScalar(const Value& value) {
return !value.getType().isa<ShapedType>();
return !isa<ShapedType>(value.getType());
}

bool hasPureBufferSemantics() {
return ::llvm::all_of(getOperands(),
[&](const Value& opOperand) {
return isScalar(opOperand) ||
opOperand.getType().isa<::mlir::MemRefType>();
isa<::mlir::MemRefType>(opOperand.getType());
});
}

bool hasPureTensorSemantics() {
return ::llvm::all_of(getOperands(),
[&](const Value& opOperand) {
return isScalar(opOperand) ||
opOperand.getType().isa<::mlir::TensorType>();
isa<::mlir::TensorType>(opOperand.getType());
});
}

Expand Down Expand Up @@ -190,7 +190,7 @@ def ViewOp :
}

// The result of the op is always a ranked memref.
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
MemRefType getType() { return cast<MemRefType>(getResult().getType()); }
Value getViewSource() { return getPtr(); }
Value getOffset() {
return getOffsets().empty() ? nullptr : getOffsets()[0];
Expand All @@ -199,7 +199,7 @@ def ViewOp :
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned resultRank = getResult().getType().cast<ShapedType>().getRank();
unsigned resultRank = cast<ShapedType>(getResult().getType()).getRank();
return {1, resultRank, resultRank};
}

Expand Down Expand Up @@ -265,12 +265,12 @@ def PrintOp : Aux_Op<"print", [DeclareOpInterfaceMethods<MemoryEffectsOpInterfac
bool hasPureBufferSemantics() {
return ::llvm::all_of(getOperands(),
[&](const Value& opOperand) {
return opOperand.getType().isa<::mlir::MemRefType>();
return isa<::mlir::MemRefType>(opOperand.getType());
});
}

ShapedType getInitType() {
return getOperands()[0].getType().cast<ShapedType>();;
return cast<ShapedType>(getOperands()[0].getType());;
}

MutableOperandRange getDpsInitsMutable() { return getValuesMutable(); }
Expand Down
1 change: 0 additions & 1 deletion include/triton-linalg/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
add_subdirectory(Arith)
add_subdirectory(Auxiliary)
add_subdirectory(LinalgExt)
add_subdirectory(MathExt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtEnums.td"
include "triton-linalg/Dialect/LinalgExt/IR/LinalgExtInterface.td"
include "triton-linalg/Interfaces/InferResultTypeOpInterface.td"
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
Expand Down
Loading

0 comments on commit e601be5

Please sign in to comment.