Skip to content

Commit

Permalink
[CIR] Derived-to-base conversions
Browse files Browse the repository at this point in the history
Implement derived-to-base address conversions for non-virtual base
classes.  The code gen for this situation was only implemented when the
offset was zero, and it simply created a `cir.base_class_addr` op for
which no lowering or other transformation existed.

Conversion to a virtual base class is not yet implemented.

Two new fields are added to the `cir.base_class_addr` operation: the byte
offset of the necessary adjustment, and a boolean flag indicating
whether the source operand may be null.  The offset is easy to compute
in the front end while the entire path of intermediate classes is still
available.  It would be difficult for the back end to recompute the
offset.  So it is best to store it in the operation.  The null-pointer
check is best done late in the lowering process.  But whether or not the
null-pointer check is needed is only known by the front end; the back
end can't figure that out.  So that flag needs to be stored in the
operation.

`CIRGenFunction::getAddressOfBaseClass` was largely rewritten.  The code
path no longer matches the equivalent function in the LLVM IR code gen,
because the generated ClangIR is quite different from the generated LLVM
IR.

`cir.base_class_addr` is lowered to LLVM IR as a `getelementptr`
operation.  If a null-pointer check is needed, then that is wrapped in a
`select` operation.

When generating code for a constructor or destructor, an incorrect
`cir.ptr_stride` op was used to convert the pointer to a base class.
The code was assuming that the operand of `cir.ptr_stride` was measured
in bytes; the operand is the number elements, not the number of bytes.
So the base class constructor was being called on the wrong chunk of
memory.  Fix this by using a `cir.base_class_addr` op instead of
`cir.ptr_stride` in this scenario.

The use of `cir.ptr_stride` in `ApplyNonVirtualAndVirtualOffset` had the
same problem.  Continue using `cir.ptr_stride` here, but temporarily
convert the pointer to type `char*` so the pointer is adjusted
correctly.

Adjust the expected results of three existing tests in response to these
changes.

Add two new tests, one code gen and one lowering, to cover the case
where a base class is at a non-zero offset.
  • Loading branch information
dkolsen-pgi committed Oct 3, 2024
1 parent 9975749 commit c21cbfa
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 61 deletions.
22 changes: 18 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2960,23 +2960,37 @@ def BaseClassAddrOp : CIR_Op<"base_class_addr"> {
let summary = "Get the base class address for a class/struct";
let description = [{
The `cir.base_class_addr` operaration gets the address of a particular
base class given a derived class pointer.
non-virtual base class given a derived class pointer. The offset in bytes
of the base class must be passed in, since it is easier for the front end
to calculate that than the MLIR passes. The operation contains a flag for
whether or not the operand may be nullptr. That depends on the context and
cannot be known by the operation, and that information affects how the
operation is lowered.

Example:
```c++
struct Base { };
struct Derived : Base { };
Derived d;
Base& b = d;
```
will generate
```mlir
TBD
%3 = cir.base_class_addr (%1 : !cir.ptr<!ty_Derived> nonnull) [0] -> !cir.ptr<!ty_Base>
```
}];

let arguments = (ins
Arg<CIR_PointerType, "derived class pointer", [MemRead]>:$derived_addr);
Arg<CIR_PointerType, "derived class pointer", [MemRead]>:$derived_addr,
IndexAttr:$offset, UnitAttr:$assume_not_null);

let results = (outs Res<CIR_PointerType, "">:$base_addr);

let assemblyFormat = [{
`(`
$derived_addr `:` qualified(type($derived_addr))
`)` `->` qualified(type($base_addr)) attr-dict
(`nonnull` $assume_not_null^)?
`)` `[` $offset `]` `->` qualified(type($base_addr)) attr-dict
}];

// FIXME: add verifier.
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,14 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}

