Skip to content

Commit

Permalink
[CIR] Lower certain cir.cmp3way operations to corresponding LLVM intr…
Browse files Browse the repository at this point in the history
…insics

LLVM recently added two families of intrinsics named `llvm.scmp.*` and
`llvm.ucmp.*` that generate potentially better code for three-way comparison
operations. This patch lowers certain `cir.cmp3way` operations to these
intrinsics.

Not all `cir.cmp3way` operations can be lowered to these intrinsics. The
qualifying conditions are: 1) the comparison is between two integers, and 2) the
comparison produces a strong order. `cir.cmp3way` operations that are not
qualified are not affected by this patch.

Qualifying `cir.cmp3way` operations may still need some canonicalization work
before lowering. The "canonicalized" form of a qualifying three-way comparison
operation yields -1 for lt, 0 for eq, and 1 for gt. This patch converts those
non-canonicalized but qualifying `cir.cmp3way` operations to their canonical
forms in the LLVM lowering prepare pass.
  • Loading branch information
Lancern committed Apr 20, 2024
1 parent ae3e0ad commit 243caeb
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 33 deletions.
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::cir::UnaryOpKind::Not, value);
}

mlir::cir::CmpOp createCompare(mlir::Location loc, mlir::cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
return create<mlir::cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::Value createBinop(mlir::Value lhs, mlir::cir::BinOpKind kind,
const llvm::APInt &rhs) {
return create<mlir::cir::BinOp>(
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,18 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
}];

let hasVerifier = 0;

let extraClassDeclaration = [{
/// Determine whether this three-way comparison produces a strong ordering.
bool isStrongOrdering() {
return getInfo().getOrdering() == mlir::cir::CmpOrdering::Strong;
}

/// Determine whether this three-way comparison compares integral operands.
bool isIntegralComparison() {
return getLhs().getType().isa<mlir::cir::IntType>();
}
}];
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 0 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::ContinueOp>(loc);
}

mlir::cir::CmpOp createCompare(mlir::Location loc, mlir::cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
return create<mlir::cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down
72 changes: 72 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,82 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
return f;
}

static void canonicalizeIntrinsicThreeWayCmp(CIRBaseBuilderTy &builder,
CmpThreeWayOp op) {
auto loc = op->getLoc();
auto cmpInfo = op.getInfo();

if (cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 && cmpInfo.getGt() == 1) {
// The comparison is already in canonicalized form.
return;
}

auto canonicalizedCmpInfo =
mlir::cir::CmpThreeWayInfoAttr::get(builder.getContext(), -1, 0, 1);
mlir::Value result =
builder
.create<mlir::cir::CmpThreeWayOp>(loc, op.getType(), op.getLhs(),
op.getRhs(), canonicalizedCmpInfo)
.getResult();

auto compareAndYield = [&](mlir::Value input, int64_t test,
int64_t yield) -> mlir::Value {
// Create a conditional branch that tests whether `input` is equal to
// `test`. If `input` is equal to `test`, yield `yield`. Otherwise, yield
// `input` as is.
auto testValue = builder.getConstant(
loc, mlir::cir::IntAttr::get(input.getType(), test));
auto yieldValue = builder.getConstant(
loc, mlir::cir::IntAttr::get(input.getType(), yield));
auto eqToTest =
builder.createCompare(loc, mlir::cir::CmpOpKind::eq, input, testValue);
return builder
.create<mlir::cir::TernaryOp>(
loc, eqToTest,
[&](OpBuilder &, Location) {
builder.create<mlir::cir::YieldOp>(loc,
mlir::ValueRange{yieldValue});
},
[&](OpBuilder &, Location) {
builder.create<mlir::cir::YieldOp>(loc, mlir::ValueRange{input});
})
->getResult(0);
};

if (cmpInfo.getLt() != -1)
result = compareAndYield(result, -1, cmpInfo.getLt());

if (cmpInfo.getEq() != 0)
result = compareAndYield(result, 0, cmpInfo.getEq());

if (cmpInfo.getGt() != 1)
result = compareAndYield(result, 1, cmpInfo.getGt());

op.replaceAllUsesWith(result);
op.erase();
}

void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);

