Skip to content

Commit

Permalink
add bf16 test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Oct 15, 2024
1 parent aecf9ab commit 4df3eb8
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions test/Transforms/BF16ToGPU/EltwiseAdd.bf16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,43 @@ module @eltwise_add attributes {gpu.container_module} {
}
func.func private @printMemrefBF16(memref<*xbf16>)
}


module @eltwise_add_usm attributes {gpu.container_module} {
memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01>
func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
%c20 = arith.constant 20 : index
%c10 = arith.constant 10 : index
%c1 = arith.constant 1 : index
%memref_1 = gpu.alloc host_shared () : memref<10x20xbf16>
gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%arg0 : memref<10x20xbf16>, %arg1 : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>)
%alloc = memref.alloc() : memref<10x20xbf16>
memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16>
gpu.dealloc %memref_1 : memref<10x20xbf16>
return %alloc : memref<10x20xbf16>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y
%cst = arith.constant 0.5 : bf16
%0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
%1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>
%2 = arith.addf %0, %1 : bf16
%3 = arith.addf %2, %cst : bf16
memref.store %3, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
gpu.return
}
}
func.func @main() {
%0 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
%1 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
%2 = call @test(%0, %1) : (memref<10x20xbf16>, memref<10x20xbf16>) -> memref<10x20xbf16>
%cast = memref.cast %2 : memref<10x20xbf16> to memref<*xbf16>
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
// CHECK-COUNT-200: 1.5
call @printMemrefBF16(%cast) : (memref<*xbf16>) -> ()
return
}
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
}

0 comments on commit 4df3eb8

Please sign in to comment.