Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for immutable strings. #92

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions example/ExampleDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,13 @@ def ImmutableOp : Op<ExampleDialect, "immutable.op", [WillReturn]> {
Make an argument immutable
}];
}

def StringAttrOp : Op<ExampleDialect, "string.attr.op", [WillReturn]> {
let results = (outs);
let arguments = (ins ImmutableStringAttr:$val);

let summary = "demonstrate an argument that takes in a StringRef";
let description = [{
The argument should not have a setter method
}];
}
6 changes: 6 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ void createFunctionExample(Module &module, const Twine &name) {
moreVarArgs.push_back(b.getInt32(4));
b.create<xd::InstNameConflictVarargsOp>(moreVarArgs, "four.varargs");

b.create<xd::StringAttrOp>("Hello world!");
tsymalla marked this conversation as resolved.
Show resolved Hide resolved

b.CreateRetVoid();
}

Expand Down Expand Up @@ -242,6 +244,10 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
for (Value *arg : op.getArgs())
out << " " << *arg << '\n';
});
b.add<xd::StringAttrOp>(
[](raw_ostream &out, xd::StringAttrOp &op) {
out << "visiting StringAttrOp: " << op.getVal() << '\n';
});
b.add<ReturnInst>([](raw_ostream &out, ReturnInst &ret) {
out << "visiting ReturnInst: " << ret << '\n';
});
Expand Down
9 changes: 9 additions & 0 deletions include/llvm-dialects/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ def : AttrLlvmType<AttrI16, I16>;
def : AttrLlvmType<AttrI32, I32>;
def : AttrLlvmType<AttrI64, I64>;

def ImmutableStringAttr : Attr<"::llvm::StringRef"> {
let toLlvmValue = [{ $_builder.CreateGlobalString($0) }];
let fromLlvmValue = [{ ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>($0)->getInitializer())->getAsString() }];
let isImmutable = true;
}

// Global string variables are essentially pointers in addrspace(0).
def : AttrLlvmType<ImmutableStringAttr, Ptr>;

// ============================================================================
/// More general attributes
// ============================================================================
Expand Down
120 changes: 101 additions & 19 deletions test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ namespace xd {
state.setError();
});

builder.add<StringAttrOp>([](::llvm_dialects::VerifierState &state, StringAttrOp &op) {
if (!op.verifier(state.out()))
state.setError();
});

builder.add<WriteOp>([](::llvm_dialects::VerifierState &state, WriteOp &op) {
if (!op.verifier(state.out()))
state.setError();
Expand All @@ -154,21 +159,21 @@ namespace xd {
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
Expand Down Expand Up @@ -329,7 +334,7 @@ return true;


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), {
lhs->getType(),
rhs->getType(),
Expand Down Expand Up @@ -451,7 +456,7 @@ uint32_t const extra = getExtra();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {lhs->getType()});
Expand Down Expand Up @@ -546,7 +551,7 @@ rhs


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {::llvm::cast<XdVectorType>(vector->getType())->getElementType()});
Expand Down Expand Up @@ -650,7 +655,7 @@ index


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -820,7 +825,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), {
}, false);

Expand Down Expand Up @@ -882,7 +887,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -980,7 +985,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -1113,7 +1118,7 @@ source
(void)context;

using ::llvm_dialects::printable;

if (arg_size() != 1) {
errs << " wrong number of arguments: " << arg_size()
<< ", expected 1\n";
Expand Down Expand Up @@ -1147,7 +1152,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {vector->getType()});
Expand Down Expand Up @@ -1607,7 +1612,7 @@ instName


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1670,7 +1675,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1744,7 +1749,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1836,7 +1841,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1928,7 +1933,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -2011,6 +2016,75 @@ initial



const ::llvm::StringLiteral StringAttrOp::s_name{"xd.string.attr.op"};

