Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argument buffer tier2 layout computation. #6101

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions source/slang/slang-type-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ struct MetalLayoutRulesImpl : public CPULayoutRulesImpl
auto alignedElementCount = 1 << Math::Log2Ceil((uint32_t)elementCount);

// Metal aligns vectors to 2/4 element boundaries.
size_t size = elementSize * elementCount;
size_t size = alignedElementCount * elementSize;
size_t alignment = alignedElementCount * elementSize;

SimpleLayoutInfo vectorInfo;
Expand Down Expand Up @@ -1147,6 +1147,14 @@ struct MetalLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
LayoutRulesImpl* getStructuredBufferRules(CompilerOptionSet& compilerOptions) override;
};

struct MetalArgumentBufferTier2LayoutRulesFamilyImpl : MetalLayoutRulesFamilyImpl
{
virtual LayoutRulesImpl* getConstantBufferRules(
CompilerOptionSet& compilerOptions,
Type* containerType) override;
virtual LayoutRulesImpl* getParameterBlockRules(CompilerOptionSet& compilerOptions) override;
};

struct WGSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
{
virtual LayoutRulesImpl* getAnyValueRules() override;
Expand Down Expand Up @@ -1175,6 +1183,7 @@ HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl;
CPULayoutRulesFamilyImpl kCPULayoutRulesFamilyImpl;
CUDALayoutRulesFamilyImpl kCUDALayoutRulesFamilyImpl;
MetalLayoutRulesFamilyImpl kMetalLayoutRulesFamilyImpl;
MetalArgumentBufferTier2LayoutRulesFamilyImpl kMetalArgumentBufferTier2LayoutRulesFamilyImpl;
WGSLLayoutRulesFamilyImpl kWGSLLayoutRulesFamilyImpl;

// CPU case
Expand Down Expand Up @@ -1969,8 +1978,44 @@ struct MetalArgumentBufferElementLayoutRulesImpl : ObjectLayoutRulesImpl, Defaul
}
};

struct MetalTier2ObjectLayoutRulesImpl : ObjectLayoutRulesImpl
{
virtual ObjectLayoutInfo GetObjectLayout(ShaderParameterKind kind, const Options& /* options */)
override
{
switch (kind)
{
case ShaderParameterKind::ConstantBuffer:
case ShaderParameterKind::ParameterBlock:
case ShaderParameterKind::StructuredBuffer:
case ShaderParameterKind::MutableStructuredBuffer:
case ShaderParameterKind::RawBuffer:
case ShaderParameterKind::Buffer:
case ShaderParameterKind::MutableRawBuffer:
case ShaderParameterKind::MutableBuffer:
case ShaderParameterKind::ShaderStorageBuffer:
case ShaderParameterKind::AccelerationStructure:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8);
case ShaderParameterKind::AppendConsumeStructuredBuffer:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 16, 8);
case ShaderParameterKind::MutableTexture:
case ShaderParameterKind::TextureUniformBuffer:
case ShaderParameterKind::Texture:
case ShaderParameterKind::SamplerState:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8);
case ShaderParameterKind::TextureSampler:
case ShaderParameterKind::MutableTextureSampler:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 16, 8);
default:
SLANG_UNEXPECTED("unhandled shader parameter kind");
UNREACHABLE_RETURN(SimpleLayoutInfo());
}
}
};

static MetalObjectLayoutRulesImpl kMetalObjectLayoutRulesImpl;
static MetalArgumentBufferElementLayoutRulesImpl kMetalArgumentBufferElementLayoutRulesImpl;
static MetalTier2ObjectLayoutRulesImpl kMetalTier2ObjectLayoutRulesImpl;
static MetalLayoutRulesImpl kMetalLayoutRulesImpl;

LayoutRulesImpl kMetalAnyValueLayoutRulesImpl_ = {
Expand All @@ -1991,6 +2036,18 @@ LayoutRulesImpl kMetalParameterBlockLayoutRulesImpl_ = {
&kMetalArgumentBufferElementLayoutRulesImpl,
};

LayoutRulesImpl kMetalTier2ConstantBufferLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
&kMetalTier2ObjectLayoutRulesImpl,
};

LayoutRulesImpl kMetalTier2ParameterBlockLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
&kMetalTier2ObjectLayoutRulesImpl,
};

LayoutRulesImpl kMetalStructuredBufferLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
Expand Down Expand Up @@ -2079,6 +2136,20 @@ LayoutRulesImpl* MetalLayoutRulesFamilyImpl::getHitAttributesParameterRules()
return nullptr;
}

LayoutRulesImpl* MetalArgumentBufferTier2LayoutRulesFamilyImpl::getConstantBufferRules(
CompilerOptionSet&,
Type*)
{
return &kMetalTier2ConstantBufferLayoutRulesImpl_;
}

LayoutRulesImpl* MetalArgumentBufferTier2LayoutRulesFamilyImpl::getParameterBlockRules(
CompilerOptionSet&)
{
return &kMetalTier2ParameterBlockLayoutRulesImpl_;
}


// WGSL Family

LayoutRulesImpl kWGSLConstantBufferLayoutRulesImpl_ = {
Expand Down Expand Up @@ -2229,7 +2300,7 @@ TypeLayoutContext getInitialLayoutContextForTarget(
rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq);
break;
case slang::LayoutRules::MetalArgumentBufferTier2:
rulesFamily = &kCPULayoutRulesFamilyImpl;
rulesFamily = &kMetalArgumentBufferTier2LayoutRulesFamilyImpl;
break;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// unit-test-argument-buffer-tier-2-reflection.cpp

#include "../../source/core/slang-io.h"
#include "../../source/core/slang-process.h"
#include "slang-com-ptr.h"
#include "slang.h"
#include "unit-test/slang-unit-test.h"

#include <stdio.h>
#include <stdlib.h>

using namespace Slang;

// Test metal argument buffer tier2 layout rules.

SLANG_UNIT_TEST(metalArgumentBufferTier2Reflection)
{
const char* userSourceBody = R"(
struct A
{
float3 one;
float3 two;
float three;
}

struct Args{
ParameterBlock<A> a;
}
ParameterBlock<Args> argument_buffer;
RWStructuredBuffer<float> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
outputBuffer[0] = argument_buffer.a.two.x;
}
)";

auto moduleName = "moduleG" + String(Process::getId());
String userSource = "import " + moduleName + ";\n" + userSourceBody;
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_SPIRV;
targetDesc.profile = globalSession->findProfile("spirv_1_5");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"m",
"m.slang",
userSourceBody,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

auto layout = module->getLayout();

auto type = layout->findTypeByName("A");
auto typeLayout = layout->getTypeLayout(type, slang::LayoutRules::MetalArgumentBufferTier2);
SLANG_CHECK(typeLayout->getFieldByIndex(0)->getOffset() == 0);
SLANG_CHECK(typeLayout->getFieldByIndex(0)->getTypeLayout()->getSize() == 16);
SLANG_CHECK(typeLayout->getFieldByIndex(1)->getOffset() == 16);
SLANG_CHECK(typeLayout->getFieldByIndex(1)->getTypeLayout()->getSize() == 16);
SLANG_CHECK(typeLayout->getFieldByIndex(2)->getOffset() == 32);
SLANG_CHECK(typeLayout->getFieldByIndex(2)->getTypeLayout()->getSize() == 4);
}
Loading