Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CodeGen] Goto pass #562

Merged
merged 13 commits into from
May 8, 2024
Merged
42 changes: 42 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3751,6 +3751,48 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments, Terminator]>
];
}

//===----------------------------------------------------------------------===//
// GotoOp
//===----------------------------------------------------------------------===//

def GotoOp : CIR_Op<"goto", [Terminator]> {
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
let description = [{ Transfers control to the specified label.

Example:
```C++
void foo() {
goto exit;

exit:
return;
}
```

```mlir
cir.func @foo() {
cir.goto "exit"
^bb1:
cir.label "exit"
cir.return
}
```
}];
let arguments = (ins StrAttr:$label);
let assemblyFormat = [{ $label attr-dict }];
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
}

//===----------------------------------------------------------------------===//
// LabelOp
//===----------------------------------------------------------------------===//

// The LabelOp has AlwaysSpeculatable trait in order to not to be swept by canonicalizer
def LabelOp : CIR_Op<"label", [AlwaysSpeculatable]> {
let description = [{ An identifier which may be referred by cir.goto operation }];
let arguments = (ins StrAttr:$label);
let assemblyFormat = [{ $label attr-dict }];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Atomic operations
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ std::unique_ptr<Pass> createIdiomRecognizerPass(clang::ASTContext *astCtx);
std::unique_ptr<Pass> createLibOptPass();
std::unique_ptr<Pass> createLibOptPass(clang::ASTContext *astCtx);
std::unique_ptr<Pass> createFlattenCFGPass();
std::unique_ptr<Pass> createGotoSolverPass();

void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);

Expand Down
10 changes: 10 additions & 0 deletions clang/include/clang/CIR/Dialect/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ def FlattenCFG : Pass<"cir-flatten-cfg"> {
let dependentDialects = ["cir::CIRDialect"];
}

def GotoSolver : Pass<"cir-goto-solver"> {
let summary = "Replaces goto operatations with branches";
let description = [{
This pass transforms CIR and replaces goto-s with branch
operations to the proper blocks.
}];
let constructor = "mlir::createGotoSolverPass()";
let dependentDialects = ["cir::CIRDialect"];
}

