Skip to content

Commit

Permalink
[CIR][Fix] FP builtins should lower directly to LLVM builtins
Browse files Browse the repository at this point in the history
LLVM lowering for the following operations is introduced in llvm#616 and llvm#651: cos,
exp, exp2, log, log10, log2, sin, sqrt, fmod, and pow. However, they are not
lowered to their corresponding LLVM intrinsics; instead they are transformed to
libc calls during lowering prepare. This does not match the upstream behavior.

This patch tries to correct this mistake.
  • Loading branch information
Lancern committed Jun 24, 2024
1 parent 2dd4609 commit f47ab9e
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 136 deletions.
71 changes: 0 additions & 71 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Path.h"

Expand Down Expand Up @@ -72,7 +71,6 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void runOnOperation() override;

void runOnOp(Operation *op);
void runOnMathOp(Operation *op);
void lowerThreeWayCmpOp(CmpThreeWayOp op);
void lowerVAArgOp(VAArgOp op);
void lowerGlobalOp(GlobalOp op);
Expand Down Expand Up @@ -655,69 +653,6 @@ void LoweringPreparePass::runOnOp(Operation *op) {
}
}

void LoweringPreparePass::runOnMathOp(Operation *op) {
struct MathOpFunctionNames {
llvm::StringRef floatVer;
llvm::StringRef doubleVer;
llvm::StringRef longDoubleVer;
};

mlir::Type ty = op->getResult(0).getType();

MathOpFunctionNames rtFuncNames =
llvm::TypeSwitch<Operation *, MathOpFunctionNames>(op)
.Case<FModOp>([](auto) {
return MathOpFunctionNames{"fmodf", "fmod", "fmodl"};
})
.Case<PowOp>(
[](auto) { return MathOpFunctionNames{"powf", "pow", "powl"}; })
.Case<CosOp>(
[](auto) { return MathOpFunctionNames{"cosf", "cos", "cosl"}; })
.Case<ExpOp>(
[](auto) { return MathOpFunctionNames{"expf", "exp", "expl"}; })
.Case<Exp2Op>([](auto) {
return MathOpFunctionNames{"exp2f", "exp2", "exp2l"};
})
.Case<LogOp>(
[](auto) { return MathOpFunctionNames{"logf", "log", "logl"}; })
.Case<Log10Op>([](auto) {
return MathOpFunctionNames{"log10f", "log10", "log10l"};
})
.Case<Log2Op>([](auto) {
return MathOpFunctionNames{"log2f", "log2", "log2l"};
})
.Case<SinOp>(
[](auto) { return MathOpFunctionNames{"sinf", "sin", "sinl"}; })
.Case<SqrtOp>([](auto) {
return MathOpFunctionNames{"sqrtf", "sqrt", "sqrtl"};
});
llvm::StringRef rtFuncName = llvm::TypeSwitch<mlir::Type, llvm::StringRef>(ty)
.Case<mlir::cir::SingleType>([&](auto) {
return rtFuncNames.floatVer;
})
.Case<mlir::cir::DoubleType>([&](auto) {
return rtFuncNames.doubleVer;
})
.Case<mlir::cir::LongDoubleType>([&](auto) {
return rtFuncNames.longDoubleVer;
});

CIRBaseBuilderTy builder(*theModule.getContext());
builder.setInsertionPointToStart(theModule.getBody());

llvm::SmallVector<mlir::Type, 2> operandTypes(op->getNumOperands(), ty);
auto rtFuncTy =
mlir::cir::FuncType::get(operandTypes, op->getResult(0).getType());
FuncOp rtFunc =
buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

builder.setInsertionPointAfter(op);
auto call = builder.createCallOp(op->getLoc(), rtFunc, op->getOperands());

op->replaceAllUsesWith(call);
op->erase();
}

void LoweringPreparePass::runOnOperation() {
assert(astCtx && "Missing ASTContext, please construct with the right ctor");
auto *op = getOperation();
Expand All @@ -726,22 +661,16 @@ void LoweringPreparePass::runOnOperation() {
}

SmallVector<Operation *> opsToTransform;
SmallVector<Operation *> mathOpsToTransform;

op->walk([&](Operation *op) {
if (isa<CmpThreeWayOp, VAArgOp, GlobalOp, DynamicCastOp, StdFindOp,
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(
op))
opsToTransform.push_back(op);
else if (isa<CosOp, ExpOp, Exp2Op, LogOp, Log10Op, Log2Op, SinOp, SqrtOp,
FModOp, PowOp>(op))
mathOpsToTransform.push_back(op);
});

for (auto *o : opsToTransform)
runOnOp(o);
for (auto *o : mathOpsToTransform)
runOnMathOp(o);

buildCXXGlobalInitFunc();
buildGlobalCtorDtorList();
Expand Down
45 changes: 41 additions & 4 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3235,17 +3235,33 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern<CIROp> {

using CIRCeilOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
using CIRCosOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CosOp, mlir::LLVM::CosOp>;
using CIRExpOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::ExpOp, mlir::LLVM::ExpOp>;
using CIRExp2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Exp2Op, mlir::LLVM::Exp2Op>;
using CIRFloorOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
using CIRFabsOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
using CIRLogOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::LogOp, mlir::LLVM::LogOp>;
using CIRLog10OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log10Op, mlir::LLVM::Log10Op>;
using CIRLog2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log2Op, mlir::LLVM::Log2Op>;
using CIRNearbyintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::NearbyintOp,
mlir::LLVM::NearbyintOp>;
using CIRRintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
using CIRRoundOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
using CIRSinOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SinOp, mlir::LLVM::SinOp>;
using CIRSqrtOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SqrtOp, mlir::LLVM::SqrtOp>;
using CIRTruncOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;

Expand Down Expand Up @@ -3282,6 +3298,24 @@ using CIRFMaxOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::FMaxOp, mlir::LLVM::MaxNumOp>;
using CIRFMinOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::FMinOp, mlir::LLVM::MinNumOp>;
using CIRPowOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::PowOp, mlir::LLVM::PowOp>;

// cir.fmod is special. Instead of lowering it to an intrinsic call, lower it to
// the frem LLVM instruction.
class CIRFModOpLowering : public mlir::OpConversionPattern<mlir::cir::FModOp> {
public:
using mlir::OpConversionPattern<mlir::cir::FModOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::FModOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resTy = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, resTy, adaptor.getLhs(),
adaptor.getRhs());
return mlir::success();
}
};

void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
Expand Down Expand Up @@ -3309,10 +3343,13 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering,
CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering,
CIRCeilOpLowering, CIRFloorOpLowering, CIRFAbsOpLowering,
CIRNearbyintOpLowering, CIRRintOpLowering, CIRRoundOpLowering,
CIRTruncOpLowering, CIRCopysignOpLowering, CIRFMaxOpLowering,
CIRFMinOpLowering>(converter, patterns.getContext());
CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering,
CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering,
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering,
CIRPowOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
Loading

0 comments on commit f47ab9e

Please sign in to comment.