From a8471a1d9a5591202bf4a552aa7d1bf11088fdce Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 31 Dec 2024 09:52:51 -0800 Subject: [PATCH] Fix `getInheritanceInfo` for `ExtractExistentialType`. (#5971) --- source/slang/slang-check-impl.h | 4 +- source/slang/slang-check-inheritance.cpp | 21 +++++----- source/slang/slang-ir-legalize-types.cpp | 4 +- source/slang/slang-syntax.cpp | 20 ++++++---- .../language-feature/interfaces/gh-5900.slang | 40 +++++++++++++++++++ 5 files changed, 66 insertions(+), 23 deletions(-) create mode 100644 tests/language-feature/interfaces/gh-5900.slang diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f464f92985..300596caa9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -789,12 +789,12 @@ struct SharedSemanticsContext : public RefObject InheritanceInfo _getInheritanceInfo( DeclRef declRef, - DeclRefType* correspondingType, + Type* selfType, InheritanceCircularityInfo* circularityInfo); InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo); InheritanceInfo _calcInheritanceInfo( DeclRef declRef, - DeclRefType* correspondingType, + Type* selfType, InheritanceCircularityInfo* circularityInfo); void getDependentGenericParentImpl(DeclRef& genericParent, DeclRef declRef); diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 4b0ec0f557..f774aae383 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -76,7 +76,7 @@ bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType( InheritanceInfo SharedSemanticsContext::_getInheritanceInfo( DeclRef declRef, - DeclRefType* declRefType, + Type* selfType, InheritanceCircularityInfo* circularityInfo) { // Just as with `Type`s, we cache and re-use the inheritance @@ -95,7 +95,7 @@ InheritanceInfo SharedSemanticsContext::_getInheritanceInfo( // m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo(); - auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo); + auto info = _calcInheritanceInfo(declRef, selfType, circularityInfo); m_mapDeclRefToInheritanceInfo[declRef] = info; getSession()->m_typeDictionarySize = Math::Max( @@ -154,7 +154,7 @@ DeclRef SharedSemanticsContext::getDependentGenericParent(DeclRef declRef, - DeclRefType* declRefType, + Type* selfType, InheritanceCircularityInfo* circularityInfo) { // This method is the main engine for computing linearized inheritance @@ -200,14 +200,6 @@ InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo( // FacetList::Builder allFacets; - // It is possible that `declRef` is itself a type declaration, - // in which case `declRefType` will be the coresponding type. - // However, if `declRef` is an `extension` declaration, we - // will extract the type that the extension applies to, so - // that we can have a consistent "self type" to represent - // the type that is at the root of the inheritance list. - // - Type* selfType = declRefType; Facet::Kind selfFacetKind = Facet::Kind::Type; auto astBuilder = _getASTBuilder(); @@ -1043,6 +1035,13 @@ InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo( // return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); } + else if (auto extractExistentialType = as(type)) + { + return _getInheritanceInfo( + extractExistentialType->getThisTypeDeclRef(), + extractExistentialType, + circularityInfo); + } else if (auto conjunctionType = as(type)) { // In this case, we have a type of the form `L & R`, diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 9154277f50..962514b088 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -837,7 +837,7 @@ static LegalVal legalizeDebugVar( case LegalType::Flavor::simple: { auto legalVal = context->builder->emitDebugVar( - type.getSimple(), + tryGetPointedToType(context->builder, type.getSimple()), originalInst->getSource(), originalInst->getLine(), originalInst->getCol(), @@ -887,7 +887,7 @@ static LegalVal legalizeDebugValue( { auto ordinaryVal = legalizeDebugValue( context, - debugVar, + debugVar.getPair()->ordinaryVal, debugValue.getPair()->ordinaryVal, originalInst); return ordinaryVal; diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 5dc6ca695f..6fdd5088a6 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -701,11 +701,18 @@ Type* DeclRefType::create(ASTBuilder* astBuilder, DeclRef declRef) } return declRefType; } - else if (as(declRef.getDecl()) && as(declRef.declRefBase)) + else if (as(declRef.getDecl())) { - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + if (as(declRef.declRefBase)) + { + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); - return astBuilder->getOrCreate(declRef.declRefBase); + return astBuilder->getOrCreate(declRef.declRefBase); + } + else if (auto lookupDeclRef = as(declRef.declRefBase)) + { + return lookupDeclRef->getWitness()->getSub(); + } } else if (auto typedefDecl = as(declRef.getDecl())) { @@ -714,12 +721,9 @@ Type* DeclRefType::create(ASTBuilder* astBuilder, DeclRef declRef) typedefDecl->type.type->substitute(astBuilder, SubstitutionSet(declRef))); return astBuilder->getErrorType(); } - else - { - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); - return astBuilder->getOrCreate(declRef.declRefBase); - } + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + return astBuilder->getOrCreate(declRef.declRefBase); } // diff --git a/tests/language-feature/interfaces/gh-5900.slang b/tests/language-feature/interfaces/gh-5900.slang new file mode 100644 index 0000000000..996347b41a --- /dev/null +++ b/tests/language-feature/interfaces/gh-5900.slang @@ -0,0 +1,40 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type + +interface IFoo +{ + float get(); +} + +extension FooType { + float load() + { + return get(); + } +} + +struct Foo : IFoo +{ + RWStructuredBuffer buffer; + int dummy; + + float get() { return buffer[0]; } +} + +float bugTest(IFoo t) +{ + return t.load(); +} + +//TEST_INPUT: set input = new Foo { ubuffer(data=[1.0 0 0 0], stride=4), 0 } +ConstantBuffer input; + +//TEST_INPUT: set output = out ubuffer(data=[0], stride=4) +uniform RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + // CHECK: 1.0 + output[0] = bugTest(input); +} \ No newline at end of file