Skip to content

Commit

Permalink
[CIR][Transform] Add ternary simplification
Browse files Browse the repository at this point in the history
This patch adds a new transformation that transform suitable ternary operations
into select operations.
  • Loading branch information
Lancern committed Aug 28, 2024
1 parent d3af20c commit 92339bf
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 3 deletions.
98 changes: 95 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace cir;
Expand Down Expand Up @@ -107,6 +111,92 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
}
};

/// Simplify suitable ternary operations into select operations.
///
/// Only those ternary operations that meet the following criteria can be
/// simplified:
/// - The true branch and the false branch cannot have any side effects;
/// - The true branch and the false branch cannot be "too costly" since both of
/// them will be executed after the folding happens.
///
/// For now we only simplify those ternary operations whose true and false
/// branches either directly yield a value or directly yield a constant. That
/// is, both of the two branches of these ternary operation must either:
/// - Only contain a single cir.yield operation, or
/// - Contain a cir.const operation followed by a cir.yield operation that
/// yields the constant value produced by the cir.const operation.
///
/// For example, we will simplify the following ternary operation:
///
/// %0 = cir.ternary (%condition, true {
/// %1 = cir.const ...
/// cir.yield %1
/// } false {
/// cir.yield %2
/// })
///
/// into the following sequence of operations:
///
/// %1 = cir.const ...
/// %0 = cir.select if %condition then %1 else %2
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
using OpRewritePattern<TernaryOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TernaryOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Operation *> opsToHoist;

mlir::Value trueValue =
simplifyTernaryBranch(op.getTrueRegion(), opsToHoist);
if (!trueValue)
return mlir::failure();

mlir::Value falseValue =
simplifyTernaryBranch(op.getFalseRegion(), opsToHoist);
if (!falseValue)
return mlir::failure();

for (auto *hoistOp : opsToHoist)
rewriter.moveOpBefore(hoistOp, op);
rewriter.replaceOpWithNewOp<mlir::cir::SelectOp>(op, op.getCond(),
trueValue, falseValue);

return mlir::success();
}

private:
mlir::Value simplifyTernaryBranch(
mlir::Region &region,
llvm::SmallVector<mlir::Operation *> &opsToHoist) const {
if (!region.hasOneBlock())
return nullptr;

mlir::Block &block = region.front();

// The block can contain at most 2 operations: one cir.const operation
// followed by one cir.yield operation
if (block.getOperations().size() > 2)
return nullptr;

auto yieldOp = mlir::cast<mlir::cir::YieldOp>(block.getTerminator());
auto yieldValue = yieldOp.getArgs()[0];
if (block.getOperations().size() == 1)
return yieldValue;

// The yielded value must be produced by a cir.const operation in the same
// block to make the branch simplifiable.
auto yieldValueDef = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
yieldValue.getDefiningOp());
if (!yieldValueDef)
return nullptr;
if (yieldValueDef->getBlock() != &block)
return nullptr;

opsToHoist.push_back(yieldValueDef);
return yieldValue;
}
};

//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
Expand All @@ -131,7 +221,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
RemoveRedundantBranches,
RemoveEmptyScope,
RemoveEmptySwitch,
RemoveTrivialTry
RemoveTrivialTry,
SimplifyTernary
>(patterns.getContext());
// clang-format on
}
Expand All @@ -146,8 +237,9 @@ void CIRSimplifyPass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
// CastOp here is to perform a manual `fold` in
// applyOpPatternsAndFold
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
TernaryOp, SelectOp, ComplexCreateOp, ComplexRealOp, ComplexImagOp>(
op))
ops.push_back(op);
});

Expand Down
59 changes: 59 additions & 0 deletions clang/test/CIR/Transforms/ternary-fold.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: cir-opt -cir-simplify -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s

!s32i = !cir.int<s, 32>

module {
cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
%0 = cir.const #cir.bool<false> : !cir.bool
%1 = cir.ternary (%0, true {
cir.yield %arg0 : !s32i
}, false {
cir.yield %arg1 : !s32i
}) : (!cir.bool) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: cir.return %[[ARG]] : !s32i
// CHECK-NEXT: }

cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
%0 = cir.ternary (%arg0, true {
%1 = cir.const #cir.int<42> : !s32i
cir.yield %1 : !s32i
}, false {
cir.yield %arg1 : !s32i
}) : (!cir.bool) -> !s32i
cir.return %0 : !s32i
}

// CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
// CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
// CHECK-NEXT: cir.return %[[#B]] : !s32i
// CHECK-NEXT: }

cir.func @non_simplifiable_ternary(%arg0 : !cir.bool, %arg1 : !cir.ptr<!s32i>) -> !s32i {
// Not simplifiable, should keep as-is.
%0 = cir.ternary (%arg0, true {
%1 = cir.load %arg1 : !cir.ptr<!s32i>, !s32i
cir.yield %1 : !s32i
}, false {
%2 = cir.const #cir.int<42> : !s32i
cir.yield %2 : !s32i
}) : (!cir.bool) -> !s32i
cir.return %0 : !s32i
}

// CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !cir.ptr<!s32i>) -> !s32i {
// CHECK-NEXT: %[[#A:]] = cir.ternary(%[[ARG0]], true {
// CHECK-NEXT: %[[#B:]] = cir.load %[[ARG1]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: cir.yield %[[#B]] : !s32i
// CHECK-NEXT: }, false {
// CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i
// CHECK-NEXT: cir.yield %[[#C]] : !s32i
// CHECK-NEXT: }) : (!cir.bool) -> !s32i
// CHECK-NEXT: cir.return %[[#A]] : !s32i
// CHECK-NEXT: }
}
35 changes: 35 additions & 0 deletions clang/test/CIR/Transforms/ternary-fold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -fclangir-mem2reg -mmlir --mlir-print-ir-before=cir-simplify %s -o %t1.cir 2>&1 | FileCheck -check-prefix=CIR-BEFORE %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -fclangir-mem2reg -mmlir --mlir-print-ir-after=cir-simplify %s -o %t2.cir 2>&1 | FileCheck -check-prefix=CIR-AFTER %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s

int test1(bool x) {
return x ? 1 : 2;
}

// CIR-BEFORE: cir.func @_Z5test1b
// CIR-BEFORE: %{{.+}} = cir.ternary(%{{.+}}, true {
// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i
// CIR-BEFORE-NEXT: }, false {
// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i
// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i
// CIR-BEFORE: }

// CIR-AFTER: cir.func @_Z5test1b
// CIR-AFTER: %[[#A:]] = cir.const #cir.int<1> : !s32i
// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
// CIR-AFTER-NEXT: %{{.+}} = cir.select if %{{.+}} then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i
// CIR-AFTER: }

// LLVM: define dso_local i32 @_Z5test1b
// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2
// LLVM: }

// The following test does not work yet because mem2reg does not happen before
// ternary simplify.

// int test2(bool x, int a, int b) {
// return x ? a : b;
// }

0 comments on commit 92339bf

Please sign in to comment.