Skip to content

Commit

Permalink
[CIR] Add select operation
Browse files Browse the repository at this point in the history
This patch adds a new `cir.select` operation. This operation won't be generated
directly by CIRGen but it is useful during further CIR to CIR transformations.

This patch addresses #785 .
  • Loading branch information
Lancern committed Aug 18, 2024
1 parent e1fe8e1 commit 3582452
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 15 deletions.
18 changes: 18 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,24 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createBinop(lhs, mlir::cir::BinOpKind::Mul, val);
}

mlir::Value createSelect(mlir::Location loc, mlir::Value condition,
mlir::Value trueValue, mlir::Value falseValue) {
assert(trueValue.getType() == falseValue.getType() &&
"trueValue and falseValue should have the same type");
return create<mlir::cir::SelectOp>(loc, trueValue.getType(), condition,
trueValue, falseValue);
}

mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs,
mlir::Value rhs) {
return createSelect(loc, lhs, rhs, getBool(false, loc));
}

mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs,
mlir::Value rhs) {
return createSelect(loc, lhs, getBool(true, loc), rhs);
}

mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
mlir::Value imag) {
auto resultComplexTy =
Expand Down
40 changes: 40 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,46 @@ def TernaryOp : CIR_Op<"ternary",
}];
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

def SelectOp : CIR_Op<"select", [Pure]> {
let summary = "Yield one of two values based on a boolean value";
let description = [{
The `cir.select` operation takes three operands. The first operand
`condition` is a boolean value of type `!cir.bool`. The second and the third
operand can be of any CIR types, but their types must be the same. If the
first operand is `true`, the operation yields its second operand. Otherwise,
the operation yields its third operand.

Example:

```mlir
%0 = cir.const #cir.bool<true> : !cir.bool
%1 = cir.const #cir.int<42> : !s32i
%2 = cir.const #cir.int<72> : !s32i
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
```
}];

let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
CIR_AnyType:$false_value);
let results = (outs CIR_AnyType:$result);

let assemblyFormat = [{
`if` $condition `then` $true_value `else` $false_value
`:` `(`
qualified(type($condition)) `,`
qualified(type($true_value)) `,`
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,35 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

LogicalResult SelectOp::verify() {
auto trueValueTy = getTrueValue().getType();
auto falseValueTy = getFalseValue().getType();
auto resultTy = getType();

// true_value, false_value, and the result must have the same type.
if (trueValueTy != falseValueTy)
return emitOpError()
<< "true_value and false_value must have the same type";
if (trueValueTy != resultTy)
return emitOpError()
<< "true_value, false_value, and result must have the same type";

return mlir::success();
}

OpFoldResult SelectOp::fold(FoldAdaptor adaptor) {
auto condition = adaptor.getCondition();
if (!condition)
return nullptr;

auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
return conditionValue ? getTrueValue() : getFalseValue();
}

//===----------------------------------------------------------------------===//
// BrOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void CIRSimplifyPass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
// CastOp here is to perform a manual `fold` in
// applyOpPatternsAndFold
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
ops.push_back(op);
});
Expand Down
82 changes: 68 additions & 14 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2983,6 +2983,60 @@ class CIRRotateOpLowering
}
};

class CIRSelectOpLowering
: public mlir::OpConversionPattern<mlir::cir::SelectOp> {
public:
using OpConversionPattern<mlir::cir::SelectOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
auto definingOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
value.getDefiningOp());
if (!definingOp)
return std::nullopt;

auto constValue =
mlir::dyn_cast<mlir::cir::BoolAttr>(definingOp.getValue());
if (!constValue)
return std::nullopt;

return constValue.getValue();
};

// Two special cases in the LLVMIR codegen of select op:
// - select %0, %1, false => and %0, %1
// - select %0, true, %1 => or %0, %1
auto trueValue = op.getTrueValue();
auto falseValue = op.getFalseValue();
if (mlir::isa<mlir::cir::BoolType>(trueValue.getType())) {
if (std::optional<bool> falseValueBool = getConstantBool(falseValue);
falseValueBool.has_value() && !*falseValueBool) {
// select %0, %1, false => and %0, %1
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(
op, adaptor.getCondition(), adaptor.getTrueValue());
return mlir::success();
}
if (std::optional<bool> trueValueBool = getConstantBool(trueValue);
trueValueBool.has_value() && *trueValueBool) {
// select %0, true, %1 => or %0, %1
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(
op, adaptor.getCondition(), adaptor.getFalseValue());
return mlir::success();
}
}

