Skip to content

Commit

Permalink
Add bf16 support in tile / memtile DMA BDs (Xilinx#465)
Browse files Browse the repository at this point in the history
* Add bf16 support in tile / memtile DMA BDs

* Clang format
  • Loading branch information
erwei-xilinx authored Mar 1, 2024
1 parent 9859ffe commit a9bc6ed
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 6 deletions.
3 changes: 2 additions & 1 deletion mlir/include/air/Conversion/AIRToAIESchedulingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ int getRepeatCount(Operation *memcpy_op);

std::vector<AIE::BDDimLayoutAttr>
getWrapsAndStrides(SmallVector<Value> memcpy_sizes,
SmallVector<Value> memcpy_strides, MLIRContext *ctx);
SmallVector<Value> memcpy_strides, int byte_count_per_elem,
MLIRContext *ctx);

bool isDefaultDataAccessPattern(SmallVector<Value> memcpy_sizes,
SmallVector<Value> memcpy_strides,
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/AIRToAIEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2442,9 +2442,9 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
: AIE::LockAction::Acquire,
lockAqValue);

std::vector<AIE::BDDimLayoutAttr> dims =
getWrapsAndStrides(sizes, strides, ndcpy->getContext());

std::vector<AIE::BDDimLayoutAttr> dims = getWrapsAndStrides(
sizes, strides, getElementSizeInBytes(memref.getType()),
ndcpy->getContext());
auto wraps_and_strides =
AIE::BDDimLayoutArrayAttr::get(ndcpy->getContext(), ArrayRef(dims));
bool useDefaultDataAccessPattern =
Expand Down
14 changes: 12 additions & 2 deletions mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ int air::getRepeatCount(Operation *memcpy_op) {

std::vector<AIE::BDDimLayoutAttr>
air::getWrapsAndStrides(SmallVector<Value> memcpy_sizes,
SmallVector<Value> memcpy_strides, MLIRContext *ctx) {
SmallVector<Value> memcpy_strides,
int byte_count_per_elem, MLIRContext *ctx) {
assert(byte_count_per_elem == 4 || byte_count_per_elem == 2 ||
byte_count_per_elem == 1 && "unsupported data format");
int div_factor = mlir::ceilDiv(4, byte_count_per_elem);
if (memcpy_sizes.empty() || memcpy_strides.empty())
return std::vector<AIE::BDDimLayoutAttr>{};
assert(memcpy_sizes.size() == memcpy_strides.size() &&
Expand All @@ -207,9 +211,15 @@ air::getWrapsAndStrides(SmallVector<Value> memcpy_sizes,
for (unsigned i = 0; i < memcpy_sizes.size(); i++) {
auto stepsize = mlir::getConstantIntValue(memcpy_strides[i]);
assert(stepsize && "non-static stride");
int stepsize_v = *stepsize;
auto wrap = mlir::getConstantIntValue(memcpy_sizes[i]);
assert(wrap && "non-static wrap");
auto tuple = AIE::BDDimLayoutAttr::get(ctx, *wrap, *stepsize);
int wrap_v = *wrap;
if (i < memcpy_sizes.size() - 1)
stepsize_v /= div_factor;
if (i == memcpy_sizes.size() - 1)
wrap_v /= div_factor;
auto tuple = AIE::BDDimLayoutAttr::get(ctx, wrap_v, stepsize_v);
output.push_back(tuple);
}
return output;
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ std::string air::getElementTypeAsString(const mlir::Type ty) {

// An incomplete lookup table of common data types
uint64_t air::getElementSizeInBytes(const mlir::Type ty) {
if (auto memrefTy = ty.cast<MemRefType>()) {
return memrefTy.getElementTypeBitWidth() / 8;
}
auto typeAsString = getElementTypeAsString(ty);
if (typeAsString == "i32")
return 4;
Expand Down
67 changes: 67 additions & 0 deletions mlir/test/Conversion/AIRToAIE/air_shimcpy_to_aie2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,70 @@ func.func @func10(%arg0: memref<128xf32>, %arg1: memref<128xf32>) {
return
}

// -----

// Bf16 datatype support.
// CHECK: aie.device(xcve2802)
// CHECK: %[[tileDMA_0_4:.*]] = aie.mem
// CHECK: aie.dma_start(S2MM, 0, ^bb1, ^bb2)
// CHECK: aie.dma_bd({{.*}} : memref<32x256xbf16, 2>, 16384, 8192, [<size = 8, stride = 16>, <size = 32, stride = 128>, <size = 16, stride = 1>])
// CHECK: %[[tileDMA_0_3:.*]] = aie.mem
// CHECK: aie.dma_start(S2MM, 0, ^bb1, ^bb2)
// CHECK: aie.dma_bd({{.*}} : memref<32x256xbf16, 2>, 16384, 8192, [<size = 8, stride = 16>, <size = 32, stride = 128>, <size = 16, stride = 1>])
// CHECK: %[[memTileDMA_2_1:.*]] = aie.memtile_dma
// CHECK: aie.dma_start(MM2S, 0, ^bb1, ^bb3)
// CHECK: memref<32x256xbf16, 1>, 0, 65536, [<size = 8, stride = 16>, <size = 32, stride = 128>, <size = 128, stride = 1>])
// CHECK: aie.dma_start(MM2S, 1, ^bb4, ^bb2)
// CHECK: memref<32x256xbf16, 1>, 0, 65536, [<size = 8, stride = 16>, <size = 32, stride = 128>, <size = 128, stride = 1>])

#map = affine_map<()[s0] -> (s0 * 32)>
air.channel @channel_1 [2, 1]
func.func @func10(%arg0: memref<128xbf16>, %arg1: memref<128xbf16>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg2) in (%arg3=%c2) attributes {id = 1 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 1 : i64, y_loc = 3 : i64, y_size = 2 : i64} {
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2_0 = arith.constant 2 : index
%c8 = arith.constant 8 : index
%async_token, %results = air.execute -> (memref<32x256xbf16, 1>) {
%alloc = memref.alloc() : memref<32x256xbf16, 1>
air.execute_terminator %alloc : memref<32x256xbf16, 1>
}
%2 = scf.parallel (%arg4) = (%c0) to (%c2_0) step (%c1) init (%async_token) -> !air.async.token {
%4 = air.channel.put async [%async_token] @channel_1[%arg4, %c0] (%results[%c0, %c0, %c0] [%c8, %c32, %c256] [%c32, %c256, %c1]) {id = 4 : i32} : (memref<32x256xbf16, 1>)
scf.reduce(%4 : !air.async.token) {
^bb0(%arg5: !air.async.token, %arg6: !air.async.token):
%5 = air.wait_all async [%arg5, %arg6]
scf.reduce.return %5 : !air.async.token
}
}
%3 = air.herd @herd_0 async [%async_token] tile (%arg4, %arg5) in (%arg6=%c1, %arg7=%c2_0) attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 3 : i64} {
%c0_2 = arith.constant 0 : index
%c1_4 = arith.constant 1 : index
%c32_3 = arith.constant 32 : index
%c256_5 = arith.constant 256 : index
%c8_6 = arith.constant 8 : index
%4 = air.wait_all async
%async_token_3, %results_4 = air.execute -> (memref<32x256xbf16, 2>) {
%alloc = memref.alloc() : memref<32x256xbf16, 2>
air.execute_terminator %alloc : memref<32x256xbf16, 2>
}
%5 = air.channel.get async [%4, %async_token_3] @channel_1[%arg5, %c0_2] (%results_4[%c0_2, %c32_3, %c0_2] [%c8_6, %c32_3, %c32_3] [%c32_3, %c256_5, %c1_4]) {id = 6 : i32} : (memref<32x256xbf16, 2>)
%async_token_5 = air.execute [%5] {
memref.dealloc %results_4 : memref<32x256xbf16, 2>
}
air.herd_terminator
}
%async_token_1 = air.execute [%3] {
memref.dealloc %results : memref<32x256xbf16, 1>
}
air.segment_terminator
}
air.launch_terminator
}
return
}

0 comments on commit a9bc6ed

Please sign in to comment.