Skip to content

Commit

Permalink
[CIR][Lowering] Add MLIR lowering support for CIR shift operations (l…
Browse files Browse the repository at this point in the history
…lvm#630)

This pr adds cir.shift lowering to MLIR passes and test files.
  • Loading branch information
Krito authored and lanza committed Jun 20, 2024
1 parent eba61e7 commit 0143d85
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 33 deletions.
100 changes: 67 additions & 33 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,61 @@ class CIRExpOpLowering : public mlir::OpConversionPattern<mlir::cir::ExpOp> {
}
};

static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Value src, mlir::Type dstTy,
bool isSigned = false) {
auto srcTy = src.getType();
assert(isa<mlir::IntegerType>(srcTy));
assert(isa<mlir::IntegerType>(dstTy));

auto srcWidth = srcTy.cast<mlir::IntegerType>().getWidth();
auto dstWidth = dstTy.cast<mlir::IntegerType>().getWidth();
auto loc = src.getLoc();

if (dstWidth > srcWidth && isSigned)
return rewriter.create<mlir::arith::ExtSIOp>(loc, dstTy, src);
else if (dstWidth > srcWidth)
return rewriter.create<mlir::arith::ExtUIOp>(loc, dstTy, src);
else if (dstWidth < srcWidth)
return rewriter.create<mlir::arith::TruncIOp>(loc, dstTy, src);
else
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
}

class CIRShiftOpLowering
: public mlir::OpConversionPattern<mlir::cir::ShiftOp> {
public:
using mlir::OpConversionPattern<mlir::cir::ShiftOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::ShiftOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto cirAmtTy = op.getAmount().getType().dyn_cast<mlir::cir::IntType>();
auto cirValTy = op.getValue().getType().dyn_cast<mlir::cir::IntType>();
auto mlirTy = getTypeConverter()->convertType(op.getType());
mlir::Value amt = adaptor.getAmount();
mlir::Value val = adaptor.getValue();

assert(cirValTy && cirAmtTy && "non-integer shift is NYI");
assert(cirValTy == op.getType() && "inconsistent operands' types NYI");

// Ensure shift amount is the same type as the value. Some undefined
// behavior might occur in the casts below as per [C99 6.5.7.3].
amt = createIntCast(rewriter, amt, mlirTy, cirAmtTy.isSigned());

// Lower to the proper arith shift operation.
if (op.getIsShiftleft())
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(op, mlirTy, val, amt);
else {
if (cirValTy.isUnsigned())
rewriter.replaceOpWithNewOp<mlir::arith::ShRUIOp>(op, mlirTy, val, amt);
else
rewriter.replaceOpWithNewOp<mlir::arith::ShRSIOp>(op, mlirTy, val, amt);
}

return mlir::success();
}
};

class CIRExp2OpLowering : public mlir::OpConversionPattern<mlir::cir::Exp2Op> {
public:
using mlir::OpConversionPattern<mlir::cir::Exp2Op>::OpConversionPattern;
Expand Down Expand Up @@ -901,27 +956,6 @@ class CIRGetGlobalOpLowering
}
};

static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Value src, mlir::Type dstTy,
bool isSigned = false) {
auto srcTy = src.getType();
assert(isa<mlir::IntegerType>(srcTy));
assert(isa<mlir::IntegerType>(dstTy));

auto srcWidth = srcTy.cast<mlir::IntegerType>().getWidth();
auto dstWidth = dstTy.cast<mlir::IntegerType>().getWidth();
auto loc = src.getLoc();

if (dstWidth > srcWidth && isSigned)
return rewriter.create<mlir::arith::ExtSIOp>(loc, dstTy, src);
else if (dstWidth > srcWidth)
return rewriter.create<mlir::arith::ExtUIOp>(loc, dstTy, src);
else if (dstWidth < srcWidth)
return rewriter.create<mlir::arith::TruncIOp>(loc, dstTy, src);
else
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
}

class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
public:
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
Expand Down Expand Up @@ -1124,18 +1158,18 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering,
CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering>(
converter, patterns.getContext());
patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
CIRSinOpLowering, CIRShiftOpLowering>(converter,
patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
31 changes: 31 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/shift.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
// RUN: FileCheck %s --input-file %t.mlir

!s16i = !cir.int<s, 16>
!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>
!u16i = !cir.int<u, 16>
module {
cir.func @testShiftWithDifferentValueAndAmountTypes(%arg0: !s16i, %arg1: !s32i, %arg2: !s64i, %arg3: !u16i) {
%1 = cir.shift(left, %arg1: !s32i, %arg2 : !s64i) -> !s32i
%2 = cir.shift(left, %arg1 : !s32i, %arg0 : !s16i) -> !s32i
%3 = cir.shift(left, %arg1 : !s32i, %arg3 : !u16i) -> !s32i
%4 = cir.shift(left, %arg1 : !s32i, %arg1 : !s32i) -> !s32i
cir.return
}
}

// CHECK: module {
// CHECK-NEXT: func.func @testShiftWithDifferentValueAndAmountTypes(%arg0: i16, %arg1: i32, %arg2: i64, %arg3: i16) {
// CHECK-NEXT: %[[TRUNC:.+]] = arith.trunci %arg2 : i64 to i32
// CHECK-NEXT: %[[SHIFT_TRUNC:.+]] = arith.shli %arg1, %[[TRUNC]] : i32
// CHECK-NEXT: %[[EXTS:.+]] = arith.extsi %arg0 : i16 to i32
// CHECK-NEXT: %[[SHIFT_EXTS:.+]] = arith.shli %arg1, %[[EXTS]] : i32
// CHECK-NEXT: %[[EXTU:.+]] = arith.extui %arg3 : i16 to i32
// CHECK-NEXT: %[[SHIFT_EXTU:.+]] = arith.shli %arg1, %[[EXTU]] : i32
// CHECK-NEXT: %[[BITCAST:.+]] = arith.bitcast %arg1 : i32 to i32
// CHECK-NEXT: %[[SHIFT_BITCAST:.+]] = arith.shli %arg1, %[[BITCAST]] : i32
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 0143d85

Please sign in to comment.