diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 3b36b270b7..068a240bbf 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1,4 +1,5 @@ // slang-check-overload.cpp +#include "slang-ast-base.h" #include "slang-check-impl.h" #include "slang-lookup.h" @@ -1199,6 +1200,28 @@ namespace Slang return parent; } + void countDistanceToGloablScope(DeclRef const& leftDecl, + DeclRef const& rightDecl, + int& leftDistance, int& rightDistance) + { + leftDistance = 0; + rightDistance = 0; + + DeclRef decl = leftDecl; + while(decl) + { + leftDistance++; + decl = decl.getParent(); + } + + decl = rightDecl; + while(decl) + { + rightDistance++; + decl = decl.getParent(); + } + } + // Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. // int SemanticsVisitor::CompareLookupResultItems( @@ -1324,6 +1347,24 @@ namespace Slang } } + // We need to consider the distance of the declarations to the global scope to resolve this case: + // float f(float x); + // struct S + // { + // float f(float x); + // float g(float y) { return f(y); } // will call S::f() instead of ::f() + // } + // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope, + // because this function will only choose the valid candidates. So if there is situation like this: + // void main() { S s; s.f(1.0);} or + // struct T { float g(y) { f(y); } }, there won't be ambiguity. + // So we just need to count which declaration is farther from the global scope and favor the farther one. + int leftDistance = 0; + int rightDistance = 0; + countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance); + if (leftDistance != rightDistance) + return leftDistance > rightDistance ? -1 : 1; + // TODO: We should generalize above rules such that in a tie a declaration // A::m is better than B::m when all other factors are equal and // A inherits from B. diff --git a/tests/bugs/overload-ambiguous.slang b/tests/bugs/overload-ambiguous.slang new file mode 100644 index 0000000000..1b74cb68c2 --- /dev/null +++ b/tests/bugs/overload-ambiguous.slang @@ -0,0 +1,48 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + + +uint getData() +{ + return 1u; +} + +struct DataObtainer +{ + uint data; + uint getData() + { + return data; + } + + uint getValue() + { + return getData(); // will call DataObtainer::getData() + } + + uint getValue2() + { + return ::getData(); // will call global getData() + } +} + +RWStructuredBuffer output; + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + DataObtainer obtainer = {2u}; + outputBuffer[0] = obtainer.getValue(); + outputBuffer[1] = obtainer.getValue2(); + // BUF: 2 + // BUF-NEXT: 1 +}