Skip to content

Commit

Permalink
Implement specialization constant support in numthreads / local_size (#…
Browse files Browse the repository at this point in the history
…5963)

* Allow using specialization constants in numthreads attribute

* Add support for GLSL local_size_x_id syntax

* Fix overeager specialization constant parsing

* Add diagnostics for specialization constant numthreads

* Remove unused variable

* Fix local_size_x_id not finding existing specialization constant

* Allow materializeGetWorkGroupSize to reference specialization constants

* Use SpvOpExecutionModeId for modes that require it

* Cleanup specialization constant numthreads code

* Add tests for specialization constant work group sizes

* Fix implicit Slang::Int -> int32_t cast

* Fix querying thread group size in reflection API

---------

Co-authored-by: Yong He <[email protected]>
  • Loading branch information
juliusikkala and csyonghe authored Jan 14, 2025
1 parent 971996b commit cbdc7e1
Show file tree
Hide file tree
Showing 20 changed files with 400 additions and 86 deletions.
20 changes: 14 additions & 6 deletions source/slang/slang-ast-modifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,9 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute
//
// TODO: These should be accessors that use the
// ordinary `args` list, rather than side data.
IntVal* x;
IntVal* y;
IntVal* z;
IntVal* extents[3];

bool axisIsSpecConstId[3];

// References to specialization constants, for defining the number of
// threads with them. If set, the corresponding axis is set to nullptr
// above.
DeclRef<VarDeclBase> specConstExtents[3];
};

class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute
Expand Down Expand Up @@ -1038,9 +1043,12 @@ class NumThreadsAttribute : public Attribute
//
// TODO: These should be accessors that use the
// ordinary `args` list, rather than side data.
IntVal* x;
IntVal* y;
IntVal* z;
IntVal* extents[3];

// References to specialization constants, for defining the number of
// threads with them. If set, the corresponding axis is set to nullptr
// above.
DeclRef<VarDeclBase> specConstExtents[3];
};

class WaveSizeAttribute : public Attribute
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,8 @@ struct SemanticsVisitor : public SemanticsContext

void visitModifier(Modifier*);

DeclRef<VarDeclBase> tryGetIntSpecializationConstant(Expr* expr);

AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope);

bool hasIntArgs(Attribute* attr, int numArgs);
Expand Down
108 changes: 91 additions & 17 deletions source/slang/slang-check-modifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*)
// Do nothing with modifiers for now
}

DeclRef<VarDeclBase> SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr)
{
// First type-check the expression as normal
expr = CheckExpr(expr);

if (IsErrorExpr(expr))
return DeclRef<VarDeclBase>();

if (!isScalarIntegerType(expr->type))
return DeclRef<VarDeclBase>();

auto specConstVar = as<VarExpr>(expr);
if (!specConstVar || !specConstVar->declRef)
return DeclRef<VarDeclBase>();

auto decl = specConstVar->declRef.getDecl();
if (!decl)
return DeclRef<VarDeclBase>();

for (auto modifier : decl->modifiers)
{
if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier))
{
return specConstVar->declRef.as<VarDeclBase>();
}
}

return DeclRef<VarDeclBase>();
}

