From 94f8d28ce3d21f715b3a3725d17b3bba8c29a2b8 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Fri, 15 Nov 2024 13:55:27 +0800 Subject: [PATCH] [CIR][NFC] move data member pointer lowering to CXXABI This patch moves the lowering code for data member pointers from the conversion patterns to the implementation of CXXABI because this part should be ABI- specific. --- .../Transforms/TargetLowering/CIRCXXABI.h | 26 +++++++ .../TargetLowering/ItaniumCXXABI.cpp | 62 ++++++++++++++++ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 74 +++++++++---------- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 25 ++++++- clang/test/CIR/Lowering/data-member.cir | 27 ++++--- 5 files changed, 159 insertions(+), 55 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index 0f05ec8040f8..4c2f442721e8 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -15,9 +15,15 @@ #define LLVM_CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRCXXABI_H #include "LowerFunctionInfo.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" #include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" +#include "clang/CIR/Dialect/IR/CIRAttrs.h" #include "clang/CIR/Dialect/IR/CIRDataLayout.h" +#include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/Target/AArch64.h" namespace cir { @@ -59,6 +65,26 @@ class CIRCXXABI { /// Returns how an argument of the given record type should be passed. /// FIXME(cir): This expects a CXXRecordDecl! Not any record type. virtual RecordArgABI getRecordArgABI(const StructType RD) const = 0; + + /// Lower the given data member pointer type to its ABI type. The returned + /// type is also a CIR type. + virtual mlir::Type + lowerDataMemberType(cir::DataMemberType type, + const mlir::TypeConverter &typeConverter) const = 0; + + /// Lower the given data member pointer constant to a constant of the ABI + /// type. The returned constant is represented as an attribute as well. + virtual mlir::TypedAttr + lowerDataMemberConstant(cir::DataMemberAttr attr, + const mlir::DataLayout &layout, + const mlir::TypeConverter &typeConverter) const = 0; + + /// Lower the given cir.get_runtime_member op to a sequence of more + /// "primitive" CIR operations that act on the ABI types. + virtual mlir::Operation * + lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy, + mlir::Value loweredAddr, mlir::Value loweredMember, + mlir::OpBuilder &builder) const = 0; }; /// Creates an Itanium-family ABI. diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index 081db25808d1..a87cdc01ea9d 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -23,6 +23,7 @@ #include "../LoweringPrepareCXXABI.h" #include "CIRCXXABI.h" #include "LowerModule.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "llvm/Support/ErrorHandling.h" namespace cir { @@ -51,6 +52,19 @@ class ItaniumCXXABI : public CIRCXXABI { cir_cconv_assert(!cir::MissingFeatures::recordDeclCanPassInRegisters()); return RAA_Default; } + + mlir::Type + lowerDataMemberType(cir::DataMemberType type, + const mlir::TypeConverter &typeConverter) const override; + + mlir::TypedAttr lowerDataMemberConstant( + cir::DataMemberAttr attr, const mlir::DataLayout &layout, + const mlir::TypeConverter &typeConverter) const override; + + mlir::Operation * + lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy, + mlir::Value loweredAddr, mlir::Value loweredMember, + mlir::OpBuilder &builder) const override; }; } // namespace @@ -67,6 +81,54 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const { return false; } +mlir::Type ItaniumCXXABI::lowerDataMemberType( + cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const { + // Itanium C++ ABI 2.3: + // A pointer to data member is an offset from the base address of + // the class object containing it, represented as a ptrdiff_t + const clang::TargetInfo &target = LM.getTarget(); + clang::TargetInfo::IntType ptrdiffTy = + target.getPtrDiffType(clang::LangAS::Default); + return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy), + target.isTypeSigned(ptrdiffTy)); +} + +mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant( + cir::DataMemberAttr attr, const mlir::DataLayout &layout, + const mlir::TypeConverter &typeConverter) const { + uint64_t memberOffset; + if (attr.isNullPtr()) { + // Itanium C++ ABI 2.3: + // A NULL pointer is represented as -1. + memberOffset = -1ull; + } else { + // Itanium C++ ABI 2.3: + // A pointer to data member is an offset from the base address of + // the class object containing it, represented as a ptrdiff_t + auto memberIndex = attr.getMemberIndex().value(); + memberOffset = + attr.getType().getClsTy().getElementOffset(layout, memberIndex); + } + + mlir::Type abiTy = lowerDataMemberType(attr.getType(), typeConverter); + return cir::IntAttr::get(abiTy, memberOffset); +} + +mlir::Operation *ItaniumCXXABI::lowerGetRuntimeMember( + cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy, + mlir::Value loweredAddr, mlir::Value loweredMember, + mlir::OpBuilder &builder) const { + auto byteTy = IntType::get(op.getContext(), 8, true); + auto bytePtrTy = PointerType::get( + byteTy, mlir::cast(op.getAddr().getType()).getAddrSpace()); + auto objectBytesPtr = builder.create(op.getLoc(), bytePtrTy, + CastKind::bitcast, op.getAddr()); + auto memberBytesPtr = builder.create( + op.getLoc(), bytePtrTy, objectBytesPtr, loweredMember); + return builder.create(op.getLoc(), op.getType(), CastKind::bitcast, + memberBytesPtr); +} + CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) { switch (LM.getCXXABIKind()) { // Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 040f72562cac..d78fb529e713 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1521,28 +1521,6 @@ bool hasTrailingZeros(cir::ConstArrayAttr attr) { })); } -static mlir::Attribute -lowerDataMemberAttr(mlir::ModuleOp moduleOp, cir::DataMemberAttr attr, - const mlir::TypeConverter &typeConverter) { - mlir::DataLayout layout{moduleOp}; - - uint64_t memberOffset; - if (attr.isNullPtr()) { - // TODO(cir): the numerical value of a null data member pointer is - // ABI-specific and should be queried through ABI. - assert(!MissingFeatures::targetCodeGenInfoGetNullPointer()); - memberOffset = -1ull; - } else { - auto memberIndex = attr.getMemberIndex().value(); - memberOffset = - attr.getType().getClsTy().getElementOffset(layout, memberIndex); - } - - auto underlyingIntTy = mlir::IntegerType::get( - moduleOp->getContext(), layout.getTypeSizeInBits(attr.getType())); - return mlir::IntegerAttr::get(underlyingIntTy, memberOffset); -} - mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( cir::ConstantOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { @@ -1602,9 +1580,13 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( } attr = op.getValue(); } else if (mlir::isa(op.getType())) { + assert(lowerMod && "lower module is not available"); auto dataMember = mlir::cast(op.getValue()); - attr = lowerDataMemberAttr(op->getParentOfType(), - dataMember, *typeConverter); + mlir::DataLayout layout(op->getParentOfType()); + mlir::TypedAttr abiValue = lowerMod->getCXXABI().lowerDataMemberConstant( + dataMember, layout, *typeConverter); + rewriter.replaceOpWithNewOp(op, abiValue); + return mlir::success(); } // TODO(cir): constant arrays are currently just pushed into the stack using // the store instruction, instead of being stored as global variables and @@ -2208,8 +2190,15 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( return mlir::success(); } else if (auto dataMemberAttr = mlir::dyn_cast(init.value())) { - init = lowerDataMemberAttr(op->getParentOfType(), - dataMemberAttr, *typeConverter); + assert(lowerMod && "lower module is not available"); + mlir::DataLayout layout(op->getParentOfType()); + mlir::TypedAttr abiValue = lowerMod->getCXXABI().lowerDataMemberConstant( + dataMemberAttr, layout, *typeConverter); + auto abiOp = mlir::cast(rewriter.clone(*op.getOperation())); + abiOp.setInitialValueAttr(abiValue); + abiOp.setSymType(abiValue.getType()); + rewriter.replaceOp(op, abiOp); + return mlir::success(); } else if (const auto structAttr = mlir::dyn_cast(init.value())) { setupRegionInitializedLLVMGlobalOp(op, rewriter); @@ -3237,11 +3226,11 @@ mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite( mlir::LogicalResult CIRToLLVMGetRuntimeMemberOpLowering::matchAndRewrite( cir::GetRuntimeMemberOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - auto llvmResTy = getTypeConverter()->convertType(op.getType()); - auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8); - - rewriter.replaceOpWithNewOp( - op, llvmResTy, llvmElementTy, adaptor.getAddr(), adaptor.getMember()); + assert(lowerMod && "lowering module is not available"); + mlir::Type llvmResTy = getTypeConverter()->convertType(op.getType()); + mlir::Operation *llvmOp = lowerMod->getCXXABI().lowerGetRuntimeMember( + op, llvmResTy, adaptor.getAddr(), adaptor.getMember(), rewriter); + rewriter.replaceOp(op, llvmOp); return mlir::success(); } @@ -3850,7 +3839,7 @@ mlir::LogicalResult CIRToLLVMSignBitOpLowering::matchAndRewrite( void populateCIRToLLVMConversionPatterns( mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, - mlir::DataLayout &dataLayout, + mlir::DataLayout &dataLayout, cir::LowerModule *lowerModule, llvm::StringMap &stringGlobalsMap, llvm::StringMap &argStringGlobalsMap, llvm::MapVector &argsVarMap) { @@ -3858,6 +3847,9 @@ void populateCIRToLLVMConversionPatterns( patterns.add(converter, dataLayout, stringGlobalsMap, argStringGlobalsMap, argsVarMap, patterns.getContext()); + patterns.add( + converter, patterns.getContext(), lowerModule); patterns.add< // clang-format off CIRToLLVMAbsOpLowering, @@ -3891,7 +3883,6 @@ void populateCIRToLLVMConversionPatterns( CIRToLLVMComplexImagPtrOpLowering, CIRToLLVMComplexRealOpLowering, CIRToLLVMComplexRealPtrOpLowering, - CIRToLLVMConstantOpLowering, CIRToLLVMCopyOpLowering, CIRToLLVMDerivedClassAddrOpLowering, CIRToLLVMEhInflightOpLowering, @@ -3902,8 +3893,6 @@ void populateCIRToLLVMConversionPatterns( CIRToLLVMGetBitfieldOpLowering, CIRToLLVMGetGlobalOpLowering, CIRToLLVMGetMemberOpLowering, - CIRToLLVMGetRuntimeMemberOpLowering, - CIRToLLVMGlobalOpLowering, CIRToLLVMInlineAsmOpLowering, CIRToLLVMIsConstantOpLowering, CIRToLLVMIsFPClassOpLowering, @@ -3990,10 +3979,13 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS); }); - converter.addConversion([&](cir::DataMemberType type) -> mlir::Type { - return mlir::IntegerType::get(type.getContext(), - dataLayout.getTypeSizeInBits(type)); - }); + converter.addConversion( + [&, lowerModule](cir::DataMemberType type) -> mlir::Type { + assert(lowerModule && "CXXABI is not available"); + mlir::Type abiType = + lowerModule->getCXXABI().lowerDataMemberType(type, converter); + return converter.convertType(abiType); + }); converter.addConversion([&](cir::ArrayType type) -> mlir::Type { auto ty = converter.convertType(type.getEltType()); return mlir::LLVM::LLVMArrayType::get(ty, type.getSize()); @@ -4328,8 +4320,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { llvm::MapVector argsVarMap; populateCIRToLLVMConversionPatterns(patterns, converter, dataLayout, - stringGlobalsMap, argStringGlobalsMap, - argsVarMap); + lowerModule.get(), stringGlobalsMap, + argStringGlobalsMap, argsVarMap); mlir::populateFuncToLLVMConversionPatterns(converter, patterns); mlir::ConversionTarget target(getContext()); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index a88c30d3dd15..d86d9dc0e1b5 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -291,8 +291,15 @@ class CIRToLLVMStoreOpLowering class CIRToLLVMConstantOpLowering : public mlir::OpConversionPattern { + cir::LowerModule *lowerMod; + public: - using mlir::OpConversionPattern::OpConversionPattern; + CIRToLLVMConstantOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + cir::LowerModule *lowerModule) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) { + setHasBoundedRewriteRecursion(); + } mlir::LogicalResult matchAndRewrite(cir::ConstantOp op, OpAdaptor, @@ -490,8 +497,15 @@ class CIRToLLVMSwitchFlatOpLowering class CIRToLLVMGlobalOpLowering : public mlir::OpConversionPattern { + cir::LowerModule *lowerMod; + public: - using mlir::OpConversionPattern::OpConversionPattern; + CIRToLLVMGlobalOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + cir::LowerModule *lowerModule) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) { + setHasBoundedRewriteRecursion(); + } mlir::LogicalResult matchAndRewrite(cir::GlobalOp op, OpAdaptor, @@ -774,8 +788,13 @@ class CIRToLLVMGetMemberOpLowering class CIRToLLVMGetRuntimeMemberOpLowering : public mlir::OpConversionPattern { + cir::LowerModule *lowerMod; + public: - using mlir::OpConversionPattern::OpConversionPattern; + CIRToLLVMGetRuntimeMemberOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + cir::LowerModule *lowerModule) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {} mlir::LogicalResult matchAndRewrite(cir::GetRuntimeMemberOp op, OpAdaptor, diff --git a/clang/test/CIR/Lowering/data-member.cir b/clang/test/CIR/Lowering/data-member.cir index 1609ac43ff03..14f3138bde56 100644 --- a/clang/test/CIR/Lowering/data-member.cir +++ b/clang/test/CIR/Lowering/data-member.cir @@ -5,7 +5,10 @@ !s64i = !cir.int !structT = !cir.struct, !cir.int, !cir.int}> -module @test { +module @test attributes { + cir.triple = "x86_64-unknown-linux-gnu", + llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" +} { cir.global external @pt_member = #cir.data_member<1> : !cir.data_member // MLIR: llvm.mlir.global external @pt_member(4 : i64) {addr_space = 0 : i32} : i64 // LLVM: @pt_member = global i64 4 @@ -15,8 +18,8 @@ module @test { cir.return %0 : !cir.data_member } // MLIR: llvm.func @constant() -> i64 - // MLIR-NEXT: %0 = llvm.mlir.constant(4 : i64) : i64 - // MLIR-NEXT: llvm.return %0 : i64 + // MLIR-NEXT: %[[#VAL:]] = llvm.mlir.constant(4 : i64) : i64 + // MLIR-NEXT: llvm.return %[[#VAL]] : i64 // MLIR-NEXT: } // LLVM: define i64 @constant() @@ -28,8 +31,8 @@ module @test { cir.return %0 : !cir.data_member } // MLIR: llvm.func @null_constant() -> i64 - // MLIR-NEXT: %0 = llvm.mlir.constant(-1 : i64) : i64 - // MLIR-NEXT: llvm.return %0 : i64 + // MLIR-NEXT: %[[#VAL:]] = llvm.mlir.constant(-1 : i64) : i64 + // MLIR-NEXT: llvm.return %[[#VAL]] : i64 // MLIR-NEXT: } // LLVM: define i64 @null_constant() !dbg !7 { @@ -40,13 +43,15 @@ module @test { %0 = cir.get_runtime_member %arg0[%arg1 : !cir.data_member] : !cir.ptr -> !cir.ptr cir.return %0 : !cir.ptr } - // MLIR: llvm.func @get_runtime_member(%arg0: !llvm.ptr, %arg1: i64) -> !llvm.ptr - // MLIR-NEXT: %0 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, i8 - // MLIR-NEXT: llvm.return %0 : !llvm.ptr + // MLIR: llvm.func @get_runtime_member(%[[ARG0:.+]]: !llvm.ptr, %[[ARG1:.+]]: i64) -> !llvm.ptr + // MLIR-NEXT: %[[#PTR:]] = llvm.bitcast %[[ARG0]] : !llvm.ptr to !llvm.ptr + // MLIR-NEXT: %[[#VAL:]] = llvm.getelementptr %[[#PTR]][%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 + // MLIR-NEXT: %[[#RET:]] = llvm.bitcast %[[#VAL]] : !llvm.ptr to !llvm.ptr + // MLIR-NEXT: llvm.return %[[#RET]] : !llvm.ptr // MLIR-NEXT: } - // LLVM: define ptr @get_runtime_member(ptr %0, i64 %1) - // LLVM-NEXT: %3 = getelementptr i8, ptr %0, i64 %1 - // LLVM-NEXT: ret ptr %3 + // LLVM: define ptr @get_runtime_member(ptr %[[ARG0:.+]], i64 %[[ARG1:.+]]) + // LLVM-NEXT: %[[#VAL:]] = getelementptr i8, ptr %[[ARG0]], i64 %[[ARG1]] + // LLVM-NEXT: ret ptr %[[#VAL]] // LLVM-NEXT: } } \ No newline at end of file