Skip to content

Commit

Permalink
Add support for b4_SSE2 batched mode (24)
Browse files Browse the repository at this point in the history
Signed-off-by: Tuomas Tonteri <[email protected]>
  • Loading branch information
johnfea committed Jun 25, 2024
1 parent f4d033c commit 6551e79
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
60 changes: 44 additions & 16 deletions src/liboslexec/llvm_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3854,13 +3854,36 @@ LLVM_Util::mask_as_int8(llvm::Value* mask)
llvm::Value*
LLVM_Util::mask4_as_int8(llvm::Value* mask)
{
OSL_ASSERT(m_supports_llvm_bit_masks_natively);
// combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
// to i8
llvm::Value* zero_mask4
= llvm::ConstantDataVector::getSplat(4, constant_bool(false));
return builder().CreateBitCast(op_combine_4x_vectors(mask, zero_mask4),
type_int8());
if (m_supports_llvm_bit_masks_natively) {
// combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
// to i8
llvm::Value* zero_mask4
= llvm::ConstantDataVector::getSplat(4, constant_bool(false));
return builder().CreateBitCast(op_combine_4x_vectors(mask, zero_mask4),
type_int8());
} else {
// Convert <4 x i1> -> <4 x i32>
llvm::Value* wide_int_mask = builder().CreateSExt(mask,
type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 128bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(wide_int_mask,
w4_float_type);

llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int32 = builder().CreateCall(func, toArrayRef(args));

llvm::Value* i8 = builder().CreateIntCast(int32, type_int8(), true);

return i8;
}
}


Expand Down Expand Up @@ -4013,17 +4036,22 @@ LLVM_Util::op_1st_active_lane_of(llvm::Value* mask)
intMaskType = type_int8();
break;
case 4: {
// We can just reinterpret cast a 4 bit mask to a 8 bit integer
// and all types are happy
intMaskType = type_int8();

// extended_int_vector_type = (llvm::Type *) llvm::VectorType::get(llvm::Type::getInt32Ty (*m_llvm_context), m_vector_width);
// llvm::Value * wide_int_mask = builder().CreateSExt(mask, extended_int_vector_type);
//
// int_reinterpret_cast_vector_type = (llvm::Type *) llvm::Type::getInt128Ty (*m_llvm_context);
// zeroConstant = constant128(0);
//
// llvm::Value * mask_as_int = builder().CreateBitCast (wide_int_mask, int_reinterpret_cast_vector_type);
llvm::Value* mask_as_int = mask4_as_int8(mask);

// Count trailing zeros, least significant
llvm::Type* types[] = { intMaskType };
llvm::Function* func_cttz
= llvm::Intrinsic::getDeclaration(module(), llvm::Intrinsic::cttz,
toArrayRef(types));

llvm::Value* args[2] = { mask_as_int, constant_bool(true) };

llvm::Value* firstNonZeroIndex = builder().CreateCall(func_cttz,
toArrayRef(args));
return firstNonZeroIndex;

break;
}
default: OSL_ASSERT(0 && "unsupported native bit mask width");
Expand Down
11 changes: 8 additions & 3 deletions testsuite/example-batched-deformer/oslbatcheddeformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,13 @@ main(int argc, char* argv[])
batch_width = 16;
} else if (shadsys->configure_batch_execution_at(8)) {
batch_width = 8;
} else if (shadsys->configure_batch_execution_at(4)) {
batch_width = 4;
} else {
std::cout
<< "Error: Hardware doesn't support 8 or 16 wide SIMD or the OSL has not been configured and built with a proper USE_BATCHED."
<< "Error: Hardware doesn't support 4, 8 or 16 wide SIMD or the OSL has not been configured and built with a proper USE_BATCHED."
<< std::endl;
std::cout << "Error: e.g.: USE_BATCHED=b8_AVX2,b8_AVX512,b16_AVX512"
std::cout << "Error: e.g.: USE_BATCHED=b4_SSE2,b8_AVX2,b8_AVX512,b16_AVX512"
<< std::endl;
return -1;
}
Expand Down Expand Up @@ -437,8 +439,11 @@ main(int argc, char* argv[])

if (batch_width == 16) {
batched_shadepoints(std::integral_constant<int, 16> {});
} else {
}
else if (batch_width == 8) {
batched_shadepoints(std::integral_constant<int, 8> {});
} else {
batched_shadepoints(std::integral_constant<int, 4> {});
}

// Print some results to prove that we generated an expected Pout.
Expand Down

0 comments on commit 6551e79

Please sign in to comment.