Skip to content

Commit

Permalink
[CIR][Dialect] Extend UnaryFPToFPBuiltinOp to vector of FP type (#1132)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghehg authored Nov 23, 2024
1 parent 8176d88 commit aa6fe48
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4342,8 +4342,8 @@ def LLrintOp : UnaryFPToIntBuiltinOp<"llrint", "LlrintOp">;

class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins CIR_AnyFloat:$src);
let results = (outs CIR_AnyFloat:$result);
let arguments = (ins CIR_AnyFloatOrVecOfFloat:$src);
let results = (outs CIR_AnyFloatOrVecOfFloat:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
Expand Down
10 changes: 10 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,20 @@ def SignedIntegerVector : Type<
]>, "!cir.vector of !cir.int"> {
}

// Vector of Float type
def FPVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::SingleType, ::cir::DoubleType>("
"::mlir::cast<::cir::VectorType>($_self).getEltType())">,
]>, "!cir.vector of !cir.fp"> {
}

// Constraints
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf<
[PrimitiveSInt, SignedIntegerVector]>;
def CIR_AnyFloatOrVecOfFloat: AnyTypeOf<[CIR_AnyFloat, FPVector]>;

// Pointer to Arrays
def ArrayPtr : Type<
Expand Down
86 changes: 85 additions & 1 deletion clang/test/CIR/Lowering/builtin-floating-point.cir
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,133 @@
// RUN: FileCheck --input-file=%t.ll %s

module {
cir.func @test(%arg0 : !cir.float) {
cir.func @test(%arg0 : !cir.float, %arg1 : !cir.vector<!cir.double x 2>, %arg2 : !cir.vector<!cir.float x 4>) {
%1 = cir.cos %arg0 : !cir.float
// CHECK: llvm.intr.cos(%arg0) : (f32) -> f32

%101 = cir.cos %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.cos(%arg1) : (vector<2xf64>) -> vector<2xf64>

%201 = cir.cos %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.cos(%arg2) : (vector<4xf32>) -> vector<4xf32>

%2 = cir.ceil %arg0 : !cir.float
// CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32

%102 = cir.ceil %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.ceil(%arg1) : (vector<2xf64>) -> vector<2xf64>

%202 = cir.ceil %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.ceil(%arg2) : (vector<4xf32>) -> vector<4xf32>

%3 = cir.exp %arg0 : !cir.float
// CHECK: llvm.intr.exp(%arg0) : (f32) -> f32

%103 = cir.exp %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.exp(%arg1) : (vector<2xf64>) -> vector<2xf64>

%203 = cir.exp %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.exp(%arg2) : (vector<4xf32>) -> vector<4xf32>

%4 = cir.exp2 %arg0 : !cir.float
// CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32

%104 = cir.exp2 %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.exp2(%arg1) : (vector<2xf64>) -> vector<2xf64>

%204 = cir.exp2 %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.exp2(%arg2) : (vector<4xf32>) -> vector<4xf32>

%5 = cir.fabs %arg0 : !cir.float
// CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32

%105 = cir.fabs %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.fabs(%arg1) : (vector<2xf64>) -> vector<2xf64>

%205 = cir.fabs %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.fabs(%arg2) : (vector<4xf32>) -> vector<4xf32>

%6 = cir.floor %arg0 : !cir.float
// CHECK: llvm.intr.floor(%arg0) : (f32) -> f32

%106 = cir.floor %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.floor(%arg1) : (vector<2xf64>) -> vector<2xf64>

%206 = cir.floor %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.floor(%arg2) : (vector<4xf32>) -> vector<4xf32>

%7 = cir.log %arg0 : !cir.float
// CHECK: llvm.intr.log(%arg0) : (f32) -> f32

%107 = cir.log %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.log(%arg1) : (vector<2xf64>) -> vector<2xf64>

%207 = cir.log %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.log(%arg2) : (vector<4xf32>) -> vector<4xf32>

%8 = cir.log10 %arg0 : !cir.float
// CHECK: llvm.intr.log10(%arg0) : (f32) -> f32

%108 = cir.log10 %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.log10(%arg1) : (vector<2xf64>) -> vector<2xf64>

%208 = cir.log10 %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.log10(%arg2) : (vector<4xf32>) -> vector<4xf32>

%9 = cir.log2 %arg0 : !cir.float
// CHECK: llvm.intr.log2(%arg0) : (f32) -> f32

%109 = cir.log2 %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.log2(%arg1) : (vector<2xf64>) -> vector<2xf64>

%209 = cir.log2 %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.log2(%arg2) : (vector<4xf32>) -> vector<4xf32>

%10 = cir.nearbyint %arg0 : !cir.float
// CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32

%110 = cir.nearbyint %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.nearbyint(%arg1) : (vector<2xf64>) -> vector<2xf64>

%210 = cir.nearbyint %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.nearbyint(%arg2) : (vector<4xf32>) -> vector<4xf32>

%11 = cir.rint %arg0 : !cir.float
// CHECK: llvm.intr.rint(%arg0) : (f32) -> f32

%111 = cir.rint %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.rint(%arg1) : (vector<2xf64>) -> vector<2xf64>

%211 = cir.rint %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.rint(%arg2) : (vector<4xf32>) -> vector<4xf32>

%12 = cir.round %arg0 : !cir.float
// CHECK: llvm.intr.round(%arg0) : (f32) -> f32

%112 = cir.round %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.round(%arg1) : (vector<2xf64>) -> vector<2xf64>

%212 = cir.round %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.round(%arg2) : (vector<4xf32>) -> vector<4xf32>

%13 = cir.sin %arg0 : !cir.float
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32

%113 = cir.sin %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.sin(%arg1) : (vector<2xf64>) -> vector<2xf64>

%213 = cir.sin %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.sin(%arg2) : (vector<4xf32>) -> vector<4xf32>

%14 = cir.sqrt %arg0 : !cir.float
// CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32

%114 = cir.sqrt %arg1 : !cir.vector<!cir.double x 2>
// CHECK: llvm.intr.sqrt(%arg1) : (vector<2xf64>) -> vector<2xf64>

%214 = cir.sqrt %arg2 : !cir.vector<!cir.float x 4>
// CHECK: llvm.intr.sqrt(%arg2) : (vector<4xf32>) -> vector<4xf32>

cir.return
}
}

0 comments on commit aa6fe48

Please sign in to comment.