From 7377ed43d22cd99e7560c47a9b100a9fc8b48008 Mon Sep 17 00:00:00 2001 From: Thomas Symalla Date: Tue, 21 May 2024 11:03:13 +0200 Subject: [PATCH] Add support for immutable attributes. Sometimes we don't want or can't update operands with a auto-generated setter method, so this change adds support to make attributes immutable. --- example/ExampleDialect.td | 16 +++ include/llvm-dialects/Dialect/Dialect.td | 3 + include/llvm-dialects/TableGen/Constraints.h | 3 + lib/TableGen/Constraints.cpp | 8 ++ lib/TableGen/Operations.cpp | 5 +- test/example/generated/ExampleDialect.cpp.inc | 118 +++++++++++++++--- test/example/generated/ExampleDialect.h.inc | 20 +++ 7 files changed, 154 insertions(+), 19 deletions(-) diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 58eb196..616a73a 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -36,6 +36,12 @@ def VectorKindLittleEndian : CppConstant<"xd::VectorKind::LittleEndian">; def VectorKindBigEndian : CppConstant<"xd::VectorKind::BigEndian">; def VectorKindMiddleEndian : CppConstant<"xd::VectorKind::MiddleEndian">; +def ImmutableAttrI1 : IntegerAttr<"bool"> { + let isImmutable = true; +} + +def : AttrLlvmType; + def isReasonableVectorKind : TgPredicate< (args AttrVectorKind:$kind), (eq $kind, (or VectorKindLittleEndian, VectorKindBigEndian))>; @@ -301,3 +307,13 @@ def InstNameConflictVarargsOp : Op { + let results = (outs); + let arguments = (ins ImmutableAttrI1:$val); + + let summary = "demonstrate how an argument will not get a setter method"; + let description = [{ + Make an argument immutable + }]; +} diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index b425ed4..7e3c949 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -69,6 +69,9 @@ class Attr : MetaType { // A check statement that is issued before using the C++ value in builders. // $0 is the C++ value. string check = ""; + + // Overriding prevents generating a setter method. Attributes are mutable by default. + bit isImmutable = false; } class IntegerAttr : Attr { diff --git a/include/llvm-dialects/TableGen/Constraints.h b/include/llvm-dialects/TableGen/Constraints.h index 0f8d70d..77cd364 100644 --- a/include/llvm-dialects/TableGen/Constraints.h +++ b/include/llvm-dialects/TableGen/Constraints.h @@ -203,6 +203,7 @@ class MetaType { bool isTypeArg() const { return m_kind == Kind::Type; } bool isValueArg() const { return m_kind == Kind::Value; } bool isVarArgList() const { return m_kind == Kind::VarArgList; } + bool isImmutable() const; protected: MetaType(Kind kind) : m_kind(kind) {} @@ -231,6 +232,7 @@ class Attr : public MetaType { llvm::StringRef getToUnsigned() const { return m_toUnsigned; } llvm::StringRef getFromUnsigned() const { return m_fromUnsigned; } llvm::StringRef getCheck() const { return m_check; } + bool getIsImmutable() const { return m_isImmutable; } // Set the LLVMType once -- used during initialization to break a circular // dependency in how IntegerType is defined. @@ -249,6 +251,7 @@ class Attr : public MetaType { std::string m_toUnsigned; std::string m_fromUnsigned; std::string m_check; + bool m_isImmutable; }; } // namespace llvm_dialects diff --git a/lib/TableGen/Constraints.cpp b/lib/TableGen/Constraints.cpp index aab3649..81333c6 100644 --- a/lib/TableGen/Constraints.cpp +++ b/lib/TableGen/Constraints.cpp @@ -361,6 +361,13 @@ StringRef MetaType::getBuilderCppType() const { return getCppType(); } +bool MetaType::isImmutable() const { + if (auto *attr = dyn_cast(this)) + return attr->getIsImmutable(); + + return false; +} + /// Return the C++ expression @p value transformed to be suitable for printing /// using LLVM's raw_ostream. std::string MetaType::printable(const MetaType *type, llvm::StringRef value) { @@ -394,6 +401,7 @@ std::unique_ptr Attr::parse(raw_ostream &errs, attr->m_toUnsigned = record->getValueAsString("toUnsigned"); attr->m_fromUnsigned = record->getValueAsString("fromUnsigned"); attr->m_check = record->getValueAsString("check"); + attr->m_isImmutable = record->getValueAsBit("isImmutable"); return attr; } diff --git a/lib/TableGen/Operations.cpp b/lib/TableGen/Operations.cpp index 0ee493a..a7d09af 100644 --- a/lib/TableGen/Operations.cpp +++ b/lib/TableGen/Operations.cpp @@ -160,7 +160,7 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out, FmtContext &fmt) const { for (const auto &arg : m_arguments) { std::string defaultDeclaration = "$0 get$1();"; - if (!arg.type->isVarArgList()) { + if (!arg.type->isVarArgList() && !arg.type->isImmutable()) { defaultDeclaration += R"( void set$1($0 $2); )"; @@ -205,6 +205,9 @@ void AccessorBuilder::emitGetterDefinition() const { } void AccessorBuilder::emitSetterDefinition() const { + if (m_arg.type->isImmutable()) + return; + std::string toLlvm = m_arg.name; if (auto *attr = dyn_cast(m_arg.type)) { diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 251608d..138817b 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -69,6 +69,11 @@ namespace xd { state.setError(); }); + builder.add([](::llvm_dialects::VerifierState &state, ImmutableOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + builder.add([](::llvm_dialects::VerifierState &state, InsertElementOp &op) { if (!op.verifier(state.out())) state.setError(); @@ -149,21 +154,21 @@ namespace xd { ::llvm::AttrBuilder attrBuilder{context}; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod)); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod)); m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } { @@ -324,7 +329,7 @@ return true; const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), { lhs->getType(), rhs->getType(), @@ -446,7 +451,7 @@ uint32_t const extra = getExtra(); const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {lhs->getType()}); @@ -541,7 +546,7 @@ rhs const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {::llvm::cast(vector->getType())->getElementType()}); @@ -645,7 +650,7 @@ index const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -815,7 +820,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), { }, false); @@ -877,7 +882,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -975,7 +980,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -1064,6 +1069,75 @@ source + const ::llvm::StringLiteral ImmutableOp::s_name{"xd.immutable.op"}; + + ImmutableOp* ImmutableOp::create(llvm_dialects::Builder& b, bool val, const llvm::Twine &instName) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(4); + auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), { +::llvm::IntegerType::get(context, 1), +}, false); + + auto fn = module.getOrInsertFunction(s_name, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << s_name << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + +::llvm::SmallVector<::llvm::Value*, 1> args = { + ::llvm::ConstantInt::get(::llvm::IntegerType::get(context, 1), val) + }; + + return ::llvm::cast(b.CreateCall(fn, args, instName)); + } + + + bool ImmutableOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 1) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 1\n"; + return false; + } + + if (getArgOperand(0)->getType() != ::llvm::IntegerType::get(context, 1)) { + errs << " argument 0 (val) has type: " + << *getArgOperand(0)->getType() << '\n'; + errs << " expected: " << *::llvm::IntegerType::get(context, 1) << '\n'; + return false; + } + bool const val = getVal(); +(void)val; + return true; +} + + + bool ImmutableOp::getVal() { + return ::llvm::cast<::llvm::ConstantInt>(getArgOperand(0))->getZExtValue() ; + } + + + const ::llvm::StringLiteral InsertElementOp::s_name{"xd.insertelement"}; InsertElementOp* InsertElementOp::create(llvm_dialects::Builder& b, ::llvm::Value * vector, ::llvm::Value * value, ::llvm::Value * index, const llvm::Twine &instName) { @@ -1073,7 +1147,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {vector->getType()}); @@ -1539,7 +1613,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1602,7 +1676,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1676,7 +1750,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -1768,7 +1842,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -1860,7 +1934,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -1952,7 +2026,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -2015,7 +2089,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -2133,6 +2207,14 @@ data } + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{false, "xd.immutable.op"}; + return desc; + } + + template <> const ::llvm_dialects::OpDescription & ::llvm_dialects::OpDescription::get() { diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 4f7f5db..0b29e7d 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -269,6 +269,26 @@ bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getResult(); + }; + + class ImmutableOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.immutable.op"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isSimpleOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static ImmutableOp* create(::llvm_dialects::Builder& b, bool val, const llvm::Twine &instName = ""); + +bool verifier(::llvm::raw_ostream &errs); + +bool getVal(); + + }; class InsertElementOp : public ::llvm::CallInst {