cir::Address createBaseClassAddr(mlir::Location loc, cir::Address addr,
mlir::Type destType) {
mlir::Type destType, unsigned offset,
bool assumeNotNull) {
if (destType == addr.getElementType())
return addr;

auto ptrTy = getPointerTo(destType);
auto baseAddr =
create<mlir::cir::BaseClassAddrOp>(loc, ptrTy, addr.getPointer());

auto baseAddr = create<mlir::cir::BaseClassAddrOp>(
loc, ptrTy, addr.getPointer(), mlir::APInt(64, offset), assumeNotNull);
return Address(baseAddr, ptrTy, addr.getAlignment());
}

Expand Down
83 changes: 39 additions & 44 deletions clang/lib/CIR/CodeGen/CIRGenClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,17 +530,9 @@ Address CIRGenFunction::getAddressOfDirectBaseInCompleteClass(
else
Offset = Layout.getBaseClassOffset(Base);

// Shift and cast down to the base type.
// TODO: for complete types, this should be possible with a GEP.
Address V = This;
if (!Offset.isZero()) {
mlir::Value OffsetVal = builder.getSInt32(Offset.getQuantity(), loc);
mlir::Value VBaseThisPtr = builder.create<mlir::cir::PtrStrideOp>(
loc, This.getPointer().getType(), This.getPointer(), OffsetVal);
V = Address(VBaseThisPtr, CXXABIThisAlignment);
}
V = builder.createElementBitCast(loc, V, ConvertType(Base));
return V;
return builder.createBaseClassAddr(loc, This, ConvertType(Base),
Offset.getQuantity(),
/*assume_not_null=*/true);
}

