Skip to content

Commit

Permalink
Implement Packed Dot Product intrinsics (#6068)
Browse files Browse the repository at this point in the history
* implement dot acc intrinsics

* fix sm version

* fix test

* improve comment

---------

Co-authored-by: Yong He <[email protected]>
  • Loading branch information
fairywreath and csyonghe authored Jan 16, 2025
1 parent 9167e0d commit ad7d13a
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 6 deletions.
72 changes: 66 additions & 6 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -16760,20 +16760,80 @@ static const uint HIT_KIND_TRIANGLE_BACK_FACE = 255;

//
// Shader Model 6.4
// @public:
//

/// Treats `left` and `right` as 4-component vectors of `UInt8` and computes `dot(left, right) + acc`
/// Treats `x` and `y` as 4-component vectors of `UInt8` and computes `dot(x, y) + acc`
/// @category math
uint dot4add_u8packed(uint left, uint right, uint acc);
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)]
uint dot4add_u8packed(uint x, uint y, uint acc)
{
__target_switch
{
case hlsl: __intrinsic_asm "dot4add_u8packed";
case wgsl: __intrinsic_asm "(dot4U8Packed($0, $1) + $2)";
case spirv:
// OpUDotAccSat cannot be used as there should not be any saturation.
return spirv_asm
{
OpCapability DotProduct;
OpCapability DotProductInput4x8BitPacked;
OpExtension "SPV_KHR_integer_dot_product";
%dotResult = OpUDot $$uint $x $y 0;
result:$$uint = OpIAdd %dotResult $acc;
};
default:
uint4 vecX = unpack_u8u32(uint8_t4_packed(x));
uint4 vecY = unpack_u8u32(uint8_t4_packed(y));
return dot(vecX, vecY) + acc;
}
}

/// Treats `left` and `right` as 4-component vectors of `Int8` and computes `dot(left, right) + acc`
/// Treats `x` and `y` as 4-component vectors of `int8` and computes `dot(x, y) + acc`
/// @category math
int dot4add_i8packed(uint left, uint right, int acc);
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)]
int dot4add_i8packed(uint x, uint y, int acc)
{
__target_switch
{
case hlsl: __intrinsic_asm "dot4add_i8packed";
case wgsl: __intrinsic_asm "(dot4I8Packed($0, $1) + $2)";
case spirv:
// OpSDottAccSat cannot be used as there should not be any saturation.
return spirv_asm
{
OpCapability DotProduct;
OpCapability DotProductInput4x8BitPacked;
OpExtension "SPV_KHR_integer_dot_product";
%dotResult = OpSDot $$int $x $y 0;
result:$$int = OpIAdd %dotResult $acc;
};
default:
int4 vecX = unpack_s8s32(int8_t4_packed(x));
int4 vecY = unpack_s8s32(int8_t4_packed(y));
return dot(vecX, vecY) + acc;
}
}

/// Computes `dot(left, right) + acc`.
/// Computes `dot(x, y) + acc`.
/// May not produce infinities or NaNs for intermediate results that overflow the range of `half`
/// @category math
float dot2add(float2 left, float2 right, float acc);
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)]
float dot2add(half2 x, half2 y, float acc)
{
__target_switch
{
case hlsl: __intrinsic_asm "dot2add";
default:
return float(dot(x, y)) + acc;
}
}

//
// Shader Model 6.5
Expand Down
55 changes: 55 additions & 0 deletions tests/hlsl-intrinsic/dot-accumulate.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
// Does not run on DX11 as SM 6.4 is required.
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx11
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile cs_6_4 -use-dxil -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-metal -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-wgsl -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -g0 -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
uint outputIndex = 0;

//
// dot4add_u8packed()
// [4 3 2 1] dot [1 2 4 2] + 5
// (4 * 1) + (3 * 2) + (2 * 4) + (1 * 2) + 5 = 25
//
uint unsignedX = 0x01020304U;
uint unsignedY = 0x02040201U;
uint unsignedAcc = 5U;
uint unsignedResult = dot4add_u8packed(unsignedX, unsignedY, unsignedAcc);
outputBuffer[outputIndex++] = unsignedResult;

//
// dot4add_i8packed()
// [6 2 3 -1] dot [-2 -6 2 6] - 100
// (6 * -2) + (2 * -6) + (3 * 2) + (-1 * 6) - 100 = -124
//
int signedX = 0xFF030206;
int signedY = 0x0602FAFE;
int signedAcc = -100;
int signedResult = dot4add_i8packed(signedX, signedY, signedAcc);
outputBuffer[outputIndex++] = signedResult;

//
// dot2add()
// [10.8 -3.3] dot [1.4 -20.3] - 2.11
// (10.8 * 1.4) + (-3.3 * -20.3) - 2.0 = 80.11
//
half2 half2X = half2(half(10.8), half(-3.3));
half2 half2Y = half2(half(1.4), half(-20.3));

// `half2Acc` is assigned -2.0 here.
// Thread index is used so that `half2Acc` will not be implicitly emitted as literal `-2.0` which
// may be treated as a double by DXC and cause it to fail to compile because no overload exists for `dot2add` that
// accepts double.
float half2Acc = float(dispatchThreadID.x + 1) * -2.0f;
float half2Result = dot2add(half2X, half2Y, half2Acc);
outputBuffer[outputIndex++] = int(half2Result);
}
4 changes: 4 additions & 0 deletions tests/hlsl-intrinsic/dot-accumulate.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type: int32_t
25
-124
80

0 comments on commit ad7d13a

Please sign in to comment.