Skip to content

Commit

Permalink
Fixup a minor issue with BufferMemrefToFuncArgs function which fails …
Browse files Browse the repository at this point in the history
…with bf16 (Xilinx#504)
  • Loading branch information
erwei-xilinx authored Mar 18, 2024
1 parent c73fd4b commit b83a3e4
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
7 changes: 3 additions & 4 deletions mlir/lib/Conversion/AIRRtToIpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,10 +1152,9 @@ struct AIRRtToIpuPass : public impl::AIRRtToIpuBase<AIRRtToIpuPass> {
memref = cast.getOperand(0);
}
// push back if unique
if (std::find(memrefs.begin(), memrefs.end(), dma.getMemref()) ==
memrefs.end()) {
memrefs.push_back(dma.getMemref());
memrefTypes.push_back(dma.getMemref().getType());
if (std::find(memrefs.begin(), memrefs.end(), memref) == memrefs.end()) {
memrefs.push_back(memref);
memrefTypes.push_back(memref.getType());
}
});

Expand Down
90 changes: 90 additions & 0 deletions mlir/test/Conversion/AIRRtToIpu/buffer_memref_to_args.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,93 @@ module {
return
}
}

// -----

// Bf16 datatype support.

// CHECK-LABEL: aie.device(ipu)
// CHECK: func.func @func2(%[[VAL_0:.*]]: memref<2097152xi32>, %[[VAL_1:.*]]: memref<2097152xi32>, %[[VAL_2:.*]]: memref<2097152xi32>) {
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][4, 8, 128, 128][0, 128, 1024]) {id = 0 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32>

module {
aie.device(ipu) {
aie.shim_dma_allocation @airMemcpyId26(S2MM, 0, 0)
memref.global "public" @airMemcpyId26 : memref<128x128xbf16, 1 : i32>
aie.shim_dma_allocation @airMemcpyId4(MM2S, 0, 0)
memref.global "public" @airMemcpyId4 : memref<128x256xbf16, 1 : i32>
aie.shim_dma_allocation @airMemcpyId10(MM2S, 0, 0)
memref.global "public" @airMemcpyId10 : memref<128x256xbf16, 1 : i32>
aie.shim_dma_allocation @airMemcpyId7(MM2S, 1, 0)
memref.global "public" @airMemcpyId7 : memref<256x128xbf16, 1 : i32>
aie.shim_dma_allocation @airMemcpyId13(MM2S, 1, 0)
memref.global "public" @airMemcpyId13 : memref<256x128xbf16, 1 : i32>
} {sym_name = "segment_0"}
func.func @func2() {
%c128_i64 = arith.constant 128 : i64
%c8_i64 = arith.constant 8 : i64
%c2048_i64 = arith.constant 2048 : i64
%c256_i64 = arith.constant 256 : i64
%c26_i32 = arith.constant 26 : i32
%c7_i32 = arith.constant 7 : i32
%c4_i32 = arith.constant 4 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c0 = arith.constant 0 : index
%0 = memref.alloc() : memref<2048x2048xbf16>
%1 = airrt.wait_all : !airrt.event
airrt.wait_all %1
memref.assume_alignment %0, 64 : memref<2048x2048xbf16>
%2 = airrt.wait_all : !airrt.event
%3 = memref.alloc() : memref<2048x2048xbf16>
%4 = airrt.wait_all : !airrt.event
airrt.wait_all %4
memref.assume_alignment %3, 64 : memref<2048x2048xbf16>
%5 = airrt.wait_all : !airrt.event
%6 = memref.alloc() : memref<2048x2048xbf16>
%7 = airrt.wait_all : !airrt.event
airrt.wait_all %7
memref.assume_alignment %6, 64 : memref<2048x2048xbf16>
%8 = airrt.wait_all : !airrt.event
%9 = airrt.wait_all %8, %5, %2 : !airrt.event
affine.for %arg0 = 0 to 4 {
affine.for %arg1 = 0 to 4 {
%10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg0]
%11 = airrt.wait_all : !airrt.event
%12 = airrt.wait_all %11 : !airrt.event
%13 = arith.index_cast %arg0 : index to i64
%14 = arith.index_cast %arg1 : index to i64
%15 = arith.index_cast %10 : index to i64
%16 = airrt.dma_memcpy_nd(%c4_i32, %13, %14, %0[%c0_i64, %c0_i64, %15, %c0_i64], [%c1_i64, %c8_i64, %c128_i64, %c256_i64], [%c0_i64, %c256_i64, %c2048_i64]) {metadata = @airMemcpyId10} : (i32, i64, i64, memref<2048x2048xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
%17 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg1]
%18 = airrt.wait_all : !airrt.event
%19 = airrt.wait_all %18 : !airrt.event
%20 = arith.index_cast %arg0 : index to i64
%21 = arith.index_cast %arg1 : index to i64
%22 = arith.index_cast %17 : index to i64
%23 = airrt.dma_memcpy_nd(%c7_i32, %20, %21, %3[%c0_i64, %c0_i64, %c0_i64, %22], [%c1_i64, %c1_i64, %c2048_i64, %c128_i64], [%c0_i64, %c0_i64, %c2048_i64]) {metadata = @airMemcpyId13} : (i32, i64, i64, memref<2048x2048xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
%24 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg0]
%25 = airrt.wait_all : !airrt.event
%26 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg1]
%27 = airrt.wait_all : !airrt.event
%28 = airrt.wait_all %27, %25 : !airrt.event
%29 = arith.index_cast %arg0 : index to i64
%30 = arith.index_cast %arg1 : index to i64
%31 = arith.index_cast %24 : index to i64
%32 = arith.index_cast %26 : index to i64
%33 = airrt.dma_memcpy_nd(%c26_i32, %29, %30, %6[%c0_i64, %c0_i64, %31, %32], [%c1_i64, %c1_i64, %c128_i64, %c128_i64], [%c0_i64, %c0_i64, %c2048_i64]) {metadata = @airMemcpyId26} : (i32, i64, i64, memref<2048x2048xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
%p = airrt.segment_load "segment_0" : i64
%34 = airrt.wait_all : !airrt.event
}
}
return
}
}

0 comments on commit b83a3e4

Please sign in to comment.