Skip to content

Commit

Permalink
[CIR][AArch64][Lowering] Support fields with structs containing const…
Browse files Browse the repository at this point in the history
…ant arrays or pointers (#1136)

This PR adds support for function arguments with structs that contain
constant arrays or pointers for AArch64.

For example, 
```
typedef struct {
  int a[42];
} CAT;

void pass_cat(CAT a) {}
```

As usual, the main ideas are gotten from the original
[CodeGen](https://github.com/llvm/clangir/blob/3aed38cf52e72cb51a907fad9dd53802f6505b81/clang/lib/AST/ASTContext.cpp#L1823),
and I have added a couple of tests.
  • Loading branch information
bruteforceboy authored Nov 25, 2024
1 parent 9b73052 commit 0ace889
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ clang::TypeInfo CIRLowerContext::getTypeInfoImpl(const mlir::Type T) const {
Align = Target->getDoubleAlign();
break;
}
if (mlir::isa<PointerType>(T)) {
Width = Target->getPointerWidth(clang::LangAS::Default);
Align = Target->getPointerAlign(clang::LangAS::Default);
break;
}
cir_cconv_unreachable("Unknown builtin type!");
break;
}
Expand Down Expand Up @@ -167,9 +172,28 @@ int64_t CIRLowerContext::toBits(clang::CharUnits CharSize) const {
return CharSize.getQuantity() * getCharWidth();
}

/// Performing the computation in CharUnits
/// instead of in bits prevents overflowing the uint64_t for some large arrays.
clang::TypeInfoChars getConstantArrayInfoInChars(const CIRLowerContext &ctx,
cir::ArrayType arrTy) {
clang::TypeInfoChars eltInfo = ctx.getTypeInfoInChars(arrTy.getEltType());
uint64_t tySize = arrTy.getSize();
assert((tySize == 0 || static_cast<uint64_t>(eltInfo.Width.getQuantity()) <=
(uint64_t)(-1) / tySize) &&
"Overflow in array type char size evaluation");
uint64_t width = eltInfo.Width.getQuantity() * tySize;
unsigned align = eltInfo.Align.getQuantity();
if (!ctx.getTargetInfo().getCXXABI().isMicrosoft() ||
ctx.getTargetInfo().getPointerWidth(clang::LangAS::Default) == 64)
width = llvm::alignTo(width, align);
return clang::TypeInfoChars(clang::CharUnits::fromQuantity(width),
clang::CharUnits::fromQuantity(align),
eltInfo.AlignRequirement);
}

clang::TypeInfoChars CIRLowerContext::getTypeInfoInChars(mlir::Type T) const {
if (auto arrTy = mlir::dyn_cast<ArrayType>(T))
cir_cconv_unreachable("NYI");
return getConstantArrayInfoInChars(*this, arrTy);
clang::TypeInfo Info = getTypeInfo(T);
return clang::TypeInfoChars(toCharUnitsFromBits(Info.Width),
toCharUnitsFromBits(Info.Align),
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,20 @@ S_PAD ret_s_pad() {
S_PAD s;
return s;
}

typedef struct {
int a[42];
} CAT;

// CHECK: cir.func @pass_cat(%arg0: !cir.ptr<!ty_CAT>
// CHECK: %[[#V0:]] = cir.alloca !cir.ptr<!ty_CAT>, !cir.ptr<!cir.ptr<!ty_CAT>>, [""] {alignment = 8 : i64}
// CHECK: cir.store %arg0, %[[#V0]] : !cir.ptr<!ty_CAT>, !cir.ptr<!cir.ptr<!ty_CAT>>
// CHECK: %[[#V1:]] = cir.load %[[#V0]] : !cir.ptr<!cir.ptr<!ty_CAT>>, !cir.ptr<!ty_CAT>
// CHECK: cir.return

// LLVM: void @pass_cat(ptr %[[#V0:]])
// LLVM: %[[#V2:]] = alloca ptr, i64 1, align 8
// LLVM: store ptr %[[#V0]], ptr %[[#V2]], align 8
// LLVM: %[[#V3:]] = load ptr, ptr %[[#V2]], align 8
// LLVM: ret void
void pass_cat(CAT a) {}
49 changes: 49 additions & 0 deletions clang/test/CIR/CallConvLowering/AArch64/ptr-fields.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: %clang_cc1 -triple aarch64-unknown-linux-gnu -fclangir -fclangir-call-conv-lowering -emit-cir-flat -mmlir --mlir-print-ir-after=cir-call-conv-lowering %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple aarch64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll -fclangir-call-conv-lowering
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM

typedef int (*myfptr)(int);

typedef struct {
myfptr f;
} A;

int foo(int x) { return x; }

// CIR: cir.func @passA(%arg0: !u64i
// CIR: %[[#V0:]] = cir.alloca !ty_A, !cir.ptr<!ty_A>, [""] {alignment = 4 : i64}
// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_A>), !cir.ptr<!u64i>
// CIR: cir.store %arg0, %[[#V1]] : !u64i, !cir.ptr<!u64i>
// CIR: %[[#V2:]] = cir.get_global @foo : !cir.ptr<!cir.func<!s32i (!s32i)>>
// CIR: %[[#V3:]] = cir.get_member %[[#V0]][0] {name = "f"} : !cir.ptr<!ty_A> -> !cir.ptr<!cir.ptr<!cir.func<!s32i (!s32i)>>>
// CIR: cir.store %[[#V2]], %[[#V3]] : !cir.ptr<!cir.func<!s32i (!s32i)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!s32i)>>>
// CIR: cir.return

// LLVM: void @passA(i64 %[[#V0:]])
// LLVM: %[[#V2:]] = alloca %struct.A, i64 1, align 4
// LLVM: store i64 %[[#V0]], ptr %[[#V2]], align 8
// LLVM: %[[#V3:]] = getelementptr %struct.A, ptr %[[#V2]], i32 0, i32 0
// LLVM: store ptr @foo, ptr %[[#V3]], align 8
// LLVM: ret void
void passA(A a) { a.f = foo; }

typedef struct {
int a;
} S_1;

typedef struct {
S_1* s;
} S_2;

// CIR: cir.func @passB(%arg0: !u64i
// CIR: %[[#V0:]] = cir.alloca !ty_S_2_, !cir.ptr<!ty_S_2_>, [""] {alignment = 4 : i64}
// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_S_2_>), !cir.ptr<!u64i>
// CIR: cir.store %arg0, %[[#V1]] : !u64i, !cir.ptr<!u64i>
// CIR: cir.return

// LLVM: void @passB(i64 %[[#V0:]])
// LLVM: %[[#V2:]] = alloca %struct.S_2, i64 1, align 4
// LLVM: store i64 %[[#V0]], ptr %[[#V2]], align 8
// LLVM: ret void
void passB(S_2 s) {}

0 comments on commit 0ace889

Please sign in to comment.