Skip to content

Commit

Permalink
[CIR][NFC] move data member pointer lowering to CXXABI
Browse files Browse the repository at this point in the history
This patch moves the lowering code for data member pointers from the conversion
patterns to the implementation of CXXABI because this part should be ABI-
specific.
  • Loading branch information
Lancern committed Nov 15, 2024
1 parent c10f493 commit edc33f9
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 74 deletions.
24 changes: 24 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
#define LLVM_CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRCXXABI_H

#include "LowerFunctionInfo.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Target/AArch64.h"

namespace cir {
Expand Down Expand Up @@ -59,6 +65,24 @@ class CIRCXXABI {
/// Returns how an argument of the given record type should be passed.
/// FIXME(cir): This expects a CXXRecordDecl! Not any record type.
virtual RecordArgABI getRecordArgABI(const StructType RD) const = 0;

/// Get the LLVM type corresponding to the given data member pointer type.
virtual mlir::Type
getLLVMTypeForDataMember(cir::DataMemberType type,
mlir::TypeConverter &typeConverter) const = 0;

/// Lower the given data member pointer constant to its corresponding LLVM
/// constant as an attribute.
virtual mlir::Attribute
lowerDataMemberConstant(cir::DataMemberAttr attr,
const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const = 0;

/// Lower the given cir.get_runtime_member op to its equivalent LLVM op.
virtual mlir::Operation *
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const = 0;
};

/// Creates an Itanium-family ABI.
Expand Down
56 changes: 56 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "../LoweringPrepareCXXABI.h"
#include "CIRCXXABI.h"
#include "LowerModule.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/ErrorHandling.h"

namespace cir {
Expand Down Expand Up @@ -51,6 +52,19 @@ class ItaniumCXXABI : public CIRCXXABI {
cir_cconv_assert(!cir::MissingFeatures::recordDeclCanPassInRegisters());
return RAA_Default;
}

mlir::Type
getLLVMTypeForDataMember(cir::DataMemberType type,
mlir::TypeConverter &typeConverter) const override;

mlir::Attribute lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const override;

mlir::Operation *
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const override;
};

} // namespace
Expand All @@ -67,6 +81,48 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
return false;
}

mlir::Type ItaniumCXXABI::getLLVMTypeForDataMember(
cir::DataMemberType type, mlir::TypeConverter &typeConverter) const {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
const clang::TargetInfo &target = LM.getTarget();
unsigned width =
target.getTypeWidth(target.getPtrDiffType(clang::LangAS::Default));
return mlir::IntegerType::get(type.getContext(), width);
}

mlir::Attribute ItaniumCXXABI::lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const {
uint64_t memberOffset;
if (attr.isNullPtr()) {
// Itanium C++ ABI 2.3:
// A NULL pointer is represented as -1.
memberOffset = -1ull;
} else {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
auto memberIndex = attr.getMemberIndex().value();
memberOffset =
attr.getType().getClsTy().getElementOffset(layout, memberIndex);
}

auto underlyingIntTy = mlir::IntegerType::get(
attr.getContext(), layout.getTypeSizeInBits(attr.getType()));
return mlir::IntegerAttr::get(underlyingIntTy, memberOffset);
}

mlir::Operation *ItaniumCXXABI::lowerGetRuntimeMember(
cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const {
auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8);
return builder.create<mlir::LLVM::GEPOp>(
op.getLoc(), loweredResultTy, llvmElementTy, loweredAddr, loweredMember);
}

CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
switch (LM.getCXXABIKind()) {
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't
Expand Down
125 changes: 55 additions & 70 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1509,28 +1509,6 @@ bool hasTrailingZeros(cir::ConstArrayAttr attr) {
}));
}

static mlir::Attribute
lowerDataMemberAttr(mlir::ModuleOp moduleOp, cir::DataMemberAttr attr,
const mlir::TypeConverter &typeConverter) {
mlir::DataLayout layout{moduleOp};

uint64_t memberOffset;
if (attr.isNullPtr()) {
// TODO(cir): the numerical value of a null data member pointer is
// ABI-specific and should be queried through ABI.
assert(!MissingFeatures::targetCodeGenInfoGetNullPointer());
memberOffset = -1ull;
} else {
auto memberIndex = attr.getMemberIndex().value();
memberOffset =
attr.getType().getClsTy().getElementOffset(layout, memberIndex);
}

auto underlyingIntTy = mlir::IntegerType::get(
moduleOp->getContext(), layout.getTypeSizeInBits(attr.getType()));
return mlir::IntegerAttr::get(underlyingIntTy, memberOffset);
}

mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
cir::ConstantOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -1590,9 +1568,11 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
}
attr = op.getValue();
} else if (mlir::isa<cir::DataMemberType>(op.getType())) {
assert(lowerMod && "lower module is not available");
auto dataMember = mlir::cast<cir::DataMemberAttr>(op.getValue());
attr = lowerDataMemberAttr(op->getParentOfType<mlir::ModuleOp>(),
dataMember, *typeConverter);
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
attr = lowerMod->getCXXABI().lowerDataMemberConstant(dataMember, layout,
*typeConverter);
}
// TODO(cir): constant arrays are currently just pushed into the stack using
// the store instruction, instead of being stored as global variables and
Expand Down Expand Up @@ -1851,8 +1831,8 @@ mlir::LogicalResult CIRToLLVMVAArgOpLowering::matchAndRewrite(
return op.emitError("cir.vaarg lowering is NYI");
}

/// Returns the name used for the linkage attribute. This *must* correspond
/// to the name of the attribute in ODS.
/// Returns the name used for the linkage attribute. This *must* correspond
/// to the name of the attribute in ODS.
StringRef CIRToLLVMFuncOpLowering::getLinkageAttrNameString() {
return "linkage";
}
Expand Down Expand Up @@ -1886,8 +1866,8 @@ void CIRToLLVMFuncOpLowering::lowerFuncAttributes(
}
}

/// When do module translation, we can only translate LLVM-compatible types.
/// Here we lower possible OpenCLKernelMetadataAttr to use the converted type.
/// When do module translation, we can only translate LLVM-compatible types.
/// Here we lower possible OpenCLKernelMetadataAttr to use the converted type.
void CIRToLLVMFuncOpLowering::lowerFuncOpenCLKernelMetadata(
mlir::NamedAttribute &extraAttrsEntry) const {
const auto attrKey = cir::OpenCLKernelMetadataAttr::getMnemonic();
Expand Down Expand Up @@ -2100,8 +2080,8 @@ mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite(
return mlir::success();
}

/// Replace CIR global with a region initialized LLVM global and update
/// insertion point to the end of the initializer block.
/// Replace CIR global with a region initialized LLVM global and update
/// insertion point to the end of the initializer block.
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
const auto llvmType = getTypeConverter()->convertType(op.getSymType());
Expand Down Expand Up @@ -2196,8 +2176,10 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
return mlir::success();
} else if (auto dataMemberAttr =
mlir::dyn_cast<cir::DataMemberAttr>(init.value())) {
init = lowerDataMemberAttr(op->getParentOfType<mlir::ModuleOp>(),
dataMemberAttr, *typeConverter);
assert(lowerMod && "lower module is not available");
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
init = lowerMod->getCXXABI().lowerDataMemberConstant(dataMemberAttr, layout,
*typeConverter);
} else if (const auto structAttr =
mlir::dyn_cast<cir::ConstStructAttr>(init.value())) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
Expand Down Expand Up @@ -3225,11 +3207,11 @@ mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
mlir::LogicalResult CIRToLLVMGetRuntimeMemberOpLowering::matchAndRewrite(
cir::GetRuntimeMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto llvmResTy = getTypeConverter()->convertType(op.getType());
auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8);

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, llvmResTy, llvmElementTy, adaptor.getAddr(), adaptor.getMember());
assert(lowerMod && "lowering module is not available");
mlir::Type llvmResTy = getTypeConverter()->convertType(op.getType());
mlir::Operation *llvmOp = lowerMod->getCXXABI().lowerGetRuntimeMember(
op, llvmResTy, adaptor.getAddr(), adaptor.getMember(), rewriter);
rewriter.replaceOp(op, llvmOp);
return mlir::success();
}

