Skip to content

Commit

Permalink
Fix global atomic functions (#582)
Browse files Browse the repository at this point in the history
Fixes #581

This change adds a new parameter passing mode `__ref` to exist alongisde `in`, `out`, and `inout`.
The `__ref` modifier indicates true by-reference parameter passing (whereas `inout` is copy-in-copy-out).

This is not intended to be something that users interact with directly, but rather a low-level feature that lets us provide a correct signature for the `Interlocked*()` operations in the standard library.
Most of the support for passing what are logically addresses around already exists in the IR, so the majority of the work here is just in introducing the new type `Ref<T>` and then using it appropriately when lowering `__ref` parameters/arguments to the IR.
  • Loading branch information
Tim Foley authored May 29, 2018
1 parent ace9a8d commit e7a8332
Show file tree
Hide file tree
Showing 21 changed files with 221 additions and 38 deletions.
14 changes: 13 additions & 1 deletion source/slang/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3011,6 +3011,11 @@ namespace Slang
// because there is no way for overload resolution to pick between them.
if (fstParam.getDecl()->HasModifier<OutModifier>() != sndParam.getDecl()->HasModifier<OutModifier>())
return false;

// If one parameter is `ref` and the other isn't, then they don't match.
//
if(fstParam.getDecl()->HasModifier<RefModifier>() != sndParam.getDecl()->HasModifier<RefModifier>())
return false;
}

// Note(tfoley): return type doesn't enter into it, because we can't take
Expand Down Expand Up @@ -7046,8 +7051,15 @@ namespace Slang
for (UInt pp = 0; pp < paramCount; ++pp)
{
auto paramType = funcType->getParamType(pp);
if (auto outParamType = paramType->As<OutTypeBase>())
if (paramType->As<OutTypeBase>() || paramType->As<RefType>())
{
// `out`, `inout`, and `ref` parameters currently require
// an *exact* match on the type of the argument.
//
// TODO: relax this requirement by allowing an argument
// for an `inout` parameter to be converted in both
// directions.
//
if( pp < expr->Arguments.Count() )
{
auto argExpr = expr->Arguments[pp];
Expand Down
3 changes: 3 additions & 0 deletions source/slang/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ namespace Slang
// Construct the type `InOut<valueType>`
RefPtr<InOutType> getInOutType(RefPtr<Type> valueType);

// Construct the type `Ref<valueType>`
RefPtr<RefType> getRefType(RefPtr<Type> valueType);

// Construct a pointer type like `Ptr<valueType>`, but where
// the actual type name for the pointer type is given by `ptrTypeName`
RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName);
Expand Down
6 changes: 6 additions & 0 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ __intrinsic_type($(kIROp_InOutType))
struct InOut
{};

__generic<T>
__magic_type(RefType)
__intrinsic_type($(kIROp_RefType))
struct Ref
{};

__magic_type(StringType)
__intrinsic_type($(kIROp_StringType))
struct String
Expand Down
9 changes: 9 additions & 0 deletions source/slang/core.meta.slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ SLANG_RAW(")\n")
SLANG_RAW("struct InOut\n")
SLANG_RAW("{};\n")
SLANG_RAW("\n")
SLANG_RAW("__generic<T>\n")
SLANG_RAW("__magic_type(RefType)\n")
SLANG_RAW("__intrinsic_type(")
SLANG_SPLICE(kIROp_RefType
)
SLANG_RAW(")\n")
SLANG_RAW("struct Ref\n")
SLANG_RAW("{};\n")
SLANG_RAW("\n")
SLANG_RAW("__magic_type(StringType)\n")
SLANG_RAW("__intrinsic_type(")
SLANG_SPLICE(kIROp_StringType
Expand Down
2 changes: 2 additions & 0 deletions source/slang/diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ DIAGNOSTIC(40006, Error, needCompileTimeConstant, "expected a compile-time const

DIAGNOSTIC(40007, Internal, irValidationFailed, "IR validation failed: $0")

DIAGNOSTIC(40008, Error, invalidLValueForRefParameter, "the form of this l-value argument is not valid for a `ref` parameter")

// 41000 - IR-level validation issues

DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected")
Expand Down
7 changes: 7 additions & 0 deletions source/slang/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4408,6 +4408,13 @@ struct EmitVisitor
emit("inout ");
type = inOutType->getValueType();
}
else if( auto refType = as<IRRefType>(type))
{
// Note: There is no HLSL/GLSL equivalent for by-reference parameters,
// so we don't actually expect to encounter these in user code.
emit("inout ");
type = inOutType->getValueType();
}

emitIRType(ctx, type, name);
}
Expand Down
32 changes: 16 additions & 16 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -617,32 +617,32 @@ __target_intrinsic(glsl, "groupMemoryBarrier()); (barrier()")
void GroupMemoryBarrierWithGroupSync();

// Atomics
void InterlockedAdd(in out int dest, int value, out int original_value);
void InterlockedAdd(in out uint dest, uint value, out uint original_value);
void InterlockedAdd(__ref int dest, int value, out int original_value);
void InterlockedAdd(__ref uint dest, uint value, out uint original_value);

void InterlockedAnd(in out int dest, int value, out int original_value);
void InterlockedAnd(in out uint dest, uint value, out uint original_value);
void InterlockedAnd(__ref int dest, int value, out int original_value);
void InterlockedAnd(__ref uint dest, uint value, out uint original_value);

void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value);
void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value);
void InterlockedCompareExchange(__ref int dest, int compare_value, int value, out int original_value);
void InterlockedCompareExchange(__ref uint dest, uint compare_value, uint value, out uint original_value);

