Skip to content

Commit

Permalink
[CIR][Fix] FP builtins should lower directly to LLVM builtins (#670)
Browse files Browse the repository at this point in the history
LLVM lowering for the following operations is introduced in #616 and
#651: `cos`, `exp`, `exp2`, `log`, `log10`, `log2`, `sin`, `sqrt`,
`fmod`, and `pow`. However, they are not lowered to their corresponding
LLVM intrinsics; instead they are transformed to libc calls during
lowering prepare. This does not match the upstream behavior.

This PR tries to correct this mistake. It makes all CIR FP intrinsic ops
lower to their corresponding LLVM intrinsics (`fmod` is a special case
and it is lowered to the `frem` LLVM instruction).
  • Loading branch information
Lancern authored and lanza committed Nov 3, 2024
1 parent 01548c1 commit b6a635d
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 137 deletions.
20 changes: 20 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_cosf16:
case Builtin::BI__builtin_cosl:
case Builtin::BI__builtin_cosf128:
assert(getContext().getLangOpts().FastMath &&
"cir.cos is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::CosOp>(*this, *E);

case Builtin::BIexp:
Expand All @@ -458,6 +460,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_expf16:
case Builtin::BI__builtin_expl:
case Builtin::BI__builtin_expf128:
assert(getContext().getLangOpts().FastMath &&
"cir.exp is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::ExpOp>(*this, *E);

case Builtin::BIexp2:
Expand All @@ -468,6 +472,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_exp2f16:
case Builtin::BI__builtin_exp2l:
case Builtin::BI__builtin_exp2f128:
assert(getContext().getLangOpts().FastMath &&
"cir.exp2 is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::Exp2Op>(*this, *E);

case Builtin::BIfabs:
Expand Down Expand Up @@ -534,6 +540,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_fmod:
case Builtin::BI__builtin_fmodf:
case Builtin::BI__builtin_fmodl:
assert(getContext().getLangOpts().FastMath &&
"cir.fmod is only expected under -ffast-math");
return buildBinaryFPBuiltin<mlir::cir::FModOp>(*this, *E);

case Builtin::BI__builtin_fmodf16:
Expand All @@ -548,6 +556,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_logf16:
case Builtin::BI__builtin_logl:
case Builtin::BI__builtin_logf128:
assert(getContext().getLangOpts().FastMath &&
"cir.log is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::LogOp>(*this, *E);

case Builtin::BIlog10:
Expand All @@ -558,6 +568,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_log10f16:
case Builtin::BI__builtin_log10l:
case Builtin::BI__builtin_log10f128:
assert(getContext().getLangOpts().FastMath &&
"cir.log10 is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::Log10Op>(*this, *E);

case Builtin::BIlog2:
Expand All @@ -568,6 +580,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_log2f16:
case Builtin::BI__builtin_log2l:
case Builtin::BI__builtin_log2f128:
assert(getContext().getLangOpts().FastMath &&
"cir.log2 is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::Log2Op>(*this, *E);

case Builtin::BInearbyint:
Expand All @@ -585,6 +599,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_pow:
case Builtin::BI__builtin_powf:
case Builtin::BI__builtin_powl:
assert(getContext().getLangOpts().FastMath &&
"cir.pow is only expected under -ffast-math");
return RValue::get(
buildBinaryMaybeConstrainedFPBuiltin<mlir::cir::PowOp>(*this, *E));

Expand Down Expand Up @@ -620,6 +636,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_sinf16:
case Builtin::BI__builtin_sinl:
case Builtin::BI__builtin_sinf128:
assert(getContext().getLangOpts().FastMath &&
"cir.sin is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::SinOp>(*this, *E);

case Builtin::BIsqrt:
Expand All @@ -630,6 +648,8 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_sqrtf16:
case Builtin::BI__builtin_sqrtl:
case Builtin::BI__builtin_sqrtf128:
assert(getContext().getLangOpts().FastMath &&
"cir.sqrt is only expected under -ffast-math");
return buildUnaryFPBuiltin<mlir::cir::SqrtOp>(*this, *E);

case Builtin::BItrunc:
Expand Down
71 changes: 0 additions & 71 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Path.h"

Expand Down Expand Up @@ -71,7 +70,6 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void runOnOperation() override;

void runOnOp(Operation *op);
void runOnMathOp(Operation *op);
void lowerThreeWayCmpOp(CmpThreeWayOp op);
void lowerVAArgOp(VAArgOp op);
void lowerGlobalOp(GlobalOp op);
Expand Down Expand Up @@ -650,69 +648,6 @@ void LoweringPreparePass::runOnOp(Operation *op) {
}
}

void LoweringPreparePass::runOnMathOp(Operation *op) {
struct MathOpFunctionNames {
llvm::StringRef floatVer;
llvm::StringRef doubleVer;
llvm::StringRef longDoubleVer;
};

mlir::Type ty = op->getResult(0).getType();

MathOpFunctionNames rtFuncNames =
llvm::TypeSwitch<Operation *, MathOpFunctionNames>(op)
.Case<FModOp>([](auto) {
return MathOpFunctionNames{"fmodf", "fmod", "fmodl"};
})
.Case<PowOp>(
[](auto) { return MathOpFunctionNames{"powf", "pow", "powl"}; })
.Case<CosOp>(
[](auto) { return MathOpFunctionNames{"cosf", "cos", "cosl"}; })
.Case<ExpOp>(
[](auto) { return MathOpFunctionNames{"expf", "exp", "expl"}; })
.Case<Exp2Op>([](auto) {
return MathOpFunctionNames{"exp2f", "exp2", "exp2l"};
})
.Case<LogOp>(
[](auto) { return MathOpFunctionNames{"logf", "log", "logl"}; })
.Case<Log10Op>([](auto) {
return MathOpFunctionNames{"log10f", "log10", "log10l"};
})
.Case<Log2Op>([](auto) {
return MathOpFunctionNames{"log2f", "log2", "log2l"};
})
.Case<SinOp>(
[](auto) { return MathOpFunctionNames{"sinf", "sin", "sinl"}; })
.Case<SqrtOp>([](auto) {
return MathOpFunctionNames{"sqrtf", "sqrt", "sqrtl"};
});
llvm::StringRef rtFuncName = llvm::TypeSwitch<mlir::Type, llvm::StringRef>(ty)
.Case<mlir::cir::SingleType>([&](auto) {
return rtFuncNames.floatVer;
})
.Case<mlir::cir::DoubleType>([&](auto) {
return rtFuncNames.doubleVer;
})
.Case<mlir::cir::LongDoubleType>([&](auto) {
return rtFuncNames.longDoubleVer;
});

CIRBaseBuilderTy builder(*theModule.getContext());
builder.setInsertionPointToStart(theModule.getBody());

llvm::SmallVector<mlir::Type, 2> operandTypes(op->getNumOperands(), ty);
auto rtFuncTy =
mlir::cir::FuncType::get(operandTypes, op->getResult(0).getType());
FuncOp rtFunc =
buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

builder.setInsertionPointAfter(op);
auto call = builder.createCallOp(op->getLoc(), rtFunc, op->getOperands());

op->replaceAllUsesWith(call);
op->erase();
}

void LoweringPreparePass::runOnOperation() {
assert(astCtx && "Missing ASTContext, please construct with the right ctor");
auto *op = getOperation();
Expand All @@ -721,22 +656,16 @@ void LoweringPreparePass::runOnOperation() {
}

SmallVector<Operation *> opsToTransform;
SmallVector<Operation *> mathOpsToTransform;

op->walk([&](Operation *op) {
if (isa<CmpThreeWayOp, VAArgOp, GlobalOp, DynamicCastOp, StdFindOp,
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(
op))
opsToTransform.push_back(op);
else if (isa<CosOp, ExpOp, Exp2Op, LogOp, Log10Op, Log2Op, SinOp, SqrtOp,
FModOp, PowOp>(op))
mathOpsToTransform.push_back(op);
});

for (auto *o : opsToTransform)
runOnOp(o);
for (auto *o : mathOpsToTransform)
runOnMathOp(o);

buildCXXGlobalInitFunc();
buildGlobalCtorDtorList();
Expand Down
46 changes: 41 additions & 5 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3395,17 +3395,33 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern<CIROp> {

using CIRCeilOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
using CIRCosOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CosOp, mlir::LLVM::CosOp>;
using CIRExpOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::ExpOp, mlir::LLVM::ExpOp>;
using CIRExp2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Exp2Op, mlir::LLVM::Exp2Op>;
using CIRFloorOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
using CIRFabsOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
using CIRLogOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::LogOp, mlir::LLVM::LogOp>;
using CIRLog10OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log10Op, mlir::LLVM::Log10Op>;
using CIRLog2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log2Op, mlir::LLVM::Log2Op>;
using CIRNearbyintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::NearbyintOp,
mlir::LLVM::NearbyintOp>;
using CIRRintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
using CIRRoundOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
using CIRSinOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SinOp, mlir::LLVM::SinOp>;
using CIRSqrtOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SqrtOp, mlir::LLVM::SqrtOp>;
using CIRTruncOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;

Expand Down Expand Up @@ -3442,6 +3458,24 @@ using CIRFMaxOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::FMaxOp, mlir::LLVM::MaxNumOp>;
using CIRFMinOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::FMinOp, mlir::LLVM::MinNumOp>;
using CIRPowOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::PowOp, mlir::LLVM::PowOp>;

// cir.fmod is special. Instead of lowering it to an intrinsic call, lower it to
// the frem LLVM instruction.
class CIRFModOpLowering : public mlir::OpConversionPattern<mlir::cir::FModOp> {
public:
using mlir::OpConversionPattern<mlir::cir::FModOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::FModOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resTy = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, resTy, adaptor.getLhs(),
adaptor.getRhs());
return mlir::success();
}
};

class CIRClearCacheOpLowering
: public mlir::OpConversionPattern<mlir::cir::ClearCacheOp> {
Expand Down Expand Up @@ -3489,11 +3523,13 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering,
CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering,
CIRCeilOpLowering, CIRFloorOpLowering, CIRFAbsOpLowering,
CIRNearbyintOpLowering, CIRRintOpLowering, CIRRoundOpLowering,
CIRTruncOpLowering, CIRCopysignOpLowering, CIRFMaxOpLowering,
CIRFMinOpLowering, CIRClearCacheOpLowering>(converter,
patterns.getContext());
CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering,
CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering,
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
CIRClearCacheOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
Loading

0 comments on commit b6a635d

Please sign in to comment.