Expand Down Expand Up @@ -3838,14 +3820,17 @@ mlir::LogicalResult CIRToLLVMSignBitOpLowering::matchAndRewrite(

void populateCIRToLLVMConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter,
mlir::DataLayout &dataLayout,
mlir::DataLayout &dataLayout, cir::LowerModule *lowerModule,
llvm::StringMap<mlir::LLVM::GlobalOp> &stringGlobalsMap,
llvm::StringMap<mlir::LLVM::GlobalOp> &argStringGlobalsMap,
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap) {
patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext());
patterns.add<CIRToLLVMAllocaOpLowering>(converter, dataLayout,
stringGlobalsMap, argStringGlobalsMap,
argsVarMap, patterns.getContext());
patterns.add<CIRToLLVMConstantOpLowering, CIRToLLVMGlobalOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering>(
converter, patterns.getContext(), lowerModule);
patterns.add<
CIRToLLVMCmpOpLowering, CIRToLLVMSelectOpLowering,
CIRToLLVMBitClrsbOpLowering, CIRToLLVMBitClzOpLowering,
Expand All @@ -3858,40 +3843,39 @@ void populateCIRToLLVMConversionPatterns(
CIRToLLVMTryCallOpLowering, CIRToLLVMEhInflightOpLowering,
CIRToLLVMUnaryOpLowering, CIRToLLVMBinOpLowering,
CIRToLLVMBinOpOverflowOpLowering, CIRToLLVMShiftOpLowering,
CIRToLLVMLoadOpLowering, CIRToLLVMConstantOpLowering,
CIRToLLVMStoreOpLowering, CIRToLLVMFuncOpLowering,
CIRToLLVMCastOpLowering, CIRToLLVMGlobalOpLowering,
CIRToLLVMLoadOpLowering, CIRToLLVMStoreOpLowering,
CIRToLLVMFuncOpLowering, CIRToLLVMCastOpLowering,
CIRToLLVMGetGlobalOpLowering, CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexRealOpLowering, CIRToLLVMComplexImagOpLowering,
CIRToLLVMComplexRealPtrOpLowering, CIRToLLVMComplexImagPtrOpLowering,
CIRToLLVMVAStartOpLowering, CIRToLLVMVAEndOpLowering,
CIRToLLVMVACopyOpLowering, CIRToLLVMVAArgOpLowering,
CIRToLLVMBrOpLowering, CIRToLLVMGetMemberOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering, CIRToLLVMSwitchFlatOpLowering,
CIRToLLVMPtrDiffOpLowering, CIRToLLVMCopyOpLowering,
CIRToLLVMMemCpyOpLowering, CIRToLLVMMemChrOpLowering,
CIRToLLVMAbsOpLowering, CIRToLLVMExpectOpLowering,
CIRToLLVMVTableAddrPointOpLowering, CIRToLLVMVecCreateOpLowering,
CIRToLLVMVecCmpOpLowering, CIRToLLVMVecSplatOpLowering,
CIRToLLVMVecTernaryOpLowering, CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecShuffleOpLowering, CIRToLLVMStackSaveOpLowering,
CIRToLLVMUnreachableOpLowering, CIRToLLVMTrapOpLowering,
CIRToLLVMInlineAsmOpLowering, CIRToLLVMSetBitfieldOpLowering,
CIRToLLVMGetBitfieldOpLowering, CIRToLLVMPrefetchOpLowering,
CIRToLLVMObjSizeOpLowering, CIRToLLVMIsConstantOpLowering,
CIRToLLVMCmpThreeWayOpLowering, CIRToLLVMMemCpyOpLowering,
CIRToLLVMSwitchFlatOpLowering, CIRToLLVMPtrDiffOpLowering,
CIRToLLVMCopyOpLowering, CIRToLLVMMemCpyOpLowering,
CIRToLLVMMemChrOpLowering, CIRToLLVMAbsOpLowering,
CIRToLLVMExpectOpLowering, CIRToLLVMVTableAddrPointOpLowering,
CIRToLLVMVecCreateOpLowering, CIRToLLVMVecCmpOpLowering,
CIRToLLVMVecSplatOpLowering, CIRToLLVMVecTernaryOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering, CIRToLLVMVecShuffleOpLowering,
CIRToLLVMStackSaveOpLowering, CIRToLLVMUnreachableOpLowering,
CIRToLLVMTrapOpLowering, CIRToLLVMInlineAsmOpLowering,
CIRToLLVMSetBitfieldOpLowering, CIRToLLVMGetBitfieldOpLowering,
CIRToLLVMPrefetchOpLowering, CIRToLLVMObjSizeOpLowering,
CIRToLLVMIsConstantOpLowering, CIRToLLVMCmpThreeWayOpLowering,
CIRToLLVMReturnAddrOpLowering, CIRToLLVMClearCacheOpLowering,
CIRToLLVMEhTypeIdOpLowering, CIRToLLVMCatchParamOpLowering,
CIRToLLVMResumeOpLowering, CIRToLLVMAllocExceptionOpLowering,
CIRToLLVMFreeExceptionOpLowering, CIRToLLVMThrowOpLowering,
CIRToLLVMLLVMIntrinsicCallOpLowering, CIRToLLVMAssumeOpLowering,
CIRToLLVMAssumeAlignedOpLowering, CIRToLLVMAssumeSepStorageOpLowering,
CIRToLLVMBaseClassAddrOpLowering, CIRToLLVMDerivedClassAddrOpLowering,
CIRToLLVMVTTAddrPointOpLowering, CIRToLLVMIsFPClassOpLowering,
CIRToLLVMAbsOpLowering, CIRToLLVMMemMoveOpLowering,
CIRToLLVMMemSetOpLowering, CIRToLLVMMemCpyInlineOpLowering,
CIRToLLVMSignBitOpLowering, CIRToLLVMPtrMaskOpLowering
CIRToLLVMMemCpyOpLowering, CIRToLLVMIsConstantOpLowering,
CIRToLLVMCmpThreeWayOpLowering, CIRToLLVMReturnAddrOpLowering,
CIRToLLVMClearCacheOpLowering, CIRToLLVMEhTypeIdOpLowering,
CIRToLLVMCatchParamOpLowering, CIRToLLVMResumeOpLowering,
CIRToLLVMAllocExceptionOpLowering, CIRToLLVMFreeExceptionOpLowering,
CIRToLLVMThrowOpLowering, CIRToLLVMLLVMIntrinsicCallOpLowering,
CIRToLLVMAssumeOpLowering, CIRToLLVMAssumeAlignedOpLowering,
CIRToLLVMAssumeSepStorageOpLowering, CIRToLLVMBaseClassAddrOpLowering,
CIRToLLVMDerivedClassAddrOpLowering, CIRToLLVMVTTAddrPointOpLowering,
CIRToLLVMIsFPClassOpLowering, CIRToLLVMAbsOpLowering,
CIRToLLVMMemMoveOpLowering, CIRToLLVMMemSetOpLowering,
CIRToLLVMMemCpyInlineOpLowering, CIRToLLVMSignBitOpLowering,
CIRToLLVMPtrMaskOpLowering
#define GET_BUILTIN_LOWERING_LIST
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
#undef GET_BUILTIN_LOWERING_LIST
Expand Down Expand Up @@ -3934,9 +3918,10 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter,

return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS);
});
converter.addConversion([&](cir::DataMemberType type) -> mlir::Type {
return mlir::IntegerType::get(type.getContext(),
dataLayout.getTypeSizeInBits(type));
converter.addConversion([&, lowerModule](
cir::DataMemberType type) -> mlir::Type {
assert(lowerModule && "CXXABI is not available");
return lowerModule->getCXXABI().getLLVMTypeForDataMember(type, converter);
});
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
auto ty = converter.convertType(type.getEltType());
Expand Down Expand Up @@ -4272,8 +4257,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> argsVarMap;

populateCIRToLLVMConversionPatterns(patterns, converter, dataLayout,
stringGlobalsMap, argStringGlobalsMap,
argsVarMap);
lowerModule.get(), stringGlobalsMap,
argStringGlobalsMap, argsVarMap);
mlir::populateFuncToLLVMConversionPatterns(converter, patterns);

mlir::ConversionTarget target(getContext());
Expand Down
21 changes: 18 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,13 @@ class CIRToLLVMStoreOpLowering

class CIRToLLVMConstantOpLowering
: public mlir::OpConversionPattern<cir::ConstantOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::ConstantOp>::OpConversionPattern;
CIRToLLVMConstantOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}

