From 87c50675941a3ac853a79f50ec0ce3465631fa8f Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Tue, 15 Aug 2017 15:12:29 -0700 Subject: [PATCH 1/2] More work on IR With this change, basic generation of IR works for a trivial shader, and there is some basic support for dumping the generated IR in an assembly-like format. As with the other IR change, the use of the IR is statically disabled for now, so that existing users won't be affected. --- slang.h | 2 + source/slang/emit.cpp | 110 +++++-- source/slang/ir.cpp | 549 +++++++++++++++++++++++++++++++++- source/slang/ir.h | 253 +++++++++++----- source/slang/lower-to-ir.cpp | 552 ++++++++++++++++++++++++++++++----- source/slang/lower-to-ir.h | 2 +- 6 files changed, 1309 insertions(+), 159 deletions(-) diff --git a/slang.h b/slang.h index eefee9a070..2f6c112c0b 100644 --- a/slang.h +++ b/slang.h @@ -973,6 +973,7 @@ namespace slang #include "source/core/text-io.cpp" #include "source/slang/diagnostics.cpp" #include "source/slang/emit.cpp" +#include "source/slang/ir.cpp" #include "source/slang/lexer.cpp" #include "source/slang/name.cpp" #include "source/slang/options.cpp" @@ -982,6 +983,7 @@ namespace slang #include "source/slang/profile.cpp" #include "source/slang/lookup.cpp" #include "source/slang/lower.cpp" +#include "source/slang/lower-to-ir.cpp" #include "source/slang/check.cpp" #include "source/slang/compiler.cpp" #include "source/slang/slang-stdlib.cpp" diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 71ee31e5e7..3bec174566 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -2,6 +2,7 @@ #include "emit.h" #include "lower.h" +#include "lower-to-ir.h" #include "name.h" #include "syntax.h" #include "type-layout.h" @@ -519,7 +520,7 @@ struct EmitVisitor emitRawText("\n#line "); char buffer[16]; - sprintf(buffer, "%d", sourceLocation.line); + sprintf(buffer, "%llu", (unsigned long long)sourceLocation.line); emitRawText(buffer); emitRawText(" "); @@ -679,7 +680,7 @@ struct EmitVisitor // and how we do. if(sourceLocation.column > context->shared->loc.column) { - int delta = sourceLocation.column - context->shared->loc.column; + Slang::Int delta = sourceLocation.column - context->shared->loc.column; for( int ii = 0; ii < delta; ++ii ) { emitRawText(" "); @@ -3756,7 +3757,7 @@ struct EmitVisitor void EmitDecl(RefPtr decl) { - emitDeclImpl(decl, nullptr); +emitDeclImpl(decl, nullptr); } void EmitDeclUsingLayout(RefPtr decl, RefPtr layout) @@ -3770,9 +3771,9 @@ struct EmitVisitor { EmitDecl(decl); } - else if(auto declGroup = declBase.As()) + else if( auto declGroup = declBase.As() ) { - for(auto d : declGroup->decls) + for( auto d : declGroup->decls ) EmitDecl(d); } else @@ -3805,7 +3806,7 @@ String emitEntryPoint( { globalStructLayout = gs.Ptr(); } - else if(auto globalConstantBufferLayout = globalScopeLayout.As()) + else if( auto globalConstantBufferLayout = globalScopeLayout.As() ) { // TODO: the `cbuffer` case really needs to be emitted very // carefully, but that is beyond the scope of what a simple rewriter @@ -3837,32 +3838,95 @@ String emitEntryPoint( } sharedContext.globalStructLayout = globalStructLayout; + auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr(); + EmitContext context; context.shared = &sharedContext; EmitVisitor visitor(&context); - auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr(); - - // We perform lowering of the program before emitting *anything*, - // because the lowering process might change how we emit some - // boilerplate at the start of the ouput for GLSL (e.g., what - // version we require). - auto lowered = lowerEntryPoint(entryPoint, programLayout, target, &sharedContext.extensionUsageTracker); - sharedContext.program = lowered.program; - - // Note that we emit the main body code of the program *before* - // we emit any leading preprocessor directives for GLSL. - // This is to give the emit logic a change to make last-minute - // adjustments like changing the required GLSL version. + // Depending on how the compiler was invoked, we may need to perform + // some amount of preocessing on the code before we can emit it. + // + // For our purposes, there are basically three different "modes" we + // care about: + // + // 1. "Full rewriter" mode, where the user provides HLSL/GLSL, and + // doesn't make use of any Slang code via `import`. + // + // 2. "Partial rewriter" mode, where the user starts with HLSL/GLSL, + // but also imports some Slang code, and may need us to rewrite + // their HLSL/GLSL function bodies to make things work. + // + // 3. "Full" mode, where all of the input code is in Slang (and/or + // the subset of HLSL we can fully type-check). // - // TODO: All such adjustments would be better handled during - // lowering, but that requires having a semantic rather than - // textual format for the HLSL->GLSL mapping. - visitor.EmitDeclsInContainer(lowered.program.Ptr()); + // We'll try to detect the cases here: + // +#if 0 + if(!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING )) + { + // This seems to be case (3), because the user is asking for full + // checking, and so we can assume we understand the code fully. + // + // In this case we want to translate to our intermediate representation + // and do optimizations/transformations there before we emit final code. + // + + auto lowered = lowerEntryPointToIR(entryPoint, programLayout, target); + + dumpIR(lowered); + + throw 99; + + } + else if(translationUnit->compileRequest->loadedModulesList.Count() != 0) +#else + if(!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING ) + || translationUnit->compileRequest->loadedModulesList.Count() != 0) +#endif + { + // The user has `import`ed some Slang modules, and so we are in case (2) + // + // We need to apply a "rewriting" pass to the code the user wrote, + // and then emit the result. + + // We perform lowering of the program before emitting *anything*, + // because the lowering process might change how we emit some + // boilerplate at the start of the ouput for GLSL (e.g., what + // version we require). + auto lowered = lowerEntryPoint(entryPoint, programLayout, target, &sharedContext.extensionUsageTracker); + sharedContext.program = lowered.program; + + // Note that we emit the main body code of the program *before* + // we emit any leading preprocessor directives for GLSL. + // This is to give the emit logic a change to make last-minute + // adjustments like changing the required GLSL version. + // + // TODO: All such adjustments would be better handled during + // lowering, but that requires having a semantic rather than + // textual format for the HLSL->GLSL mapping. + visitor.EmitDeclsInContainer(lowered.program.Ptr()); + } + else + { + // We are in case (1). + // + // We should be able to just emit the AST we parsed right back out, + // along with whatever annotations we added along the way. + + sharedContext.program = translationUnitSyntax; + visitor.EmitDeclsInContainer(translationUnitSyntax); + } + String code = sharedContext.sb.ProduceString(); sharedContext.sb.Clear(); + // Now that we've emitted the code for all the declaratiosn in the file, + // it is time to stich together the final output. + + + // There may be global-scope modifiers that we should emit now visitor.emitGLSLPreprocessorDirectives(translationUnitSyntax); String prefix = sharedContext.sb.ProduceString(); diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 44eb7a727b..9e5169bd20 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -6,9 +6,341 @@ namespace Slang { +#define OP(ID, MNEMONIC, ARG_COUNT, FLAGS) \ + static const IROpInfo kIROpInfo_##ID { \ + #MNEMONIC, ARG_COUNT, FLAGS, } + +#define PARENT kIROpFlag_Parent + + OP(TypeType, type.type, 0, 0); + OP(VoidType, type.void, 0, 0); + OP(BlockType, type.block, 0, 0); + OP(VectorType, type.vector, 2, 0); + OP(BoolType, type.bool, 0, 0); + OP(Float32Type, type.f32, 0, 0); + OP(Int32Type, type.i32, 0, 0); + OP(UInt32Type, type.u32, 0, 0); + OP(StructType, type.struct, 0, 0); + + OP(IntLit, integer_constant, 0, 0); + OP(FloatLit, float_constant, 0, 0); + + OP(Construct, construct, 0, 0); + + OP(Module, module, 0, PARENT); + OP(Func, func, 0, PARENT); + OP(Block, block, 0, PARENT); + + OP(Param, param, 0, 0); + + OP(FieldExtract, get_field, 1, 0); + OP(ReturnVal, return_val, 1, 0); + OP(ReturnVoid, return_void, 1, 0); + +#define INTRINSIC(NAME) \ + static const IROpInfo kIROpInfo_Intrinsic_##NAME { \ + "intrinsic." #NAME, 0, 0, }; +#include "intrinsic-defs.h" + +#undef PARENT +#undef OP + + + static IROpInfo const* const kIRIntrinsicOpInfos[] = + { + nullptr, + +#define INTRINSIC(NAME) &kIROpInfo_Intrinsic_##NAME, +#include "intrinsic-defs.h" + + }; + + // + + void IRUse::init(IRValue* u, IRValue* v) + { + user = u; + usedValue = v; + + if(v) + { + nextUse = v->firstUse; + prevLink = &v->firstUse; + + v->firstUse = this; + } + } + + // + + IRUse* IRInst::getArgs() + { + return &type; + } + + // + + // Add an instruction into the current scope + static void addInst( + IRBuilder* builder, + IRInst* inst) + { + auto parent = builder->parentInst; + if (!parent) + return; + + inst->parent = parent; + + if (!parent->firstChild) + { + inst->prevInst = nullptr; + inst->nextInst = nullptr; + + parent->firstChild = inst; + parent->lastChild = inst; + } + else + { + auto prev = parent->lastChild; + + inst->prevInst = prev; + inst->nextInst = nullptr; + + prev->nextInst = inst; + parent->lastChild = inst; + } + } + + // Create an IR instruction/value and initialize it. + // + // In this case `argCount` and `args` represnt the + // arguments *after* the type (which is a mandatory + // argument for all instructions). + static IRValue* createInstImpl( + IRBuilder* builder, + UInt size, + IROpInfo const* op, + IRType* type, + UInt argCount, + IRValue* const* args) + { + IRValue* inst = (IRInst*) malloc(size); + memset(inst, 0, size); + + IRUse* instArgs = inst->getArgs(); + + auto module = builder->module; + if (!module || (type && type->op == &kIROpInfo_VoidType)) + { + // Can't or shouldn't assign an ID to this op + } + else + { + inst->id = ++module->idCounter; + } + + inst->op = op; + + inst->type.init(inst, type); + + for( UInt aa = 0; aa < argCount; ++aa ) + { + instArgs[aa+1].init(inst, args[aa]); + } + + addInst(builder, inst); + + return inst; + } + + // Create an IR instruction/value and initialize it. + // + // For this overload, the type of the instruction is + // folded into the argument list (so `args[0]` needs + // to be the type of the instruction) + static IRValue* createInstImpl( + IRBuilder* builder, + UInt size, + IROpInfo const* op, + UInt argCount, + IRValue* const* args) + { + return createInstImpl( + builder, + size, + op, + (IRType*) args[0], + argCount - 1, + args + 1); + } + + template + static T* createInst( + IRBuilder* builder, + IROpInfo const* op, + IRType* type, + UInt argCount, + IRValue* const* args) + { + return (T*)createInstImpl( + builder, + sizeof(T), + op, + type, + argCount, + args); + } + + template + static T* createInst( + IRBuilder* builder, + IROpInfo const* op, + IRType* type) + { + return (T*)createInstImpl( + builder, + sizeof(T), + op, + type, + 0, + nullptr); + } + + template + static T* createInst( + IRBuilder* builder, + IROpInfo const* op, + IRType* type, + IRValue* arg) + { + return (T*)createInstImpl( + builder, + sizeof(T), + op, + type, + 1, + &arg); + } + + template + static T* createInst( + IRBuilder* builder, + IROpInfo const* op, + IRType* type, + IRValue* arg1, + IRValue* arg2) + { + IRValue* args[] = { arg1, arg2 }; + return (T*)createInstImpl( + builder, + sizeof(T), + op, + type, + 2, + &args[0]); + } + + template + static T* createValueWithTrailingArgs( + IRBuilder* builder, + IROpInfo const* op, + IRType* type, + UInt argCount, + IRValue* const* args) + { + return (T*)createInstImpl( + builder, + sizeof(T) + argCount * sizeof(IRUse), + op, + type, + argCount, + args); + } + + + + // + + static IRType* getBaseTypeImpl(IRBuilder* builder, IROpInfo const* op) + { + return createInst( + builder, + op, + builder->getTypeType()); + } + + IRType* IRBuilder::getBaseType(BaseType flavor) + { + switch( flavor ) + { + case BaseType::Bool: return getBaseTypeImpl(this, &kIROpInfo_BoolType); + case BaseType::Float: return getBaseTypeImpl(this, &kIROpInfo_Float32Type); + case BaseType::Int: return getBaseTypeImpl(this, &kIROpInfo_Int32Type); + case BaseType::UInt: return getBaseTypeImpl(this, &kIROpInfo_UInt32Type); + + default: + SLANG_UNEXPECTED("unhandled base type"); + return nullptr; + } + } + IRType* IRBuilder::getBoolType() { - SLANG_UNIMPLEMENTED_X("IR"); + return getBaseType(BaseType::Bool); + } + + IRType* IRBuilder::getVectorType(IRType* elementType, IRValue* elementCount) + { + // TODO: should unique things + return createInst( + this, + &kIROpInfo_VectorType, + getTypeType(), + elementType, + elementCount); + } + + IRType* IRBuilder::getTypeType() + { + // TODO: should unique things + IRType* type = createInst( + this, + &kIROpInfo_TypeType, + nullptr); + + // TODO: we need some way to stop this recursion, + // but just saying that `Type isa Type` is unfounded. + type->type.init(type, type); + + return type; + } + + IRType* IRBuilder::getVoidType() + { + return createInst( + this, + &kIROpInfo_VoidType, + getTypeType()); + } + + IRType* IRBuilder::getBlockType() + { + return createInst( + this, + &kIROpInfo_BlockType, + getTypeType()); + } + + IRType* IRBuilder::getStructType( + UInt fieldCount, + IRType* const* fieldTypes) + { + return createValueWithTrailingArgs( + this, + &kIROpInfo_StructType, + getTypeType(), + fieldCount, + (IRValue* const*)fieldTypes); } IRValue* IRBuilder::getBoolValue(bool value) @@ -18,7 +350,14 @@ namespace Slang IRValue* IRBuilder::getIntValue(IRType* type, IRIntegerValue value) { - SLANG_UNIMPLEMENTED_X("IR"); + IRIntLit* val = createInst( + this, + &kIROpInfo_IntLit, + type); + + val->value = value; + + return val; } IRValue* IRBuilder::getFloatValue(IRType* type, IRFloatingPointValue value) @@ -26,4 +365,210 @@ namespace Slang SLANG_UNIMPLEMENTED_X("IR"); } + IRInst* IRBuilder::emitIntrinsicInst( + IRType* type, + IntrinsicOp intrinsicOp, + UInt argCount, + IRValue* const* args) + { + return createValueWithTrailingArgs( + this, + kIRIntrinsicOpInfos[(int)intrinsicOp], + type, + argCount, + args); + } + + IRInst* IRBuilder::emitConstructorInst( + IRType* type, + UInt argCount, + IRValue* const* args) + { + return createValueWithTrailingArgs( + this, + &kIROpInfo_Construct, + type, + argCount, + args); + } + + IRModule* IRBuilder::createModule() + { + return createInst( + this, + &kIROpInfo_Module, + nullptr); + } + + + IRFunc* IRBuilder::createFunc() + { + return createInst( + this, + &kIROpInfo_Func, + nullptr); + } + + IRBlock* IRBuilder::createBlock() + { + return createInst( + this, + &kIROpInfo_Block, + getBlockType()); + } + + IRParam* IRBuilder::createParam( + IRType* type) + { + return createInst( + this, + &kIROpInfo_Param, + type); + } + + IRInst* IRBuilder::createFieldExtract( + IRType* type, + IRValue* base, + UInt fieldIndex) + { + IRFieldExtract* irInst = createInst( + this, + &kIROpInfo_FieldExtract, + type, + base); + + irInst->fieldIndex = fieldIndex; + + return irInst; + } + + IRInst* IRBuilder::createReturn( + IRValue* val) + { + return createInst( + this, + &kIROpInfo_ReturnVal, + getVoidType(), + val); + } + + IRInst* IRBuilder::createReturn() + { + return createInst( + this, + &kIROpInfo_ReturnVoid, + getVoidType()); + } + + struct IRDumpContext + { + FILE* file; + int indent; + }; + + static void dump( + IRDumpContext* context, + char const* text) + { + fprintf(context->file, "%s", text); + } + + static void dump( + IRDumpContext* context, + UInt val) + { + fprintf(context->file, "%llu", (unsigned long long)val); + } + + static void dumpIndent( + IRDumpContext* context) + { + for (int ii = 0; ii < context->indent; ++ii) + { + dump(context, " "); + } + } + + static void dumpID( + IRDumpContext* context, + IRInst* inst) + { + if (!inst) + { + dump(context, ""); + } + else + { + dump(context, "%"); + dump(context, inst->id); + } + } + + static void dumpInst( + IRDumpContext* context, + IRInst* inst) + { + dumpIndent(context); + if (!inst) + { + dump(context, ""); + } + + // TODO: need to display a name for the result... + + auto op = inst->op; + + if (inst->id) + { + dumpID(context, inst); + dump(context, " = "); + } + + dump(context, op->name); + + // TODO: dump operands + unsigned argCount = op->fixedArgCount + 1; + for (unsigned ii = 0; ii < argCount; ++ii) + { + if (ii != 0) + dump(context, ", "); + else + { + dump(context, " "); + } + + auto argVal = inst->getArgs()[ii].usedValue; + + // TODO: actually print the damn operand... + + dumpID(context, argVal); + } + + dump(context, "\n"); + + if (op->flags & kIROpFlag_Parent) + { + dumpIndent(context); + dump(context, "{\n"); + context->indent++; + auto parent = (IRParentInst*)inst; + for (auto ii = parent->firstChild; ii; ii = ii->nextInst) + { + dumpInst(context, ii); + } + context->indent--; + dumpIndent(context); + dump(context, "}\n"); + } + } + + void dumpIR(IRModule* module) + { + IRDumpContext context; + context.file = stderr; + context.indent = 0; + + dumpInst(&context, module); + } + } diff --git a/source/slang/ir.h b/source/slang/ir.h index db46235b36..95dec6007e 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -7,125 +7,248 @@ // similar in spirit to LLVM (but much simpler). // +// We need the definition of `BaseType` which currently belongs to the AST +#include "syntax.h" + namespace Slang { -struct IRBlock; struct IRFunc; +struct IRInst; struct IRModule; +struct IRParentInst; struct IRType; -// A value that can be referenced in the program. -struct IRValue +typedef unsigned int IROpFlags; +enum : IROpFlags +{ + kIROpFlags_None = 0, + + // This op is a parent op + kIROpFlag_Parent = 1 << 0, +}; + +// A logical operation/opcode in the IR +struct IROpInfo +{ + // What is the name/mnemonic for this operation + char const* name; + + // How many required arguments are there + // (not including the mandatory type argument) + unsigned int fixedArgCount; + + // Flags to control how we emit additional info + IROpFlags flags; +}; + +// A use of another value/inst within an IR operation +struct IRUse +{ + // The value that is doing the using. + IRInst* user; + + // The value that is being used + IRInst* usedValue; + + // The next use of the same value + IRUse* nextUse; + + // A "link" back to where this use is referenced, + // so that we can simplify updates. + IRUse** prevLink; + + void init(IRInst* user, IRInst* usedValue); +}; + +// In the IR, almost *everything* is an instruction, +// in order to make the representation as uniform as possible. +struct IRInst +{ + // The operation that this value represents + IROpInfo const* op; + + // A unique ID to represent the op when printing + // (or zero to indicate that the value of this + // op isn't special). + UInt id; + + // The parent of this instruction. + // This will often be a basic block, but we + // allow instructions to nest in more general ways. + IRParentInst* parent; + + // The next and previous instructions in the same parent block + IRInst* nextInst; + IRInst* prevInst; + + // The first use of this value (start of a linked list) + IRUse* firstUse; + + // The type of this value + IRUse type; + + IRUse* getArgs(); +}; + +typedef IRInst IRValue; + +typedef long long IRIntegerValue; +typedef double IRFloatingPointValue; + +struct IRIntLit : IRInst +{ + IRIntegerValue value; +}; + +struct IRFloatLit : IRInst { - // Type type of this value - IRType* type; + IRFloatingPointValue value; }; // Representation of a type at the IR level. // Such a type may not correspond to the high-level-language notion // of a type as used by the front end. // -// Note that types are values in the IR, so that operations +// Note that types are instructions in the IR, so that operations // may take type operands as easily as values. -struct IRType : IRValue +struct IRType : IRInst { }; -// An instruction in the program. -struct IRInst : IRValue +struct IRVectorType : IRType { - // The basic block that contains this instruction, - // or NULL if the instruction currently has no parent. - IRBlock* parentBlock; + IRUse elementType; + IRUse elementCount; +}; - // The next and previous instructions in the same parent block - IRInst* nextInst; - IRInst* prevInst; +struct IRStructType : IRType +{}; + +struct IRFieldExtract : IRInst +{ + IRUse base; + UInt fieldIndex; }; // A instruction that ends a basic block (usually because of control flow) struct IRTerminatorInst : IRInst {}; -// A basic block, consisting of a sequence of instructions that can only -// be entered at the top, and can only be exited at the last instruction. -// -// Note that a block is itself a value, so that it can be a direct operand -// of an instruction (e.g., an instruction that branches to the block) -struct IRBlock : IRValue +struct IRReturn : IRTerminatorInst +{}; + +struct IRReturnVal : IRReturn { - // The function that contains this block - IRFunc* parentFunc; + IRUse val; +}; - // The first and last instruction in the block (or NULL in - // the case that the block is empty). +struct IRReturnVoid : IRReturn +{}; + +// A parent instruction contains a sequence of other instructions +// +struct IRParentInst : IRInst +{ + // The first and last instruction in the container (or NULL in + // the case that the container is empty). // + IRInst* firstChild; + IRInst* lastChild; +}; + +// A basic block is a parent instruction that adds the constraint +// that all the children need to be "ordinary" instructions (so +// no function declarations, or nested blocks). We also expect +// that the previous/next instruction are always a basic block. +// +struct IRBlock : IRParentInst +{ // Note that in a valid program, every block must end with // a "terminator" instruction, so these should be non-NULL, // and `last` should actually be an `IRTerminatorInst`. - IRInst* first; - IRInst* last; - - // Next and previous block in the same function - IRBlock* nextBlock; - IRBlock* prevBlock; + IRInst* firstChild; + IRInst* lastChild; }; -// A function parameter. -struct IRParam : IRValue +// A function parameter is represented by an instruction +// in the entry block of a function. +struct IRParam : IRInst { - // The function that declared this parameter - IRFunc* parentFunc; - - // The next and previous parameter of the function - IRParam* nextParam; - IRParam* prevParam; - }; -// A function, which consists of zero or more blocks of instructions. +// A function is a parent to zero or more blocks of instructions. // // A function is itself a value, so that it can be a direct operand of // an instruction (e.g., a call). -struct IRFunc : IRValue +struct IRFunc : IRParentInst { - // The IR module that defines this function - IRModule* parentModule; - - // The unique entry block for the function is always the - // first block in the list of blocks. - IRBlock* firstBlock; - - // The last block in the function. - IRBlock* lastBlock; - - // The parameters of the function - IRParam* firstParam; - IRParam* lastParam; - - // The next/previous function in the same IR module - IRFunc* nextFunc; - IRFunc* prevFunc; }; -// A module defining global values -struct IRModule +// A module is a parent to functions, global variables, types, etc. +struct IRModule : IRParentInst { -}; - -typedef long long IRIntegerValue; -typedef double IRFloatingPointValue; + // The designated entry-point function, if any + IRFunc* entryPoint; + // A special counter used to assign logical ids to instructions in this module. + UInt idCounter; +}; struct IRBuilder { + // The module that will own all of the IR + IRModule* module; + + // The parent instruction to add children to. + IRParentInst* parentInst; + + IRType* getBaseType(BaseType flavor); IRType* getBoolType(); + IRType* getVectorType(IRType* elementType, IRValue* elementCount); + IRType* getTypeType(); + IRType* getVoidType(); + IRType* getBlockType(); + IRType* getStructType( + UInt fieldCount, + IRType* const* fieldTypes); IRValue* getBoolValue(bool value); IRValue* getIntValue(IRType* type, IRIntegerValue value); IRValue* getFloatValue(IRType* type, IRFloatingPointValue value); + + IRInst* emitIntrinsicInst( + IRType* type, + IntrinsicOp intrinsicOp, + UInt argCount, + IRValue* const* args); + + IRInst* emitConstructorInst( + IRType* type, + UInt argCount, + IRValue* const* args); + + IRModule* createModule(); + + IRFunc* createFunc(); + + IRBlock* createBlock(); + + IRParam* createParam( + IRType* type); + + IRInst* createFieldExtract( + IRType* type, + IRValue* base, + UInt fieldIndex); + + IRInst* createReturn( + IRValue* val); + + IRInst* createReturn(); }; +void dumpIR(IRModule* module); + } #endif diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index bee3edb16e..e00dffa1d4 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -8,51 +8,74 @@ namespace Slang { -struct SharedIRGenContext -{ - EntryPointRequest* entryPoint; - ProgramLayout* programLayout; - CodeGenTarget target; -}; - -struct LoweredExprInfo +struct LoweredValInfo { enum class Flavor { - Value, + None, + Simple, }; - static LoweredExprInfo createValue(IRValue* value) + union { - LoweredExprInfo result; - result.flavor = Flavor::Value; - result.value = value; - return result; + IRValue* val; + }; + Flavor flavor; + + LoweredValInfo() + { + flavor = Flavor::None; + val = nullptr; } - Flavor flavor; - union + static LoweredValInfo simple(IRValue* v) { - IRValue* value; - }; + LoweredValInfo info; + info.flavor = Flavor::Simple; + info.val = v; + return info; + } }; +struct SharedIRGenContext +{ + EntryPointRequest* entryPoint; + ProgramLayout* programLayout; + CodeGenTarget target; + + Dictionary, LoweredValInfo> declValues; +}; + + struct IRGenContext { - Dictionary declValues; + SharedIRGenContext* shared; IRBuilder* irBuilder; }; -struct LoweredValInfo +IRValue* getSimpleVal(LoweredValInfo lowered) { -}; + switch(lowered.flavor) + { + case LoweredValInfo::Flavor::None: + return nullptr; + + case LoweredValInfo::Flavor::Simple: + return lowered.val; + + default: + SLANG_UNEXPECTED("unhandled value flavor"); + return nullptr; + } +} struct LoweredTypeInfo { enum class Flavor { - Type, + None, + Simple, }; union @@ -60,8 +83,47 @@ struct LoweredTypeInfo IRType* type; }; Flavor flavor; + + LoweredTypeInfo() + { + flavor = Flavor::None; + } + + LoweredTypeInfo(IRType* t) + { + flavor = Flavor::Simple; + type = t; + } }; +IRType* getSimpleType(LoweredTypeInfo lowered) +{ + switch(lowered.flavor) + { + case LoweredTypeInfo::Flavor::None: + return nullptr; + + case LoweredTypeInfo::Flavor::Simple: + return lowered.type; + + default: + SLANG_UNEXPECTED("unhandled value flavor"); + return nullptr; + } +} + +LoweredValInfo lowerVal( + IRGenContext* context, + Val* val); + +IRValue* lowerSimpleVal( + IRGenContext* context, + Val* val) +{ + auto lowered = lowerVal(context, val); + return getSimpleVal(lowered); +} + LoweredTypeInfo lowerType( IRGenContext* context, Type* type); @@ -73,28 +135,105 @@ static LoweredTypeInfo lowerType( return lowerType(context, type.type); } -LoweredExprInfo lowerExpr( +// Lower a type and expect the result to be simple +IRType* lowerSimpleType( + IRGenContext* context, + Type* type) +{ + auto lowered = lowerType(context, type); + return getSimpleType(lowered); +} + +IRType* lowerSimpleType( + IRGenContext* context, + QualType const& type) +{ + auto lowered = lowerType(context, type); + return getSimpleType(lowered); +} + + +LoweredValInfo lowerExpr( IRGenContext* context, Expr* expr); +void lowerStmt( + IRGenContext* context, + Stmt* stmt); + +LoweredValInfo ensureDecl( + IRGenContext* context, + DeclRef const& declRef); + // struct ValLoweringVisitor : ValVisitor { IRGenContext* context; + IRBuilder* getBuilder() { return context->irBuilder; } + LoweredValInfo visitVal(Val* val) { SLANG_UNIMPLEMENTED_X("value lowering"); } + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) + { + // TODO: it is a bit messy here that the `ConstantIntVal` representation + // has no notion of a *type* associated with the value... + + auto type = getBuilder()->getBaseType(BaseType::Int); + return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); + } + LoweredTypeInfo visitType(Type* type) { SLANG_UNIMPLEMENTED_X("type lowering"); } + LoweredTypeInfo visitDeclRefType(DeclRefType* type) + { + // Catch-all for user-defined type references + LoweredValInfo loweredDeclRef = ensureDecl(context, type->declRef); + + // TODO: make sure that the value is actually a type... + + switch (loweredDeclRef.flavor) + { + case LoweredValInfo::Flavor::Simple: + return LoweredTypeInfo((IRType*)loweredDeclRef.val); + + default: + SLANG_UNIMPLEMENTED_X("type lowering"); + } + + } + + LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type) + { + return getBuilder()->getBaseType(type->BaseType); + } + + LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type) + { + auto irElementType = lowerSimpleType(context, type->elementType); + auto irElementCount = lowerSimpleVal(context, type->elementCount); + + return getBuilder()->getVectorType(irElementType, irElementCount); + } + }; +LoweredValInfo lowerVal( + IRGenContext* context, + Val* val) +{ + ValLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(val); +} + LoweredTypeInfo lowerType( IRGenContext* context, Type* type) @@ -115,52 +254,40 @@ struct LoweringVisitor // -struct ExprLoweringVisitor : ExprVisitor +struct ExprLoweringVisitor : ExprVisitor { IRGenContext* context; - LoweredExprInfo visitVarExpr(VarExpr* expr) - { - LoweredExprInfo info; - if(context->declValues.TryGetValue(expr->declRef.getDecl(), info)) - return info; - - throw 99; + IRBuilder* getBuilder() { return context->irBuilder; } - return LoweredExprInfo(); + LoweredValInfo visitVarExpr(VarExpr* expr) + { + LoweredValInfo info = ensureDecl(context, expr->declRef); + return info; } - LoweredExprInfo visitOverloadedExpr(OverloadedExpr* expr) + LoweredValInfo visitOverloadedExpr(OverloadedExpr* expr) { SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); } - LoweredExprInfo visitInitializerListExpr(InitializerListExpr* expr) + LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for initializer list expression"); } - IRType* getIRType(LoweredTypeInfo const& typeInfo) + LoweredValInfo visitConstantExpr(ConstantExpr* expr) { - switch( typeInfo.flavor ) - { - case LoweredTypeInfo::Flavor::Type: - return typeInfo.type; - } - } - - LoweredExprInfo visitConstantExpr(ConstantExpr* expr) - { - auto type = getIRType(lowerType(context, expr->type)); + auto type = lowerSimpleType(context, expr->type); switch( expr->ConstType ) { case ConstantExpr::ConstantType::Bool: - return LoweredExprInfo::createValue(context->irBuilder->getBoolValue(expr->integerValue != 0)); + return LoweredValInfo::simple(context->irBuilder->getBoolValue(expr->integerValue != 0)); case ConstantExpr::ConstantType::Int: - return LoweredExprInfo::createValue(context->irBuilder->getIntValue(type, expr->integerValue)); + return LoweredValInfo::simple(context->irBuilder->getIntValue(type, expr->integerValue)); case ConstantExpr::ConstantType::Float: - return LoweredExprInfo::createValue(context->irBuilder->getFloatValue(type, expr->floatingPointValue)); + return LoweredValInfo::simple(context->irBuilder->getFloatValue(type, expr->floatingPointValue)); case ConstantExpr::ConstantType::String: break; } @@ -168,68 +295,198 @@ struct ExprLoweringVisitor : ExprVisitor SLANG_UNEXPECTED("unexpected constant type"); } - LoweredExprInfo visitAggTypeCtorExpr(AggTypeCtorExpr* expr) + LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression"); } - LoweredExprInfo visitInvokeExpr(InvokeExpr* expr) + void addArgs(List* ioArgs, LoweredValInfo argInfo) + { + auto& args = *ioArgs; + switch( argInfo.flavor ) + { + case LoweredValInfo::Flavor::Simple: + args.Add(getSimpleVal(argInfo)); + break; + + default: + SLANG_UNIMPLEMENTED_X("addArgs case"); + break; + } + } + + LoweredValInfo lowerIntrinsicCall( + InvokeExpr* expr, + IntrinsicOp intrinsicOp) + { + auto type = lowerSimpleType(context, expr->type); + + List irArgs; + for( auto arg : expr->Arguments ) + { + auto loweredArg = lowerExpr(context, arg); + addArgs(&irArgs, loweredArg); + } + + UInt argCount = irArgs.Count(); + + return LoweredValInfo::simple(getBuilder()->emitIntrinsicInst(type, intrinsicOp, argCount, &irArgs[0])); + } + + LoweredValInfo lowerSimpleCall(InvokeExpr* expr) { + auto loweredFunc = lowerExpr(context, expr->FunctionExpr); + + for( auto arg : expr->Arguments ) + { + auto loweredArg = lowerExpr(context, arg); + } + SLANG_UNIMPLEMENTED_X("codegen for invoke expression"); } - LoweredExprInfo visitIndexExpr(IndexExpr* expr) + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) + { + // TODO: need to detect calls to builtins here, so that we can expand + // them as their own special opcodes... + + auto funcExpr = expr->FunctionExpr; + if( auto funcDeclRefExpr = funcExpr.As() ) + { + auto funcDeclRef = funcDeclRefExpr->declRef; + auto funcDecl = funcDeclRef.getDecl(); + if(auto intrinsicOpModifier = funcDecl->FindModifier()) + { + return lowerIntrinsicCall(expr, intrinsicOpModifier->op); + // + } + // TODO: handle target intrinsic modifier too... + + if( auto ctorDeclRef = funcDeclRef.As() ) + { + // HACK: we know all constructors are builtins for now, + // so we need to emit them as a call to the corresponding + // builtin operation. + + auto type = lowerSimpleType(context, expr->type); + + List irArgs; + for( auto arg : expr->Arguments ) + { + auto loweredArg = lowerExpr(context, arg); + addArgs(&irArgs, loweredArg); + } + + UInt argCount = irArgs.Count(); + + return LoweredValInfo::simple(getBuilder()->emitConstructorInst(type, argCount, &irArgs[0])); + } + } + + return lowerSimpleCall(expr); + } + + LoweredValInfo visitIndexExpr(IndexExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); } - LoweredExprInfo visitMemberExpr(MemberExpr* expr) + LoweredValInfo extractField( + LoweredTypeInfo fieldType, + LoweredValInfo base, + UInt fieldIndex) + { + switch (base.flavor) + { + case LoweredValInfo::Flavor::Simple: + { + IRValue* irBase = base.val; + return LoweredValInfo::simple( + getBuilder()->createFieldExtract( + getSimpleType(fieldType), + irBase, + fieldIndex)); + } + break; + + default: + SLANG_UNIMPLEMENTED_X("codegen for field extract"); + } + } + + LoweredValInfo visitMemberExpr(MemberExpr* expr) { + auto loweredType = lowerType(context, expr->type); + auto loweredBase = lowerExpr(context, expr->BaseExpression); + + auto declRef = expr->declRef; + if (auto fieldDeclRef = declRef.As()) + { + // Okay, easy enough: we have a reference to a field of a struct type... + + // HACK: for now just scan the decl to find the right index. + // TODO: we need to deal with the fact that the struct might get + // tuple-ified. + // + UInt index = 0; + for (auto fieldDecl : getMembersOfType(fieldDeclRef.GetParent().As())) + { + if (fieldDecl == fieldDeclRef.getDecl()) + { + break; + } + + index++; + } + + return extractField(loweredType, loweredBase, index); + } + SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); } - LoweredExprInfo visitSwizzleExpr(SwizzleExpr* expr) + LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for swizzle expression"); } - LoweredExprInfo visitDerefExpr(DerefExpr* expr) + LoweredValInfo visitDerefExpr(DerefExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for deref expression"); } - LoweredExprInfo visitTypeCastExpr(TypeCastExpr* expr) + LoweredValInfo visitTypeCastExpr(TypeCastExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for type cast expression"); } - LoweredExprInfo visitSelectExpr(SelectExpr* expr) + LoweredValInfo visitSelectExpr(SelectExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for select expression"); } - LoweredExprInfo visitGenericAppExpr(GenericAppExpr* expr) + LoweredValInfo visitGenericAppExpr(GenericAppExpr* expr) { SLANG_UNIMPLEMENTED_X("generic application expression during code generation"); } - LoweredExprInfo visitSharedTypeExpr(SharedTypeExpr* expr) + LoweredValInfo visitSharedTypeExpr(SharedTypeExpr* expr) { SLANG_UNIMPLEMENTED_X("shared type expression during code generation"); } - LoweredExprInfo visitAssignExpr(AssignExpr* expr) + LoweredValInfo visitAssignExpr(AssignExpr* expr) { SLANG_UNIMPLEMENTED_X("shared type expression during code generation"); } - LoweredExprInfo visitParenExpr(ParenExpr* expr) + LoweredValInfo visitParenExpr(ParenExpr* expr) { return lowerExpr(context, expr->base); } }; -LoweredExprInfo lowerExpr( +LoweredValInfo lowerExpr( IRGenContext* context, Expr* expr) { @@ -238,25 +495,141 @@ LoweredExprInfo lowerExpr( return visitor.dispatch(expr); } -struct LoweredDeclInfo -{}; +struct StmtLoweringVisitor : StmtVisitor +{ + IRGenContext* context; + + IRBuilder* getBuilder() { return context->irBuilder; } -struct DeclLoweringVisitor : DeclVisitor + void visitStmt(Stmt* stmt) + { + SLANG_UNIMPLEMENTED_X("stmt catch-all"); + } + + void visitBlockStmt(BlockStmt* stmt) + { + lowerStmt(context, stmt->body); + } + + void visitReturnStmt(ReturnStmt* stmt) + { + if( auto expr = stmt->Expression ) + { + auto loweredExpr = lowerExpr(context, expr); + + getBuilder()->createReturn(getSimpleVal(loweredExpr)); + } + else + { + getBuilder()->createReturn(); + } + } +}; + +void lowerStmt( + IRGenContext* context, + Stmt* stmt) +{ + StmtLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(stmt); +} + +struct DeclLoweringVisitor : DeclVisitor { IRGenContext* context; - LoweredDeclInfo visitDeclBase(DeclBase* decl) + IRBuilder* getBuilder() + { + return context->irBuilder; + } + + LoweredValInfo visitDeclBase(DeclBase* decl) { SLANG_UNIMPLEMENTED_X("decl catch-all"); } - LoweredDeclInfo visitDecl(Decl* decl) + LoweredValInfo visitDecl(Decl* decl) { SLANG_UNIMPLEMENTED_X("decl catch-all"); } + + LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) + { + // User-defined aggregate type: need to translate into + // a corresponding IR aggregate type. + + List fieldTypes; + List irFieldTypes; + + for (auto fieldDecl : decl->GetFields()) + { + // TODO: need to be prepared to deal with tuple-ness of fields here + auto fieldType = lowerType(context, fieldDecl->getType()); + + fieldTypes.Add(fieldType); + + switch (fieldType.flavor) + { + case LoweredTypeInfo::Flavor::Simple: + irFieldTypes.Add(fieldType.type); + break; + + default: + SLANG_UNIMPLEMENTED_X("struct field type"); + } + } + + // TODO: need to track relationship to original fields... + + IRType* irStructType = getBuilder()->getStructType( + irFieldTypes.Count(), + &irFieldTypes[0]); + + return LoweredValInfo::simple(irStructType); + } + + LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) + { + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + + // need to create an IR function here + + IRFunc* irFunc = subBuilder->createFunc(); + subBuilder->parentInst = irFunc; + + IRBlock* entryBlock = subBuilder->createBlock(); + subBuilder->parentInst = entryBlock; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->irBuilder = subBuilder; + + // set up sub context for generating our new function + + for( auto paramDecl : decl->GetParameters() ) + { + IRType* irParamType = lowerSimpleType(context, paramDecl->getType()); + IRParam* irParam = subBuilder->createParam(irParamType); + + DeclRef paramDeclRef = makeDeclRef(paramDecl.Ptr()); + + LoweredValInfo irParamVal = LoweredValInfo::simple(irParam); + + subContext->shared->declValues.Add(paramDeclRef, irParamVal); + } + + auto irResultType = lowerType(context, decl->ReturnType); + + + lowerStmt(subContext, decl->Body); + + return LoweredValInfo::simple(irFunc); + } }; -LoweredDeclInfo lowerDecl( +LoweredValInfo lowerDecl( IRGenContext* context, Decl* decl) { @@ -265,6 +638,30 @@ LoweredDeclInfo lowerDecl( return visitor.dispatch(decl); } +LoweredValInfo ensureDecl( + IRGenContext* context, + DeclRef const& declRef) +{ + auto shared = context->shared; + + LoweredValInfo result; + if(shared->declValues.TryGetValue(declRef, result)) + return result; + + // TODO: this is where we need to apply any specializations + // from the declaration reference, so that they can be + // applied correctly to the declaration itself... + + IRGenContext subContext = *context; + + result = lowerDecl(context, declRef.getDecl()); + + shared->declValues[declRef] = result; + + return result; +} + + EntryPointLayout* findEntryPointLayout( SharedIRGenContext* shared, EntryPointRequest* entryPointRequest) @@ -299,7 +696,7 @@ static void lowerEntryPointToIR( lowerDecl(context, entryPointFunc); } -void lowerEntryPointToIR( +IRModule* lowerEntryPointToIR( EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target) @@ -307,13 +704,32 @@ void lowerEntryPointToIR( SharedIRGenContext sharedContextStorage; SharedIRGenContext* sharedContext = &sharedContextStorage; + sharedContext->entryPoint = entryPoint; + sharedContext->programLayout = programLayout; + sharedContext->target = target; + IRGenContext contextStorage; IRGenContext* context = &contextStorage; + context->shared = sharedContext; + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->module = nullptr; + builder->parentInst = nullptr; + + IRModule* module = builder->createModule(); + builder->module = module; + builder->parentInst = module; + + context->irBuilder = builder; + auto entryPointLayout = findEntryPointLayout(sharedContext, entryPoint); lowerEntryPointToIR(context, entryPoint, entryPointLayout); + return module; + } } diff --git a/source/slang/lower-to-ir.h b/source/slang/lower-to-ir.h index b8e0f1a675..aa2cef6311 100644 --- a/source/slang/lower-to-ir.h +++ b/source/slang/lower-to-ir.h @@ -19,7 +19,7 @@ namespace Slang struct ExtensionUsageTracker; - void lowerEntryPointToIR( + IRModule* lowerEntryPointToIR( EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target); From e30ba2f6b7ad346fa5f2d435a9edc9ba1c56efab Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Wed, 16 Aug 2017 16:04:09 -0700 Subject: [PATCH 2/2] Fixups for IR checkpoint - The changes introduced a new path where we don't even go through the current "lowering" (really an AST-to-AST legalization pass), but this exposed a few issues I didn't anticipate: - First, we needed to make sure to pass in the computed layout information when emitting the original program (since the layout info is no longer automatically attached to AST nodes) - Second, we needed to take the sample-rate input checks that were being done in lowering before, and move them to the emit logic (which is really ugly, but I don't see a way around it for GLSL). --- source/slang/emit.cpp | 75 +++++++++++++++++++++++++++++++++++++++++- source/slang/lower.cpp | 39 ---------------------- 2 files changed, 74 insertions(+), 40 deletions(-) diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 3bec174566..c8f3f06cbb 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -48,6 +48,9 @@ struct SharedEmitContext // The entry point we are being asked to compile EntryPointRequest* entryPoint; + // The layout for the entry point + EntryPointLayout* entryPointLayout; + // The target language we want to generate code for CodeGenTarget target; @@ -2283,8 +2286,32 @@ struct EmitVisitor emitName(expr->lookupResult2.getName()); } + void setSampleRateFlag() + { + context->shared->entryPointLayout->flags |= EntryPointLayout::Flag::usesAnySampleRateInput; + } + + void doSampleRateInputCheck(VarDeclBase* decl) + { + if (decl->HasModifier()) + { + setSampleRateFlag(); + } + } + + void doSampleRateInputCheck(Name* name) + { + auto text = getText(name); + if (text == "gl_SampleID") + { + setSampleRateFlag(); + } + } + void visitVarExpr(VarExpr* varExpr, ExprEmitArg const& arg) { + doSampleRateInputCheck(varExpr->name); + auto prec = kEOp_Atomic; auto outerPrec = arg.outerPrec; bool needClose = MaybeEmitParens(outerPrec, kEOp_Atomic); @@ -2485,6 +2512,11 @@ struct EmitVisitor Emit("{\n"); for( auto& token : stmt->tokens ) { + if (token.type == TokenType::Identifier) + { + doSampleRateInputCheck(token.getName()); + } + emitTokenWithLocation(token); } Emit("}\n"); @@ -3560,6 +3592,15 @@ struct EmitVisitor void visitVarDeclBase(RefPtr decl, DeclEmitArg const& arg) { + // Global variable? Check if it is a sample-rate input. + if (dynamic_cast(decl->ParentDecl)) + { + if (decl->HasModifier()) + { + doSampleRateInputCheck(decl); + } + } + // Skip fields that have been tuple-ified and don't contribute // any fields of "ordinary" type. if (auto tupleFieldMod = decl->FindModifier()) @@ -3783,6 +3824,29 @@ emitDeclImpl(decl, nullptr); } }; + +EntryPointLayout* findEntryPointLayout( + ProgramLayout* programLayout, + EntryPointRequest* entryPointRequest) +{ + for( auto entryPointLayout : programLayout->entryPoints ) + { + if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + continue; + + if(entryPointLayout->profile != entryPointRequest->profile) + continue; + + // TODO: can't easily filter on translation unit here... + // Ideally the `EntryPointRequest` should get filled in with a pointer + // the specific function declaration that represents the entry point. + + return entryPointLayout.Ptr(); + } + + return nullptr; +} + String emitEntryPoint( EntryPointRequest* entryPoint, ProgramLayout* programLayout, @@ -3795,6 +3859,13 @@ String emitEntryPoint( sharedContext.finalTarget = entryPoint->compileRequest->Target; sharedContext.entryPoint = entryPoint; + if (entryPoint) + { + sharedContext.entryPointLayout = findEntryPointLayout( + programLayout, + entryPoint); + } + sharedContext.programLayout = programLayout; // Layout information for the global scope is either an ordinary @@ -3916,7 +3987,9 @@ String emitEntryPoint( // along with whatever annotations we added along the way. sharedContext.program = translationUnitSyntax; - visitor.EmitDeclsInContainer(translationUnitSyntax); + visitor.EmitDeclsInContainerUsingLayout( + translationUnitSyntax, + globalStructLayout); } String code = sharedContext.sb.ProduceString(); diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index 2c62a69a81..7a12fc3b70 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -886,8 +886,6 @@ struct LoweringVisitor LoweredExpr visitVarExpr( VarExpr* expr) { - doSampleRateInputCheck(expr->name); - // If the expression didn't get resolved, we can leave it as-is if (!expr->declRef) return expr; @@ -2206,14 +2204,6 @@ struct LoweringVisitor RefPtr loweredStmt = new UnparsedStmt(); lowerStmtFields(loweredStmt, stmt); - for (auto token : stmt->tokens) - { - if (token.type == TokenType::Identifier) - { - doSampleRateInputCheck(token.getName()); - } - } - loweredStmt->tokens = stmt->tokens; addStmt(loweredStmt); @@ -3369,28 +3359,6 @@ struct LoweringVisitor return SourceLanguage::Unknown; } - void setSampleRateFlag() - { - shared->entryPointLayout->flags |= EntryPointLayout::Flag::usesAnySampleRateInput; - } - - void doSampleRateInputCheck(VarDeclBase* decl) - { - if (decl->HasModifier()) - { - setSampleRateFlag(); - } - } - - void doSampleRateInputCheck(Name* name) - { - auto text = getText(name); - if (text == "gl_SampleIndex") - { - setSampleRateFlag(); - } - } - AggTypeDecl* isStructType(RefPtr type) { if (type->As()) return nullptr; @@ -3440,14 +3408,8 @@ struct LoweringVisitor LoweredDecl visitVariable( Variable* decl) { - // Global variable? Check if it is a sample-rate input. if (dynamic_cast(decl->ParentDecl)) { - if (decl->HasModifier()) - { - doSampleRateInputCheck(decl); - } - auto varLayout = tryToFindLayout(decl); if (varLayout) { @@ -3902,7 +3864,6 @@ struct LoweringVisitor } else if (ns == "sv_sampleindex") { - setSampleRateFlag(); globalVarExpr = createGLSLBuiltinRef("gl_SampleID", getIntType()); } else if (ns == "sv_stencilref")