Skip to content

Commit

Permalink
[CIR][ThroughMLIR] Support lowering cir.condition and cir.while to sc…
Browse files Browse the repository at this point in the history
…f.condition, scf.while (llvm#636)

This pr intruduces CIRConditionLowering and CIRWhileLowering for
lowering to scf.
  • Loading branch information
GaoXiangYa authored and lanza committed Jun 20, 2024
1 parent 45b8509 commit 3d83a59
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 5 deletions.
69 changes: 67 additions & 2 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToMLIR.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace cir;
using namespace llvm;
Expand Down Expand Up @@ -55,6 +56,19 @@ class SCFLoop {
int64_t step = 0;
};

class SCFWhileLoop {
public:
SCFWhileLoop(mlir::cir::WhileOp op, mlir::cir::WhileOp::Adaptor adaptor,
mlir::ConversionPatternRewriter *rewriter)
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
void transferToSCFWhileOp();

private:
mlir::cir::WhileOp whileOp;
mlir::cir::WhileOp::Adaptor adaptor;
mlir::ConversionPatternRewriter *rewriter;
};

static int64_t getConstant(mlir::cir::ConstantOp op) {
auto attr = op->getAttrs().front().getValue();
const auto IntAttr = attr.dyn_cast<mlir::cir::IntAttr>();
Expand Down Expand Up @@ -233,6 +247,20 @@ void SCFLoop::transferToSCFForOp() {
});
}

void SCFWhileLoop::transferToSCFWhileOp() {
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
rewriter->createBlock(&scfWhileOp.getBefore());
rewriter->createBlock(&scfWhileOp.getAfter());

rewriter->cloneRegionBefore(whileOp.getCond(),
&scfWhileOp.getBefore().back());
rewriter->eraseBlock(&scfWhileOp.getBefore().back());

rewriter->cloneRegionBefore(whileOp.getBody(), &scfWhileOp.getAfter().back());
rewriter->eraseBlock(&scfWhileOp.getAfter().back());
}

class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
public:
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
Expand All @@ -248,9 +276,46 @@ class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
}
};

class CIRWhileOpLowering
: public mlir::OpConversionPattern<mlir::cir::WhileOp> {
public:
using OpConversionPattern<mlir::cir::WhileOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::WhileOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
SCFWhileLoop loop(op, adaptor, &rewriter);
loop.transferToSCFWhileOp();
rewriter.eraseOp(op);
return mlir::success();
}
};

class CIRConditionOpLowering
: public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
public:
using OpConversionPattern<mlir::cir::ConditionOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::ConditionOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::WhileOp>([&](auto) {
auto condition = adaptor.getCondition();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), rewriter.getI1Type(), condition);
rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
op, i1Condition, parentOp->getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
}
};

void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRForOpLowering>(converter, patterns.getContext());
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering>(
converter, patterns.getContext());
}

} // namespace cir
} // namespace cir
10 changes: 7 additions & 3 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -43,7 +47,9 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToMLIR.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace cir;
Expand Down Expand Up @@ -558,7 +564,6 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
return mlir::failure();

rewriter.eraseOp(op);

return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -883,7 +888,6 @@ class CIRScopeOpLowering
if (mlir::failed(getTypeConverter()->convertTypes(scopeOp->getResultTypes(),
mlirResultTypes)))
return mlir::LogicalResult::failure();

rewriter.setInsertionPoint(scopeOp);
auto newScopeOp = rewriter.create<mlir::memref::AllocaScopeOp>(
scopeOp.getLoc(), mlirResultTypes);
Expand Down Expand Up @@ -956,7 +960,7 @@ class CIRYieldOpLowering
mlir::ConversionPatternRewriter &rewriter) const override {
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::IfOp, mlir::scf::ForOp>([&](auto) {
.Case<mlir::scf::IfOp, mlir::scf::ForOp, mlir::scf::WhileOp>([&](auto) {
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
op, adaptor.getOperands());
return mlir::success();
Expand Down
35 changes: 35 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/while.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void foo() {
int a = 0;
while(a < 2) {
a++;
}
}

//CHECK: func.func @foo() {
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
//CHECK: memref.store %[[C0_I32]], %[[alloca]][] : memref<i32>
//CHECK: memref.alloca_scope {
//CHECK: scf.while : () -> () {
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
//CHECK: %[[ONE:.+]] = arith.cmpi ult, %[[ZERO:.+]], %[[C2_I32]] : i32
//CHECK: %[[TWO:.+]] = arith.extui %[[ONE:.+]] : i1 to i32
//CHECK: %[[C0_I32_0:.+]] = arith.constant 0 : i32
//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO:.+]], %[[C0_I32_0]] : i32
//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE:.+]] : i1 to i8
//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR:.+]] : i8 to i1
//CHECK: scf.condition(%[[FIVE]])
//CHECK: } do {
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
//CHECK: %[[ONE:.+]] = arith.addi %0, %[[C1_I32:.+]] : i32
//CHECK: memref.store %[[ONE:.+]], %[[alloca]][] : memref<i32>
//CHECK: scf.yield
//CHECK: }
//CHECK: }
//CHECK: return
//CHECK: }

0 comments on commit 3d83a59

Please sign in to comment.