Skip to content

Commit

Permalink
AIRSegmentLoopFusion: Fixups on affine::DelinearizeIndexOp and rank…
Browse files Browse the repository at this point in the history
… reduction (Xilinx#752)

* Support affine::DelinearizeIndexOp in scf::ForOp's iv chain

* Post-shrinkage subview mutation now supports rank reduction

* Unit test

* Remove #include <iostream>
  • Loading branch information
erwei-xilinx authored Oct 24, 2024
1 parent 0541d96 commit e7710d5
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 15 deletions.
7 changes: 4 additions & 3 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4559,9 +4559,10 @@ struct ShrinkMemrefSizesByAccessPattern
auto shrunkMemrefType =
MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace);
MemRefType inferredSubViewOutputTy =
llvm::cast<MemRefType>(memref::SubViewOp::inferResultType(
shrunkMemrefType, subViewOp.getStaticOffsets(),
subViewOp.getStaticSizes(), subViewOp.getStaticStrides()));
llvm::cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
subViewOp.getType().getShape(), shrunkMemrefType,
subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
subViewOp.getStaticStrides()));
// Case 1: static size mismatches the shrunk shape.
for (unsigned i = 0; i < static_sizes.size(); i++) {
if (static_sizes[i] < 0) {
Expand Down
40 changes: 28 additions & 12 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,28 +1242,44 @@ static void updateAccessPatternByScfForNest(
&pattern,
SmallVector<Value> indices, OpBuilder builder) {
auto loc = builder.getUnknownLoc();
auto updateWrapAndStride = [&](Value index, int i) {
if (auto scfForOp = scf::getForInductionVarOwner(index)) {
std::get<1>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, *air::getStaticScfForTripCountAsInt(scfForOp));
std::get<2>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, (*getConstantIntValue(scfForOp.getStep())) *
(*getConstantIntValue(std::get<2>(pattern)[i])));

scfForOp.getStep();
auto updateWrapAndStride = [&](int stepSize, int tripCount, int i) {
std::get<1>(pattern)[i] =
builder.create<arith::ConstantIndexOp>(loc, tripCount);
std::get<2>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, stepSize * (*getConstantIntValue(std::get<2>(pattern)[i])));
};
// Infer data access pattern's sizes from parent scf.for loop and any affine
// op applied on the induction variable
auto inferDataAccessSizes = [](scf::ForOp scfForOp, air::ExecuteOp execOp,
Value index) {
int scfForTripCount = *air::getStaticScfForTripCountAsInt(scfForOp);
// If scf.for's iv applies affine::DelinerizeIndexOp
if (auto delinearizeOp =
dyn_cast<affine::AffineDelinearizeIndexOp>(execOp.getChildOp())) {
int resIdx =
llvm::find(execOp.getResults(), index) - execOp.getResults().begin();
scfForTripCount = *getConstantIntValue(delinearizeOp.getBasis()[resIdx]);
}
return scfForTripCount;
};
int dim = -1;
for (auto index : indices) {
dim++;
if (getConstantIntValue(index))
continue;
updateWrapAndStride(index, dim);
if (auto scfForOp = scf::getForInductionVarOwner(index))
updateWrapAndStride(*getConstantIntValue(scfForOp.getStep()),
*air::getStaticScfForTripCountAsInt(scfForOp), dim);
if (!index.getDefiningOp())
continue;
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp()))
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp())) {
for (auto oper : execOp.getChildOp()->getOperands())
updateWrapAndStride(oper, dim);
if (auto scfForOp = scf::getForInductionVarOwner(oper)) {
int scfForTripCount = inferDataAccessSizes(scfForOp, execOp, index);
updateWrapAndStride(*getConstantIntValue(scfForOp.getStep()),
scfForTripCount, dim);
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,63 @@ func.func @func10(%arg0: memref<8x512xi32>, %arg1: memref<256x512xi32>, %arg2: m
}
return
}

// Affine::DelinearizeIndexOp support; rank-reduced memref::SubViewOp.

// CHECK-LABEL: func.func @func11
// CHECK: air.herd
// CHECK: %[[SUBVIEW0:.*]] = memref.subview{{.*}} : memref<16x16x4x4xf32, 1 : i32> to memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>
// CHECK: %[[SUBVIEW1:.*]] = memref.subview{{.*}} : memref<1x16x4xf32, 2 : i32> to memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>
// CHECK: %[[SUBVIEW2:.*]] = memref.subview{{.*}} : memref<1x1x16x16x4x4xbf16, 2 : i32> to memref<1x1x4x4xbf16, strided<[4096, 4096, 4, 1], offset: ?>, 2 : i32>
// CHECK: linalg.generic{{.*}} ins(%[[SUBVIEW0]], %[[SUBVIEW1]] {{.*}}outs(%[[SUBVIEW2]]

#map17 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
func.func @func11(%arg0: memref<512x512xbf16>, %arg1: memref<512x16384xbf16>, %arg2: memref<512xf32>, %arg3: memref<512x16384xbf16>) {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%0 = air.launch async (%arg4, %arg5) in (%arg6=%c4, %arg7=%c128) attributes {id = 1 : i32} {
%1 = air.segment @matmul_elementwise_bf16_dispatch_0_matmul_512x16384x512_bf16xbf16xf32_0 async attributes {id = 2 : i32} {
%c2 = arith.constant 2 : index
%async_token, %results = air.execute -> (memref<2x2x16x16x4x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<2x2x16x16x4x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<2x2x16x16x4x4xbf16, 2 : i32>
}
%async_token_0, %results_1 = air.execute -> (memref<1x16x4xf32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x16x4xf32, 2 : i32>
air.execute_terminator %alloc : memref<1x16x4xf32, 2 : i32>
}
%async_token_2, %results_3 = air.execute -> (memref<16x16x4x4xf32, 1 : i32>) {
%alloc = memref.alloc() : memref<16x16x4x4xf32, 1 : i32>
air.execute_terminator %alloc : memref<16x16x4x4xf32, 1 : i32>
}
%2 = air.herd @herd_0 async tile (%arg8, %arg9) in (%arg10=%c2, %arg11=%c2) args(%arg12=%results_3, %arg13=%results_1, %arg14=%results) : memref<16x16x4x4xf32, 1 : i32>, memref<1x16x4xf32, 2 : i32>, memref<2x2x16x16x4x4xbf16, 2 : i32> {
%c16 = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c256 = arith.constant 256 : index
%3 = air.wait_all async
%4 = scf.for %arg15 = %c0 to %c256 step %c1 iter_args(%arg16 = %3) -> (!air.async.token) {
%async_token_4, %results_5:2 = air.execute [%arg16] -> (index, index) {
%6:2 = affine.delinearize_index %arg15 into (%c16, %c16) : index, index
air.execute_terminator %6#0, %6#1 : index, index
}
%subview = memref.subview %arg12[%results_5#0, %results_5#1, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<16x16x4x4xf32, 1 : i32> to memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>
%subview_6 = memref.subview %arg13[0, %results_5#1, 0] [1, 1, 4] [1, 1, 1] : memref<1x16x4xf32, 2 : i32> to memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>
%subview_7 = memref.subview %arg14[%arg8, %arg9, %results_5#0, %results_5#1, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x16x16x4x4xbf16, 2 : i32> to memref<1x1x4x4xbf16, strided<[256, 16, 4, 1], offset: ?>, 2 : i32>
%async_token_8 = air.execute [%arg16] {
linalg.generic {indexing_maps = [#map17, #map18, #map17], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview_6 : memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>, memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>) outs(%subview_7 : memref<1x1x4x4xbf16, strided<[256, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: f32, %in_9: f32, %out: bf16):
%6 = arith.addf %in, %in_9 : f32
%7 = arith.truncf %6 : f32 to bf16
linalg.yield %7 : bf16
}
}
%5 = air.wait_all async [%async_token_4, %async_token_8]
scf.yield %5 : !air.async.token
}
}
}
}
return
}

0 comments on commit e7710d5

Please sign in to comment.