StringAttrOp* StringAttrOp::create(llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName) {
::llvm::LLVMContext& context = b.getContext();
(void)context;
::llvm::Module& module = *b.GetInsertBlock()->getModule();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(4);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), {
::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0),
}, false);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
::llvm::SmallString<32> newName;
for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) ||
::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) {
// If a function with the same name but a different types already exists,
// we get a bitcast of a function or a function with the wrong type.
// Try new names until we get one with the correct type.
newName = "";
::llvm::raw_svector_ostream newNameStream(newName);
newNameStream << s_name << "_" << i;
fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs);
}
assert(::llvm::isa<::llvm::Function>(fn.getCallee()));
assert(fn.getFunctionType() == fnType);
assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType());


::llvm::SmallVector<::llvm::Value*, 1> args = {
b.CreateGlobalString(val)
};

return ::llvm::cast<StringAttrOp>(b.CreateCall(fn, args, instName));
}


bool StringAttrOp::verifier(::llvm::raw_ostream &errs) {
::llvm::LLVMContext &context = getModule()->getContext();
(void)context;

using ::llvm_dialects::printable;

if (arg_size() != 1) {
errs << " wrong number of arguments: " << arg_size()
<< ", expected 1\n";
return false;
}

if (getArgOperand(0)->getType() != ::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0)) {
errs << " argument 0 (val) has type: "
<< *getArgOperand(0)->getType() << '\n';
errs << " expected: " << *::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0) << '\n';
return false;
}
::llvm::StringRef const val = getVal();
(void)val;
return true;
}


::llvm::StringRef StringAttrOp::getVal() {
return ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>(getArgOperand(0))->getInitializer())->getAsString() ;
}



const ::llvm::StringLiteral WriteOp::s_name{"xd.write"};

WriteOp* WriteOp::create(llvm_dialects::Builder& b, ::llvm::Value * data, const llvm::Twine &instName) {
Expand All @@ -2020,7 +2094,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2083,7 +2157,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2297,6 +2371,14 @@ data
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::StringAttrOp>() {
static const ::llvm_dialects::OpDescription desc{false, "xd.string.attr.op"};
return desc;
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::WriteOp>() {
Expand Down
20 changes: 20 additions & 0 deletions test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,26 @@ bool verifier(::llvm::raw_ostream &errs);
::llvm::Value * getResult();


};

class StringAttrOp : public ::llvm::CallInst {
static const ::llvm::StringLiteral s_name; //{"xd.string.attr.op"};

public:
static bool classof(const ::llvm::CallInst* i) {
return ::llvm_dialects::detail::isSimpleOperation(i, s_name);
}
static bool classof(const ::llvm::Value* v) {
return ::llvm::isa<::llvm::CallInst>(v) &&
classof(::llvm::cast<::llvm::CallInst>(v));
}
static StringAttrOp* create(::llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName = "");

bool verifier(::llvm::raw_ostream &errs);

::llvm::StringRef getVal();


};

class WriteOp : public ::llvm::CallInst {
Expand Down
6 changes: 5 additions & 1 deletion test/example/test-builder.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals
; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy.
; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s

;.
; CHECK: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
;.
; CHECK-LABEL: @example(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.read__i32()
Expand Down Expand Up @@ -42,5 +45,6 @@
; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3)
; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4)
; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0:[0-9]+]])
; CHECK-NEXT: ret void
;
6 changes: 5 additions & 1 deletion test/example/visitor-basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
; DEFAULT-NEXT: %v2 =
; DEFAULT-NEXT: %q =
; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q)
; DEFAULT-NEXT: visiting StringAttrOp: Hello world!
; DEFAULT-NEXT: visiting Ret (set): ret void
; DEFAULT-NEXT: visiting ReturnInst: ret void
; DEFAULT-NEXT: inner.counter = 1

@0 = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1

define void @test1(ptr %p) {
entry:
%v = call i32 @xd.read__i32()
Expand All @@ -36,6 +39,7 @@ entry:
call void (...) @xd.set.write(i8 %v.2)
call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q)
%vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q)
call void @xd.string.attr.op(ptr @0)
ret void
}

Expand All @@ -46,6 +50,6 @@ declare void @xd.write(...)
declare void @xd.set.write(...)
declare void @xd.write.vararg(...)
declare i8 @xd.itrunc__i8(...)

declare void @xd.string.attr.op(ptr)
declare i32 @llvm.umax.i32(i32, i32)
declare i32 @llvm.umin.i32(i32, i32)
Loading