void InterlockedCompareStore(in out int dest, int compare_value, int value);
void InterlockedCompareStore(in out uint dest, uint compare_value, uint value);
void InterlockedCompareStore(__ref int dest, int compare_value, int value);
void InterlockedCompareStore(__ref uint dest, uint compare_value, uint value);

void InterlockedExchange(in out int dest, int value, out int original_value);
void InterlockedExchange(in out uint dest, uint value, out uint original_value);
void InterlockedExchange(__ref int dest, int value, out int original_value);
void InterlockedExchange(__ref uint dest, uint value, out uint original_value);

void InterlockedMax(in out int dest, int value, out int original_value);
void InterlockedMax(in out uint dest, uint value, out uint original_value);
void InterlockedMax(__ref int dest, int value, out int original_value);
void InterlockedMax(__ref uint dest, uint value, out uint original_value);

void InterlockedMin(in out int dest, int value, out int original_value);
void InterlockedMin(in out uint dest, uint value, out uint original_value);

void InterlockedOr(in out int dest, int value, out int original_value);
void InterlockedOr(in out uint dest, uint value, out uint original_value);
void InterlockedOr(__ref int dest, int value, out int original_value);
void InterlockedOr(__ref uint dest, uint value, out uint original_value);

void InterlockedXor(in out int dest, int value, out int original_value);
void InterlockedXor(in out uint dest, uint value, out uint original_value);
void InterlockedXor(__ref int dest, int value, out int original_value);
void InterlockedXor(__ref uint dest, uint value, out uint original_value);