static void buildBaseInitializer(mlir::Location loc, CIRGenFunction &CGF,
Expand Down Expand Up @@ -680,10 +672,17 @@ static Address ApplyNonVirtualAndVirtualOffset(
baseOffset = virtualOffset;
}

// Apply the base offset.
// Apply the base offset. cir.ptr_stride adjusts by a number of elements,
// not bytes. So the pointer must be cast to a byte pointer and back.

mlir::Value ptr = addr.getPointer();
ptr = CGF.getBuilder().create<mlir::cir::PtrStrideOp>(loc, ptr.getType(), ptr,
baseOffset);
mlir::Type charPtrType = CGF.CGM.UInt8PtrTy;
mlir::Value charPtr = CGF.getBuilder().createCast(
mlir::cir::CastKind::bitcast, ptr, charPtrType);
mlir::Value adjusted = CGF.getBuilder().create<mlir::cir::PtrStrideOp>(
loc, charPtrType, charPtr, baseOffset);
ptr = CGF.getBuilder().createCast(mlir::cir::CastKind::bitcast, adjusted,
ptr.getType());

// If we have a virtual component, the alignment of the result will
// be relative only to the known alignment of that vbase.
Expand Down Expand Up @@ -1481,7 +1480,7 @@ CIRGenFunction::getAddressOfBaseClass(Address Value,
// *start* with a step down to the correct virtual base subobject,
// and hence will not require any further steps.
if ((*Start)->isVirtual()) {
llvm_unreachable("NYI");
llvm_unreachable("NYI: Cast to virtual base class");
}

// Compute the static offset of the ultimate destination within its
Expand All @@ -1494,55 +1493,51 @@ CIRGenFunction::getAddressOfBaseClass(Address Value,
// For now, that's limited to when the derived type is final.
// TODO: "devirtualize" this for accesses to known-complete objects.
if (VBase && Derived->hasAttr<FinalAttr>()) {
llvm_unreachable("NYI");
const ASTRecordLayout &layout = getContext().getASTRecordLayout(Derived);
CharUnits vBaseOffset = layout.getVBaseClassOffset(VBase);
NonVirtualOffset += vBaseOffset;
VBase = nullptr; // we no longer have a virtual step
}

// Get the base pointer type.
auto BaseValueTy = convertType((PathEnd[-1])->getType());
assert(!MissingFeatures::addressSpace());
// auto BasePtrTy = builder.getPointerTo(BaseValueTy);
// QualType DerivedTy = getContext().getRecordType(Derived);
// CharUnits DerivedAlign = CGM.getClassPointerAlignment(Derived);

// If the static offset is zero and we don't have a virtual step,
// just do a bitcast; null checks are unnecessary.
if (NonVirtualOffset.isZero() && !VBase) {
// If there is no virtual base, use cir.base_class_addr. It takes care of
// the adjustment and the null pointer check.
if (!VBase) {
if (sanitizePerformTypeCheck()) {
llvm_unreachable("NYI");
llvm_unreachable("NYI: sanitizePerformTypeCheck");
}
return builder.createBaseClassAddr(getLoc(Loc), Value, BaseValueTy);
return builder.createBaseClassAddr(getLoc(Loc), Value, BaseValueTy,
NonVirtualOffset.getQuantity(),
/*assumeNotNull=*/not NullCheckValue);
}

// Skip over the offset (and the vtable load) if we're supposed to
// null-check the pointer.
if (NullCheckValue) {
llvm_unreachable("NYI");
}

if (sanitizePerformTypeCheck()) {
llvm_unreachable("NYI");
}
// Conversion to a virtual base. cir.base_class_addr can't handle this.
// Generate the code to look up the address in the virtual table.

// Compute the virtual offset.
mlir::Value VirtualOffset{};
if (VBase) {
llvm_unreachable("NYI");
}
llvm_unreachable("NYI: Cast to virtual base class");

// Apply both offsets.
// This is just an outline of what the code might look like, since I can't
// actually test it.
#if 0
mlir::Value VirtualOffset = ...; // This is a dynamic expression. Creating
// it requires calling an ABI-specific
// function.
Value = ApplyNonVirtualAndVirtualOffset(getLoc(Loc), *this, Value,
NonVirtualOffset, VirtualOffset,
Derived, VBase);
// Cast to the destination type.
Value = builder.createElementBitCast(Value.getPointer().getLoc(), Value,
BaseValueTy);

// Build a phi if we needed a null check.
if (sanitizePerformTypeCheck()) {
// Do something here
}
if (NullCheckValue) {
llvm_unreachable("NYI");
// Convert to 'derivedPtr == nullptr ? nullptr : basePtr'
}
#endif

llvm_unreachable("NYI");
return Value;
}

Expand Down
36 changes: 35 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,39 @@ class CIRPtrStrideOpLowering
}
};

class CIRBaseClassAddrOpLowering
: public mlir::OpConversionPattern<mlir::cir::BaseClassAddrOp> {
public:
using mlir::OpConversionPattern<
mlir::cir::BaseClassAddrOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const auto resultType =
getTypeConverter()->convertType(baseClassOp.getType());
mlir::Value derivedAddr = adaptor.getDerivedAddr();
llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {
adaptor.getOffset().getZExtValue()};
mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
mlir::IntegerType::Signless);
if (baseClassOp.getAssumeNotNull()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
baseClassOp, resultType, byteType, derivedAddr, offset);
} else {
auto loc = baseClassOp.getLoc();
mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr,
rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType()));
mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>(
loc, resultType, byteType, derivedAddr, offset);
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull,
derivedAddr, adjusted);
}
return mlir::success();
}
};

