diff --git a/source/core/slang-list.h b/source/core/slang-list.h index d27afd4153..7c96e38445 100644 --- a/source/core/slang-list.h +++ b/source/core/slang-list.h @@ -537,7 +537,7 @@ class List } } - inline void swapElements(T* vals, Index index1, Index index2) + inline static void swapElements(T* vals, Index index1, Index index2) { if (index1 != index2) { @@ -547,6 +547,8 @@ class List } } + inline void swapElements(Index index1, Index index2) { swapElements(m_buffer, index1, index2); } + template Index binarySearch(const T2& obj, Comparer comparer) const { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c9497ce545..3fea267ee9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6541,7 +6541,25 @@ bool SemanticsVisitor::findWitnessForInterfaceRequirement( } } } - + if (lookupResult.isOverloaded()) + { + // If we found multiple members with the same name, + // we want to move the declarations in the same parent as inheritanceDecl + // to the front of the list, so that we always consider them first instead of + // the members declared in other extension decls. + // + Index front = 0; + auto parentOfInheritanceDecl = getParentAggTypeDeclBase(inheritanceDecl); + for (Index i = 0; i < lookupResult.items.getCount(); i++) + { + if (getParentAggTypeDeclBase(lookupResult.items[i].declRef.getDecl()) == + parentOfInheritanceDecl) + { + lookupResult.items.swapElements(i, front); + front++; + } + } + } // Iterate over the members and look for one that matches // the expected signature for the requirement. for (auto member : lookupResult) diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 1d3763299a..5dc6ca695f 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1051,6 +1051,18 @@ Decl* getParentAggTypeDecl(Decl* decl) return nullptr; } +Decl* getParentAggTypeDeclBase(Decl* decl) +{ + decl = decl->parentDecl; + while (decl) + { + if (as(decl)) + return decl; + decl = decl->parentDecl; + } + return nullptr; +} + Decl* getParentFunc(Decl* decl) { while (decl) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 5e31b54479..accc490f2b 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -370,6 +370,7 @@ Module* getModule(Decl* decl); /// Get the parent decl, skipping any generic decls in between. Decl* getParentDecl(Decl* decl); Decl* getParentAggTypeDecl(Decl* decl); +Decl* getParentAggTypeDeclBase(Decl* decl); Decl* getParentFunc(Decl* decl); } // namespace Slang diff --git a/tests/language-feature/interfaces/overloaded-associatedtype.slang b/tests/language-feature/interfaces/overloaded-associatedtype.slang new file mode 100644 index 0000000000..4630c76a60 --- /dev/null +++ b/tests/language-feature/interfaces/overloaded-associatedtype.slang @@ -0,0 +1,43 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type + +interface IFoo { + associatedtype Output : IBar; + func foo(other: T) -> Output; +} + +interface IBar { int getId(); } + +struct Ant:IBar { int getId() { return 0; } }; +struct Bat:IBar { int getId() { return 1; } }; +struct Cat:IBar { int getId() { return 2; } }; +struct Dog:IBar { int getId() { return 3; } }; +struct Ewe:IBar { int getId() { return 4; } }; +struct Fox:IBar { int getId() { return 5; } }; +struct Gnu:IBar { int getId() { return 6; } }; + +extension Ant: IFoo { + typedef Cat Output; + func foo(other: Bat) -> Cat { return Cat(); } +} +extension Ant: IFoo { + typedef Ewe Output; + func foo(other: Dog) -> Ewe { return Ewe(); } +} +extension Ant: IFoo { + typedef Gnu Output; + func foo(other: Fox) -> Gnu { return Gnu(); } +} + +int test>(T v) { + return v.foo(Fox()).getId(); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=output +RWStructuredBuffer output; + +[numthreads(1,1,1)] +void computeMain() { + Ant a; + // CHECK: 6 + output[0] = test(a); +} \ No newline at end of file