Skip to content

Commit

Permalink
[CIR][LLVMLowering] Add LLVM lowering for unary fp2fp builtins (#651)
Browse files Browse the repository at this point in the history
This patch adds LLVM lowering support for unary fp2fp builtins.

Those builtins that should be lowered to runtime function calls are
lowered to such calls during lowering prepare. Other builtins are
lowered to LLVM intrinsic calls during LLVM lowering.
  • Loading branch information
Lancern authored Jun 5, 2024
1 parent 4200ad0 commit 76613f6
Show file tree
Hide file tree
Showing 3 changed files with 569 additions and 53 deletions.
126 changes: 75 additions & 51 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Path.h"

Expand Down Expand Up @@ -71,6 +72,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void runOnOperation() override;

void runOnOp(Operation *op);
void runOnMathOp(Operation *op);
void lowerThreeWayCmpOp(CmpThreeWayOp op);
void lowerVAArgOp(VAArgOp op);
void lowerGlobalOp(GlobalOp op);
Expand All @@ -80,8 +82,6 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void lowerIterEndOp(IterEndOp op);
void lowerArrayDtor(ArrayDtor op);
void lowerArrayCtor(ArrayCtor op);
void lowerFModOp(FModOp op);
void lowerPowOp(PowOp op);

/// Build the function that initializes the specified global
FuncOp buildCXXGlobalVarDeclInitFunc(GlobalOp op);
Expand Down Expand Up @@ -627,49 +627,6 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
op.erase();
}

static void lowerBinaryFPToFPBuiltinOp(LoweringPreparePass &pass,
mlir::Operation *op,
llvm::StringRef floatRtFuncName,
llvm::StringRef doubleRtFuncName,
llvm::StringRef longDoubleRtFuncName) {
mlir::Type ty = op->getResult(0).getType();

llvm::StringRef rtFuncName;
if (ty.isa<mlir::cir::SingleType>())
rtFuncName = floatRtFuncName;
else if (ty.isa<mlir::cir::DoubleType>())
rtFuncName = doubleRtFuncName;
else if (ty.isa<mlir::cir::LongDoubleType>())
rtFuncName = longDoubleRtFuncName;
else
llvm_unreachable("unknown binary fp2fp builtin operand type");

CIRBaseBuilderTy builder(*pass.theModule.getContext());
builder.setInsertionPointToStart(pass.theModule.getBody());

auto rtFuncTy = mlir::cir::FuncType::get({ty, ty}, ty);
FuncOp rtFunc =
pass.buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);

builder.setInsertionPointAfter(op);
auto call = builder.create<mlir::cir::CallOp>(op->getLoc(), rtFunc,
mlir::ValueRange{lhs, rhs});

op->replaceAllUsesWith(call);
op->erase();
}

void LoweringPreparePass::lowerFModOp(FModOp op) {
lowerBinaryFPToFPBuiltinOp(*this, op, "fmodf", "fmod", "fmodl");
}

void LoweringPreparePass::lowerPowOp(PowOp op) {
lowerBinaryFPToFPBuiltinOp(*this, op, "powf", "pow", "powl");
}

void LoweringPreparePass::runOnOp(Operation *op) {
if (auto threeWayCmp = dyn_cast<CmpThreeWayOp>(op)) {
lowerThreeWayCmpOp(threeWayCmp);
Expand All @@ -695,13 +652,73 @@ void LoweringPreparePass::runOnOp(Operation *op) {
} else if (auto globalDtor = fnOp.getGlobalDtorAttr()) {
globalDtorList.push_back(globalDtor);
}
} else if (auto fmodOp = dyn_cast<FModOp>(op)) {
lowerFModOp(fmodOp);
} else if (auto powOp = dyn_cast<PowOp>(op)) {
lowerPowOp(powOp);
}
}