// Is floating-point value finite?
__generic<T : __BuiltinFloatingPointType> bool isfinite(T x);
Expand Down
32 changes: 16 additions & 16 deletions source/slang/hlsl.meta.slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,32 +650,32 @@ SLANG_RAW("__target_intrinsic(glsl, \"groupMemoryBarrier()); (barrier()\")\n")
SLANG_RAW("void GroupMemoryBarrierWithGroupSync();\n")
SLANG_RAW("\n")
SLANG_RAW("// Atomics\n")
SLANG_RAW("void InterlockedAdd(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedAdd(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedAdd(__ref int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedAdd(__ref uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedAnd(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedAnd(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedAnd(__ref int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedAnd(__ref uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value);\n")
SLANG_RAW("void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedCompareExchange(__ref int dest, int compare_value, int value, out int original_value);\n")
SLANG_RAW("void InterlockedCompareExchange(__ref uint dest, uint compare_value, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedCompareStore(in out int dest, int compare_value, int value);\n")
SLANG_RAW("void InterlockedCompareStore(in out uint dest, uint compare_value, uint value);\n")
SLANG_RAW("void InterlockedCompareStore(__ref int dest, int compare_value, int value);\n")
SLANG_RAW("void InterlockedCompareStore(__ref uint dest, uint compare_value, uint value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedExchange(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedExchange(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedExchange(__ref int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedExchange(__ref uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedMax(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedMax(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedMax(__ref int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedMax(__ref uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedMin(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedMin(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedOr(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedOr(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedOr(__ref int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedOr(__ref uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("void InterlockedXor(in out int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedXor(in out uint dest, uint value, out uint original_value);\n")
SLANG_RAW("void InterlockedXor(__ref int dest, int value, out int original_value);\n")
SLANG_RAW("void InterlockedXor(__ref uint dest, uint value, out uint original_value);\n")
SLANG_RAW("\n")
SLANG_RAW("// Is floating-point value finite?\n")
SLANG_RAW("__generic<T : __BuiltinFloatingPointType> bool isfinite(T x);\n")
Expand Down
1 change: 1 addition & 0 deletions source/slang/ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ INST(Nop, nop, 0, 0)

/* PtrTypeBase */
INST(PtrType, Ptr, 1, 0)
INST(RefType, Ref, 1, 0)
/* OutTypeBase */
INST(OutType, Out, 1, 0)
INST(InOutType, InOut, 1, 0)
Expand Down
1 change: 1 addition & 0 deletions source/slang/ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ struct IRBuilder
IRPtrType* getPtrType(IRType* valueType);
IROutType* getOutType(IRType* valueType);
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);

IRArrayTypeBase* getArrayTypeBase(
Expand Down
5 changes: 5 additions & 0 deletions source/slang/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,11 @@ namespace Slang
return (IRInOutType*) getPtrType(kIROp_InOutType, valueType);
}

IRRefType* IRBuilder::getRefType(IRType* valueType)
{
return (IRRefType*) getPtrType(kIROp_RefType, valueType);
}

IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType)
{
IRInst* operands[] = { valueType };
Expand Down
1 change: 1 addition & 0 deletions source/slang/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ struct IRPtrType : IRPtrTypeBase
SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase)
SIMPLE_IR_TYPE(OutType, OutTypeBase)
SIMPLE_IR_TYPE(InOutType, OutTypeBase)
SIMPLE_IR_TYPE(RefType, OutTypeBase)

struct IRFuncType : IRType
{
Expand Down
59 changes: 55 additions & 4 deletions source/slang/lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,11 @@ void assign(
LoweredValInfo const& left,
LoweredValInfo const& right);

IRInst* getAddress(
IRGenContext* context,
LoweredValInfo const& inVal,
SourceLoc diagnosticLocation);

void lowerStmt(
IRGenContext* context,
Stmt* stmt);
Expand Down Expand Up @@ -1668,7 +1673,24 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// make a conscious decision at some point.
}

if (paramDecl->HasModifier<OutModifier>()
if(paramDecl->HasModifier<RefModifier>())
{
// A `ref` qualified parameter must be implemented with by-reference
// parameter passing, so the argument value should be lowered as
// an l-value.
//
LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr);

// According to our "calling convention" we need to
// pass a pointer into the callee. Unlike the case for
// `out` and `inout` below, it is never valid to do
// copy-in/copy-out for a `ref` parameter, so we just
// pass in the actual pointer.
//
IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc);
(*ioArgs).Add(argPtr);
}
else if (paramDecl->HasModifier<OutModifier>()
|| paramDecl->HasModifier<InOutModifier>())
{
// This is a `out` or `inout` parameter, and so
Expand Down Expand Up @@ -2930,6 +2952,26 @@ static LoweredValInfo maybeMoveMutableTemp(
}
}

