From 57a10afbebde29ecb0220b56351f823b4c941ac6 Mon Sep 17 00:00:00 2001 From: Maxime France-Pillois Date: Wed, 7 Aug 2024 15:17:44 +0100 Subject: [PATCH] [RAISE-BP] Add support for `arith.remsi|remui` as `tt.addptr` input (#1570) - Add minimal support for handling `arith.remsi|remui` as `tt.addptr` input. - Improve handling of unfolded arithmetic operations when evaluating the modulo property and constant values. Closes Issue: #1436 and #1482 --------- Signed-off-by: Maxime France-Pillois --- test/Triton/raise-block-pointer.mlir | 145 ++++++++++++- .../TritonRaiseBlockPointer.cpp | 203 ++++++++++++++++-- 2 files changed, 333 insertions(+), 15 deletions(-) diff --git a/test/Triton/raise-block-pointer.mlir b/test/Triton/raise-block-pointer.mlir index a5b1cf3b24..36dcd5be60 100644 --- a/test/Triton/raise-block-pointer.mlir +++ b/test/Triton/raise-block-pointer.mlir @@ -237,6 +237,79 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr) -> tensor<128x2x128x tt.return %3 : tensor<128x2x128xf32> } + +// CHECK: tt.func public @wrap_side_by_side_masked([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { +// CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32 +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32 +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 +// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 +// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64 +// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32 +// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64 +// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64 +// CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array} : > +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64 +// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64 +// CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array} : > +// CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr> +// CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr> +// CHECK: tt.return +module { +tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c4_i32 = arith.constant 4 : i32 + %cst_0 = arith.constant dense<2> : tensor<4x1xi32> + %cst_1 = arith.constant dense<6> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = arith.addi %0, %cst_1 : tensor<4xi32> + %3 = tt.splat %arg2 : i32 -> tensor<4xi32> + %4 = arith.remsi %2, %3 : tensor<4xi32> + %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %6 = tt.splat %arg3 : i32 -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %9 = tt.splat %arg4 : i32 -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> + %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %17 = tt.splat %arg5 : i32 -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %22 = tt.splat %arg6 : i32 -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> + %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> + %29 = arith.muli %arg3, %c4_i32 : i32 + %30 = tt.splat %29 : i32 -> tensor<4x4xi32> + %31 = arith.muli %arg4, %c4_i32 : i32 + %32 = tt.splat %31 : i32 -> tensor<4x4xi32> + %34 = tt.load %15 : tensor<4x4x!tt.ptr> + tt.store %26, %34 : tensor<4x4x!tt.ptr> + tt.return + } +} + + // CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { // CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 // CHECK: [[CST_3_:%.+]] = arith.constant 3 : index @@ -319,6 +392,77 @@ module { } } + +// CHECK: tt.func public @wrap_stacked_masked_loop([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { +// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 +// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 +// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 +// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 +// CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64 +// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64 +// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32 +// CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array} : > +// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64 +// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64 +// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array} : > +// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr> +// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr> +// CHECK: tt.return +module { + tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c4_i32 = arith.constant 4 : i32 + %cst_0 = arith.constant dense<3> : tensor<1x4xi32> + %cst_1 = arith.constant dense<3> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = tt.splat %arg2 : i32 -> tensor<4xi32> + %3 = arith.remui %1, %2 : tensor<4xi32> + %4 = arith.addi %0, %cst_1 : tensor<4xi32> + %5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %6 = tt.splat %arg3 : i32 -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %9 = tt.splat %arg4 : i32 -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> + %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %17 = tt.splat %arg5 : i32 -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %22 = tt.splat %arg6 : i32 -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32> + %28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1> + %29 = arith.muli %arg4, %c4_i32 : i32 + %30 = tt.splat %29 : i32 -> tensor<4x4xi32> + %32 = tt.load %15 : tensor<4x4x!tt.ptr> + tt.store %26, %32 : tensor<4x4x!tt.ptr> + tt.return + } +} + + // CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { // CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 // CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 @@ -423,7 +567,6 @@ module { } - // CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr) { // CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 // CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index f91e775193..f6e78e3fbd 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -25,6 +25,51 @@ namespace { constexpr unsigned offsetBitwidth = 32; constexpr unsigned shapeAndStridesBitwidth = 64; +std::optional getIntAttr(const OpFoldResult ofr) { + if (ofr.is() && isa(ofr.get())) + return cast(ofr.get()).getInt(); + return std::nullopt; +} + +// This function folds the `op` operation and returns the constant value if it +// has successfully folded to a constant. Otherwise, it returns `std::nullopt`. +std::optional getFoldedConstantValue(Operation *op) { + SmallVector results; + if (failed(op->fold(results))) { + return std::nullopt; + } + + // If fold succeeded but `results` is empty, we give a second try, after the + // operands have been switched during the first call to `fold()`. + if (results.empty()) { + if (failed(op->fold(results))) { + return std::nullopt; + } + } + + if (results.size() != 1) { + return std::nullopt; + } + + auto intAttr = getIntAttr(results[0]); + if (intAttr.has_value()) { + return intAttr.value(); + } + + auto val = cast(results[0]); + auto constOp = val.getDefiningOp(); + if (!constOp) + return std::nullopt; + + return getIntAttr(constOp.getValue()); +} + +// return true if the `val` value is a constant containing a value equal to zero +bool hasConstZero(Value val) { + auto intVal = getFoldedConstantValue(val.getDefiningOp()); + return (intVal.has_value() && (intVal.value() == 0)); +} + // Data structure used to decode pointer arithmetics. Offsets, sizes, and // strides are in unit of elements in a linearly laid-out memory, which is the // same as pointer arithmetic operations in Triton language. Scalar is a @@ -70,10 +115,7 @@ struct PtrState { // When PtrState describes a non-block pointer, shape field indicates how // address wraps around. As a result, a constant 0 indicates no wrap around // (i.e. modulo) for the dimension. - if (auto intOp = shape[dim].getDefiningOp()) { - return intOp.value() != 0; - } - return true; + return !hasConstZero(shape[dim]); } // @return true if addresses wrap around in any of the pointer dimension. @@ -113,11 +155,50 @@ struct PtrState { sizes.push_back(lhsState.sizes[i]); } + // AddPtr where both lhs and rhs containing modulo operators not supported + if (lhsState.hasModulo() && rhsState.hasModulo()) { + op->emitRemark( + "TritonRaiseBlockPointer: do not support adding two pointer states " + "that both have modulo"); + return failure(); + } + + assert( + !(lhsState.hasModulo() || rhsState.hasModulo()) || + (lhsState.getRank() <= 2) && + "cannot have rank > 2 if operand one of the operands has a modulo"); + + // dealing with modulo: + // - If lhs has no modulo, skip + // - If rhs has zero offset on dim i, we can just use lhs's modulo + // - Else, the analysis fails + + // An example for the 3rd condition above can look like: + // %0 = tt.splat %scalar + // %1 = tt.splat %ptr + // %2 = tt.arange + // %3 = arith.remsi %2, %size + // %4 = tt.addptr %1, %3 + // %5 = tt.addptr %4, %0 + // %5 may also occur in a loop to increment %4 every iteration. + const PtrState *lhs = &lhsState; const PtrState *rhs = &rhsState; - for (uint64_t i = 0; i < lhs->getRank(); ++i) { - shape.push_back(lhs->shape[i]); + if (rhs->hasModulo()) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->getRank(); i++) { + if (!lhs->dimHasModulo(i)) { + shape.push_back(lhs->shape[i]); + } else if (hasConstZero(rhs->offsets[i])) { + shape.push_back(lhs->shape[i]); + } else { + op->emitRemark("TritonRaiseBlockPointer: do not support adding to " + "operand with modulo"); + return failure(); + } } return success(); @@ -155,9 +236,18 @@ struct PtrState { ArithBuilder abuilder(builder, loc); for (const auto &[offset, stride, dim, size] : llvm::zip(lhs->offsets, lhs->strides, lhs->shape, lhs->sizes)) { - Value newOffset = abuilder.mul(offset, i32Scalar); - Value newStride = abuilder.mul(stride, i64Scalar); - Value newDim = abuilder.mul(dim, i64Scalar); + + Value newOffset = + abuilder.mul(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI32Type(), offset), + i32Scalar); + Value newStride = + abuilder.mul(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), stride), + i64Scalar); + Value newDim = abuilder.mul(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), dim), + i64Scalar); offsets.push_back(newOffset); strides.push_back(newStride); @@ -695,11 +785,12 @@ struct TritonRaiseBlockPointer } return TypeSwitch(definingOp) - .Case([this, &state, loc, &builder](auto op) { - return visitAddPointerOperand(op, state, loc, builder); - }) + .Case( + [this, &state, loc, &builder](auto op) { + return visitAddPointerOperand(op, state, loc, builder); + }) .Default([](Operation *op) { llvm::dbgs() << "TritonRaiseBlockPointer: encountered addptr operand " "produced by an unsupported operation\n" @@ -712,6 +803,11 @@ struct TritonRaiseBlockPointer LogicalResult visitAddPointerOperand(OpTy op, PtrState &state, Location loc, OpBuilder &builder); + template ::value>> + LogicalResult visitAddPointerRemOperand(OpTy remOp, PtrState &state, + Location loc, OpBuilder &builder); + template ::value>> LogicalResult rewriteLoadStoreOp(OpTy op) { @@ -762,6 +858,85 @@ struct TritonRaiseBlockPointer int level = 0; }; +template ::value>> +LogicalResult TritonRaiseBlockPointer::visitAddPointerRemOperand( + OpTy remOp, PtrState &state, Location loc, OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + PtrState rhsState; + if (failed(visitOperand(remOp.getRhs(), rhsState, loc, builder))) { + return failure(); + } + + if (!rhsState.scalar) { + remOp->emitRemark( + "TritonRaiseBlockPointer: only support cases when rhs of remainder " + "contains scalar"); + return failure(); + } + + if (failed(visitOperand(remOp.getLhs(), state, loc, builder))) { + return failure(); + } + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + if (state.hasModulo()) { + remOp->emitRemark("TritonRaiseBlockPointer: do not support multiple modulo " + "within an expression"); + return failure(); + } + + switch (state.getRank()) { + case 1: + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.shape.back() = rhsState.scalar; + break; + case 2: { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.shape[1] = rhsState.scalar; + } else if (shape[1] == 1) { + state.shape[0] = rhsState.scalar; + } else { + remOp->emitRemark("TritonRaiseBlockPointer: taking modulo on a 2D tensor " + "with no singleton dimension not supported"); + return failure(); + } + break; + } + default: + remOp->emitRemark("TritonRaiseBlockPointer: unsupported modulo pattern"); + return failure(); + } + + return success(); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::RemSIOp remOp, PtrState &state, Location loc, OpBuilder &builder) { + return visitAddPointerRemOperand(remOp, state, loc, builder); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::RemUIOp remOp, PtrState &state, Location loc, OpBuilder &builder) { + return visitAddPointerRemOperand(remOp, state, loc, builder); +} + template <> LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand(triton::MakeRangeOp rangeOp,