diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f5dd86df15..ee29750a6a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -973,9 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; + + bool axisIsSpecConstId[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef specConstExtents[3]; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute @@ -1038,9 +1043,12 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef specConstExtents[3]; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index b3e30dbc23..3ef1e8f3be 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1656,6 +1656,8 @@ struct SemanticsVisitor : public SemanticsContext void visitModifier(Modifier*); + DeclRef tryGetIntSpecializationConstant(Expr* expr); + AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); bool hasIntArgs(Attribute* attr, int numArgs); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 3723c98f86..6e451b5cf9 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } +DeclRef SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + if (IsErrorExpr(expr)) + return DeclRef(); + + if (!isScalarIntegerType(expr->type)) + return DeclRef(); + + auto specConstVar = as(expr); + if (!specConstVar || !specConstVar->declRef) + return DeclRef(); + + auto decl = specConstVar->declRef.getDecl(); + if (!decl) + return DeclRef(); + + for (auto modifier : decl->modifiers) + { + if (as(modifier) || as(modifier)) + { + return specConstVar->declRef.as(); + } + } + + return DeclRef(); +} + static bool _isDeclAllowedAsAttribute(DeclRef declRef) { if (as(declRef.getDecl())) @@ -350,8 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; - for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -359,6 +387,14 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + numThreadsAttr->extents[i] = nullptr; + numThreadsAttr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; + numThreadsAttr->extents[i] = value; } - - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; } else if (auto waveSizeAttr = as(attr)) { @@ -1831,15 +1863,24 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; + // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. + auto decl = as(syntaxNode); + SLANG_ASSERT(decl); for (int i = 0; i < 3; ++i) { - IntVal* value = nullptr; + attr->extents[i] = nullptr; auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + attr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1847,7 +1888,45 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as(intValue)) { - if (cintVal->getValue() < 1) + if (attr->axisIsSpecConstId[i]) + { + // This integer should actually be a reference to a + // specialization constant with this ID. + Int specConstId = cintVal->getValue(); + + for (auto member : decl->parentDecl->members) + { + auto constantId = member->findModifier(); + if (constantId) + { + SLANG_ASSERT(constantId->args.getCount() == 1); + auto id = checkConstantIntVal(constantId->args[0]); + if (id->getValue() == specConstId) + { + attr->specConstExtents[i] = + DeclRef(member->getDefaultDeclRef()); + break; + } + } + } + + // If not found, we need to create a new specialization + // constant with this ID. + if (!attr->specConstExtents[i]) + { + auto specConstVarDecl = getASTBuilder()->create(); + auto constantIdModifier = + getASTBuilder()->create(); + constantIdModifier->location = (int32_t)specConstId; + specConstVarDecl->type.type = getASTBuilder()->getIntType(); + addModifier(specConstVarDecl, constantIdModifier); + decl->parentDecl->addMember(specConstVarDecl); + attr->specConstExtents[i] = + DeclRef(specConstVarDecl->getDefaultDeclRef()); + } + continue; + } + else if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - value = intValue; + attr->extents[i] = intValue; } else { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; } - - attr->x = values[0]; - attr->y = values[1]; - attr->z = values[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 1d09189cce..d86cd8be2a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2459,6 +2459,12 @@ DIAGNOSTIC( Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") +DIAGNOSTIC( + 55205, + Error, + unsupportedSpecializationConstantForNumThreads, + "Specialization constants are not supported in the 'numthreads' attribute for the current " + "target.") DIAGNOSTIC( 56001, Error, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 7b51495e2b..d3a9359ff2 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -295,14 +295,48 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } -/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( +IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) +{ + Int specializationConstantIds[kThreadGroupAxisCount]; + IRNumThreadsDecoration* decor = + getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); + + for (auto id : specializationConstantIds) + { + if (id >= 0) + { + getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads); + break; + } + } + return decor; +} + +/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); - for (int i = 0; i < 3; ++i) + for (int i = 0; i < kThreadGroupAxisCount; ++i) { - outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1; + if (!decor) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = -1; + } + else if (auto specConst = as(decor->getOperand(i))) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = getSpecializationConstantId(specConst); + } + else + { + outNumThreads[i] = Int(getIntVal(decor->getOperand(i))); + outSpecializationConstantIds[i] = -1; + } } return decor; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index e5080f731b..1354b7cbd8 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -500,11 +500,20 @@ class CLikeSourceEmitter : public SourceEmitterBase /// different. Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 - static IRNumThreadsDecoration* getComputeThreadGroupSize( + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1 + IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1. If specialization constants are used for an axis, their + /// IDs is reported in non-negative entries of outSpecializationConstantIds. + static IRNumThreadsDecoration* getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]); + /// Finds the IRWaveSizeDecoration and gets the size from that. static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 23fff37acb..0dab07cfce 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( auto emitLocalSizeLayout = [&]() { Int sizeAlongAxis[kThreadGroupAxisCount]; - getComputeThreadGroupSize(irFunc, sizeAlongAxis); + Int specializationConstantIds[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds); m_writer->emit("layout("); char const* axes[] = {"x", "y", "z"}; @@ -1345,8 +1346,17 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( m_writer->emit(", "); m_writer->emit("local_size_"); m_writer->emit(axes[ii]); - m_writer->emit(" = "); - m_writer->emit(sizeAlongAxis[ii]); + + if (specializationConstantIds[ii] >= 0) + { + m_writer->emit("_id = "); + m_writer->emit(specializationConstantIds[ii]); + } + else + { + m_writer->emit(" = "); + m_writer->emit(sizeAlongAxis[ii]); + } } m_writer->emit(") in;\n"); }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 068e1563ca..2cf84a8540 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4353,23 +4353,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // [3.6. Execution Mode]: LocalSize case kIROp_NumThreadsDecoration: { - // TODO: The `LocalSize` execution mode option requires - // literal values for the X,Y,Z thread-group sizes. - // There is a `LocalSizeId` variant that takes ``s - // for those sizes, and we should consider using that - // and requiring the appropriate capabilities - // if any of the operands to the decoration are not - // literals (in a future where we support non-literals - // in those positions in the Slang IR). - // auto numThreads = cast(decoration); - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSize, - SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + if (numThreads->getXSpecConst() || numThreads->getYSpecConst() || + numThreads->getZSpecConst()) + { + // If any of the dimensions needs an ID, we need to emit + // all dimensions as an ID due to how LocalSizeId works. + int32_t ids[3]; + for (int i = 0; i < 3; ++i) + ids[i] = ensureInst(numThreads->getOperand(i))->id; + + // LocalSizeId is supported from SPIR-V 1.2 onwards without + // any extra capabilities. + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSizeId, + SpvLiteralInteger::from32(int32_t(ids[0])), + SpvLiteralInteger::from32(int32_t(ids[1])), + SpvLiteralInteger::from32(int32_t(ids[2]))); + } + else + { + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSize, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + } } break; case kIROp_MaxVertexCountDecoration: @@ -7977,10 +7990,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (m_executionModes[entryPoint].add(executionMode)) { + SpvOp execModeOp = SpvOpExecutionMode; + if (executionMode == SpvExecutionModeLocalSizeId || + executionMode == SpvExecutionModeLocalSizeHintId || + executionMode == SpvExecutionModeSubgroupsPerWorkgroupId) + { + execModeOp = SpvOpExecutionModeId; + } + emitInst( getSection(SpvLogicalSectionID::ExecutionModes), parentInst, - SpvOpExecutionMode, + execModeOp, entryPoint, executionMode, ops...); diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index 1c833a2948..372ef298e7 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -279,6 +279,16 @@ struct CollectGlobalUniformParametersContext continue; } + // NumThreadsDecoration may sometimes be the user for a global + // parameter. This occurs when the parameter was supposed to be + // a specialization constant, but isn't due to that not being + // supported for the target. These can be skipped here and + // diagnosed later. + if (as(user)) + { + continue; + } + // For each use site for the global parameter, we will // insert new code right before the instruction that uses // the parameter. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a58c2e900c..f46586aa2b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -570,6 +570,7 @@ struct IRInstanceDecoration : IRDecoration IRIntLit* getCount() { return cast(getOperand(0)); } }; +struct IRGlobalParam; struct IRNumThreadsDecoration : IRDecoration { enum @@ -578,11 +579,13 @@ struct IRNumThreadsDecoration : IRDecoration }; IR_LEAF_ISA(NumThreadsDecoration) - IRIntLit* getX() { return cast(getOperand(0)); } - IRIntLit* getY() { return cast(getOperand(1)); } - IRIntLit* getZ() { return cast(getOperand(2)); } + IRIntLit* getX() { return as(getOperand(0)); } + IRIntLit* getY() { return as(getOperand(1)); } + IRIntLit* getZ() { return as(getOperand(2)); } - IRIntLit* getExtentAlongAxis(int axis) { return cast(getOperand(axis)); } + IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } + IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } + IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 33f3944fd9..5a18b533ab 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -188,7 +188,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp for (int axis = 0; axis < kAxisCount; axis++) { - auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); + auto litValue = as(numThreadsDecor->getOperand(axis)); if (!litValue) return nullptr; @@ -1432,6 +1432,20 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); + if (!groupExtents) + { + m_sink->diagnose( + m_entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1); + groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } + dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 5bfa62e4af..3b47bd59ef 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1828,12 +1828,26 @@ struct LegalizeMetalEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); + auto uint3Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3)); + auto computeExtent = + emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); + if (!computeExtent) + { + m_sink->diagnose( + entryPoint.entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = + builder.getIntValue(uint3Type->getElementType(), 1); + computeExtent = + builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index a44e16a7ce..077cdb98d0 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -282,10 +282,11 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* values[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; + auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), 3, @@ -328,10 +329,10 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* args[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index c753600a7c..d05e1db7d4 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1973,4 +1973,17 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } +Int getSpecializationConstantId(IRGlobalParam* param) +{ + auto layout = findVarLayout(param); + if (!layout) + return 0; + + auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant); + if (!offset) + return 0; + + return offset->getOffset(); +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index e23aeb6180..666ac71c03 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -373,6 +373,8 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget) int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); +Int getSpecializationConstantId(IRGlobalParam* param); + } // namespace Slang #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e82fc03fde..0863457198 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7625,12 +7625,29 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); + + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = layoutLocalSizeAttr->specConstExtents[i] + ? emitDeclRef( + context, + layoutLocalSizeAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + layoutLocalSizeAttr->specConstExtents[i]))) + : lowerVal(context, layoutLocalSizeAttr->extents[i]); + } + for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); } else if (as(modifier)) { @@ -10336,11 +10353,28 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = numThreadsAttr->specConstExtents[i] + ? emitDeclRef( + context, + numThreadsAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + numThreadsAttr->specConstExtents[i]))) + : lowerVal(context, numThreadsAttr->extents[i]); + } + numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->z)))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); + numThreadsDecor->sourceLoc = numThreadsAttr->loc; } else if (auto waveSizeAttr = as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c275a868b5..6ae41a2eb9 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8437,7 +8437,9 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || + (nameText.endsWith("_id") && + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8451,6 +8453,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; + for (auto& b : numThreadsAttrib->axisIsSpecConstId) + b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8467,6 +8471,11 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; + + // We can't resolve the specialization constant declaration + // here, because it may not even exist. IDs pointing to unnamed + // specialization constants are allowed in GLSL. + numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d235c82703..d1adfedc0b 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,18 +4033,14 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier(); if (numThreadsAttribute) { - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) - sizeAlongAxis[0] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->x) - sizeAlongAxis[0] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) - sizeAlongAxis[1] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->y) - sizeAlongAxis[1] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) - sizeAlongAxis[2] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->z) - sizeAlongAxis[2] = 0; + for (int i = 0; i < 3; ++i) + { + if (auto cint = + entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) + sizeAlongAxis[i] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->extents[i]) + sizeAlongAxis[i] = 0; + } } // diff --git a/tests/glsl/compute-shader-layout-id.slang b/tests/glsl/compute-shader-layout-id.slang new file mode 100644 index 0000000000..bee8137d82 --- /dev/null +++ b/tests/glsl/compute-shader-layout-id.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main -allow-glsl +#version 450 + +[vk::constant_id(1)] +const int constValue1 = 0; + +[vk::constant_id(2)] +const int constValue3 = 5; + +// CHECK-DAG: OpExecutionModeId %main LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: OpDecorate %[[C2]] SpecId 2 + +layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = constValue3) in; +void main() +{ +} + diff --git a/tests/spirv/spec-constant-numthreads.slang b/tests/spirv/spec-constant-numthreads.slang new file mode 100644 index 0000000000..5c133219cf --- /dev/null +++ b/tests/spirv/spec-constant-numthreads.slang @@ -0,0 +1,35 @@ +//TEST:SIMPLE(filecheck=GLSL): -target glsl -allow-glsl +//TEST:SIMPLE(filecheck=GLSL): -target glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv -allow-glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// CHECK-DAG: OpExecutionModeId %computeMain1 LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: %[[C2]] = OpConstant %int 4 +// CHECK-DAG: OpStore %{{.*}} %[[C0]] +// CHECK-DAG: OpStore %{{.*}} %[[C1]] +// CHECK-DAG: OpStore %{{.*}} %[[C2]] + +// GLSL-DAG: layout(constant_id = 1) +// GLSL-DAG: int constValue0_0 = 0; +// GLSL-DAG: layout(constant_id = 0) +// GLSL-DAG: int constValue1_0 = 0; +// GLSL-DAG: layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = 4) in; + +[vk::specialization_constant] +const int constValue0 = 0; + +[vk::constant_id(0)] +const int constValue1 = 0; + +RWStructuredBuffer outputBuffer; + +[numthreads(constValue0, constValue1, 4)] +void computeMain1() +{ + int3 size = WorkgroupSize(); + outputBuffer[0] = size.x; + outputBuffer[1] = size.y; + outputBuffer[2] = size.z; +}