void LoweringPreparePass::runOnMathOp(Operation *op) {
struct MathOpFunctionNames {
llvm::StringRef floatVer;
llvm::StringRef doubleVer;
llvm::StringRef longDoubleVer;
};

mlir::Type ty = op->getResult(0).getType();

MathOpFunctionNames rtFuncNames =
llvm::TypeSwitch<Operation *, MathOpFunctionNames>(op)
.Case<FModOp>([](auto) {
return MathOpFunctionNames{"fmodf", "fmod", "fmodl"};
})
.Case<PowOp>(
[](auto) { return MathOpFunctionNames{"powf", "pow", "powl"}; })
.Case<CosOp>(
[](auto) { return MathOpFunctionNames{"cosf", "cos", "cosl"}; })
.Case<ExpOp>(
[](auto) { return MathOpFunctionNames{"expf", "exp", "expl"}; })
.Case<Exp2Op>([](auto) {
return MathOpFunctionNames{"exp2f", "exp2", "exp2l"};
})
.Case<LogOp>(
[](auto) { return MathOpFunctionNames{"logf", "log", "logl"}; })
.Case<Log10Op>([](auto) {
return MathOpFunctionNames{"log10f", "log10", "log10l"};
})
.Case<Log2Op>([](auto) {
return MathOpFunctionNames{"log2f", "log2", "log2l"};
})
.Case<SinOp>(
[](auto) { return MathOpFunctionNames{"sinf", "sin", "sinl"}; })
.Case<SqrtOp>([](auto) {
return MathOpFunctionNames{"sqrtf", "sqrt", "sqrtl"};
});
llvm::StringRef rtFuncName = llvm::TypeSwitch<mlir::Type, llvm::StringRef>(ty)
.Case<mlir::cir::SingleType>([&](auto) {
return rtFuncNames.floatVer;
})
.Case<mlir::cir::DoubleType>([&](auto) {
return rtFuncNames.doubleVer;
})
.Case<mlir::cir::LongDoubleType>([&](auto) {
return rtFuncNames.longDoubleVer;
});

CIRBaseBuilderTy builder(*theModule.getContext());
builder.setInsertionPointToStart(theModule.getBody());

llvm::SmallVector<mlir::Type, 2> operandTypes(op->getNumOperands(), ty);
auto rtFuncTy =
mlir::cir::FuncType::get(operandTypes, op->getResult(0).getType());
FuncOp rtFunc =
buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

builder.setInsertionPointAfter(op);
auto call = builder.create<mlir::cir::CallOp>(op->getLoc(), rtFunc,
op->getOperands());

op->replaceAllUsesWith(call);
op->erase();
}

void LoweringPreparePass::runOnOperation() {
assert(astCtx && "Missing ASTContext, please construct with the right ctor");
auto *op = getOperation();
Expand All @@ -710,15 +727,22 @@ void LoweringPreparePass::runOnOperation() {
}

SmallVector<Operation *> opsToTransform;
SmallVector<Operation *> mathOpsToTransform;

op->walk([&](Operation *op) {
if (isa<CmpThreeWayOp, VAArgOp, GlobalOp, DynamicCastOp, StdFindOp,
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp,
FModOp, PowOp>(op))
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(
op))
opsToTransform.push_back(op);
else if (isa<CosOp, ExpOp, Exp2Op, LogOp, Log10Op, Log2Op, SinOp, SqrtOp,
FModOp, PowOp>(op))
mathOpsToTransform.push_back(op);
});

for (auto *o : opsToTransform)
runOnOp(o);
for (auto *o : mathOpsToTransform)
runOnMathOp(o);

buildCXXGlobalInitFunc();
buildGlobalCtorDtorList();
Expand Down
38 changes: 36 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3161,6 +3161,38 @@ class CIRCmpThreeWayOpLowering
}
};

template <typename CIROp, typename LLVMOp>
class CIRUnaryFPToFPBuiltinOpLowering
: public mlir::OpConversionPattern<CIROp> {
public:
using mlir::OpConversionPattern<CIROp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(CIROp op,
typename mlir::OpConversionPattern<CIROp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resTy = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<LLVMOp>(op, resTy, adaptor.getSrc());
return mlir::success();
}
};

using CIRCeilOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
using CIRFloorOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
using CIRFabsOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
using CIRNearbyintOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::NearbyintOp,
mlir::LLVM::NearbyintOp>;
using CIRRintOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
using CIRRoundOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
using CIRTruncOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;

template <typename CIROp, typename LLVMOp>
class CIRBinaryFPToFPBuiltinOpLowering
: public mlir::OpConversionPattern<CIROp> {
Expand Down Expand Up @@ -3210,8 +3242,10 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering, CIRCopysignOpLowering, CIRFMaxOpLowering,
CIRFMinOpLowering>(converter, patterns.getContext());
CIRCmpThreeWayOpLowering, CIRCeilOpLowering, CIRFloorOpLowering,
CIRFAbsOpLowering, CIRNearbyintOpLowering, CIRRintOpLowering,
CIRRoundOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
CIRFMaxOpLowering, CIRFMinOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
Loading

0 comments on commit 76613f6

Please sign in to comment.