auto llvmCondition = rewriter.create<mlir::LLVM::TruncOp>(
op.getLoc(), mlir::IntegerType::get(op->getContext(), 1),
adaptor.getCondition());
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());

return mlir::success();
}
};

class CIRBrOpLowering : public mlir::OpConversionPattern<mlir::cir::BrOp> {
public:
using OpConversionPattern<mlir::cir::BrOp>::OpConversionPattern;
Expand Down Expand Up @@ -3831,20 +3885,20 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
patterns.add<CIRReturnLowering>(patterns.getContext());
patterns.add<CIRAllocaLowering>(converter, dataLayout, patterns.getContext());
patterns.add<
CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering,
CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
CIRBitPopcountOpLowering, CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering,
CIRAtomicFetchLowering, CIRByteswapOpLowering, CIRRotateOpLowering,
CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering,
CIRTryCallLowering, CIREhInflightOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRBinOpOverflowOpLowering, CIRShiftOpLowering,
CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRFuncLowering,
CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
CIRComplexCreateOpLowering, CIRComplexRealOpLowering,
CIRComplexImagOpLowering, CIRComplexRealPtrOpLowering,
CIRComplexImagPtrOpLowering, CIRVAStartLowering, CIRVAEndLowering,
CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering,
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering,
CIRBitParityOpLowering, CIRBitPopcountOpLowering,
CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering,
CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering,
CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering,
CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering,
CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
CIRComplexRealOpLowering, CIRComplexImagOpLowering,
CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering,
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
Expand Down
50 changes: 50 additions & 0 deletions clang/test/CIR/Lowering/select.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: cir-translate -cir-to-llvmir -o %t.ll %s
// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s

!s32i = !cir.int<s, 32>

module {
cir.func @select_int(%arg0 : !cir.bool, %arg1 : !s32i, %arg2 : !s32i) -> !s32i {
%0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !s32i, !s32i) -> !s32i
cir.return %0 : !s32i
}

// LLVM: define i32 @select_int(i8 %[[#COND:]], i32 %[[#TV:]], i32 %[[#FV:]])
// LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1
// LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i32 %[[#TV]], i32 %[[#FV]]
// LLVM-NEXT: ret i32 %[[#RES]]
// LLVM-NEXT: }

cir.func @select_bool(%arg0 : !cir.bool, %arg1 : !cir.bool, %arg2 : !cir.bool) -> !cir.bool {
%0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %0 : !cir.bool
}

// LLVM: define i8 @select_bool(i8 %[[#COND:]], i8 %[[#TV:]], i8 %[[#FV:]])
// LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1
// LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i8 %[[#TV]], i8 %[[#FV]]
// LLVM-NEXT: ret i8 %[[#RES]]
// LLVM-NEXT: }

cir.func @logical_and(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool {
%0 = cir.const #cir.bool<false> : !cir.bool
%1 = cir.select if %arg0 then %arg1 else %0 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %1 : !cir.bool
}

// LLVM: define i8 @logical_and(i8 %[[#ARG0:]], i8 %[[#ARG1:]])
// LLVM-NEXT: %[[#RES:]] = and i8 %[[#ARG0]], %[[#ARG1]]
// LLVM-NEXT: ret i8 %[[#RES]]
// LLVM-NEXT: }

cir.func @logical_or(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool {
%0 = cir.const #cir.bool<true> : !cir.bool
%1 = cir.select if %arg0 then %0 else %arg1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %1 : !cir.bool
}

// LLVM: define i8 @logical_or(i8 %[[#ARG0:]], i8 %[[#ARG1:]])
// LLVM-NEXT: %[[#RES:]] = or i8 %[[#ARG0]], %[[#ARG1]]
// LLVM-NEXT: ret i8 %[[#RES]]
// LLVM-NEXT: }
}
26 changes: 26 additions & 0 deletions clang/test/CIR/Transforms/select.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: cir-opt --canonicalize -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s

!s32i = !cir.int<s, 32>

module {
cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
%0 = cir.const #cir.bool<true> : !cir.bool
%1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: cir.return %[[ARG0]] : !s32i
// CHECK-NEXT: }

cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
%0 = cir.const #cir.bool<false> : !cir.bool
%1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: cir.return %[[ARG1]] : !s32i
// CHECK-NEXT: }
}

0 comments on commit 3582452

Please sign in to comment.