Skip to content

Commit

Permalink
[CIR] Data member pointer comparison and casts
Browse files Browse the repository at this point in the history
This patch adds CIRGen and LLVM lowering support for the following language
features related to pointers to data members:

  - Comparisons between pointers to data members.
  - Casting from pointers to data members to boolean.
  - Reinterpret casts between pointers to data members.
  • Loading branch information
Lancern committed Jan 5, 2025
1 parent 04d7dcf commit 9404c5b
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 15 deletions.
3 changes: 2 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def CK_FloatComplexToIntegralComplex
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
def CK_IntegralComplexToFloatComplex
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;
def CK_MemberPtrToBoolean : I32EnumAttrCase<"member_ptr_to_bool", 25>;

def CastKind : I32EnumAttr<
"CastKind",
Expand All @@ -135,7 +136,7 @@ def CastKind : I32EnumAttr<
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
CK_IntegralComplexToFloatComplex]> {
CK_IntegralComplexToFloatComplex, CK_MemberPtrToBoolean]> {
let cppNamespace = "::cir";
}

Expand Down
22 changes: 17 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
};

if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
assert(0 && "not implemented");
assert(E->getOpcode() == BO_EQ || E->getOpcode() == BO_NE);
mlir::Value lhs = CGF.emitScalarExpr(E->getLHS());
mlir::Value rhs = CGF.emitScalarExpr(E->getRHS());
cir::CmpOpKind kind = ClangCmpToCIRCmp(E->getOpcode());
Result =
Builder.createCompare(CGF.getLoc(E->getExprLoc()), kind, lhs, rhs);
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
BinOpInfo BOInfo = emitBinOps(E);
mlir::Value LHS = BOInfo.LHS;
Expand Down Expand Up @@ -1741,8 +1746,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
auto Ty = mlir::cast<cir::DataMemberType>(CGF.getCIRType(DestTy));
return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc()));
}
case CK_ReinterpretMemberPointer:
llvm_unreachable("NYI");
case CK_ReinterpretMemberPointer: {
mlir::Value src = Visit(E);
return Builder.createBitcast(CGF.getLoc(E->getExprLoc()), src,
CGF.getCIRType(DestTy));
}
case CK_BaseToDerivedMemberPointer:
case CK_DerivedToBaseMemberPointer: {
mlir::Value src = Visit(E);
Expand Down Expand Up @@ -1875,8 +1883,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return emitPointerToBoolConversion(Visit(E), E->getType());
case CK_FloatingToBoolean:
return emitFloatToBoolConversion(Visit(E), CGF.getLoc(E->getExprLoc()));
case CK_MemberPointerToBoolean:
llvm_unreachable("NYI");
case CK_MemberPointerToBoolean: {
mlir::Value memPtr = Visit(E);
return Builder.createCast(CGF.getLoc(CE->getSourceRange()),
cir::CastKind::member_ptr_to_bool, memPtr,
ConvertType(DestTy));
}
case CK_FloatingComplexToReal:
case CK_IntegralComplexToReal:
case CK_FloatingComplexToBoolean:
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ LogicalResult cir::CastOp::verify() {
return success();
}

// Handle the data member pointer types.
if (mlir::isa<cir::DataMemberType>(srcType) &&
mlir::isa<cir::DataMemberType>(resType))
return success();

// This is the only cast kind where we don't want vector types to decay
// into the element type.
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||
Expand Down Expand Up @@ -704,6 +709,13 @@ LogicalResult cir::CastOp::verify() {
<< "requires !cir.complex<!cir.float> type for result";
return success();
}
case cir::CastKind::member_ptr_to_bool: {
if (!mlir::isa<cir::DataMemberType>(srcType))
return emitOpError() << "requires !cir.data_member type for source";
if (!mlir::isa<cir::BoolType>(resType))
return emitOpError() << "requires !cir.bool type for result";
return success();
}
}

llvm_unreachable("Unknown CastOp kind?");
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ class CIRCXXABI {
virtual mlir::Value
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value
lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;
};

/// Creates an Itanium-family ABI.
Expand Down
52 changes: 48 additions & 4 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ class ItaniumCXXABI : public CIRCXXABI {
mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;
};

} // namespace
Expand All @@ -89,18 +101,23 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
return false;
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
static mlir::Type getABITypeForDataMember(LowerModule &lowerMod) {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
const clang::TargetInfo &target = LM.getTarget();
const clang::TargetInfo &target = lowerMod.getTarget();
clang::TargetInfo::IntType ptrdiffTy =
target.getPtrDiffType(clang::LangAS::Default);
return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy),
return cir::IntType::get(lowerMod.getMLIRContext(),
target.getTypeWidth(ptrdiffTy),
target.isTypeSigned(ptrdiffTy));
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
return getABITypeForDataMember(LM);
}

mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const {
Expand Down Expand Up @@ -175,6 +192,33 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
/*isDerivedToBase=*/false, builder);
}

mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const {
return builder.create<cir::CmpOp>(op.getLoc(), op.getKind(), loweredLhs,
loweredRhs);
}

