From f0a77916f631badf52f69dee76dc752e836d4391 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Thu, 22 Aug 2024 00:19:02 +0800 Subject: [PATCH] [CIR] Add select operation (#796) This PR adds a new `cir.select` operation. This operation won't be generated directly by CIRGen but it is useful during further CIR to CIR transformations. This PR addresses #785 . --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 18 ++++ clang/include/clang/CIR/Dialect/IR/CIROps.td | 40 +++++++++ clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 13 +++ .../CIR/Dialect/Transforms/CIRSimplify.cpp | 2 +- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 82 +++++++++++++++---- clang/test/CIR/Lowering/select.cir | 50 +++++++++++ clang/test/CIR/Transforms/select.cir | 26 ++++++ 7 files changed, 216 insertions(+), 15 deletions(-) create mode 100644 clang/test/CIR/Lowering/select.cir create mode 100644 clang/test/CIR/Transforms/select.cir diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index d2fc105f502d..a458547d330d 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -274,6 +274,24 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { return createBinop(lhs, mlir::cir::BinOpKind::Mul, val); } + mlir::Value createSelect(mlir::Location loc, mlir::Value condition, + mlir::Value trueValue, mlir::Value falseValue) { + assert(trueValue.getType() == falseValue.getType() && + "trueValue and falseValue should have the same type"); + return create(loc, trueValue.getType(), condition, + trueValue, falseValue); + } + + mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, + mlir::Value rhs) { + return createSelect(loc, lhs, rhs, getBool(false, loc)); + } + + mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, + mlir::Value rhs) { + return createSelect(loc, lhs, getBool(true, loc), rhs); + } + mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag) { auto resultComplexTy = diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 551e12f318f1..c508ec6412b6 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -761,6 +761,46 @@ def TernaryOp : CIR_Op<"ternary", }]; } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +def SelectOp : CIR_Op<"select", [Pure, + AllTypesMatch<["true_value", "false_value", "result"]>]> { + let summary = "Yield one of two values based on a boolean value"; + let description = [{ + The `cir.select` operation takes three operands. The first operand + `condition` is a boolean value of type `!cir.bool`. The second and the third + operand can be of any CIR types, but their types must be the same. If the + first operand is `true`, the operation yields its second operand. Otherwise, + the operation yields its third operand. + + Example: + + ```mlir + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.const #cir.int<42> : !s32i + %2 = cir.const #cir.int<72> : !s32i + %3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i + ``` + }]; + + let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value, + CIR_AnyType:$false_value); + let results = (outs CIR_AnyType:$result); + + let assemblyFormat = [{ + `if` $condition `then` $true_value `else` $false_value + `:` `(` + qualified(type($condition)) `,` + qualified(type($true_value)) `,` + qualified(type($false_value)) + `)` `->` qualified(type($result)) attr-dict + }]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // ConditionOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index bf0ef3274c6a..ff2f2268a2dd 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1382,6 +1382,19 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond, result.addTypes(TypeRange{yield.getOperandTypes().front()}); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +OpFoldResult SelectOp::fold(FoldAdaptor adaptor) { + auto condition = adaptor.getCondition(); + if (!condition) + return nullptr; + + auto conditionValue = mlir::cast(condition).getValue(); + return conditionValue ? getTrueValue() : getFalseValue(); +} + //===----------------------------------------------------------------------===// // BrOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index e88c2ce9a04e..1eea92026134 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -146,7 +146,7 @@ void CIRSimplifyPass::runOnOperation() { getOperation()->walk([&](Operation *op) { // CastOp here is to perform a manual `fold` in // applyOpPatternsAndFold - if (isa(op)) ops.push_back(op); }); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 361aeebc6758..a27e7c0414c0 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2987,6 +2987,60 @@ class CIRRotateOpLowering } }; +class CIRSelectOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::SelectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto getConstantBool = [](mlir::Value value) -> std::optional { + auto definingOp = mlir::dyn_cast_if_present( + value.getDefiningOp()); + if (!definingOp) + return std::nullopt; + + auto constValue = + mlir::dyn_cast(definingOp.getValue()); + if (!constValue) + return std::nullopt; + + return constValue.getValue(); + }; + + // Two special cases in the LLVMIR codegen of select op: + // - select %0, %1, false => and %0, %1 + // - select %0, true, %1 => or %0, %1 + auto trueValue = op.getTrueValue(); + auto falseValue = op.getFalseValue(); + if (mlir::isa(trueValue.getType())) { + if (std::optional falseValueBool = getConstantBool(falseValue); + falseValueBool.has_value() && !*falseValueBool) { + // select %0, %1, false => and %0, %1 + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), adaptor.getTrueValue()); + return mlir::success(); + } + if (std::optional trueValueBool = getConstantBool(trueValue); + trueValueBool.has_value() && *trueValueBool) { + // select %0, true, %1 => or %0, %1 + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), adaptor.getFalseValue()); + return mlir::success(); + } + } + + auto llvmCondition = rewriter.create( + op.getLoc(), mlir::IntegerType::get(op->getContext(), 1), + adaptor.getCondition()); + rewriter.replaceOpWithNewOp( + op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); + + return mlir::success(); + } +}; + class CIRBrOpLowering : public mlir::OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -3835,20 +3889,20 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, patterns.add(patterns.getContext()); patterns.add(converter, dataLayout, patterns.getContext()); patterns.add< - CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering, - CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering, - CIRBitPopcountOpLowering, CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, - CIRAtomicFetchLowering, CIRByteswapOpLowering, CIRRotateOpLowering, - CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering, - CIRTryCallLowering, CIREhInflightOpLowering, CIRUnaryOpLowering, - CIRBinOpLowering, CIRBinOpOverflowOpLowering, CIRShiftOpLowering, - CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, - CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering, - CIRComplexCreateOpLowering, CIRComplexRealOpLowering, - CIRComplexImagOpLowering, CIRComplexRealPtrOpLowering, - CIRComplexImagPtrOpLowering, CIRVAStartLowering, CIRVAEndLowering, - CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering, - CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering, + CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering, + CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering, + CIRBitParityOpLowering, CIRBitPopcountOpLowering, + CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering, + CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering, + CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering, + CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering, + CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering, + CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering, + CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering, + CIRComplexRealOpLowering, CIRComplexImagOpLowering, + CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering, + CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, + CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering, CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering, CIRVTableAddrPointOpLowering, CIRVectorCreateLowering, diff --git a/clang/test/CIR/Lowering/select.cir b/clang/test/CIR/Lowering/select.cir new file mode 100644 index 000000000000..1836210d6a7c --- /dev/null +++ b/clang/test/CIR/Lowering/select.cir @@ -0,0 +1,50 @@ +// RUN: cir-translate -cir-to-llvmir -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s + +!s32i = !cir.int + +module { + cir.func @select_int(%arg0 : !cir.bool, %arg1 : !s32i, %arg2 : !s32i) -> !s32i { + %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %0 : !s32i + } + + // LLVM: define i32 @select_int(i8 %[[#COND:]], i32 %[[#TV:]], i32 %[[#FV:]]) + // LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1 + // LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i32 %[[#TV]], i32 %[[#FV]] + // LLVM-NEXT: ret i32 %[[#RES]] + // LLVM-NEXT: } + + cir.func @select_bool(%arg0 : !cir.bool, %arg1 : !cir.bool, %arg2 : !cir.bool) -> !cir.bool { + %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %0 : !cir.bool + } + + // LLVM: define i8 @select_bool(i8 %[[#COND:]], i8 %[[#TV:]], i8 %[[#FV:]]) + // LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1 + // LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i8 %[[#TV]], i8 %[[#FV]] + // LLVM-NEXT: ret i8 %[[#RES]] + // LLVM-NEXT: } + + cir.func @logical_and(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.select if %arg0 then %arg1 else %0 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %1 : !cir.bool + } + + // LLVM: define i8 @logical_and(i8 %[[#ARG0:]], i8 %[[#ARG1:]]) + // LLVM-NEXT: %[[#RES:]] = and i8 %[[#ARG0]], %[[#ARG1]] + // LLVM-NEXT: ret i8 %[[#RES]] + // LLVM-NEXT: } + + cir.func @logical_or(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.select if %arg0 then %0 else %arg1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %1 : !cir.bool + } + + // LLVM: define i8 @logical_or(i8 %[[#ARG0:]], i8 %[[#ARG1:]]) + // LLVM-NEXT: %[[#RES:]] = or i8 %[[#ARG0]], %[[#ARG1]] + // LLVM-NEXT: ret i8 %[[#RES]] + // LLVM-NEXT: } +} diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir new file mode 100644 index 000000000000..c3db14daaf4e --- /dev/null +++ b/clang/test/CIR/Transforms/select.cir @@ -0,0 +1,26 @@ +// RUN: cir-opt --canonicalize -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s + +!s32i = !cir.int + +module { + cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: cir.return %[[ARG0]] : !s32i + // CHECK-NEXT: } + + cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: cir.return %[[ARG1]] : !s32i + // CHECK-NEXT: } +}