From a0257736dc3fce682310ba1878caf84163fd4175 Mon Sep 17 00:00:00 2001 From: Krito Date: Wed, 29 May 2024 14:36:39 +0800 Subject: [PATCH] [CIR][Lowering] Add MLIR lowering support for CIR shift operations (#630) This pr adds cir.shift lowering to MLIR passes and test files. --- .../Lowering/ThroughMLIR/LowerCIRToMLIR.cpp | 100 ++++++++++++------ clang/test/CIR/Lowering/ThroughMLIR/shift.cir | 31 ++++++ 2 files changed, 98 insertions(+), 33 deletions(-) create mode 100644 clang/test/CIR/Lowering/ThroughMLIR/shift.cir diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index cb51dfd16ee0..4df181bef5db 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -321,6 +321,61 @@ class CIRExpOpLowering : public mlir::OpConversionPattern { } }; +static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter, + mlir::Value src, mlir::Type dstTy, + bool isSigned = false) { + auto srcTy = src.getType(); + assert(isa(srcTy)); + assert(isa(dstTy)); + + auto srcWidth = srcTy.cast().getWidth(); + auto dstWidth = dstTy.cast().getWidth(); + auto loc = src.getLoc(); + + if (dstWidth > srcWidth && isSigned) + return rewriter.create(loc, dstTy, src); + else if (dstWidth > srcWidth) + return rewriter.create(loc, dstTy, src); + else if (dstWidth < srcWidth) + return rewriter.create(loc, dstTy, src); + else + return rewriter.create(loc, dstTy, src); +} + +class CIRShiftOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::cir::ShiftOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto cirAmtTy = op.getAmount().getType().dyn_cast(); + auto cirValTy = op.getValue().getType().dyn_cast(); + auto mlirTy = getTypeConverter()->convertType(op.getType()); + mlir::Value amt = adaptor.getAmount(); + mlir::Value val = adaptor.getValue(); + + assert(cirValTy && cirAmtTy && "non-integer shift is NYI"); + assert(cirValTy == op.getType() && "inconsistent operands' types NYI"); + + // Ensure shift amount is the same type as the value. Some undefined + // behavior might occur in the casts below as per [C99 6.5.7.3]. + amt = createIntCast(rewriter, amt, mlirTy, cirAmtTy.isSigned()); + + // Lower to the proper arith shift operation. + if (op.getIsShiftleft()) + rewriter.replaceOpWithNewOp(op, mlirTy, val, amt); + else { + if (cirValTy.isUnsigned()) + rewriter.replaceOpWithNewOp(op, mlirTy, val, amt); + else + rewriter.replaceOpWithNewOp(op, mlirTy, val, amt); + } + + return mlir::success(); + } +}; + class CIRExp2OpLowering : public mlir::OpConversionPattern { public: using mlir::OpConversionPattern::OpConversionPattern; @@ -901,27 +956,6 @@ class CIRGetGlobalOpLowering } }; -static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter, - mlir::Value src, mlir::Type dstTy, - bool isSigned = false) { - auto srcTy = src.getType(); - assert(isa(srcTy)); - assert(isa(dstTy)); - - auto srcWidth = srcTy.cast().getWidth(); - auto dstWidth = dstTy.cast().getWidth(); - auto loc = src.getLoc(); - - if (dstWidth > srcWidth && isSigned) - return rewriter.create(loc, dstTy, src); - else if (dstWidth > srcWidth) - return rewriter.create(loc, dstTy, src); - else if (dstWidth < srcWidth) - return rewriter.create(loc, dstTy, src); - else - return rewriter.create(loc, dstTy, src); -} - class CIRCastOpLowering : public mlir::OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1124,18 +1158,18 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); - patterns - .add( - converter, patterns.getContext()); + patterns.add(converter, + patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { diff --git a/clang/test/CIR/Lowering/ThroughMLIR/shift.cir b/clang/test/CIR/Lowering/ThroughMLIR/shift.cir new file mode 100644 index 000000000000..aecbc3f45940 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/shift.cir @@ -0,0 +1,31 @@ +// RUN: cir-opt %s -cir-to-mlir -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir + +!s16i = !cir.int +!s32i = !cir.int +!s64i = !cir.int +!u16i = !cir.int +module { + cir.func @testShiftWithDifferentValueAndAmountTypes(%arg0: !s16i, %arg1: !s32i, %arg2: !s64i, %arg3: !u16i) { + %1 = cir.shift(left, %arg1: !s32i, %arg2 : !s64i) -> !s32i + %2 = cir.shift(left, %arg1 : !s32i, %arg0 : !s16i) -> !s32i + %3 = cir.shift(left, %arg1 : !s32i, %arg3 : !u16i) -> !s32i + %4 = cir.shift(left, %arg1 : !s32i, %arg1 : !s32i) -> !s32i + cir.return + } +} + +// CHECK: module { +// CHECK-NEXT: func.func @testShiftWithDifferentValueAndAmountTypes(%arg0: i16, %arg1: i32, %arg2: i64, %arg3: i16) { +// CHECK-NEXT: %[[TRUNC:.+]] = arith.trunci %arg2 : i64 to i32 +// CHECK-NEXT: %[[SHIFT_TRUNC:.+]] = arith.shli %arg1, %[[TRUNC]] : i32 +// CHECK-NEXT: %[[EXTS:.+]] = arith.extsi %arg0 : i16 to i32 +// CHECK-NEXT: %[[SHIFT_EXTS:.+]] = arith.shli %arg1, %[[EXTS]] : i32 +// CHECK-NEXT: %[[EXTU:.+]] = arith.extui %arg3 : i16 to i32 +// CHECK-NEXT: %[[SHIFT_EXTU:.+]] = arith.shli %arg1, %[[EXTU]] : i32 +// CHECK-NEXT: %[[BITCAST:.+]] = arith.bitcast %arg1 : i32 to i32 +// CHECK-NEXT: %[[SHIFT_BITCAST:.+]] = arith.shli %arg1, %[[BITCAST]] : i32 +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +