Skip to content

Commit

Permalink
Create DirectDeclRef when creating Decl to prevent invalid dedup. (#5945
Browse files Browse the repository at this point in the history
)

* Create DirectDeclRef when creating Decl to prevent invalid dedup.

* Fix test.

* fix

* update slang-rhi
  • Loading branch information
csyonghe authored Jan 3, 2025
1 parent 5df3a74 commit 114c976
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 13 deletions.
10 changes: 1 addition & 9 deletions source/slang/slang-ast-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,7 @@ void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
}
DeclRefBase* Decl::getDefaultDeclRef()
{
if (auto astBuilder = getCurrentASTBuilder())
{
const Index currentEpoch = astBuilder->getEpoch();
if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef)
{
m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this);
m_defaultDeclRefEpoch = currentEpoch;
}
}
SLANG_ASSERT(m_defaultDeclRef);
return m_defaultDeclRef;
}

Expand Down
1 change: 0 additions & 1 deletion source/slang/slang-ast-base.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,6 @@ class Decl : public DeclBase

private:
SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1;
};

class Expr : public SyntaxNode
Expand Down
5 changes: 4 additions & 1 deletion source/slang/slang-ast-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,10 @@ class ASTBuilder : public RefObject
auto val = (Val*)(node);
val->m_resolvedValEpoch = getEpoch();
}

else if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType)))
{
((Decl*)node)->m_defaultDeclRef = getOrCreate<DirectDeclRef>((Decl*)node);
}
return node;
}

Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4598,7 +4598,7 @@ void Module::_processFindDeclsExportSymbolsRec(Decl* decl)
if (_canExportDeclSymbol(decl->astNodeType))
{
// It's a reference to a declaration in another module, so first get the symbol name.
String mangledName = getMangledName(getASTBuilder(), decl);
String mangledName = getMangledName(getCurrentASTBuilder(), decl);

Index index = Index(m_mangledExportPool.add(mangledName));

Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang.natvis
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@
<Type Name="Slang::Val" Inheritable="true">
<DisplayString Optional="true" Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType#{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
<DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
<DisplayString Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">DirectRef {*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
<DisplayString Optional="true" Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">DirectRef#{_debugUID} {*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
<DisplayString Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">DirectRef {*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
<DisplayString Optional="true">{astNodeType,en} #{_debugUID}</DisplayString>
<DisplayString>{astNodeType,en}</DisplayString>

Expand Down
92 changes: 92 additions & 0 deletions tools/slang-unit-test/unit-test-module-ptr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// unit-test-module-ptr.cpp

#include "core/slang-memory-file-system.h"
#include "slang-com-ptr.h"
#include "slang.h"
#include "unit-test/slang-unit-test.h"

#include <stdio.h>
#include <stdlib.h>

using namespace Slang;

SLANG_UNIT_TEST(modulePtr)
{
const char* testModuleSource = R"(
module test_module;
public void atomicFunc(__ref Atomic<int> ptr) {
ptr.add(1);
}
)";

const char* testSource = R"(
import "test_module";
RWStructuredBuffer<Atomic<int>> input0;
[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 workGroup : SV_GroupID)
{
atomicFunc(input0[0]);
}
)";
ComPtr<ISlangMutableFileSystem> memoryFileSystem =
ComPtr<ISlangMutableFileSystem>(new Slang::MemoryFileSystem());

ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_SPIRV;
targetDesc.profile = globalSession->findProfile("spirv_1_5");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
sessionDesc.compilerOptionEntryCount = 0;
sessionDesc.fileSystem = memoryFileSystem;

// Precompile test_module to file.
{
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"test_module",
"test_module.slang",
testModuleSource,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

ComPtr<slang::IBlob> moduleBlob;
module->serialize(moduleBlob.writeRef());
memoryFileSystem->saveFile(
"test_module.slang-module",
moduleBlob->getBufferPointer(),
moduleBlob->getBufferSize());
}

// compile test.
{
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"test",
"test.slang",
testSource,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

ComPtr<slang::IComponentType> linkedProgram;
module->link(linkedProgram.writeRef());

ComPtr<slang::IBlob> code;

linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef());

SLANG_CHECK(code->getBufferSize() > 0);
}
}

0 comments on commit 114c976

Please sign in to comment.