class CIRBrCondOpLowering
: public mlir::OpConversionPattern<mlir::cir::BrCondOp> {
public:
Expand Down Expand Up @@ -3823,7 +3856,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering, CIRClearCacheOpLowering, CIRUndefOpLowering,
CIREhTypeIdOpLowering, CIRCatchParamOpLowering, CIRResumeOpLowering,
CIRAllocExceptionOpLowering, CIRThrowOpLowering, CIRIntrinsicCallLowering
CIRAllocExceptionOpLowering, CIRThrowOpLowering, CIRIntrinsicCallLowering,
CIRBaseClassAddrOpLowering
#define GET_BUILTIN_LOWERING_LIST
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
#undef GET_BUILTIN_LOWERING_LIST
Expand Down
33 changes: 29 additions & 4 deletions clang/test/CIR/CodeGen/derived-to-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void C3::Layer::Initialize() {
// CHECK: cir.func @_ZN2C35Layer10InitializeEv

// CHECK: cir.scope {
// CHECK: %2 = cir.base_class_addr(%1 : !cir.ptr<!ty_C33A3ALayer>) -> !cir.ptr<!ty_C23A3ALayer>
// CHECK: %2 = cir.base_class_addr(%1 : !cir.ptr<!ty_C33A3ALayer> nonnull) [0] -> !cir.ptr<!ty_C23A3ALayer>
// CHECK: %3 = cir.get_member %2[1] {name = "m_C1"} : !cir.ptr<!ty_C23A3ALayer> -> !cir.ptr<!cir.ptr<!ty_C2_>>
// CHECK: %4 = cir.load %3 : !cir.ptr<!cir.ptr<!ty_C2_>>, !cir.ptr<!ty_C2_>
// CHECK: %5 = cir.const #cir.ptr<null> : !cir.ptr<!ty_C2_>
Expand All @@ -99,7 +99,7 @@ enumy C3::Initialize() {

// CHECK: cir.store %arg0, %0 : !cir.ptr<!ty_C3_>, !cir.ptr<!cir.ptr<!ty_C3_>>
// CHECK: %2 = cir.load %0 : !cir.ptr<!cir.ptr<!ty_C3_>>, !cir.ptr<!ty_C3_>
// CHECK: %3 = cir.base_class_addr(%2 : !cir.ptr<!ty_C3_>) -> !cir.ptr<!ty_C2_>
// CHECK: %3 = cir.base_class_addr(%2 : !cir.ptr<!ty_C3_> nonnull) [0] -> !cir.ptr<!ty_C2_>
// CHECK: %4 = cir.call @_ZN2C210InitializeEv(%3) : (!cir.ptr<!ty_C2_>) -> !s32i

void vcall(C1 &c1) {
Expand Down Expand Up @@ -144,7 +144,7 @@ class B : public A {
// CHECK: %1 = cir.load deref %0 : !cir.ptr<!cir.ptr<!ty_B>>, !cir.ptr<!ty_B>
// CHECK: cir.scope {
// CHECK: %2 = cir.alloca !ty_A, !cir.ptr<!ty_A>, ["ref.tmp0"] {alignment = 8 : i64}
// CHECK: %3 = cir.base_class_addr(%1 : !cir.ptr<!ty_B>) -> !cir.ptr<!ty_A>
// CHECK: %3 = cir.base_class_addr(%1 : !cir.ptr<!ty_B> nonnull) [0] -> !cir.ptr<!ty_A>

// Call @A::A(A const&)
// CHECK: cir.call @_ZN1AC2ERKS_(%2, %3) : (!cir.ptr<!ty_A>, !cir.ptr<!ty_A>) -> ()
Expand All @@ -171,4 +171,29 @@ int test_ref() {
int x = 42;
C c(x);
return c.ref;
}
}

// Multiple base classes, to test non-zero offsets
struct Base1 { int a; };
struct Base2 { int b; };
struct Derived : Base1, Base2 { int c; };
void test_multi_base() {
Derived d;

Base2& bref = d; // no null check needed
// CHECK: %6 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived> nonnull) [4] -> !cir.ptr<!ty_Base2_>

Base2* bptr = &d; // has null pointer check
// CHECK: %7 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived>) [4] -> !cir.ptr<!ty_Base2_>

int a = d.a;
// CHECK: %8 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived> nonnull) [0] -> !cir.ptr<!ty_Base1_>
// CHECK: %9 = cir.get_member %8[0] {name = "a"} : !cir.ptr<!ty_Base1_> -> !cir.ptr<!s32i>

int b = d.b;
// CHECK: %11 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived> nonnull) [4] -> !cir.ptr<!ty_Base2_>
// CHECK: %12 = cir.get_member %11[0] {name = "b"} : !cir.ptr<!ty_Base2_> -> !cir.ptr<!s32i>

int c = d.c;
// CHECK: %14 = cir.get_member %0[2] {name = "c"} : !cir.ptr<!ty_Derived> -> !cir.ptr<!s32i>
}
8 changes: 5 additions & 3 deletions clang/test/CIR/CodeGen/multi-vtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ int main() {
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV5Child, vtable_index = 1, address_point_index = 2) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>
// CIR: %{{[0-9]+}} = cir.const #cir.int<8> : !s64i
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!ty_Child>, %{{[0-9]+}} : !s64i), !cir.ptr<!ty_Child>
// CIR: %11 = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!u8i>, %{{[0-9]+}} : !s64i), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!u8i>), !cir.ptr<!ty_Child>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.return
// CIR: }
Expand All @@ -68,7 +70,7 @@ int main() {

// LLVM-DAG: define linkonce_odr void @_ZN5ChildC2Ev(ptr %0)
// LLVM-DAG: store ptr getelementptr inbounds ({ [4 x ptr], [3 x ptr] }, ptr @_ZTV5Child, i32 0, i32 0, i32 2), ptr %{{[0-9]+}}, align 8
// LLVM-DAG: %{{[0-9]+}} = getelementptr %class.Child, ptr %3, i64 8
// LLVM-DAG: %{{[0-9]+}} = getelementptr i8, ptr %3, i64 8
// LLVM-DAG: store ptr getelementptr inbounds ({ [4 x ptr], [3 x ptr] }, ptr @_ZTV5Child, i32 0, i32 1, i32 2), ptr %{{[0-9]+}}, align 8
// LLVM-DAG: ret void
// }
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/vtable-rtti.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class B : public A
// CHECK: %0 = cir.alloca !cir.ptr<![[ClassB]]>, !cir.ptr<!cir.ptr<![[ClassB]]>>, ["this", init] {alignment = 8 : i64}
// CHECK: cir.store %arg0, %0 : !cir.ptr<![[ClassB]]>, !cir.ptr<!cir.ptr<![[ClassB]]>>
// CHECK: %1 = cir.load %0 : !cir.ptr<!cir.ptr<![[ClassB]]>>, !cir.ptr<![[ClassB]]>
// CHECK: %2 = cir.cast(bitcast, %1 : !cir.ptr<![[ClassB]]>), !cir.ptr<![[ClassA]]>
// CHECK: %2 = cir.base_class_addr(%1 : !cir.ptr<![[ClassB]]> nonnull) [0] -> !cir.ptr<![[ClassA]]>
// CHECK: cir.call @_ZN1AC2Ev(%2) : (!cir.ptr<![[ClassA]]>) -> ()
// CHECK: %3 = cir.vtable.address_point(@_ZTV1B, vtable_index = 0, address_point_index = 2) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>
// CHECK: %4 = cir.cast(bitcast, %1 : !cir.ptr<![[ClassB]]>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
Expand Down
28 changes: 28 additions & 0 deletions clang/test/CIR/Lowering/derived-to-base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM

struct Base1 { int a; };
struct Base2 { int b; };
struct Derived : Base1, Base2 { int c; };
void test_multi_base() {
Derived d;

Base2& bref = d; // no null check needed
// LLVM: %7 = getelementptr i8, ptr %1, i32 4

Base2* bptr = &d; // has null pointer check
// LLVM: %8 = icmp eq ptr %1, null
// LLVM: %9 = getelementptr i8, ptr %1, i32 4
// LLVM: %10 = select i1 %8, ptr %1, ptr %9

int a = d.a;
// LLVM: %11 = getelementptr i8, ptr %1, i32 0
// LLVM: %12 = getelementptr %struct.Base1, ptr %11, i32 0, i32 0

int b = d.b;
// LLVM: %14 = getelementptr i8, ptr %1, i32 4
// LLVM: %15 = getelementptr %struct.Base2, ptr %14, i32 0, i32 0

int c = d.c;
// LLVM: %17 = getelementptr %struct.Derived, ptr %1, i32 0, i32 2
}

0 comments on commit c21cbfa

Please sign in to comment.