Skip to content

Commit

Permalink
Enable user-provided maximum column size for air-collapse-herd (Xilin…
Browse files Browse the repository at this point in the history
…x#486)

* Enable user-provided maximum column size for air-collapse-herd

* Set default to be 'disabled'

* Add registration method for options
  • Loading branch information
erwei-xilinx authored Mar 13, 2024
1 parent 2b7f5ce commit 2e2d2f1
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 14 deletions.
4 changes: 4 additions & 0 deletions mlir/include/air/Transform/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,10 @@ def AIRLabelBroadcastChannelWithTilePass : Pass<"air-label-broadcast-channel-wit
def AIRCollapseHerdPass : Pass<"air-collapse-herd", "func::FuncOp"> {
let summary = "Collapse a multi-dimensional air.herd into a single column.";
let constructor = "xilinx::air::createAIRCollapseHerdPass()";
let options = [
Option<"clMaxColSize", "max-col-size", "int", /*default=*/"-1",
"The maximum column size after collapse, before collapse is cancelled. Disabled by default.">
];
}

def AIRUnrollOuterPerfectlyNestedLoopsPass : Pass<"air-unroll-outer-affine-loops", "func::FuncOp"> {
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,8 @@ class AIRCollapseHerdPass
public:
AIRCollapseHerdPass() = default;
AIRCollapseHerdPass(const AIRCollapseHerdPass &pass){};
AIRCollapseHerdPass(const ::xilinx::air::AIRCollapseHerdPassOptions &options)
: AIRCollapseHerdPassBase(options) {}

void runOnOperation() override;

Expand All @@ -956,8 +958,12 @@ class AIRCollapseHerdPass
void AIRCollapseHerdPass::runOnOperation() {
SmallVector<air::HerdOp> herds;
auto func = getOperation();
int maximumColumnSize = clMaxColSize;
if (clMaxColSize == -1)
maximumColumnSize = INT_MAX; // max-col-size disabled.
func.walk([&](air::HerdOp op) {
if (op.getNumCols() != 1 && op.getNumDims() == 2)
if (op.getNumCols() != 1 && op.getNumDims() == 2 &&
op.getNumRows() * op.getNumCols() <= maximumColumnSize)
herds.push_back(op);
});

Expand Down
48 changes: 35 additions & 13 deletions mlir/test/Transform/AIRMiscPasses/air_collapse_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
//
//===----------------------------------------------------------------------===//

// RUN: air-opt %s -air-collapse-herd -canonicalize --split-input-file | FileCheck %s
// RUN: air-opt %s -air-collapse-herd="max-col-size=4" -canonicalize --split-input-file | FileCheck %s
// RUN: air-opt %s -air-collapse-herd -canonicalize --split-input-file | FileCheck %s --check-prefix=MAXCOL

// CHECK: func.func @test0
// CHECK-LABEL: func.func @test0
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
// CHECK: %[[CST9:.*]] = arith.constant 9 : index
// CHECK: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST9]])
// CHECK: %[[CST4:.*]] = arith.constant 4 : index
// CHECK: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST4]])
// MAXCOL-LABEL: func.func @test0
// MAXCOL: %[[CST1:.*]] = arith.constant 1 : index
// MAXCOL: %[[CST4:.*]] = arith.constant 4 : index
// MAXCOL: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST4]])

func.func @test0() -> () {
%c3 = arith.constant 3 : index
air.herd tile (%x, %y) in (%sx=%c3, %sy=%c3) {
%c2 = arith.constant 2 : index
air.herd tile (%x, %y) in (%sx=%c2, %sy=%c2) {
}
return
}
Expand All @@ -25,12 +30,12 @@ func.func @test0() -> () {
// CHECK: [[$SET1:#set[0-9]+]] = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 2 >= 0, s1 == 0)>
// CHECK: func.func @test1
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
// CHECK: %[[CST9:.*]] = arith.constant 9 : index
// CHECK: air.herd tile (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[ARG2:.*]]=%[[CST1]], %[[ARG3:.*]]=%[[CST9]])
// CHECK: %[[CST4:.*]] = arith.constant 4 : index
// CHECK: air.herd tile (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[ARG2:.*]]=%[[CST1]], %[[ARG3:.*]]=%[[CST4]])
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[CST3:.*]] = arith.constant 3 : index
// CHECK: %[[VAL0:.*]] = arith.remsi %[[ARG1]], %[[CST3]] : index
// CHECK: %[[VAL1:.*]] = arith.divsi %[[ARG1]], %[[CST3]] : index
// CHECK: %[[CST2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL0:.*]] = arith.remsi %[[ARG1]], %[[CST2]] : index
// CHECK: %[[VAL1:.*]] = arith.divsi %[[ARG1]], %[[CST2]] : index
// CHECK: affine.if [[$SET0]]()[%[[VAL1]], %[[VAL0]]] {
// CHECK: } else {
// CHECK: }
Expand All @@ -40,8 +45,8 @@ func.func @test0() -> () {
#set0 = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 2 >= 0)>
#set1 = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 2 >= 0, s1 == 0)>
func.func @test1() -> () {
%c3 = arith.constant 3 : index
air.herd tile (%x, %y) in (%sx=%c3, %sy=%c3) {
%c2 = arith.constant 2 : index
air.herd tile (%x, %y) in (%sx=%c2, %sy=%c2) {
%c0 = arith.constant 0 : index
affine.if #set0()[%x, %y] {
%alloc = memref.alloc() : memref<8x16xi32, 2>
Expand All @@ -68,3 +73,20 @@ func.func @test1() -> () {
}
return
}

// -----

// CHECK-LABEL: func.func @test2
// CHECK: %[[CST3:.*]] = arith.constant 3 : index
// CHECK: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST3]], %[[VAL3:.*]]=%[[CST3]])
// MAXCOL-LABEL: func.func @test2
// MAXCOL: %[[CST1:.*]] = arith.constant 1 : index
// MAXCOL: %[[CST9:.*]] = arith.constant 9 : index
// MAXCOL: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST9]])

func.func @test2() -> () {
%c3 = arith.constant 3 : index
air.herd tile (%x, %y) in (%sx=%c3, %sy=%c3) {
}
return
}

0 comments on commit 2e2d2f1

Please sign in to comment.