diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp index 055f97c63b3e..41311c1408e4 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp @@ -24,6 +24,7 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/LowerToMLIR.h" #include "clang/CIR/Passes.h" +#include "llvm/ADT/TypeSwitch.h" using namespace cir; using namespace llvm; @@ -55,6 +56,19 @@ class SCFLoop { int64_t step = 0; }; +class SCFWhileLoop { +public: + SCFWhileLoop(mlir::cir::WhileOp op, mlir::cir::WhileOp::Adaptor adaptor, + mlir::ConversionPatternRewriter *rewriter) + : whileOp(op), adaptor(adaptor), rewriter(rewriter) {} + void transferToSCFWhileOp(); + +private: + mlir::cir::WhileOp whileOp; + mlir::cir::WhileOp::Adaptor adaptor; + mlir::ConversionPatternRewriter *rewriter; +}; + static int64_t getConstant(mlir::cir::ConstantOp op) { auto attr = op->getAttrs().front().getValue(); const auto IntAttr = attr.dyn_cast(); @@ -233,6 +247,20 @@ void SCFLoop::transferToSCFForOp() { }); } +void SCFWhileLoop::transferToSCFWhileOp() { + auto scfWhileOp = rewriter->create( + whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands()); + rewriter->createBlock(&scfWhileOp.getBefore()); + rewriter->createBlock(&scfWhileOp.getAfter()); + + rewriter->cloneRegionBefore(whileOp.getCond(), + &scfWhileOp.getBefore().back()); + rewriter->eraseBlock(&scfWhileOp.getBefore().back()); + + rewriter->cloneRegionBefore(whileOp.getBody(), &scfWhileOp.getAfter().back()); + rewriter->eraseBlock(&scfWhileOp.getAfter().back()); +} + class CIRForOpLowering : public mlir::OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -248,9 +276,46 @@ class CIRForOpLowering : public mlir::OpConversionPattern { } }; +class CIRWhileOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::WhileOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + SCFWhileLoop loop(op, adaptor, &rewriter); + loop.transferToSCFWhileOp(); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class CIRConditionOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::cir::ConditionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto *parentOp = op->getParentOp(); + return llvm::TypeSwitch(parentOp) + .Case([&](auto) { + auto condition = adaptor.getCondition(); + auto i1Condition = rewriter.create( + op.getLoc(), rewriter.getI1Type(), condition); + rewriter.replaceOpWithNewOp( + op, i1Condition, parentOp->getOperands()); + return mlir::success(); + }) + .Default([](auto) { return mlir::failure(); }); + } +}; + void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } -} // namespace cir +} // namespace cir \ No newline at end of file diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index bf86f8c1c7e0..0f9ade9d6c57 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -31,6 +31,10 @@ #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -43,7 +47,9 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/LowerToMLIR.h" #include "clang/CIR/Passes.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace cir; @@ -558,7 +564,6 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern { return mlir::failure(); rewriter.eraseOp(op); - return mlir::LogicalResult::success(); } }; @@ -883,7 +888,6 @@ class CIRScopeOpLowering if (mlir::failed(getTypeConverter()->convertTypes(scopeOp->getResultTypes(), mlirResultTypes))) return mlir::LogicalResult::failure(); - rewriter.setInsertionPoint(scopeOp); auto newScopeOp = rewriter.create( scopeOp.getLoc(), mlirResultTypes); @@ -956,7 +960,7 @@ class CIRYieldOpLowering mlir::ConversionPatternRewriter &rewriter) const override { auto *parentOp = op->getParentOp(); return llvm::TypeSwitch(parentOp) - .Case([&](auto) { + .Case([&](auto) { rewriter.replaceOpWithNewOp( op, adaptor.getOperands()); return mlir::success(); diff --git a/clang/test/CIR/Lowering/ThroughMLIR/while.c b/clang/test/CIR/Lowering/ThroughMLIR/while.c new file mode 100644 index 000000000000..df459fd2c27a --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/while.c @@ -0,0 +1,35 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +void foo() { + int a = 0; + while(a < 2) { + a++; + } +} + +//CHECK: func.func @foo() { +//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref +//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 +//CHECK: memref.store %[[C0_I32]], %[[alloca]][] : memref +//CHECK: memref.alloca_scope { +//CHECK: scf.while : () -> () { +//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref +//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32 +//CHECK: %[[ONE:.+]] = arith.cmpi ult, %[[ZERO:.+]], %[[C2_I32]] : i32 +//CHECK: %[[TWO:.+]] = arith.extui %[[ONE:.+]] : i1 to i32 +//CHECK: %[[C0_I32_0:.+]] = arith.constant 0 : i32 +//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO:.+]], %[[C0_I32_0]] : i32 +//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE:.+]] : i1 to i8 +//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR:.+]] : i8 to i1 +//CHECK: scf.condition(%[[FIVE]]) +//CHECK: } do { +//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref +//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32 +//CHECK: %[[ONE:.+]] = arith.addi %0, %[[C1_I32:.+]] : i32 +//CHECK: memref.store %[[ONE:.+]], %[[alloca]][] : memref +//CHECK: scf.yield +//CHECK: } +//CHECK: } +//CHECK: return +//CHECK: } \ No newline at end of file