From c947bf3b7305c8630799380e94042bd9e9737e88 Mon Sep 17 00:00:00 2001 From: Tuomas Tonteri Date: Tue, 25 Jun 2024 11:59:00 +0300 Subject: [PATCH] Add support for b4_SSE2 batched mode (23) Signed-off-by: Tuomas Tonteri --- src/liboslexec/llvm_util.cpp | 60 ++++++++++++++----- .../oslbatcheddeformer.cpp | 10 +++- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/src/liboslexec/llvm_util.cpp b/src/liboslexec/llvm_util.cpp index 807f8ab3b..044f7a270 100644 --- a/src/liboslexec/llvm_util.cpp +++ b/src/liboslexec/llvm_util.cpp @@ -3854,13 +3854,36 @@ LLVM_Util::mask_as_int8(llvm::Value* mask) llvm::Value* LLVM_Util::mask4_as_int8(llvm::Value* mask) { - OSL_ASSERT(m_supports_llvm_bit_masks_natively); - // combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it - // to i8 - llvm::Value* zero_mask4 - = llvm::ConstantDataVector::getSplat(4, constant_bool(false)); - return builder().CreateBitCast(op_combine_4x_vectors(mask, zero_mask4), - type_int8()); + if (m_supports_llvm_bit_masks_natively) { + // combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it + // to i8 + llvm::Value* zero_mask4 + = llvm::ConstantDataVector::getSplat(4, constant_bool(false)); + return builder().CreateBitCast(op_combine_4x_vectors(mask, zero_mask4), + type_int8()); + } else { + // Convert <4 x i1> -> <4 x i32> + llvm::Value* wide_int_mask = builder().CreateSExt(mask, + type_wide_int()); + + // Now we will use the horizontal sign extraction intrinsic + // to build a 32 bit mask value. However the only 128bit + // version works on floats, so we will cast from int32 to + // float beforehand + llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4); + llvm::Value* w4_float_mask = builder().CreateBitCast(wide_int_mask, + w4_float_type); + + llvm::Function* func = llvm::Intrinsic::getDeclaration( + module(), llvm::Intrinsic::x86_sse_movmsk_ps); + + llvm::Value* args[1] = { w4_float_mask }; + llvm::Value* int32 = builder().CreateCall(func, toArrayRef(args)); + + llvm::Value* i8 = builder().CreateIntCast(int32, type_int8(), true); + + return i8; + } } @@ -4013,17 +4036,22 @@ LLVM_Util::op_1st_active_lane_of(llvm::Value* mask) intMaskType = type_int8(); break; case 4: { - // We can just reinterpret cast a 4 bit mask to a 8 bit integer - // and all types are happy intMaskType = type_int8(); - // extended_int_vector_type = (llvm::Type *) llvm::VectorType::get(llvm::Type::getInt32Ty (*m_llvm_context), m_vector_width); - // llvm::Value * wide_int_mask = builder().CreateSExt(mask, extended_int_vector_type); - // - // int_reinterpret_cast_vector_type = (llvm::Type *) llvm::Type::getInt128Ty (*m_llvm_context); - // zeroConstant = constant128(0); - // - // llvm::Value * mask_as_int = builder().CreateBitCast (wide_int_mask, int_reinterpret_cast_vector_type); + llvm::Value* mask_as_int = mask4_as_int8(mask); + + // Count trailing zeros, least significant + llvm::Type* types[] = { intMaskType }; + llvm::Function* func_cttz + = llvm::Intrinsic::getDeclaration(module(), llvm::Intrinsic::cttz, + toArrayRef(types)); + + llvm::Value* args[2] = { mask_as_int, constant_bool(true) }; + + llvm::Value* firstNonZeroIndex = builder().CreateCall(func_cttz, + toArrayRef(args)); + return firstNonZeroIndex; + break; } default: OSL_ASSERT(0 && "unsupported native bit mask width"); diff --git a/testsuite/example-batched-deformer/oslbatcheddeformer.cpp b/testsuite/example-batched-deformer/oslbatcheddeformer.cpp index fe620fefa..aaaccd710 100644 --- a/testsuite/example-batched-deformer/oslbatcheddeformer.cpp +++ b/testsuite/example-batched-deformer/oslbatcheddeformer.cpp @@ -237,11 +237,13 @@ main(int argc, char* argv[]) batch_width = 16; } else if (shadsys->configure_batch_execution_at(8)) { batch_width = 8; + } else if (shadsys->configure_batch_execution_at(4)) { + batch_width = 4; } else { std::cout - << "Error: Hardware doesn't support 8 or 16 wide SIMD or the OSL has not been configured and built with a proper USE_BATCHED." + << "Error: Hardware doesn't support 4, 8 or 16 wide SIMD or the OSL has not been configured and built with a proper USE_BATCHED." << std::endl; - std::cout << "Error: e.g.: USE_BATCHED=b8_AVX2,b8_AVX512,b16_AVX512" + std::cout << "Error: e.g.: USE_BATCHED=b4_SSE2,b8_AVX2,b8_AVX512,b16_AVX512" << std::endl; return -1; } @@ -437,8 +439,10 @@ main(int argc, char* argv[]) if (batch_width == 16) { batched_shadepoints(std::integral_constant {}); - } else { + else if (batch_width == 8) { batched_shadepoints(std::integral_constant {}); + } else { + batched_shadepoints(std::integral_constant {}); } // Print some results to prove that we generated an expected Pout.