Skip to content

Commit

Permalink
Fixup AIRRtToIpuPass wrap tiling, if wrap is larger than 1023 but not…
Browse files Browse the repository at this point in the history
… divisible by 512 (Xilinx#544)

* Retiling SHIM DMA BD wrap beyond 512

* Test
  • Loading branch information
erwei-xilinx authored Apr 23, 2024
1 parent 7cafce1 commit 0b3ba11
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 19 deletions.
23 changes: 18 additions & 5 deletions mlir/lib/Conversion/AIRRtToIpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,19 @@ bool violatesAIE2WrapLimit(airrt::DmaMemcpyNdOp dma) {
return false;
}

// A naive implementation to find largest factor, smaller than a given int, for
// a given integer.
int getLargestFactorSmallerThan(int inputInt, int smallerThanInt = 0) {
int factor = 1;
for (int i = 2; i < inputInt; i++) {
if (smallerThanInt && i >= smallerThanInt)
break;
if (inputInt % i == 0)
factor = i;
}
return factor;
}

void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
auto loc = memcpy_op->getLoc();
auto oper_begin = memcpy_op.getOperands().begin();
Expand All @@ -547,18 +560,18 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
auto const_stride = *getConstantIntValue(strides[i]);
if (const_wrap >= AIE2_WRAP_UPPER_BOUND) {
// Found dimension with illegal wrap. Tiling.
assert(!(const_wrap % (AIE2_WRAP_UPPER_BOUND / 2)) &&
"Currently do not support remainder tiles");
int new_wrap = mlir::ceilDiv(const_wrap, AIE2_WRAP_UPPER_BOUND / 2);
int inner_wrap =
getLargestFactorSmallerThan(const_wrap, AIE2_WRAP_UPPER_BOUND);
int new_wrap = mlir::ceilDiv(const_wrap, inner_wrap);
wraps[i] = builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), AIE2_WRAP_UPPER_BOUND / 2));
IntegerAttr::get(builder.getI64Type(), inner_wrap));
wraps.insert(wraps.begin() + i,
builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), new_wrap)));
auto new_const_stride =
(const_stride * AIE2_WRAP_UPPER_BOUND / 2) %
(const_stride * inner_wrap) %
air::getTensorVolume(
memcpy_op.getMemref().getType().cast<MemRefType>());
strides.insert(
Expand Down
72 changes: 58 additions & 14 deletions mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,58 @@ module {

// -----

// Dealing with scenarios where wrap dimension in airrt.dma_memcpy_nd goes beyond the [0, 1023] hardware limit (test case 2).

// CHECK-LABEL: aie.device(ipu)
// CHECK: func.func @func10(%[[ARG0:.*]]: memref<2654208xi32>)
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>

#map = affine_map<()[s0] -> (s0 * 64)>
module {
aie.device(ipu) {
%tile_0_0 = aie.tile(0, 0)
aie.shim_dma_allocation @airMemcpyId21(MM2S, 0, 2)
memref.global "public" @airMemcpyId21 : memref<256x64xbf16, 1>
} {sym_name = "segment_0"}
airrt.module_metadata{
}
func.func @func10(%arg2: memref<2304x2304xbf16>) {
%c64_i64 = arith.constant 64 : i64
%c8_i64 = arith.constant 8 : i64
%c2304_i64 = arith.constant 2304 : i64
%c256_i64 = arith.constant 256 : i64
%c26_i32 = arith.constant 26 : i32
%c15_i32 = arith.constant 15 : i32
%c14_i32 = arith.constant 14 : i32
%c21_i32 = arith.constant 21 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
affine.for %arg0 = 0 to 3 {
affine.for %arg1 = 0 to 3 {
%p = airrt.segment_load "segment_0" : i64
%34 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%arg1]
%39 = arith.index_cast %arg0 : index to i64
%40 = arith.index_cast %arg1 : index to i64
%41 = arith.index_cast %34 : index to i64
%42 = airrt.dma_memcpy_nd(%c21_i32, %39, %40, %arg2[%c0_i64, %c0_i64, %c0_i64, %41], [%c1_i64, %c1_i64, %c2304_i64, %c64_i64], [%c0_i64, %c0_i64, %c2304_i64]) {metadata = @airMemcpyId21} : (i32, i64, i64, memref<2304x2304xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
}
}
return
}
}

// -----

// 16-bit type conversion

// CHECK-LABEL: func.func @func10
// CHECK-LABEL: func.func @func11
// CHECK-SAME: %arg0: memref<8192xi32>
// CHECK-NEXT: aiex.ipu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][4, 4, 32, 16][2048, 16, 64]){{.*}}: memref<8192xi32>
module {
aie.device(ipu) {
func.func @func10(%arg0: memref<128x128xbf16>, %arg1: memref<128x128xbf16>) {
func.func @func11(%arg0: memref<128x128xbf16>, %arg1: memref<128x128xbf16>) {
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c4_i64 = arith.constant 4 : i64
Expand All @@ -541,11 +585,11 @@ module {

// 16-bit conversion with dma operands that aren't function arguments

// CHECK-LABEL: func.func @func11
// CHECK-LABEL: func.func @func12
// CHECK-SAME: %arg0: memref<16xi32>
// CHECK-NEXT: aiex.ipu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][1, 1, 1, 16][0, 0, 0]) {{.*}} : memref<16xi32>
module {
func.func @func11() {
func.func @func12() {
%c1_i32 = arith.constant 1 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
Expand Down Expand Up @@ -580,14 +624,14 @@ module {
//
// The key difference is that memref.alloc is removed.

// CHECK-LABEL: func12
// CHECK-LABEL: func13
// CHECK-NOT: memref.alloc
// CHECK: memref.assume_alignment
// CHECK-SAME: memref<16xi32>
// CHECK-NOT: memref.alloc
// CHECK: return
module {
func.func @func12() {
func.func @func13() {

%c1_i32 = arith.constant 1 : i32
%c0_i64 = arith.constant 0 : i64
Expand All @@ -606,11 +650,11 @@ module {

// Multi-dimensional offset collapsing

// CHECK-LABEL: func.func @func13
// CHECK-LABEL: func.func @func14
// CHECK-SAME: %arg0: memref<512xi32>
// CHECK-NEXT: aiex.ipu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 264][1, 1, 16, 8][0, 0, 16]) {id = 0 : i64, metadata = @md0} : memref<512xi32>
module {
func.func @func13(%arg0 : memref<32x32xbf16>) {
func.func @func14(%arg0 : memref<32x32xbf16>) {
%c1_i32 = arith.constant 1 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
Expand All @@ -625,10 +669,10 @@ module {

// Loop carried event

// CHECK-LABEL: func.func @func14
// CHECK-LABEL: func.func @func15
// CHECK-NEXT: return
module {
func.func @func14() {
func.func @func15() {
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
Expand All @@ -653,7 +697,7 @@ module {
// CHECK: aie.shim_dma_allocation @airMemcpyId17(MM2S, 0, 0)
// CHECK: aie.shim_dma_allocation @airMemcpyId12(MM2S, 1, 0)
// CHECK: aie.shim_dma_allocation @airMemcpyId22(MM2S, 1, 0)
// CHECK-LABEL: func.func @func15
// CHECK-LABEL: func.func @func16
// CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) {
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
Expand Down Expand Up @@ -691,7 +735,7 @@ module {
aie.shim_dma_allocation @airMemcpyId22(MM2S, 1, 0)
memref.global "public" @airMemcpyId22 : memref<256x256xbf16, 1>
} {sym_name = "segment_0"}
func.func @func15(%arg0: memref<512x1024xbf16>, %arg1: memref<1024x512xbf16>, %arg2: memref<512x512xbf16>) {
func.func @func16(%arg0: memref<512x1024xbf16>, %arg1: memref<1024x512xbf16>, %arg2: memref<512x512xbf16>) {
%c64_i64 = arith.constant 64 : i64
%c512_i64 = arith.constant 512 : i64
%c4_i64 = arith.constant 4 : i64
Expand Down Expand Up @@ -751,10 +795,10 @@ module {

// AIRRt alloc / dealloc.

// CHECK-LABEL: func.func @func16
// CHECK-LABEL: func.func @func17
// CHECK-NEXT: return
module {
func.func @func16() {
func.func @func17() {
%0 = airrt.alloc : memref<8x16xi32, 1 : i32>
%1 = airrt.alloc : memref<32x16xi32, 1 : i32>
%2 = airrt.alloc : memref<8x32xi32, 1 : i32>
Expand Down

0 comments on commit 0b3ba11

Please sign in to comment.