mlir::Value
ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return builder.create<cir::CastOp>(op.getLoc(), loweredDstTy,
cir::CastKind::bitcast, loweredSrc);
}

mlir::Value
ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
// Itanium C++ ABI 2.3:
// A NULL pointer is represented as -1.
auto nullAttr = cir::IntAttr::get(getABITypeForDataMember(LM), -1);
auto nullValue = builder.create<cir::ConstantOp>(op.getLoc(), nullAttr);
return builder.create<cir::CmpOp>(op.getLoc(), cir::CmpOpKind::ne, loweredSrc,
nullValue);
}

CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
switch (LM.getCXXABIKind()) {
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't
Expand Down
35 changes: 32 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,8 +1179,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
}
case cir::CastKind::bitcast: {
auto dstTy = castOp.getType();
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy = getTypeConverter()->convertType(dstTy);

if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType())) {
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
castOp, llvmDstTy, src, rewriter);
rewriter.replaceOp(castOp, loweredResult);
return mlir::success();
}
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
llvm_unreachable("NYI");

auto llvmSrcVal = adaptor.getOperands().front();
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
llvmSrcVal);
return mlir::success();
Expand All @@ -1204,6 +1214,16 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
llvmSrcVal);
break;
}
case cir::CastKind::member_ptr_to_bool: {
mlir::Value loweredResult;
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
llvm_unreachable("NYI");
else
loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast(
castOp, src, rewriter);
rewriter.replaceOp(castOp, loweredResult);
break;
}
default: {
return castOp.emitError("Unhandled cast kind: ")
<< castOp.getKindAttrName();
Expand Down Expand Up @@ -2748,6 +2768,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
cir::CmpOp cmpOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto type = cmpOp.getLhs().getType();

if (mlir::isa<cir::DataMemberType>(type)) {
assert(lowerMod && "lowering module is not available");
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
rewriter.replaceOp(cmpOp, loweredResult);
return mlir::success();
}

mlir::Value llResult;

// Lower to LLVM comparison op.
Expand Down Expand Up @@ -3963,6 +3992,8 @@ void populateCIRToLLVMConversionPatterns(
patterns.add<
// clang-format off
CIRToLLVMBaseDataMemberOpLowering,
CIRToLLVMCastOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMConstantOpLowering,
CIRToLLVMDerivedDataMemberOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering,
Expand Down Expand Up @@ -3994,10 +4025,8 @@ void populateCIRToLLVMConversionPatterns(
CIRToLLVMBrOpLowering,
CIRToLLVMByteswapOpLowering,
CIRToLLVMCallOpLowering,
CIRToLLVMCastOpLowering,
CIRToLLVMCatchParamOpLowering,
CIRToLLVMClearCacheOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMCmpThreeWayOpLowering,
CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexImagOpLowering,
Expand Down
18 changes: 16 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,17 @@ class CIRToLLVMBrCondOpLowering
};

class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
cir::LowerModule *lowerMod;

mlir::Type convertTy(mlir::Type ty) const;

public:
using mlir::OpConversionPattern<cir::CastOp>::OpConversionPattern;
CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::CastOp op, OpAdaptor,
Expand Down Expand Up @@ -615,8 +622,15 @@ class CIRToLLVMShiftOpLowering
};

class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::CmpOp>::OpConversionPattern;
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::CmpOp op, OpAdaptor,
Expand Down
26 changes: 26 additions & 0 deletions clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,29 @@ auto derived_to_base_zero_offset(int Derived::*ptr) -> int Base1::* {
// LLVM-NEXT: %[[#ret:]] = load i64, ptr %[[#ret_slot]]
// LLVM-NEXT: ret i64 %[[#ret]]
}

struct Foo {
int a;
};

struct Bar {
int a;
};

bool to_bool(int Foo::*x) {
return x;
}

// CIR-LABEL: @_Z7to_boolM3Fooi
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cast(member_ptr_to_bool, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.bool
// CIR: }

auto bitcast(int Foo::*x) {
return reinterpret_cast<int Bar::*>(x);
}

// CIR-LABEL: @_Z7bitcastM3Fooi
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cast(bitcast, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.data_member<!s32i in !ty_Bar>
// CIR: }
44 changes: 44 additions & 0 deletions clang/test/CIR/CodeGen/pointer-to-data-member-cmp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir --check-prefix=CIR %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s

struct Foo {
int a;
};

struct Bar {
int a;
};

bool eq(int Foo::*x, int Foo::*y) {
return x == y;
}

// CIR-LABEL: @_Z2eqM3FooiS0_
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cmp(eq, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
// CIR: }

// LLVM-LABEL: @_Z2eqM3FooiS0_
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %{{.+}} = icmp eq i64 %[[#x]], %[[#y]]
// LLVM: }

bool ne(int Foo::*x, int Foo::*y) {
return x != y;
}

// CIR-LABEL: @_Z2neM3FooiS0_
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cmp(ne, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
// CIR: }

// LLVM-LABEL: @_Z2neM3FooiS0_
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#x]], %[[#y]]
// LLVM: }

0 comments on commit 9404c5b

Please sign in to comment.