Skip to content

Commit

Permalink
[CIR][ThroughMLIR] lowering cir.bit.clz and cir.bit.ctz to MLIR (llvm…
Browse files Browse the repository at this point in the history
…#645)

This pr adds cir.bit.clz and cir.bit.ctz lowering to MLIR passes and
test files.
I will complete the lowering of other `cir.bit` operations in subsequent
PRs.
  • Loading branch information
Krito authored and smeenai committed Oct 9, 2024
1 parent 99033b7 commit 7b496a4
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 12 deletions.
50 changes: 38 additions & 12 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,31 @@ class CIRSinOpLowering : public mlir::OpConversionPattern<mlir::cir::SinOp> {
}
};

template <typename CIROp, typename MLIROp>
class CIRBitOpLowering : public mlir::OpConversionPattern<CIROp> {
public:
using mlir::OpConversionPattern<CIROp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(CIROp op,
typename mlir::OpConversionPattern<CIROp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resultIntTy = this->getTypeConverter()
->convertType(op.getType())
.template cast<mlir::IntegerType>();
auto res = rewriter.create<MLIROp>(op->getLoc(), adaptor.getInput());
auto newOp = createIntCast(rewriter, res->getResult(0), resultIntTy,
/*isSigned=*/false);
rewriter.replaceOp(op, newOp);
return mlir::LogicalResult::success();
}
};

using CIRBitClzOpLowering =
CIRBitOpLowering<mlir::cir::BitClzOp, mlir::math::CountLeadingZerosOp>;
using CIRBitCtzOpLowering =
CIRBitOpLowering<mlir::cir::BitCtzOp, mlir::math::CountTrailingZerosOp>;

class CIRConstantOpLowering
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
public:
Expand Down Expand Up @@ -1158,18 +1183,19 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
CIRSinOpLowering, CIRShiftOpLowering>(converter,
patterns.getContext());
patterns
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering,
CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering>(
converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
94 changes: 94 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/bit.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
// RUN: FileCheck %s --input-file %t.mlir

!s16i = !cir.int<s, 16>
!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>
!u16i = !cir.int<u, 16>
!u32i = !cir.int<u, 32>
!u64i = !cir.int<u, 64>


// int clz_u16(unsigned short x) {
// return __builtin_clzs(x);
// }
cir.func @clz_u16(%arg : !u16i) {
%0 = cir.bit.clz(%arg : !u16i) : !s32i
cir.return
}

// CHECK: func.func @clz_u16(%arg0: i16) {
// CHECK-NEXT: %[[CLZ_U16:.+]] = math.ctlz %arg0 : i16
// CHECK-NEXT: %[[EXTUI_U16:.+]] = arith.extui %[[CLZ_U16]] : i16 to i32
// CHECK-NEXT: return
// CHECK-NEXT: }

// int clz_u32(unsigned x) {
// return __builtin_clz(x);
// }
cir.func @clz_u32(%arg : !u32i) {
%0 = cir.bit.clz(%arg : !u32i) : !s32i
cir.return
}

// CHECK: func.func @clz_u32(%arg0: i32) {
// CHECK-NEXT: %[[CLZ_U32:.+]] = math.ctlz %arg0 : i32
// CHECK-NEXT: %[[BITCAST_U32:.+]] = arith.bitcast %[[CLZ_U32]] : i32 to i32
// CHECK-NEXT: return
// CHECK-NEXT: }

// int clz_u64(unsigned long x) {
// return __builtin_clzl(x);
// }
cir.func @clz_u64(%arg : !u64i) {
%0 = cir.bit.clz(%arg : !u64i) : !s32i
cir.return
}

// CHECK: func.func @clz_u64(%arg0: i64) {
// CHECK-NEXT: %[[CLZ_U64:.+]] = math.ctlz %arg0 : i64
// CHECK-NEXT: %[[TRUNCI_U64:.+]] = arith.trunci %[[CLZ_U64]] : i64 to i32
// CHECK-NEXT: return
// CHECK-NEXT: }

// int ctz_u16(unsigned short x) {
// return __builtin_ctzs(x);
// }
cir.func @ctz_u16(%arg : !u16i) {
%0 = cir.bit.ctz(%arg : !u16i) : !s32i
cir.return
}

// CHECK: func.func @ctz_u16(%arg0: i16) {
// CHECK-NEXT: %[[CTZ_U16:.+]] = math.cttz %arg0 : i16
// CHECK-NEXT: %[[EXTUI_U16:.+]] = arith.extui %[[CTZ_U16]] : i16 to i32
// CHECK-NEXT: return
// CHECK-NEXT: }

// int ctz_u32(unsigned x) {
// return __builtin_ctz(x);
// }
cir.func @ctz_u32(%arg : !u32i) {
%0 = cir.bit.ctz(%arg : !u32i) : !s32i
cir.return
}

// CHECK: func.func @ctz_u32(%arg0: i32) {
// CHECK-NEXT: %[[CTZ_U32:.+]] = math.cttz %arg0 : i32
// CHECK-NEXT: %[[BITCAST_U32:.+]] = arith.bitcast %[[CTZ_U32]] : i32 to i32
// CHECK-NEXT: return
// CHECK-NEXT: }

// int ctz_u64(unsigned long x) {
// return __builtin_ctzl(x);
// }
cir.func @ctz_u64(%arg : !u64i) {
%0 = cir.bit.ctz(%arg : !u64i) : !s32i
cir.return
}

// CHECK: func.func @ctz_u64(%arg0: i64) {
// CHECK-NEXT: %[[CTZ_U64:.+]] = math.cttz %arg0 : i64
// CHECK-NEXT: %[[TRUNCI_U64:.+]] = arith.trunci %[[CTZ_U64]] : i64 to i32
// CHECK-NEXT: return
// CHECK-NEXT: }

0 comments on commit 7b496a4

Please sign in to comment.