IRInst* getAddress(
IRGenContext* context,
LoweredValInfo const& inVal,
SourceLoc diagnosticLocation)
{
LoweredValInfo val = inVal;
switch(val.flavor)
{
case LoweredValInfo::Flavor::Ptr:
return val.val;

// TODO: are there other cases we need to handle here (e.g.,
// turning a bound subscript/property into an address)

default:
context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter);
return nullptr;
}
}

void assign(
IRGenContext* context,
LoweredValInfo const& inLeft,
Expand Down Expand Up @@ -3831,9 +3873,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
enum ParameterDirection
{
kParameterDirection_In,
kParameterDirection_Out,
kParameterDirection_InOut,
kParameterDirection_In, ///< Copy in
kParameterDirection_Out, ///< Copy out
kParameterDirection_InOut, ///< Copy in, copy out
kParameterDirection_Ref, ///< By-reference
};
struct ParameterInfo
{
Expand All @@ -3856,6 +3899,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
{
if( paramDecl->HasModifier<RefModifier>() )
{
// The AST specified `ref`:
return kParameterDirection_Ref;
}
if( paramDecl->HasModifier<InOutModifier>() )
{
// The AST specified `inout`:
Expand Down Expand Up @@ -4350,6 +4398,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
case kParameterDirection_InOut:
irParamType = subBuilder->getInOutType(irParamType);
break;
case kParameterDirection_Ref:
irParamType = subBuilder->getRefType(irParamType);
break;

default:
SLANG_UNEXPECTED("unknown parameter direction");
Expand Down
4 changes: 4 additions & 0 deletions source/slang/modifier-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ SYNTAX_CLASS(RequiredGLSLVersionModifier, Modifier)
FIELD(Token, versionNumberToken)
END_SYNTAX_CLASS()


SIMPLE_SYNTAX_CLASS(InOutModifier, OutModifier)

// `__ref` modifier for by-reference parameter passing
SIMPLE_SYNTAX_CLASS(RefModifier, Modifier)

// This is a special sentinel modifier that gets added
// to the list when we have multiple variable declarations
// all sharing the same modifiers:
Expand Down
1 change: 1 addition & 0 deletions source/slang/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4399,6 +4399,7 @@ namespace Slang
MODIFIER(input, InputModifier);
MODIFIER(out, OutModifier);
MODIFIER(inout, InOutModifier);
MODIFIER(__ref, RefModifier);
MODIFIER(const, ConstModifier);
MODIFIER(instance, InstanceModifier);
MODIFIER(__builtin, BuiltinModifier);
Expand Down
11 changes: 10 additions & 1 deletion source/slang/syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ void Type::accept(IValVisitor* visitor, void* extra)
return getPtrType(valueType, "InOutType").As<InOutType>();
}

RefPtr<RefType> Session::getRefType(RefPtr<Type> valueType)
{
return getPtrType(valueType, "RefType").As<RefType>();
}

RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName)
{
auto genericDecl = findMagicDecl(
Expand Down Expand Up @@ -2085,7 +2090,11 @@ void Type::accept(IValVisitor* visitor, void* extra)
{
auto paramDecl = paramDeclRef.getDecl();
auto paramType = GetType(paramDeclRef);
if( paramDecl->FindModifier<OutModifier>() )
if( paramDecl->FindModifier<RefModifier>() )
{
paramType = session->getRefType(paramType);
}
else if( paramDecl->FindModifier<OutModifier>() )
{
if(paramDecl->FindModifier<InOutModifier>() || paramDecl->FindModifier<InModifier>())
{
Expand Down
4 changes: 4 additions & 0 deletions source/slang/type-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ END_SYNTAX_CLASS()
SYNTAX_CLASS(InOutType, OutTypeBase)
END_SYNTAX_CLASS()

// The type for an `ref` parameter, e.g., `ref T`
SYNTAX_CLASS(RefType, PtrTypeBase)
END_SYNTAX_CLASS()

// A type alias of some kind (e.g., via `typedef`)
SYNTAX_CLASS(NamedExpressionType, Type)
DECL_FIELD(DeclRef<TypeDefDecl>, declRef)
Expand Down
Loading

0 comments on commit e7a8332

Please sign in to comment.