diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 19e349807e..c3eb44cfd2 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -3011,6 +3011,11 @@ namespace Slang // because there is no way for overload resolution to pick between them. if (fstParam.getDecl()->HasModifier() != sndParam.getDecl()->HasModifier()) return false; + + // If one parameter is `ref` and the other isn't, then they don't match. + // + if(fstParam.getDecl()->HasModifier() != sndParam.getDecl()->HasModifier()) + return false; } // Note(tfoley): return type doesn't enter into it, because we can't take @@ -7046,8 +7051,15 @@ namespace Slang for (UInt pp = 0; pp < paramCount; ++pp) { auto paramType = funcType->getParamType(pp); - if (auto outParamType = paramType->As()) + if (paramType->As() || paramType->As()) { + // `out`, `inout`, and `ref` parameters currently require + // an *exact* match on the type of the argument. + // + // TODO: relax this requirement by allowing an argument + // for an `inout` parameter to be converted in both + // directions. + // if( pp < expr->Arguments.Count() ) { auto argExpr = expr->Arguments[pp]; diff --git a/source/slang/compiler.h b/source/slang/compiler.h index e7c40bdc8e..526168e3a6 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -520,6 +520,9 @@ namespace Slang // Construct the type `InOut` RefPtr getInOutType(RefPtr valueType); + // Construct the type `Ref` + RefPtr getRefType(RefPtr valueType); + // Construct a pointer type like `Ptr`, but where // the actual type name for the pointer type is given by `ptrTypeName` RefPtr getPtrType(RefPtr valueType, char const* ptrTypeName); diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 35ad77f4f3..00afde2acc 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -117,6 +117,12 @@ __intrinsic_type($(kIROp_InOutType)) struct InOut {}; +__generic +__magic_type(RefType) +__intrinsic_type($(kIROp_RefType)) +struct Ref +{}; + __magic_type(StringType) __intrinsic_type($(kIROp_StringType)) struct String diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h index bbb258d157..95c6ff0f7a 100644 --- a/source/slang/core.meta.slang.h +++ b/source/slang/core.meta.slang.h @@ -126,6 +126,15 @@ SLANG_RAW(")\n") SLANG_RAW("struct InOut\n") SLANG_RAW("{};\n") SLANG_RAW("\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(RefType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_RefType +) +SLANG_RAW(")\n") +SLANG_RAW("struct Ref\n") +SLANG_RAW("{};\n") +SLANG_RAW("\n") SLANG_RAW("__magic_type(StringType)\n") SLANG_RAW("__intrinsic_type(") SLANG_SPLICE(kIROp_StringType diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index f531ae8517..2bfed3d524 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -311,6 +311,8 @@ DIAGNOSTIC(40006, Error, needCompileTimeConstant, "expected a compile-time const DIAGNOSTIC(40007, Internal, irValidationFailed, "IR validation failed: $0") +DIAGNOSTIC(40008, Error, invalidLValueForRefParameter, "the form of this l-value argument is not valid for a `ref` parameter") + // 41000 - IR-level validation issues DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 3e601e119a..af2fa68dc9 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -4408,6 +4408,13 @@ struct EmitVisitor emit("inout "); type = inOutType->getValueType(); } + else if( auto refType = as(type)) + { + // Note: There is no HLSL/GLSL equivalent for by-reference parameters, + // so we don't actually expect to encounter these in user code. + emit("inout "); + type = inOutType->getValueType(); + } emitIRType(ctx, type, name); } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index bf235eb3c6..152e9faa11 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -617,32 +617,32 @@ __target_intrinsic(glsl, "groupMemoryBarrier()); (barrier()") void GroupMemoryBarrierWithGroupSync(); // Atomics -void InterlockedAdd(in out int dest, int value, out int original_value); -void InterlockedAdd(in out uint dest, uint value, out uint original_value); +void InterlockedAdd(__ref int dest, int value, out int original_value); +void InterlockedAdd(__ref uint dest, uint value, out uint original_value); -void InterlockedAnd(in out int dest, int value, out int original_value); -void InterlockedAnd(in out uint dest, uint value, out uint original_value); +void InterlockedAnd(__ref int dest, int value, out int original_value); +void InterlockedAnd(__ref uint dest, uint value, out uint original_value); -void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value); -void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value); +void InterlockedCompareExchange(__ref int dest, int compare_value, int value, out int original_value); +void InterlockedCompareExchange(__ref uint dest, uint compare_value, uint value, out uint original_value); -void InterlockedCompareStore(in out int dest, int compare_value, int value); -void InterlockedCompareStore(in out uint dest, uint compare_value, uint value); +void InterlockedCompareStore(__ref int dest, int compare_value, int value); +void InterlockedCompareStore(__ref uint dest, uint compare_value, uint value); -void InterlockedExchange(in out int dest, int value, out int original_value); -void InterlockedExchange(in out uint dest, uint value, out uint original_value); +void InterlockedExchange(__ref int dest, int value, out int original_value); +void InterlockedExchange(__ref uint dest, uint value, out uint original_value); -void InterlockedMax(in out int dest, int value, out int original_value); -void InterlockedMax(in out uint dest, uint value, out uint original_value); +void InterlockedMax(__ref int dest, int value, out int original_value); +void InterlockedMax(__ref uint dest, uint value, out uint original_value); void InterlockedMin(in out int dest, int value, out int original_value); void InterlockedMin(in out uint dest, uint value, out uint original_value); -void InterlockedOr(in out int dest, int value, out int original_value); -void InterlockedOr(in out uint dest, uint value, out uint original_value); +void InterlockedOr(__ref int dest, int value, out int original_value); +void InterlockedOr(__ref uint dest, uint value, out uint original_value); -void InterlockedXor(in out int dest, int value, out int original_value); -void InterlockedXor(in out uint dest, uint value, out uint original_value); +void InterlockedXor(__ref int dest, int value, out int original_value); +void InterlockedXor(__ref uint dest, uint value, out uint original_value); // Is floating-point value finite? __generic bool isfinite(T x); diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h index a1e40c37b1..ba4998b236 100644 --- a/source/slang/hlsl.meta.slang.h +++ b/source/slang/hlsl.meta.slang.h @@ -650,32 +650,32 @@ SLANG_RAW("__target_intrinsic(glsl, \"groupMemoryBarrier()); (barrier()\")\n") SLANG_RAW("void GroupMemoryBarrierWithGroupSync();\n") SLANG_RAW("\n") SLANG_RAW("// Atomics\n") -SLANG_RAW("void InterlockedAdd(in out int dest, int value, out int original_value);\n") -SLANG_RAW("void InterlockedAdd(in out uint dest, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedAdd(__ref int dest, int value, out int original_value);\n") +SLANG_RAW("void InterlockedAdd(__ref uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedAnd(in out int dest, int value, out int original_value);\n") -SLANG_RAW("void InterlockedAnd(in out uint dest, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedAnd(__ref int dest, int value, out int original_value);\n") +SLANG_RAW("void InterlockedAnd(__ref uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value);\n") -SLANG_RAW("void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedCompareExchange(__ref int dest, int compare_value, int value, out int original_value);\n") +SLANG_RAW("void InterlockedCompareExchange(__ref uint dest, uint compare_value, uint value, out uint original_value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedCompareStore(in out int dest, int compare_value, int value);\n") -SLANG_RAW("void InterlockedCompareStore(in out uint dest, uint compare_value, uint value);\n") +SLANG_RAW("void InterlockedCompareStore(__ref int dest, int compare_value, int value);\n") +SLANG_RAW("void InterlockedCompareStore(__ref uint dest, uint compare_value, uint value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedExchange(in out int dest, int value, out int original_value);\n") -SLANG_RAW("void InterlockedExchange(in out uint dest, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedExchange(__ref int dest, int value, out int original_value);\n") +SLANG_RAW("void InterlockedExchange(__ref uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedMax(in out int dest, int value, out int original_value);\n") -SLANG_RAW("void InterlockedMax(in out uint dest, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedMax(__ref int dest, int value, out int original_value);\n") +SLANG_RAW("void InterlockedMax(__ref uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") SLANG_RAW("void InterlockedMin(in out int dest, int value, out int original_value);\n") SLANG_RAW("void InterlockedMin(in out uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedOr(in out int dest, int value, out int original_value);\n") -SLANG_RAW("void InterlockedOr(in out uint dest, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedOr(__ref int dest, int value, out int original_value);\n") +SLANG_RAW("void InterlockedOr(__ref uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") -SLANG_RAW("void InterlockedXor(in out int dest, int value, out int original_value);\n") -SLANG_RAW("void InterlockedXor(in out uint dest, uint value, out uint original_value);\n") +SLANG_RAW("void InterlockedXor(__ref int dest, int value, out int original_value);\n") +SLANG_RAW("void InterlockedXor(__ref uint dest, uint value, out uint original_value);\n") SLANG_RAW("\n") SLANG_RAW("// Is floating-point value finite?\n") SLANG_RAW("__generic bool isfinite(T x);\n") diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 035defa536..0b7d15fcc6 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -65,6 +65,7 @@ INST(Nop, nop, 0, 0) /* PtrTypeBase */ INST(PtrType, Ptr, 1, 0) + INST(RefType, Ref, 1, 0) /* OutTypeBase */ INST(OutType, Out, 1, 0) INST(InOutType, InOut, 1, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index f35c391cf4..880856b36a 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -559,6 +559,7 @@ struct IRBuilder IRPtrType* getPtrType(IRType* valueType); IROutType* getOutType(IRType* valueType); IRInOutType* getInOutType(IRType* valueType); + IRRefType* getRefType(IRType* valueType); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); IRArrayTypeBase* getArrayTypeBase( diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 92801ec9ad..2a8885b0ea 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -1308,6 +1308,11 @@ namespace Slang return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); } + IRRefType* IRBuilder::getRefType(IRType* valueType) + { + return (IRRefType*) getPtrType(kIROp_RefType, valueType); + } + IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType) { IRInst* operands[] = { valueType }; diff --git a/source/slang/ir.h b/source/slang/ir.h index 91ca377f25..b23e26e5e1 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -828,6 +828,7 @@ struct IRPtrType : IRPtrTypeBase SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase) SIMPLE_IR_TYPE(OutType, OutTypeBase) SIMPLE_IR_TYPE(InOutType, OutTypeBase) +SIMPLE_IR_TYPE(RefType, OutTypeBase) struct IRFuncType : IRType { diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 81520abf58..96b28a15b0 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -922,6 +922,11 @@ void assign( LoweredValInfo const& left, LoweredValInfo const& right); +IRInst* getAddress( + IRGenContext* context, + LoweredValInfo const& inVal, + SourceLoc diagnosticLocation); + void lowerStmt( IRGenContext* context, Stmt* stmt); @@ -1668,7 +1673,24 @@ struct ExprLoweringVisitorBase : ExprVisitor // make a conscious decision at some point. } - if (paramDecl->HasModifier() + if(paramDecl->HasModifier()) + { + // A `ref` qualified parameter must be implemented with by-reference + // parameter passing, so the argument value should be lowered as + // an l-value. + // + LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); + + // According to our "calling convention" we need to + // pass a pointer into the callee. Unlike the case for + // `out` and `inout` below, it is never valid to do + // copy-in/copy-out for a `ref` parameter, so we just + // pass in the actual pointer. + // + IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc); + (*ioArgs).Add(argPtr); + } + else if (paramDecl->HasModifier() || paramDecl->HasModifier()) { // This is a `out` or `inout` parameter, and so @@ -2930,6 +2952,26 @@ static LoweredValInfo maybeMoveMutableTemp( } } +IRInst* getAddress( + IRGenContext* context, + LoweredValInfo const& inVal, + SourceLoc diagnosticLocation) +{ + LoweredValInfo val = inVal; + switch(val.flavor) + { + case LoweredValInfo::Flavor::Ptr: + return val.val; + + // TODO: are there other cases we need to handle here (e.g., + // turning a bound subscript/property into an address) + + default: + context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter); + return nullptr; + } +} + void assign( IRGenContext* context, LoweredValInfo const& inLeft, @@ -3831,9 +3873,10 @@ struct DeclLoweringVisitor : DeclVisitor // enum ParameterDirection { - kParameterDirection_In, - kParameterDirection_Out, - kParameterDirection_InOut, + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference }; struct ParameterInfo { @@ -3856,6 +3899,11 @@ struct DeclLoweringVisitor : DeclVisitor // ParameterDirection getParameterDirection(VarDeclBase* paramDecl) { + if( paramDecl->HasModifier() ) + { + // The AST specified `ref`: + return kParameterDirection_Ref; + } if( paramDecl->HasModifier() ) { // The AST specified `inout`: @@ -4350,6 +4398,9 @@ struct DeclLoweringVisitor : DeclVisitor case kParameterDirection_InOut: irParamType = subBuilder->getInOutType(irParamType); break; + case kParameterDirection_Ref: + irParamType = subBuilder->getRefType(irParamType); + break; default: SLANG_UNEXPECTED("unknown parameter direction"); diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index f9b3b66757..8ccdc75b13 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -71,8 +71,12 @@ SYNTAX_CLASS(RequiredGLSLVersionModifier, Modifier) FIELD(Token, versionNumberToken) END_SYNTAX_CLASS() + SIMPLE_SYNTAX_CLASS(InOutModifier, OutModifier) +// `__ref` modifier for by-reference parameter passing +SIMPLE_SYNTAX_CLASS(RefModifier, Modifier) + // This is a special sentinel modifier that gets added // to the list when we have multiple variable declarations // all sharing the same modifiers: diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index ab95be5ac2..ad4eef9a6f 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -4399,6 +4399,7 @@ namespace Slang MODIFIER(input, InputModifier); MODIFIER(out, OutModifier); MODIFIER(inout, InOutModifier); + MODIFIER(__ref, RefModifier); MODIFIER(const, ConstModifier); MODIFIER(instance, InstanceModifier); MODIFIER(__builtin, BuiltinModifier); diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 5e855de293..74c817b92f 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -303,6 +303,11 @@ void Type::accept(IValVisitor* visitor, void* extra) return getPtrType(valueType, "InOutType").As(); } + RefPtr Session::getRefType(RefPtr valueType) + { + return getPtrType(valueType, "RefType").As(); + } + RefPtr Session::getPtrType(RefPtr valueType, char const* ptrTypeName) { auto genericDecl = findMagicDecl( @@ -2085,7 +2090,11 @@ void Type::accept(IValVisitor* visitor, void* extra) { auto paramDecl = paramDeclRef.getDecl(); auto paramType = GetType(paramDeclRef); - if( paramDecl->FindModifier() ) + if( paramDecl->FindModifier() ) + { + paramType = session->getRefType(paramType); + } + else if( paramDecl->FindModifier() ) { if(paramDecl->FindModifier() || paramDecl->FindModifier()) { diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 14e9c0066a..c7b0004e65 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -355,6 +355,10 @@ END_SYNTAX_CLASS() SYNTAX_CLASS(InOutType, OutTypeBase) END_SYNTAX_CLASS() +// The type for an `ref` parameter, e.g., `ref T` +SYNTAX_CLASS(RefType, PtrTypeBase) +END_SYNTAX_CLASS() + // A type alias of some kind (e.g., via `typedef`) SYNTAX_CLASS(NamedExpressionType, Type) DECL_FIELD(DeclRef, declRef) diff --git a/tests/compute/atomics-groupshared.slang b/tests/compute/atomics-groupshared.slang new file mode 100644 index 0000000000..7ac6809eba --- /dev/null +++ b/tests/compute/atomics-groupshared.slang @@ -0,0 +1,35 @@ +// atomics-groupshared.slang + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer outputBuffer; + +groupshared uint shared[4]; + +uint test(uint val) +{ + uint originalValue; + + outputBuffer[val] = 0; + + GroupMemoryBarrierWithGroupSync(); + + InterlockedAdd(outputBuffer[val], val, originalValue); + InterlockedAdd(outputBuffer[val ^ 1], val*16, originalValue); + InterlockedAdd(outputBuffer[val ^ 2], val*16*16, originalValue); + + GroupMemoryBarrierWithGroupSync(); + + return outputBuffer[val]; +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + uint val = test(tid); + outputBuffer[tid] = val; +} \ No newline at end of file diff --git a/tests/compute/atomics-groupshared.slang.expected.txt b/tests/compute/atomics-groupshared.slang.expected.txt new file mode 100644 index 0000000000..30966f0df2 --- /dev/null +++ b/tests/compute/atomics-groupshared.slang.expected.txt @@ -0,0 +1,4 @@ +210 +301 + 32 +123 diff --git a/tests/compute/atomics.slang b/tests/compute/atomics.slang new file mode 100644 index 0000000000..b769be5d77 --- /dev/null +++ b/tests/compute/atomics.slang @@ -0,0 +1,24 @@ +// atomics.slang + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer outputBuffer; + +void test(uint val) +{ + uint originalValue; + + InterlockedAdd(outputBuffer[val], val, originalValue); + InterlockedAdd(outputBuffer[val ^ 1], val*16, originalValue); + InterlockedAdd(outputBuffer[val ^ 2], val*16*16, originalValue); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + test(tid); +} \ No newline at end of file diff --git a/tests/compute/atomics.slang.expected.txt b/tests/compute/atomics.slang.expected.txt new file mode 100644 index 0000000000..30966f0df2 --- /dev/null +++ b/tests/compute/atomics.slang.expected.txt @@ -0,0 +1,4 @@ +210 +301 + 32 +123