def IdiomRecognizer : Pass<"cir-idiom-recognizer"> {
let summary = "Raise calls to C/C++ libraries to CIR operations";
let description = [{
Expand Down
16 changes: 0 additions & 16 deletions clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,6 @@ void CIRGenFunction::LexicalScope::cleanup() {
auto &builder = CGF.builder;
auto *localScope = CGF.currLexScope;

// Handle pending gotos and the solved labels in this scope.
while (!localScope->PendingGotos.empty()) {
auto gotoInfo = localScope->PendingGotos.back();
// FIXME: Currently only support resolving goto labels inside the
// same lexical ecope.
assert(localScope->SolvedLabels.count(gotoInfo.second) &&
"goto across scopes not yet supported");

// The goto in this lexical context actually maps to a basic
// block.
auto g = cast<mlir::cir::BrOp>(gotoInfo.first);
g.setSuccessor(CGF.LabelMap[gotoInfo.second].getBlock());
localScope->PendingGotos.pop_back();
}
localScope->SolvedLabels.clear();

auto applyCleanup = [&]() {
if (PerformCleanup) {
// ApplyDebugLocation
Expand Down
7 changes: 0 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1962,13 +1962,6 @@ class CIRGenFunction : public CIRGenTypeCache {
return CleanupBlock;
}

// Goto's introduced in this scope but didn't get fixed.
llvm::SmallVector<std::pair<mlir::Operation *, const clang::LabelDecl *>, 4>
PendingGotos;

// Labels solved inside this scope.
llvm::SmallPtrSet<const clang::LabelDecl *, 4> SolvedLabels;

// ---
// Exception handling
// ---
Expand Down
30 changes: 13 additions & 17 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,14 +551,19 @@ mlir::LogicalResult CIRGenFunction::buildGotoStmt(const GotoStmt &S) {
// info support just yet, look at this again once we have it.
assert(builder.getInsertionBlock() && "not yet implemented");

mlir::Block *currBlock = builder.getBlock();
mlir::Block *gotoBlock = currBlock;
if (!currBlock->empty() &&
currBlock->back().hasTrait<mlir::OpTrait::IsTerminator>()) {
gotoBlock = builder.createBlock(builder.getBlock()->getParent());
builder.setInsertionPointToEnd(gotoBlock);
}

// A goto marks the end of a block, create a new one for codegen after
// buildGotoStmt can resume building in that block.

// Build a cir.br to the target label.
auto &JD = LabelMap[S.getLabel()];
auto brOp = buildBranchThroughCleanup(getLoc(S.getSourceRange()), JD);
if (!JD.isValid())
currLexScope->PendingGotos.push_back(std::make_pair(brOp, S.getLabel()));
builder.create<mlir::cir::GotoOp>(getLoc(S.getSourceRange()),
S.getLabel()->getName());
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved

// Insert the new block to continue codegen after goto.
builder.createBlock(builder.getBlock()->getParent());
Expand All @@ -568,31 +573,22 @@ mlir::LogicalResult CIRGenFunction::buildGotoStmt(const GotoStmt &S) {
}

mlir::LogicalResult CIRGenFunction::buildLabel(const LabelDecl *D) {
JumpDest &Dest = LabelMap[D];

// Create a new block to tag with a label and add a branch from
// the current one to it. If the block is empty just call attach it
// to this label.
mlir::Block *currBlock = builder.getBlock();
mlir::Block *labelBlock = currBlock;
if (!currBlock->empty()) {

{
mlir::OpBuilder::InsertionGuard guard(builder);
labelBlock = builder.createBlock(builder.getBlock()->getParent());
}

builder.create<BrOp>(getLoc(D->getSourceRange()), labelBlock);
builder.setInsertionPointToEnd(labelBlock);
}

if (!Dest.isValid()) {
Dest.Block = labelBlock;
currLexScope->SolvedLabels.insert(D);
// FIXME: add a label attribute to block...
} else {
assert(0 && "unimplemented");
}
builder.setInsertionPointToEnd(labelBlock);
builder.create<mlir::cir::LabelOp>(getLoc(D->getSourceRange()), D->getName());
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
builder.setInsertionPointToEnd(labelBlock);

// FIXME: emit debug info for labels, incrementProfileCounter
return mlir::success();
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ namespace mlir {

void populateCIRPreLoweringPasses(OpPassManager &pm) {
pm.addPass(createFlattenCFGPass());
// add other passes here
pm.addPass(createGotoSolverPass());
}

} // namespace mlir
31 changes: 31 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/Support/ErrorHandling.h"
#include <numeric>
#include <optional>
#include <set>

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
Expand Down Expand Up @@ -2174,6 +2175,24 @@ LogicalResult cir::FuncOp::verify() {
<< "' must have empty body";
}

std::set<llvm::StringRef> labels;
std::set<llvm::StringRef> gotos;

getOperation()->walk([&](mlir::Operation *op) {
if (auto lab = dyn_cast<mlir::cir::LabelOp>(op)) {
labels.emplace(lab.getLabel());
} else if (auto goTo = dyn_cast<mlir::cir::GotoOp>(op)) {
gotos.emplace(goTo.getLabel());
}
});

std::vector<llvm::StringRef> mismatched;
std::set_difference(gotos.begin(), gotos.end(), labels.begin(), labels.end(),
std::back_inserter(mismatched));

if (!mismatched.empty())
return emitOpError() << "goto/label mismatch";

return success();
}

Expand Down Expand Up @@ -3083,6 +3102,18 @@ LogicalResult BinOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// LabelOp Definitions
//===----------------------------------------------------------------------===//

LogicalResult LabelOp::verify() {
auto *op = getOperation();
auto *blk = op->getBlock();
if (&blk->front() != op)
return emitError() << "must be the first operation in a block";
return mlir::success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_clang_library(MLIRCIRTransforms
LibOpt.cpp
StdHelpers.cpp
FlattenCFG.cpp
GotoSolver.cpp

DEPENDS
MLIRCIRPassIncGen
Expand Down
8 changes: 5 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,11 @@ class CIRLoopOpInterfaceFlattening
});

// Lower optional body region yield.
auto bodyYield = dyn_cast<mlir::cir::YieldOp>(body->getTerminator());
if (bodyYield)
lowerTerminator(bodyYield, (step ? step : cond), rewriter);
for (auto &blk : op.getBody().getBlocks()) {
auto bodyYield = dyn_cast<mlir::cir::YieldOp>(blk.getTerminator());
if (bodyYield)
lowerTerminator(bodyYield, (step ? step : cond), rewriter);
}

// Lower mandatory step region yield.
if (step)
Expand Down
54 changes: 54 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

using namespace mlir;
using namespace mlir::cir;

namespace {

struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {

GotoSolverPass() = default;
void runOnOperation() override;
};

static void process(mlir::cir::FuncOp func) {

mlir::OpBuilder rewriter(func.getContext());
std::map<std::string, Block *> labels;
std::vector<mlir::cir::GotoOp> gotos;

func.getBody().walk([&](mlir::Operation *op) {
if (auto lab = dyn_cast<mlir::cir::LabelOp>(op)) {
labels.emplace(lab.getLabel().str(), lab->getBlock());
lab.erase();
} else if (auto goTo = dyn_cast<mlir::cir::GotoOp>(op)) {
gotos.push_back(goTo);
}
});

for (auto goTo : gotos) {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(goTo);
auto dest = labels[goTo.getLabel().str()];
rewriter.create<mlir::cir::BrOp>(goTo.getLoc(), dest);
goTo.erase();
}
}

void GotoSolverPass::runOnOperation() {
SmallVector<Operation *, 16> ops;
getOperation()->walk([&](mlir::cir::FuncOp op) { process(op); });
}

} // namespace

std::unique_ptr<Pass> mlir::createGotoSolverPass() {
return std::make_unique<GotoSolverPass>();
}
3 changes: 3 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ struct RemoveRedudantBranches : public OpRewritePattern<BrOp> {
Block *block = op.getOperation()->getBlock();
Block *dest = op.getDest();

if (isa<mlir::cir::LabelOp>(dest->front()))
return failure();

// Single edge between blocks: merge it.
if (block->getNumSuccessors() == 1 &&
dest->getSinglePredecessor() == block) {
Expand Down
Loading
Loading