Skip to content

Commit

Permalink
add Arm64 ABILowering for lowerVAArgOp
Browse files Browse the repository at this point in the history
  • Loading branch information
ghehg committed Apr 29, 2024
1 parent e197d4e commit a7b095d
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 20 deletions.
13 changes: 13 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,19 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createCast(mlir::cir::CastKind::int_to_ptr, src, newTy);
}

mlir::Value createGetMemberOp(mlir::Location &loc, mlir::Value structPtr,
const char *fldName, unsigned idx) {

assert(structPtr.getType().isa<mlir::cir::PointerType>());
auto structBaseTy =
structPtr.getType().cast<mlir::cir::PointerType>().getPointee();
assert(structBaseTy.isa<mlir::cir::StructType>());
auto fldTy = structBaseTy.cast<mlir::cir::StructType>().getMembers()[idx];
auto fldPtrTy = ::mlir::cir::PointerType::get(getContext(), fldTy);
return create<mlir::cir::GetMemberOp>(loc, fldPtrTy, structPtr, fldName,
idx);
}

mlir::Value createPtrToInt(mlir::Value src, mlir::Type newTy) {
return createCast(mlir::cir::CastKind::ptr_to_int, src, newTy);
}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_clang_library(MLIRCIRTransforms
LifetimeCheck.cpp
LoweringPrepare.cpp
LoweringPrepareItaniumCXXABI.cpp
LoweringPrepareArm64CXXABI.cpp
MergeCleanups.cpp
DropAST.cpp
IdiomRecognizer.cpp
Expand Down
26 changes: 21 additions & 5 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {

void runOnOp(Operation *op);
void lowerThreeWayCmpOp(CmpThreeWayOp op);
void lowerVAArgOp(VAArgOp op);
void lowerGlobalOp(GlobalOp op);
void lowerDynamicCastOp(DynamicCastOp op);
void lowerStdFindOp(StdFindOp op);
Expand Down Expand Up @@ -110,11 +111,11 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
astCtx = c;
switch (c->getCXXABIKind()) {
case clang::TargetCXXABI::GenericItanium:
cxxABI.reset(::cir::LoweringPrepareCXXABI::createItaniumABI());
break;
case clang::TargetCXXABI::GenericAArch64:
case clang::TargetCXXABI::AppleARM64:
// TODO: this isn't quite right, clang uses AppleARM64CXXABI which
// inherits from ARMCXXABI. We'll have to follow suit.
cxxABI.reset(::cir::LoweringPrepareCXXABI::createItaniumABI());
cxxABI.reset(::cir::LoweringPrepareCXXABI::createArm64ABI());
break;

default:
Expand Down Expand Up @@ -320,6 +321,18 @@ static void canonicalizeIntrinsicThreeWayCmp(CIRBaseBuilderTy &builder,
op.erase();
}

void LoweringPreparePass::lowerVAArgOp(VAArgOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPoint(op);

auto res = cxxABI->lowerVAArg(builder, op);
if (res) {
op.replaceAllUsesWith(res);
op.erase();
}
return;
}

void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
Expand Down Expand Up @@ -601,6 +614,8 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
void LoweringPreparePass::runOnOp(Operation *op) {
if (auto threeWayCmp = dyn_cast<CmpThreeWayOp>(op)) {
lowerThreeWayCmpOp(threeWayCmp);
} else if (auto vaArgOp = dyn_cast<VAArgOp>(op)) {
lowerVAArgOp(vaArgOp);
} else if (auto getGlobal = dyn_cast<GlobalOp>(op)) {
lowerGlobalOp(getGlobal);
} else if (auto dynamicCast = dyn_cast<DynamicCastOp>(op)) {
Expand Down Expand Up @@ -633,8 +648,9 @@ void LoweringPreparePass::runOnOperation() {

SmallVector<Operation *> opsToTransform;
op->walk([&](Operation *op) {
if (isa<CmpThreeWayOp, GlobalOp, DynamicCastOp, StdFindOp, IterEndOp,
IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(op))
if (isa<CmpThreeWayOp, VAArgOp, GlobalOp, DynamicCastOp, StdFindOp,
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(
op))
opsToTransform.push_back(op);
});

Expand Down
127 changes: 127 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepareArm64CXXABI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
//====- LoweringPrepareArm64CXXABI.cpp - Arm64 ABI specific code --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides ARM64 C++ ABI specific code that is used during LLVMIR
// lowering prepare.
//
//===----------------------------------------------------------------------===//

#include "LoweringPrepareItaniumCXXABI.h"

using cir::LoweringPrepareCXXABI;
using namespace mlir;
using namespace mlir::cir;

namespace {
class LoweringPrepareArm64CXXABI : public LoweringPrepareItaniumCXXABI {
public:
mlir::Value lowerVAArg(CIRBaseBuilderTy &builder,
mlir::cir::VAArgOp op) override;
};
} // namespace

LoweringPrepareCXXABI *LoweringPrepareCXXABI::createArm64ABI() {
return new LoweringPrepareArm64CXXABI();
}

/* static mlir::Value createGetMemberOp(CIRBaseBuilderTy &builder,
mlir::Location &loc, mlir::Value structPtr,
const char *fldName, unsigned idx) {
assert(structPtr.getType().isa<mlir::cir::PointerType>());
auto structBaseTy =
structPtr.getType().cast<mlir::cir::PointerType>().getPointee();
assert(structBaseTy.isa<mlir::cir::StructType>());
auto fldTy = structBaseTy.cast<mlir::cir::StructType>().getMembers()[idx];
auto fldPtrTy = ::mlir::cir::PointerType::get(builder.getContext(), fldTy);
return builder.create<mlir::cir::GetMemberOp>(loc, fldPtrTy, structPtr,
fldName, idx);
} */

mlir::Value LoweringPrepareArm64CXXABI::lowerVAArg(CIRBaseBuilderTy &builder,
mlir::cir::VAArgOp op) {
auto loc = op->getLoc();
auto valist = op->getOperand(0);
auto opResTy = op.getType();
bool isFloatingType =
opResTy.isa<mlir::cir::SingleType, mlir::cir::DoubleType>();
auto offsP = builder.createGetMemberOp(loc, valist,
isFloatingType ? "vr_offs" : "gr_offs",
isFloatingType ? 4 : 3);
auto offs = builder.create<mlir::cir::LoadOp>(loc, offsP);
auto boolTy = builder.getBoolTy();
auto zeroValue = builder.create<mlir::cir::ConstantOp>(
loc, offs.getType(), mlir::cir::IntAttr::get(offs.getType(), 0));
auto cmpRes = builder.create<mlir::cir::CmpOp>(loc, boolTy, CmpOpKind::ge,
offs, zeroValue);
auto curInsertionP = builder.saveInsertionPoint();
auto currentBlock = builder.getInsertionBlock();

auto maybeRegBlock = builder.createBlock(builder.getBlock()->getParent());
auto inRegBlock = builder.createBlock(builder.getBlock()->getParent());
auto onStackBlock = builder.createBlock(builder.getBlock()->getParent());

builder.restoreInsertionPoint(curInsertionP);
builder.create<mlir::cir::BrCondOp>(loc, cmpRes, onStackBlock, maybeRegBlock);
auto newEndBlock = currentBlock->splitBlock(op);

builder.setInsertionPointToEnd(onStackBlock);
auto stackP = builder.createGetMemberOp(loc, valist, "stack", 0);
auto stack = builder.create<mlir::cir::LoadOp>(loc, stackP);
auto ptrDiffTy =
mlir::cir::IntType::get(builder.getContext(), 64, /*signed=*/false);
auto eight = builder.create<mlir::cir::ConstantOp>(
loc, ptrDiffTy, mlir::cir::IntAttr::get(ptrDiffTy, 8));
auto i8Ty = IntegerType::get(builder.getContext(), 8);
auto i8PtrTy = PointerType::get(builder.getContext(), i8Ty);
auto castStack =
builder.createCast(mlir::cir::CastKind::bitcast, stack, i8PtrTy);
auto newStackAsi8Ptr = builder.create<mlir::cir::PtrStrideOp>(
loc, castStack.getType(), castStack, eight);
auto newStack = builder.createCast(mlir::cir::CastKind::bitcast,
newStackAsi8Ptr, stack.getType());
builder.createStore(loc, newStack, stackP);
builder.create<mlir::cir::BrOp>(loc, mlir::ValueRange{stack}, newEndBlock);

builder.setInsertionPointToEnd(maybeRegBlock);
auto boundaryValue = builder.create<mlir::cir::ConstantOp>(
loc, offs.getType(),
mlir::cir::IntAttr::get(offs.getType(), isFloatingType ? 16 : 8));
auto newRegsOffs = builder.create<mlir::cir::BinOp>(
loc, offs.getType(), mlir::cir::BinOpKind::Add, offs, boundaryValue);
builder.createStore(loc, newRegsOffs, offsP);
auto maybeRegCmpRes = builder.create<mlir::cir::CmpOp>(
loc, boolTy, CmpOpKind::le, newRegsOffs, zeroValue);
builder.create<mlir::cir::BrCondOp>(loc, maybeRegCmpRes, inRegBlock,
onStackBlock);

builder.setInsertionPointToEnd(inRegBlock);
auto regTopP = builder.createGetMemberOp(loc, valist,
isFloatingType ? "vr_top" : "gr_top",
isFloatingType ? 2 : 1);
auto regTop = builder.create<mlir::cir::LoadOp>(loc, regTopP);
auto castRegTop =
builder.createCast(mlir::cir::CastKind::bitcast, regTop, i8PtrTy);
auto resAsInt8P = builder.create<mlir::cir::PtrStrideOp>(
loc, castRegTop.getType(), castRegTop, offs);
auto resAsVoidP = builder.createCast(mlir::cir::CastKind::bitcast, resAsInt8P,
regTop.getType());
builder.create<mlir::cir::BrOp>(loc, mlir::ValueRange{resAsVoidP},
newEndBlock);

// generate additional instructions for end block
builder.setInsertionPoint(op);
newEndBlock->addArgument(stack.getType(), loc);
auto resP = newEndBlock->getArgument(0);
assert(resP.getType().isa<mlir::cir::PointerType>());
auto opResPTy = PointerType::get(builder.getContext(), opResTy);
auto castResP =
builder.createCast(mlir::cir::CastKind::bitcast, resP, opResPTy);
auto res = builder.create<mlir::cir::LoadOp>(loc, castResP);
return res.getResult();
}
3 changes: 3 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepareCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ namespace cir {
class LoweringPrepareCXXABI {
public:
static LoweringPrepareCXXABI *createItaniumABI();
static LoweringPrepareCXXABI *createArm64ABI();

virtual mlir::Value lowerDynamicCast(CIRBaseBuilderTy &builder,
mlir::cir::DynamicCastOp op) = 0;

virtual mlir::Value lowerVAArg(CIRBaseBuilderTy &builder,
mlir::cir::VAArgOp op) = 0;
virtual ~LoweringPrepareCXXABI() {}
};

Expand Down
27 changes: 14 additions & 13 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//====- LoweringPrepareItaniumCXXABI.h - Itanium ABI specific code --------===//
//====- LoweringPrepareItaniumCXXABI.cpp - Itanium ABI specific code
//--------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -11,26 +12,20 @@
//
//===----------------------------------------------------------------------===//

#include "LoweringPrepareItaniumCXXABI.h"
#include "../IR/MissingFeatures.h"
#include "LoweringPrepareCXXABI.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"

using namespace cir;

namespace {

class LoweringPrepareItaniumCXXABI : public LoweringPrepareCXXABI {
public:
mlir::Value lowerDynamicCast(CIRBaseBuilderTy &builder,
mlir::cir::DynamicCastOp op) override;
};

} // namespace
using cir::CIRBaseBuilderTy;
using cir::LoweringPrepareCXXABI;
using cir::MissingFeatures;
using namespace mlir;
using namespace mlir::cir;

LoweringPrepareCXXABI *LoweringPrepareCXXABI::createItaniumABI() {
return new LoweringPrepareItaniumCXXABI();
Expand Down Expand Up @@ -115,3 +110,9 @@ LoweringPrepareItaniumCXXABI::lowerDynamicCast(CIRBaseBuilderTy &builder,
})
.getResult();
}

mlir::Value LoweringPrepareItaniumCXXABI::lowerVAArg(CIRBaseBuilderTy &builder,
mlir::cir::VAArgOp op) {
// TODO: implement va_arg for more generic Itanium ABI.
return mlir::Value();
}
24 changes: 24 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//====- LoweringPrepareItaniumCXXABI.h - Itanium ABI specific code --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides Itanium C++ ABI specific code that is used during LLVMIR
// lowering prepare.
//
//===----------------------------------------------------------------------===//

#include "LoweringPrepareCXXABI.h"

using cir::CIRBaseBuilderTy;

class LoweringPrepareItaniumCXXABI : public cir::LoweringPrepareCXXABI {
public:
mlir::Value lowerDynamicCast(CIRBaseBuilderTy &builder,
mlir::cir::DynamicCastOp op) override;
mlir::Value lowerVAArg(CIRBaseBuilderTy &builder,
mlir::cir::VAArgOp op) override;
};
71 changes: 71 additions & 0 deletions clang/test/CIR/CodeGen/var-arg-float.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-lowering-prepare %s -o %t.cir 2>&1 | FileCheck %s -check-prefix=BEFORE
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-lowering-prepare %s -o %t.cir 2>&1 | FileCheck %s -check-prefix=AFTER

#include <stdarg.h>

double f1(int n, ...) {
va_list valist;
va_start(valist, n);
double res = va_arg(valist, double);
va_end(valist);
return res;
}

// BEFORE: !ty_22__va_list22 = !cir.struct<struct "__va_list" {!cir.ptr<!cir.void>, !cir.ptr<!cir.void>, !cir.ptr<!cir.void>, !cir.int<s, 32>, !cir.int<s, 32>}
// BEFORE: cir.func @f1(%arg0: !s32i, ...) -> !cir.double
// BEFORE: [[RETP:%.*]] = cir.alloca !cir.double, cir.ptr <!cir.double>, ["__retval"]
// BEFORE: [[RESP:%.*]] = cir.alloca !cir.double, cir.ptr <!cir.double>, ["res", init]
// BEFORE: cir.va.start [[VARLIST:%.*]] : !cir.ptr<!ty_22__va_list22>
// BEFORE: [[TMP0:%.*]] = cir.va.arg [[VARLIST]] : (!cir.ptr<!ty_22__va_list22>) -> !cir.double
// BEFORE: cir.store [[TMP0]], [[RESP]] : !cir.double, cir.ptr <!cir.double>
// BEFORE: cir.va.end [[VARLIST]] : !cir.ptr<!ty_22__va_list22>
// BEFORE: [[RES:%.*]] = cir.load [[RESP]] : cir.ptr <!cir.double>, !cir.double
// BEFORE: cir.store [[RES]], [[RETP]] : !cir.double, cir.ptr <!cir.double>
// BEFORE: [[RETV:%.*]] = cir.load [[RETP]] : cir.ptr <!cir.double>, !cir.double
// BEFORE: cir.return [[RETV]] : !cir.double

// AFTER: !ty_22__va_list22 = !cir.struct<struct "__va_list" {!cir.ptr<!cir.void>, !cir.ptr<!cir.void>, !cir.ptr<!cir.void>, !cir.int<s, 32>, !cir.int<s, 32>}
// AFTER: cir.func @f1(%arg0: !s32i, ...) -> !cir.double
// AFTER: [[RETP:%.*]] = cir.alloca !cir.double, cir.ptr <!cir.double>, ["__retval"]
// AFTER: [[RESP:%.*]] = cir.alloca !cir.double, cir.ptr <!cir.double>, ["res", init]
// AFTER: cir.va.start [[VARLIST:%.*]] : !cir.ptr<!ty_22__va_list22>
// AFTER: [[VR_OFFS_P:%.*]] = cir.get_member [[VARLIST]][4] {name = "vr_offs"} : !cir.ptr<!ty_22__va_list22> -> !cir.ptr<!s32i>
// AFTER: [[VR_OFFS:%.*]] = cir.load [[VR_OFFS_P]] : cir.ptr <!s32i>, !s32i
// AFTER: [[ZERO:%.*]] = cir.const(#cir.int<0> : !s32i) : !s32i
// AFTER: [[CMP0:%.*]] = cir.cmp(ge, [[VR_OFFS]], [[ZERO]]) : !s32i, !cir.bool
// AFTER-NEXT: cir.brcond [[CMP0]] [[BB_ON_STACK:\^bb.*]], [[BB_MAY_REG:\^bb.*]]

// AFTER-NEXT: [[BB_END:\^bb.*]]([[BLK_ARG:%.*]]: !cir.ptr<!void>): // 2 preds: [[BB_IN_REG:\^bb.*]], [[BB_ON_STACK]]
// AFTER-NEXT: [[TMP0:%.*]] = cir.cast(bitcast, [[BLK_ARG]] : !cir.ptr<!void>), !cir.ptr<!cir.double>
// AFTER-NEXT: [[TMP1:%.*]] = cir.load [[TMP0]] : cir.ptr <!cir.double>, !cir.double
// AFTER: cir.store [[TMP1]], [[RESP]] : !cir.double, cir.ptr <!cir.double>
// AFTER: cir.va.end [[VARLIST]] : !cir.ptr<!ty_22__va_list22>
// AFTER: [[RES:%.*]] = cir.load [[RESP]] : cir.ptr <!cir.double>, !cir.double
// AFTER: cir.store [[RES]], [[RETP]] : !cir.double, cir.ptr <!cir.double>
// AFTER: [[RETV:%.*]] = cir.load [[RETP]] : cir.ptr <!cir.double>, !cir.double
// AFTER: cir.return [[RETV]] : !cir.double

// AFTER: [[BB_MAY_REG]]: // pred: [[BB_BEGIN:\^bb.*]]
// AFTER-NEXT: [[SIXTEEN:%.*]] = cir.const(#cir.int<16> : !s32i) : !s32i
// AFTER-NEXT: [[NEW_REG_OFFS:%.*]] = cir.binop(add, [[VR_OFFS]], [[SIXTEEN]]) : !s32i
// AFTER-NEXT: cir.store [[NEW_REG_OFFS]], [[VR_OFFS_P]] : !s32i, cir.ptr <!s32i>
// AFTER-NEXT: [[CMP1:%.*]] = cir.cmp(le, [[NEW_REG_OFFS]], [[ZERO]]) : !s32i, !cir.bool
// AFTER-NEXT: cir.brcond [[CMP1]] [[BB_IN_REG]], [[BB_ON_STACK]]

// AFTER: [[BB_IN_REG]]: // pred: [[BB_MAY_REG]]
// AFTER-NEXT: [[VR_TOP_P:%.*]] = cir.get_member [[VARLIST]][2] {name = "vr_top"} : !cir.ptr<!ty_22__va_list22> -> !cir.ptr<!cir.ptr<!void>>
// AFTER-NEXT: [[VR_TOP:%.*]] = cir.load [[VR_TOP_P]] : cir.ptr <!cir.ptr<!void>>, !cir.ptr<!void>
// AFTER-NEXT: [[TMP2:%.*]] = cir.cast(bitcast, [[VR_TOP]] : !cir.ptr<!void>), !cir.ptr<i8>
// AFTER-NEXT: [[TMP3:%.*]] = cir.ptr_stride([[TMP2]] : !cir.ptr<i8>, [[VR_OFFS]] : !s32i), !cir.ptr<i8>
// AFTER-NEXT: [[IN_REG_OUTPUT:%.*]] = cir.cast(bitcast, [[TMP3]] : !cir.ptr<i8>), !cir.ptr<!void>
// AFTER-NEXT: cir.br [[BB_END]]([[IN_REG_OUTPUT]] : !cir.ptr<!void>)

// AFTER: [[BB_ON_STACK]]: // 2 preds: [[BB_BEGIN]], [[BB_MAY_REG]]
// AFTER-NEXT: [[STACK_P:%.*]] = cir.get_member [[VARLIST]][0] {name = "stack"} : !cir.ptr<!ty_22__va_list22> -> !cir.ptr<!cir.ptr<!void>>
// AFTER-NEXT: [[STACK_V:%.*]] = cir.load [[STACK_P]] : cir.ptr <!cir.ptr<!void>>, !cir.ptr<!void>
// AFTER-NEXT: [[EIGHT_IN_PTR_ARITH:%.*]] = cir.const(#cir.int<8> : !u64i) : !u64i
// AFTER-NEXT: [[TMP4:%.*]] = cir.cast(bitcast, [[STACK_V]] : !cir.ptr<!void>), !cir.ptr<i8>
// AFTER-NEXT: [[TMP5:%.*]] = cir.ptr_stride([[TMP4]] : !cir.ptr<i8>, [[EIGHT_IN_PTR_ARITH]] : !u64i), !cir.ptr<i8>
// AFTER-NEXT: [[NEW_STACK_V:%.*]] = cir.cast(bitcast, [[TMP5]] : !cir.ptr<i8>), !cir.ptr<!void>
// AFTER-NEXT: cir.store [[NEW_STACK_V]], [[STACK_P]] : !cir.ptr<!void>, cir.ptr <!cir.ptr<!void>>
// AFTER-NEXT: cir.br [[BB_END]]([[STACK_V]] : !cir.ptr<!void>)
Loading

0 comments on commit a7b095d

Please sign in to comment.