static bool _isDeclAllowedAsAttribute(DeclRef<Decl> declRef)
{
if (as<AttributeDecl>(declRef.getDecl()))
Expand Down Expand Up @@ -350,15 +380,21 @@ Modifier* SemanticsVisitor::validateAttribute(
{
SLANG_ASSERT(attr->args.getCount() == 3);

IntVal* values[3];

for (int i = 0; i < 3; ++i)
{
IntVal* value = nullptr;

auto arg = attr->args[i];
if (arg)
{
auto specConstDecl = tryGetIntSpecializationConstant(arg);
if (specConstDecl)
{
numThreadsAttr->extents[i] = nullptr;
numThreadsAttr->specConstExtents[i] = specConstDecl;
continue;
}

auto intValue = checkLinkTimeConstantIntVal(arg);
if (!intValue)
{
Expand Down Expand Up @@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute(
{
value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
values[i] = value;
numThreadsAttr->extents[i] = value;
}

numThreadsAttr->x = values[0];
numThreadsAttr->y = values[1];
numThreadsAttr->z = values[2];
}
else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr))
{
Expand Down Expand Up @@ -1831,23 +1863,70 @@ Modifier* SemanticsVisitor::checkModifier(
{
SLANG_ASSERT(attr->args.getCount() == 3);

IntVal* values[3];
// GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl.
auto decl = as<EmptyDecl>(syntaxNode);
SLANG_ASSERT(decl);

for (int i = 0; i < 3; ++i)
{
IntVal* value = nullptr;
attr->extents[i] = nullptr;

auto arg = attr->args[i];
if (arg)
{
auto specConstDecl = tryGetIntSpecializationConstant(arg);
if (specConstDecl)
{
attr->specConstExtents[i] = specConstDecl;
continue;
}

auto intValue = checkConstantIntVal(arg);
if (!intValue)
{
return nullptr;
}
if (auto cintVal = as<ConstantIntVal>(intValue))
{
if (cintVal->getValue() < 1)
if (attr->axisIsSpecConstId[i])
{
// This integer should actually be a reference to a
// specialization constant with this ID.
Int specConstId = cintVal->getValue();

for (auto member : decl->parentDecl->members)
{
auto constantId = member->findModifier<VkConstantIdAttribute>();
if (constantId)
{
SLANG_ASSERT(constantId->args.getCount() == 1);
auto id = checkConstantIntVal(constantId->args[0]);
if (id->getValue() == specConstId)
{
attr->specConstExtents[i] =
DeclRef<VarDeclBase>(member->getDefaultDeclRef());
break;
}
}
}

// If not found, we need to create a new specialization
// constant with this ID.
if (!attr->specConstExtents[i])
{
auto specConstVarDecl = getASTBuilder()->create<VarDecl>();
auto constantIdModifier =
getASTBuilder()->create<VkConstantIdAttribute>();
constantIdModifier->location = (int32_t)specConstId;
specConstVarDecl->type.type = getASTBuilder()->getIntType();
addModifier(specConstVarDecl, constantIdModifier);
decl->parentDecl->addMember(specConstVarDecl);
attr->specConstExtents[i] =
DeclRef<VarDeclBase>(specConstVarDecl->getDefaultDeclRef());
}
continue;
}
else if (cintVal->getValue() < 1)
{
getSink()->diagnose(
attr,
Expand All @@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier(
return nullptr;
}
}
value = intValue;
attr->extents[i] = intValue;
}
else
{
value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
values[i] = value;
}

attr->x = values[0];
attr->y = values[1];
attr->z = values[2];
}

// Default behavior is to leave things as they are,
Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,12 @@ DIAGNOSTIC(
Error,
unsupportedTargetIntrinsic,
"intrinsic operation '$0' is not supported for the current target.")
DIAGNOSTIC(
55205,
Error,
unsupportedSpecializationConstantForNumThreads,
"Specialization constants are not supported in the 'numthreads' attribute for the current "
"target.")
DIAGNOSTIC(
56001,
Error,
Expand Down
40 changes: 37 additions & 3 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,48 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
}


/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount])
{
Int specializationConstantIds[kThreadGroupAxisCount];
IRNumThreadsDecoration* decor =
getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds);

for (auto id : specializationConstantIds)
{
if (id >= 0)
{
getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads);
break;
}
}
return decor;
}

/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount],
Int outSpecializationConstantIds[kThreadGroupAxisCount])
{
IRNumThreadsDecoration* decor = func->findDecoration<IRNumThreadsDecoration>();
for (int i = 0; i < 3; ++i)
for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1;
if (!decor)
{
outNumThreads[i] = 1;
outSpecializationConstantIds[i] = -1;
}
else if (auto specConst = as<IRGlobalParam>(decor->getOperand(i)))
{
outNumThreads[i] = 1;
outSpecializationConstantIds[i] = getSpecializationConstantId(specConst);
}
else
{
outNumThreads[i] = Int(getIntVal(decor->getOperand(i)));
outSpecializationConstantIds[i] = -1;
}
}
return decor;
}
Expand Down
13 changes: 11 additions & 2 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,20 @@ class CLikeSourceEmitter : public SourceEmitterBase
/// different. Returns an empty slice if not a built in type
static UnownedStringSlice getDefaultBuiltinTypeName(IROp op);

/// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1
static IRNumThreadsDecoration* getComputeThreadGroupSize(
/// Finds the IRNumThreadsDecoration and gets the size from that or sets all
/// dimensions to 1
IRNumThreadsDecoration* getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount]);

/// Finds the IRNumThreadsDecoration and gets the size from that or sets all
/// dimensions to 1. If specialization constants are used for an axis, their
/// IDs is reported in non-negative entries of outSpecializationConstantIds.
static IRNumThreadsDecoration* getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount],
Int outSpecializationConstantIds[kThreadGroupAxisCount]);

/// Finds the IRWaveSizeDecoration and gets the size from that.
static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize);

Expand Down
16 changes: 13 additions & 3 deletions source/slang/slang-emit-glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(
auto emitLocalSizeLayout = [&]()
{
Int sizeAlongAxis[kThreadGroupAxisCount];
getComputeThreadGroupSize(irFunc, sizeAlongAxis);
Int specializationConstantIds[kThreadGroupAxisCount];
getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds);

m_writer->emit("layout(");
char const* axes[] = {"x", "y", "z"};
Expand All @@ -1345,8 +1346,17 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(
m_writer->emit(", ");
m_writer->emit("local_size_");
m_writer->emit(axes[ii]);
m_writer->emit(" = ");
m_writer->emit(sizeAlongAxis[ii]);

if (specializationConstantIds[ii] >= 0)
{
m_writer->emit("_id = ");
m_writer->emit(specializationConstantIds[ii]);
}
else
{
m_writer->emit(" = ");
m_writer->emit(sizeAlongAxis[ii]);
}
}
m_writer->emit(") in;\n");
};
Expand Down
Loading

0 comments on commit cbdc7e1

Please sign in to comment.