Skip to content

Commit

Permalink
[CIR][AArch64][ABI] fixes calls with union type
Browse files Browse the repository at this point in the history
  • Loading branch information
gitoleg committed Nov 13, 2024
1 parent a18a580 commit e369fc4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
54 changes: 43 additions & 11 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ mlir::Value createCoercedBitcast(mlir::Value Src, mlir::Type DestTy,
CastKind::bitcast, Src);
}


// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

mlir::Type getLargestMember(const CIRDataLayout& dataLayout, StructType s) {
mlir::Type typ;
for (auto t : s.getMembers()) {
if (!typ ||
dataLayout.getABITypeAlign(t) > dataLayout.getABITypeAlign(typ) ||
(dataLayout.getABITypeAlign(t) == dataLayout.getABITypeAlign(typ) &&
dataLayout.getTypeSizeInBits(t) > dataLayout.getTypeSizeInBits(typ)))
typ = t;
}
return typ;
}

/// Given a struct pointer that we are accessing some number of bytes out of it,
/// try to gep into the struct to get at its inner goodness. Dive as deep as
/// possible without entering an element with an in-memory size smaller than
Expand All @@ -66,6 +85,9 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,

mlir::Type FirstElt = SrcSTy.getMembers()[0];

if (SrcSTy.isUnion())
FirstElt = getLargestMember(CGF.LM.getDataLayout(), SrcSTy);

// If the first elt is at least as large as what we're looking for, or if the
// first element is the same size as the whole struct, we can enter it. The
// comparison must be made on the store size and not the alloca size. Using
Expand All @@ -75,10 +97,26 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,
FirstEltSize < CGF.LM.getDataLayout().getTypeStoreSize(SrcSTy))
return SrcPtr;

cir_cconv_assert_or_abort(
!cir::MissingFeatures::ABIEnterStructForCoercedAccess(), "NYI");
return SrcPtr; // FIXME: This is a temporary workaround for the assertion
// above.
auto& rw = CGF.getRewriter();
auto* ctxt = rw.getContext();
auto ptrTy = PointerType::get(ctxt, FirstElt);
if (mlir::isa<StructType>(SrcPtr.getType())) {
auto addr = SrcPtr;
if (auto load = mlir::dyn_cast<LoadOp>(SrcPtr.getDefiningOp()))
addr = load.getAddr();
cir_cconv_assert(mlir::isa<PointerType>(addr.getType()));
// we can not use getMemberOp here since we need a pointer to the first
// element. And in the case of unions we pick a type of the largest elt,
// that may or may not be the first one. Thus, getMemberOp verification
// may fail.
auto cast = createBitcast(addr, ptrTy, CGF);
SrcPtr = rw.create<LoadOp>(SrcPtr.getLoc(), cast);
}

if (auto sty = mlir::dyn_cast<StructType>(SrcPtr.getType()))
return enterStructPointerForCoercedAccess(SrcPtr, sty, DstSize, CGF);

return SrcPtr;
}

/// Convert a value Val to the specific Ty where both
Expand Down Expand Up @@ -191,12 +229,6 @@ void createCoercedStore(mlir::Value Src, mlir::Value Dst, bool DstIsVolatile,
}
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

/// Coerces a \param Src value to a value of type \param Ty.
///
/// This safely handles the case when the src type is smaller than the
Expand Down Expand Up @@ -235,7 +267,7 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty,
// extension or truncation to the desired type.
if ((mlir::isa<IntType>(Ty) || mlir::isa<PointerType>(Ty)) &&
(mlir::isa<IntType>(SrcTy) || mlir::isa<PointerType>(SrcTy))) {
cir_cconv_unreachable("NYI");
return coerceIntOrPtrToIntOrPtr(Src, Ty, CGF);
}

// If load is legal, just bitcast the src pointer.
Expand Down
32 changes: 31 additions & 1 deletion clang/test/CIR/CallConvLowering/AArch64/union.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,34 @@ void foo(U u) {}
U init() {
U u;
return u;
}
}

typedef union {

struct {
short a;
char b;
char c;
};

int x;
} A;

void passA(A x) {}

// CIR: cir.func {{.*@callA}}()
// CIR: %[[#V0:]] = cir.alloca !ty_A, !cir.ptr<!ty_A>, ["x"] {alignment = 4 : i64}
// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0:]] : !cir.ptr<!ty_A>), !cir.ptr<!s32i>
// CIR: %[[#V2:]] = cir.load %[[#V1]] : !cir.ptr<!s32i>, !s32i
// CIR: %[[#V3:]] = cir.cast(integral, %[[#V2]] : !s32i), !u64i
// CIR: cir.call @passA(%[[#V3]]) : (!u64i) -> ()

// LLVM: void @callA()
// LLVM: %[[#V0:]] = alloca %union.A, i64 1, align 4
// LLVM: %[[#V1:]] = load i32, ptr %[[#V0]], align 4
// LLVM: %[[#V2:]] = sext i32 %[[#V1]] to i64
// LLVM: call void @passA(i64 %[[#V2]])
void callA() {
A x;
passA(x);
}

0 comments on commit e369fc4

Please sign in to comment.