Skip to content

Commit

Permalink
[CIR][LLVMLowering] Add LLVM lowering for unary fp2fp builtins
Browse files Browse the repository at this point in the history
This patch adds LLVM lowering support for unary fp2fp builtins.

Those builtins that should be lowered to runtime function calls are lowered to
such calls during lowering prepare. Other builtins are lowered to LLVM intrinsic
calls during LLVM lowering.
  • Loading branch information
Lancern committed Jun 2, 2024
1 parent 4ffa090 commit ae75a19
Show file tree
Hide file tree
Showing 3 changed files with 586 additions and 3 deletions.
93 changes: 92 additions & 1 deletion clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void lowerIterEndOp(IterEndOp op);
void lowerArrayDtor(ArrayDtor op);
void lowerArrayCtor(ArrayCtor op);
void lowerCosOp(CosOp op);
void lowerExpOp(ExpOp op);
void lowerExp2Op(Exp2Op op);
void lowerLogOp(LogOp op);
void lowerLog10Op(Log10Op op);
void lowerLog2Op(Log2Op op);
void lowerSinOp(SinOp op);
void lowerSqrtOp(SqrtOp op);
void lowerFModOp(FModOp op);
void lowerPowOp(PowOp op);

Expand Down Expand Up @@ -627,6 +635,72 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
op.erase();
}

static void lowerUnaryFPToFPBuiltinOp(LoweringPreparePass &pass,
mlir::Operation *op,
llvm::StringRef floatRtFuncName,
llvm::StringRef doubleRtFuncName,
llvm::StringRef longDoubleRtFuncName) {
mlir::Type ty = op->getResult(0).getType();

llvm::StringRef rtFuncName;
if (ty.isa<mlir::cir::SingleType>())
rtFuncName = floatRtFuncName;
else if (ty.isa<mlir::cir::DoubleType>())
rtFuncName = doubleRtFuncName;
else if (ty.isa<mlir::cir::LongDoubleType>())
rtFuncName = longDoubleRtFuncName;
else
llvm_unreachable("unknown unary fp2fp builtin operand type");

CIRBaseBuilderTy builder(*pass.theModule.getContext());
builder.setInsertionPointToStart(pass.theModule.getBody());

auto rtFuncTy = mlir::cir::FuncType::get({ty}, ty);
FuncOp rtFunc =
pass.buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

auto src = op->getOperand(0);

builder.setInsertionPointAfter(op);
auto call = builder.create<mlir::cir::CallOp>(op->getLoc(), rtFunc,
mlir::ValueRange{src});

op->replaceAllUsesWith(call);
op->erase();
}

void LoweringPreparePass::lowerCosOp(CosOp op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "cosf", "cos", "cosl");
}

void LoweringPreparePass::lowerExpOp(ExpOp op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "expf", "exp", "expl");
}

void LoweringPreparePass::lowerExp2Op(Exp2Op op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "exp2f", "exp2", "exp2l");
}

void LoweringPreparePass::lowerLogOp(LogOp op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "logf", "log", "logl");
}

void LoweringPreparePass::lowerLog10Op(Log10Op op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "log10f", "log10", "log10l");
}

void LoweringPreparePass::lowerLog2Op(Log2Op op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "log2f", "log2", "log2l");
}

void LoweringPreparePass::lowerSinOp(SinOp op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "sinf", "sin", "sinl");
}

void LoweringPreparePass::lowerSqrtOp(SqrtOp op) {
lowerUnaryFPToFPBuiltinOp(*this, op, "sqrtf", "sqrt", "sqrtl");
}

static void lowerBinaryFPToFPBuiltinOp(LoweringPreparePass &pass,
mlir::Operation *op,
llvm::StringRef floatRtFuncName,
Expand Down Expand Up @@ -699,6 +773,22 @@ void LoweringPreparePass::runOnOp(Operation *op) {
lowerFModOp(fmodOp);
} else if (auto powOp = dyn_cast<PowOp>(op)) {
lowerPowOp(powOp);
} else if (auto cosOp = dyn_cast<CosOp>(op)) {
lowerCosOp(cosOp);
} else if (auto expOp = dyn_cast<ExpOp>(op)) {
lowerExpOp(expOp);
} else if (auto exp2Op = dyn_cast<Exp2Op>(op)) {
lowerExp2Op(exp2Op);
} else if (auto logOp = dyn_cast<LogOp>(op)) {
lowerLogOp(logOp);
} else if (auto log10Op = dyn_cast<Log10Op>(op)) {
lowerLog10Op(log10Op);
} else if (auto log2Op = dyn_cast<Log2Op>(op)) {
lowerLog2Op(log2Op);
} else if (auto sinOp = dyn_cast<SinOp>(op)) {
lowerSinOp(sinOp);
} else if (auto sqrtOp = dyn_cast<SqrtOp>(op)) {
lowerSqrtOp(sqrtOp);
}
}

Expand All @@ -713,7 +803,8 @@ void LoweringPreparePass::runOnOperation() {
op->walk([&](Operation *op) {
if (isa<CmpThreeWayOp, VAArgOp, GlobalOp, DynamicCastOp, StdFindOp,
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp,
FModOp, PowOp>(op))
CosOp, ExpOp, Exp2Op, LogOp, Log10Op, Log2Op, SinOp, SqrtOp, FModOp,
PowOp>(op))
opsToTransform.push_back(op);
});

Expand Down
38 changes: 36 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3161,6 +3161,38 @@ class CIRCmpThreeWayOpLowering
}
};

template <typename CIROp, typename LLVMOp>
class CIRUnaryFPToFPBuiltinOpLowering
: public mlir::OpConversionPattern<CIROp> {
public:
using mlir::OpConversionPattern<CIROp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(CIROp op,
typename mlir::OpConversionPattern<CIROp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resTy = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<LLVMOp>(op, resTy, adaptor.getSrc());
return mlir::success();
}
};

using CIRCeilOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
using CIRFloorOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
using CIRFabsOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
using CIRNearbyintOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::NearbyintOp,
mlir::LLVM::NearbyintOp>;
using CIRRintOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
using CIRRoundOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
using CIRTruncOpLowering =
CIRUnaryFPToFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;

template <typename CIROp, typename LLVMOp>
class CIRBinaryFPToFPBuiltinOpLowering
: public mlir::OpConversionPattern<CIROp> {
Expand Down Expand Up @@ -3210,8 +3242,10 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering, CIRCopysignOpLowering, CIRFMaxOpLowering,
CIRFMinOpLowering>(converter, patterns.getContext());
CIRCmpThreeWayOpLowering, CIRCeilOpLowering, CIRFloorOpLowering,
CIRFAbsOpLowering, CIRNearbyintOpLowering, CIRRintOpLowering,
CIRRoundOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
CIRFMaxOpLowering, CIRFMinOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
Loading

0 comments on commit ae75a19

Please sign in to comment.