mlir::LogicalResult
matchAndRewrite(cir::ConstantOp op, OpAdaptor,
Expand Down Expand Up @@ -480,8 +485,13 @@ class CIRToLLVMSwitchFlatOpLowering

class CIRToLLVMGlobalOpLowering
: public mlir::OpConversionPattern<cir::GlobalOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
CIRToLLVMGlobalOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}

mlir::LogicalResult
matchAndRewrite(cir::GlobalOp op, OpAdaptor,
Expand Down Expand Up @@ -764,8 +774,13 @@ class CIRToLLVMGetMemberOpLowering

class CIRToLLVMGetRuntimeMemberOpLowering
: public mlir::OpConversionPattern<cir::GetRuntimeMemberOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::GetRuntimeMemberOp>::OpConversionPattern;
CIRToLLVMGetRuntimeMemberOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}

mlir::LogicalResult
matchAndRewrite(cir::GetRuntimeMemberOp op, OpAdaptor,
Expand Down
5 changes: 4 additions & 1 deletion clang/test/CIR/Lowering/data-member.cir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
!s64i = !cir.int<s, 64>
!structT = !cir.struct<struct "Point" {!cir.int<s, 32>, !cir.int<s, 32>, !cir.int<s, 32>}>

module @test {
module @test attributes {
cir.triple = "x86_64-unknown-linux-gnu",
llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
} {
cir.global external @pt_member = #cir.data_member<1> : !cir.data_member<!s32i in !structT>
// MLIR: llvm.mlir.global external @pt_member(4 : i64) {addr_space = 0 : i32} : i64
// LLVM: @pt_member = global i64 4
Expand Down

0 comments on commit edc33f9

Please sign in to comment.