Skip to content

Commit

Permalink
[CIR][Lowering] Add long double types for cos operation lowering
Browse files Browse the repository at this point in the history
Signed-off-by: zhoujing <[email protected]>
  • Loading branch information
zhoujingya committed Apr 26, 2024
1 parent 65ca98b commit fa371f2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
10 changes: 7 additions & 3 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,8 @@ class CIRConstantOpLowering
auto boolValue = mlir::cast<mlir::cir::BoolAttr>(op.getValue());
value = rewriter.getIntegerAttr(ty, boolValue.getValue());
} else if (op.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
assert(ty.isF32() || ty.isF64() && "NYI");
value = rewriter.getFloatAttr(
typeConverter->convertType(op.getType()),
op.getValue().cast<mlir::cir::FPAttr>().getValue());
ty, op.getValue().cast<mlir::cir::FPAttr>().getValue());
} else {
auto cirIntAttr = mlir::dyn_cast<mlir::cir::IntAttr>(op.getValue());
assert(cirIntAttr && "NYI non cir.int attr");
Expand Down Expand Up @@ -664,6 +662,12 @@ static mlir::TypeConverter prepareTypeConverter() {
converter.addConversion([&](mlir::cir::DoubleType type) -> mlir::Type {
return mlir::FloatType::getF64(type.getContext());
});
converter.addConversion([&](mlir::cir::FP80Type type) -> mlir::Type {
return mlir::FloatType::getF80(type.getContext());
});
converter.addConversion([&](mlir::cir::LongDoubleType type) -> mlir::Type {
return converter.convertType(type.getUnderlying());
});
converter.addConversion([&](mlir::cir::ArrayType type) -> mlir::Type {
auto elementType = converter.convertType(type.getEltType());
return mlir::MemRefType::get(type.getSize(), elementType);
Expand Down
30 changes: 19 additions & 11 deletions clang/test/CIR/Lowering/ThroughMLIR/cos.cir
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@ module {
cir.func @foo() {
%1 = cir.const(#cir.fp<1.0> : !cir.float) : !cir.float
%2 = cir.const(#cir.fp<1.0> : !cir.double) : !cir.double
%3 = cir.cos %1 : !cir.float
%4 = cir.cos %2 : !cir.double
%3 = cir.const(#cir.fp<1.0> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%4 = cir.const(#cir.fp<1.0> : !cir.long_double<!cir.double>) : !cir.long_double<!cir.double>
%5 = cir.cos %1 : !cir.float
%6 = cir.cos %2 : !cir.double
%7 = cir.cos %3 : !cir.long_double<!cir.f80>
%8 = cir.cos %4 : !cir.long_double<!cir.double>
cir.return
}
}

//CHECK: module {
//CHECK: func.func @foo() {
//CHECK: %cst = arith.constant 1.000000e+00 : f32
//CHECK: %cst_0 = arith.constant 1.000000e+00 : f64
//CHECK: %0 = math.cos %cst : f32
//CHECK: %1 = math.cos %cst_0 : f64
//CHECK: return
//CHECK: }
//CHECK: }
// CHECK: module {
// CHECK: func.func @foo() {
// CHECK: %cst = arith.constant 1.000000e+00 : f32
// CHECK: %cst_0 = arith.constant 1.000000e+00 : f64
// CHECK: %cst_1 = arith.constant 1.000000e+00 : f80
// CHECK: %cst_2 = arith.constant 1.000000e+00 : f64
// CHECK: %0 = math.cos %cst : f32
// CHECK: %1 = math.cos %cst_0 : f64
// CHECK: %2 = math.cos %cst_1 : f80
// CHECK: %3 = math.cos %cst_2 : f64
// CHECK: return
// CHECK: }
// CHECK: }

0 comments on commit fa371f2

Please sign in to comment.