if (op.isIntegralComparison() && op.isStrongOrdering()) {
// For three-way comparisons on integral operands that produce strong
// ordering, we can generate potentially better code with the `llvm.scmp.*`
// and `llvm.ucmp.*` intrinsics. Thus we don't replace these comparisons
// here. They will be lowered directly to LLVMIR during the LLVM lowering
// pass.
//
// But we still need to take a step here. `llvm.scmp.*` and `llvm.ucmp.*`
// returns -1, 0, or 1 to represent lt, eq, and gt, which are the
// "canonicalized" result values of three-way comparisons. However,
// `cir.cmp3way` may not produce canonicalized result. We need to
// canonicalize the comparison if necessary. This is what we're doing in
// this special branch.
canonicalizeIntrinsicThreeWayCmp(builder, op);
return;
}

auto loc = op->getLoc();
auto cmpInfo = op.getInfo();

Expand Down
86 changes: 77 additions & 9 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -2057,6 +2058,16 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
}
};

static mlir::LLVM::CallIntrinsicOp
createCallLLVMIntrinsicOp(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, const llvm::Twine &intrinsicName,
mlir::Type resultTy, mlir::ValueRange operands) {
auto intrinsicNameAttr =
mlir::StringAttr::get(rewriter.getContext(), intrinsicName);
return rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, resultTy, intrinsicNameAttr, operands);
}

static mlir::Value createLLVMBitOp(mlir::Location loc,
const llvm::Twine &llvmIntrinBaseName,
mlir::Type resultTy, mlir::Value operand,
Expand All @@ -2069,21 +2080,19 @@ static mlir::Value createLLVMBitOp(mlir::Location loc,
llvmIntrinBaseName.concat(".i")
.concat(std::to_string(operandIntTy.getWidth()))
.str();
auto llvmIntrinNameAttr =
mlir::StringAttr::get(rewriter.getContext(), llvmIntrinName);

// Note that LLVM intrinsic calls to bit intrinsics have the same type as the
// operand.
mlir::LLVM::CallIntrinsicOp op;
if (poisonZeroInputFlag.has_value()) {
auto poisonZeroInputValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, rewriter.getI1Type(), static_cast<int64_t>(*poisonZeroInputFlag));
op = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, operand.getType(), llvmIntrinNameAttr,
mlir::ValueRange{operand, poisonZeroInputValue});
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(),
{operand, poisonZeroInputValue});
} else {
op = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, operand.getType(), llvmIntrinNameAttr, operand);
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(), operand);
}

mlir::Value result = op->getResult(0);
Expand Down Expand Up @@ -2902,6 +2911,65 @@ class CIRIsConstantOpLowering
}
};

class CIRCmpThreeWayOpLowering
: public mlir::OpConversionPattern<mlir::cir::CmpThreeWayOp> {
public:
using mlir::OpConversionPattern<
mlir::cir::CmpThreeWayOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::CmpThreeWayOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(op.isIntegralComparison() && op.isStrongOrdering());

auto cmpInfo = op.getInfo();
assert(cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 &&
cmpInfo.getGt() == 1);

auto operandTy = op.getLhs().getType().cast<mlir::cir::IntType>();
auto resultTy = op.getType();
auto llvmIntrinsicName = getLLVMIntrinsicName(
operandTy.isSigned(), operandTy.getWidth(), resultTy.getWidth());

rewriter.setInsertionPoint(op);

auto llvmLhs = adaptor.getLhs();
auto llvmRhs = adaptor.getRhs();
auto llvmResultTy = getTypeConverter()->convertType(resultTy);
auto callIntrinsicOp =
createCallLLVMIntrinsicOp(rewriter, op.getLoc(), llvmIntrinsicName,
llvmResultTy, {llvmLhs, llvmRhs});

rewriter.replaceOp(op, callIntrinsicOp);
return mlir::success();
}

private:
static std::string getLLVMIntrinsicName(bool signedCmp, unsigned operandWidth,
unsigned resultWidth) {
// The intrinsic's name takes the form:
// `llvm.<scmp|ucmp>.i<resultWidth>.i<operandWidth>`

std::string result = "llvm.";

if (signedCmp)
result.append("scmp.");
else
result.append("ucmp.");

// Result type part.
result.push_back('i');
result.append(std::to_string(resultWidth));
result.push_back('.');

// Operand type part.
result.push_back('i');
result.append(std::to_string(operandWidth));

return result;
}
};

