Skip to content

Commit

Permalink
[RAISE-BP] Add support for arith.remsi|remui as tt.addptr input (#…
Browse files Browse the repository at this point in the history
…1570)

- Add minimal support for handling `arith.remsi|remui` as `tt.addptr`
input.
- Improve handling of unfolded arithmetic operations when evaluating the
modulo property and constant values.

Closes Issue: #1436 and #1482

---------

Signed-off-by: Maxime France-Pillois <[email protected]>
  • Loading branch information
mfrancepillois authored Aug 7, 2024
1 parent 75eee6f commit 57a10af
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 15 deletions.
145 changes: 144 additions & 1 deletion test/Triton/raise-block-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,79 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
tt.return %3 : tensor<128x2x128xf32>
}


// CHECK: tt.func public @wrap_side_by_side_masked([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64
// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32
// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64
// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64
// CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64
// CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.return
module {
tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%cst_0 = arith.constant dense<2> : tensor<4x1xi32>
%cst_1 = arith.constant dense<6> : tensor<4xi32>
%cst_2 = arith.constant dense<2> : tensor<4xi32>
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%1 = arith.addi %0, %cst_2 : tensor<4xi32>
%2 = arith.addi %0, %cst_1 : tensor<4xi32>
%3 = tt.splat %arg2 : i32 -> tensor<4xi32>
%4 = arith.remsi %2, %3 : tensor<4xi32>
%5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%6 = tt.splat %arg3 : i32 -> tensor<4x1xi32>
%7 = arith.muli %5, %6 : tensor<4x1xi32>
%8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%9 = tt.splat %arg4 : i32 -> tensor<1x4xi32>
%10 = arith.muli %8, %9 : tensor<1x4xi32>
%11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32>
%12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32>
%13 = arith.addi %11, %12 : tensor<4x4xi32>
%14 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
%15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%17 = tt.splat %arg5 : i32 -> tensor<4x1xi32>
%18 = arith.muli %17, %16 : tensor<4x1xi32>
%19 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x1x!tt.ptr<f32>>
%20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr<f32>>, tensor<4x1xi32>
%21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%22 = tt.splat %arg6 : i32 -> tensor<1x4xi32>
%23 = arith.muli %22, %21 : tensor<1x4xi32>
%24 = tt.broadcast %20 : tensor<4x1x!tt.ptr<f32>> -> tensor<4x4x!tt.ptr<f32>>
%25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32>
%26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32>
%28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1>
%29 = arith.muli %arg3, %c4_i32 : i32
%30 = tt.splat %29 : i32 -> tensor<4x4xi32>
%31 = arith.muli %arg4, %c4_i32 : i32
%32 = tt.splat %31 : i32 -> tensor<4x4xi32>
%34 = tt.load %15 : tensor<4x4x!tt.ptr<f32>>
tt.store %26, %34 : tensor<4x4x!tt.ptr<f32>>
tt.return
}
}


// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
Expand Down Expand Up @@ -319,6 +392,77 @@ module {
}
}


// CHECK: tt.func public @wrap_stacked_masked_loop([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64
// CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64
// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32
// CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64
// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.return
module {
tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%cst_0 = arith.constant dense<3> : tensor<1x4xi32>
%cst_1 = arith.constant dense<3> : tensor<4xi32>
%cst_2 = arith.constant dense<2> : tensor<4xi32>
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%1 = arith.addi %0, %cst_2 : tensor<4xi32>
%2 = tt.splat %arg2 : i32 -> tensor<4xi32>
%3 = arith.remui %1, %2 : tensor<4xi32>
%4 = arith.addi %0, %cst_1 : tensor<4xi32>
%5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%6 = tt.splat %arg3 : i32 -> tensor<4x1xi32>
%7 = arith.muli %5, %6 : tensor<4x1xi32>
%8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%9 = tt.splat %arg4 : i32 -> tensor<1x4xi32>
%10 = arith.muli %8, %9 : tensor<1x4xi32>
%11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32>
%12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32>
%13 = arith.addi %11, %12 : tensor<4x4xi32>
%14 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
%15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%17 = tt.splat %arg5 : i32 -> tensor<4x1xi32>
%18 = arith.muli %17, %16 : tensor<4x1xi32>
%19 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x1x!tt.ptr<f32>>
%20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr<f32>>, tensor<4x1xi32>
%21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%22 = tt.splat %arg6 : i32 -> tensor<1x4xi32>
%23 = arith.muli %22, %21 : tensor<1x4xi32>
%24 = tt.broadcast %20 : tensor<4x1x!tt.ptr<f32>> -> tensor<4x4x!tt.ptr<f32>>
%25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32>
%26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32>
%28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1>
%29 = arith.muli %arg4, %c4_i32 : i32
%30 = tt.splat %29 : i32 -> tensor<4x4xi32>
%32 = tt.load %15 : tensor<4x4x!tt.ptr<f32>>
tt.store %26, %32 : tensor<4x4x!tt.ptr<f32>>
tt.return
}
}


// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
Expand Down Expand Up @@ -423,7 +567,6 @@ module {
}



// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
Expand Down
Loading

0 comments on commit 57a10af

Please sign in to comment.