Skip to content

Commit

Permalink
Fixup missing support for 1D scf.parallel ops in -air-to-std (Xilinx#435
Browse files Browse the repository at this point in the history
)

* Fixup missing support for 1D scf.parallel ops

* Add test

* EoF
  • Loading branch information
erwei-xilinx authored Feb 19, 2024
1 parent d748a07 commit 3dc53e8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
11 changes: 9 additions & 2 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,15 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder,
} else {
opers.push_back(builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), launch.getInductionVars()[0]));
opers.push_back(builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), launch.getInductionVars()[1]));
if (launch.getNumLoops() == 2)
opers.push_back(builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), launch.getInductionVars()[1]));
else if (launch.getNumLoops() == 1)
opers.push_back(builder.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 0)));
else
assert(false && "lowering of air.launch with more than 2 dimensions is "
"currently unsupported");
}

opers.push_back(thisOp.getMemref());
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,49 @@ func.func @par_with_for_put_get(%arg0: memref<32x16xi32>, %arg1: memref<32x16xi3
}
return
}

// CHECK-LABEL: func.func @one_d_scf_parallel
// CHECK: affine.for
// CHECK: airrt.dma_memcpy_nd(%{{.*}}, %{{.*}}, %{{.*}}, %arg0[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}]) : (i32, i64, i64, memref<128xf32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
// CHECK: } {affine_opt_label = "tiling"}

#map = affine_map<()[s0] -> (s0 * 64)>
air.channel @channel_6 [1, 1]
func.func @one_d_scf_parallel(%arg0: memref<128xf32>, %arg1: memref<128xf32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg2) in (%arg3=%c2) args(%arg4=%arg0) : memref<128xf32> attributes {id = 1 : i32} {
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%async_token, %results = air.execute -> (index) {
%3 = affine.apply #map()[%arg2]
air.execute_terminator %3 : index
}
%1 = air.channel.put async [%async_token] @channel_6[] (%arg4[%results] [%c64] [%c1]) {id = 1 : i32} : (memref<128xf32>)
%2 = air.segment @segment_0 async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 1 : i64, y_loc = 2 : i64, y_size = 4 : i64} {
%c1_0 = arith.constant 1 : index
%c2_1 = arith.constant 2 : index
%3 = air.wait_all async
%async_token_2, %results_3 = air.execute -> (memref<64xf32, 1>) {
%alloc = memref.alloc() : memref<64xf32, 1>
air.execute_terminator %alloc : memref<64xf32, 1>
}
%4 = air.channel.get async [%3, %async_token_2] @channel_6[] (%results_3[] [] []) {id = 3 : i32} : (memref<64xf32, 1>)
%5 = air.herd @herd_0 async [%4] tile (%arg5, %arg6) in (%arg7=%c1_0, %arg8=%c2_1) attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 2 : i64} {
%async_token_5, %results_6 = air.execute -> (memref<32xf32, 2>) {
%alloc = memref.alloc() : memref<32xf32, 2>
air.execute_terminator %alloc : memref<32xf32, 2>
}
%async_token_7 = air.execute [%async_token_5] {
memref.dealloc %results_6 : memref<32xf32, 2>
}
air.herd_terminator
}
%async_token_4 = air.execute [%4] {
memref.dealloc %results_3 : memref<64xf32, 1>
}
air.segment_terminator
}
air.launch_terminator
}
return
}

0 comments on commit 3dc53e8

Please sign in to comment.