Skip to content

Commit

Permalink
Fix getInheritanceInfo for ExtractExistentialType. (#5971)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Dec 31, 2024
1 parent b7eb585 commit a8471a1
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 23 deletions.
4 changes: 2 additions & 2 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,12 @@ struct SharedSemanticsContext : public RefObject

InheritanceInfo _getInheritanceInfo(
DeclRef<Decl> declRef,
DeclRefType* correspondingType,
Type* selfType,
InheritanceCircularityInfo* circularityInfo);
InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo);
InheritanceInfo _calcInheritanceInfo(
DeclRef<Decl> declRef,
DeclRefType* correspondingType,
Type* selfType,
InheritanceCircularityInfo* circularityInfo);

void getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef);
Expand Down
21 changes: 10 additions & 11 deletions source/slang/slang-check-inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType(

InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(
DeclRef<Decl> declRef,
DeclRefType* declRefType,
Type* selfType,
InheritanceCircularityInfo* circularityInfo)
{
// Just as with `Type`s, we cache and re-use the inheritance
Expand All @@ -95,7 +95,7 @@ InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(
//
m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo();

auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo);
auto info = _calcInheritanceInfo(declRef, selfType, circularityInfo);
m_mapDeclRefToInheritanceInfo[declRef] = info;

getSession()->m_typeDictionarySize = Math::Max(
Expand Down Expand Up @@ -154,7 +154,7 @@ DeclRef<GenericDecl> SharedSemanticsContext::getDependentGenericParent(DeclRef<D

InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(
DeclRef<Decl> declRef,
DeclRefType* declRefType,
Type* selfType,
InheritanceCircularityInfo* circularityInfo)
{
// This method is the main engine for computing linearized inheritance
Expand Down Expand Up @@ -200,14 +200,6 @@ InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(
//
FacetList::Builder allFacets;

// It is possible that `declRef` is itself a type declaration,
// in which case `declRefType` will be the coresponding type.
// However, if `declRef` is an `extension` declaration, we
// will extract the type that the extension applies to, so
// that we can have a consistent "self type" to represent
// the type that is at the root of the inheritance list.
//
Type* selfType = declRefType;
Facet::Kind selfFacetKind = Facet::Kind::Type;

auto astBuilder = _getASTBuilder();
Expand Down Expand Up @@ -1043,6 +1035,13 @@ InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(
//
return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo);
}
else if (auto extractExistentialType = as<ExtractExistentialType>(type))
{
return _getInheritanceInfo(
extractExistentialType->getThisTypeDeclRef(),
extractExistentialType,
circularityInfo);
}
else if (auto conjunctionType = as<AndType>(type))
{
// In this case, we have a type of the form `L & R`,
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-ir-legalize-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ static LegalVal legalizeDebugVar(
case LegalType::Flavor::simple:
{
auto legalVal = context->builder->emitDebugVar(
type.getSimple(),
tryGetPointedToType(context->builder, type.getSimple()),
originalInst->getSource(),
originalInst->getLine(),
originalInst->getCol(),
Expand Down Expand Up @@ -887,7 +887,7 @@ static LegalVal legalizeDebugValue(
{
auto ordinaryVal = legalizeDebugValue(
context,
debugVar,
debugVar.getPair()->ordinaryVal,
debugValue.getPair()->ordinaryVal,
originalInst);
return ordinaryVal;
Expand Down
20 changes: 12 additions & 8 deletions source/slang/slang-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,18 @@ Type* DeclRefType::create(ASTBuilder* astBuilder, DeclRef<Decl> declRef)
}
return declRefType;
}
else if (as<ThisTypeDecl>(declRef.getDecl()) && as<DirectDeclRef>(declRef.declRefBase))
else if (as<ThisTypeDecl>(declRef.getDecl()))
{
declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
if (as<DirectDeclRef>(declRef.declRefBase))
{
declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);

return astBuilder->getOrCreate<ThisType>(declRef.declRefBase);
return astBuilder->getOrCreate<ThisType>(declRef.declRefBase);
}
else if (auto lookupDeclRef = as<LookupDeclRef>(declRef.declRefBase))
{
return lookupDeclRef->getWitness()->getSub();
}
}
else if (auto typedefDecl = as<TypeDefDecl>(declRef.getDecl()))
{
Expand All @@ -714,12 +721,9 @@ Type* DeclRefType::create(ASTBuilder* astBuilder, DeclRef<Decl> declRef)
typedefDecl->type.type->substitute(astBuilder, SubstitutionSet(declRef)));
return astBuilder->getErrorType();
}
else
{
declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);

return astBuilder->getOrCreate<DeclRefType>(declRef.declRefBase);
}
declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
return astBuilder->getOrCreate<DeclRefType>(declRef.declRefBase);
}

//
Expand Down
40 changes: 40 additions & 0 deletions tests/language-feature/interfaces/gh-5900.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type

interface IFoo
{
float get();
}

extension<FooType : IFoo> FooType {
float load()
{
return get();
}
}

struct Foo : IFoo
{
RWStructuredBuffer<float> buffer;
int dummy;

float get() { return buffer[0]; }
}

float bugTest(IFoo t)
{
return t.load();
}

//TEST_INPUT: set input = new Foo { ubuffer(data=[1.0 0 0 0], stride=4), 0 }
ConstantBuffer<Foo> input;

//TEST_INPUT: set output = out ubuffer(data=[0], stride=4)
uniform RWStructuredBuffer<float> output;

[shader("compute")]
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
// CHECK: 1.0
output[0] = bugTest(input);
}

0 comments on commit a8471a1

Please sign in to comment.