void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering>(patterns.getContext());
Expand All @@ -2923,8 +2991,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRVectorShuffleVecLowering, CIRStackSaveLowering,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering>(
converter, patterns.getContext());
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
110 changes: 110 additions & 0 deletions clang/test/CIR/CodeGen/Inputs/std-compare-noncanonical.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#ifndef STD_COMPARE_H
#define STD_COMPARE_H

namespace std {
inline namespace __1 {

// exposition only
enum class _EqResult : unsigned char {
__equal = 2,
__equiv = __equal,
};

enum class _OrdResult : signed char {
__less = 1,
__greater = 3
};

struct _CmpUnspecifiedType;
using _CmpUnspecifiedParam = void (_CmpUnspecifiedType::*)();

class strong_ordering {
using _ValueT = signed char;
explicit constexpr strong_ordering(_EqResult __v) noexcept : __value_(static_cast<signed char>(__v)) {}
explicit constexpr strong_ordering(_OrdResult __v) noexcept : __value_(static_cast<signed char>(__v)) {}

public:
static const strong_ordering less;
static const strong_ordering equal;
static const strong_ordering equivalent;
static const strong_ordering greater;

// comparisons
friend constexpr bool operator==(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr bool operator!=(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr bool operator<(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr bool operator<=(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr bool operator>(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr bool operator>=(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr bool operator==(_CmpUnspecifiedParam, strong_ordering __v) noexcept;
friend constexpr bool operator!=(_CmpUnspecifiedParam, strong_ordering __v) noexcept;
friend constexpr bool operator<(_CmpUnspecifiedParam, strong_ordering __v) noexcept;
friend constexpr bool operator<=(_CmpUnspecifiedParam, strong_ordering __v) noexcept;
friend constexpr bool operator>(_CmpUnspecifiedParam, strong_ordering __v) noexcept;
friend constexpr bool operator>=(_CmpUnspecifiedParam, strong_ordering __v) noexcept;

friend constexpr strong_ordering operator<=>(strong_ordering __v, _CmpUnspecifiedParam) noexcept;
friend constexpr strong_ordering operator<=>(_CmpUnspecifiedParam, strong_ordering __v) noexcept;

// test helper
constexpr bool test_eq(strong_ordering const &other) const noexcept {
return __value_ == other.__value_;
}

private:
_ValueT __value_;
};

inline constexpr strong_ordering strong_ordering::less(_OrdResult::__less);
inline constexpr strong_ordering strong_ordering::equal(_EqResult::__equal);
inline constexpr strong_ordering strong_ordering::equivalent(_EqResult::__equiv);
inline constexpr strong_ordering strong_ordering::greater(_OrdResult::__greater);

constexpr bool operator==(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v.__value_ == 0;
}
constexpr bool operator!=(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v.__value_ != 0;
}
constexpr bool operator<(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v.__value_ < 0;
}
constexpr bool operator<=(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v.__value_ <= 0;
}
constexpr bool operator>(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v.__value_ > 0;
}
constexpr bool operator>=(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v.__value_ >= 0;
}
constexpr bool operator==(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return 0 == __v.__value_;
}
constexpr bool operator!=(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return 0 != __v.__value_;
}
constexpr bool operator<(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return 0 < __v.__value_;
}
constexpr bool operator<=(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return 0 <= __v.__value_;
}
constexpr bool operator>(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return 0 > __v.__value_;
}
constexpr bool operator>=(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return 0 >= __v.__value_;
}

constexpr strong_ordering operator<=>(strong_ordering __v, _CmpUnspecifiedParam) noexcept {
return __v;
}
constexpr strong_ordering operator<=>(_CmpUnspecifiedParam, strong_ordering __v) noexcept {
return __v < 0 ? strong_ordering::greater : (__v > 0 ? strong_ordering::less : __v);
}

} // namespace __1
} // end namespace std

#endif // STD_COMPARE_H
Loading

0 comments on commit 243caeb

Please sign in to comment.