Skip to content

Commit

Permalink
[CIR][ABI][Lowering] Supports function pointers in the calling conven…
Browse files Browse the repository at this point in the history
…tion lowering pass (#1003)

This PR adds initial function pointers support for the calling
convention lowering pass. This is a suggestion, so any other ideas are
welcome.

Several ideas was described in the #995 and basically what I'm trying to
do is to generate a clean CIR code without additional `bitcast`
operations for function pointers and without mix of lowered and initial
function types.

#### Problem
Looks like we can not just lower the function type and cast the value
since too many operations are involved. For instance, for the
next simple code: 
```
typedef struct {
  int a;
} S;

typedef int (*myfptr)(S);

int foo(S s) { return 42 + s.a; }

void bar() {
  myfptr a = foo;
}
```
we get the next CIR for the function `bar` , before the calling
convention lowering pass:
```
cir.func no_proto  @bar() extra(#fn_attr) {
    %0 = cir.alloca !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>, ["a", init]
    %1 = cir.get_global @foo : !cir.ptr<!cir.func<!s32i (!ty_S)>> 
    cir.store %1, %0 : !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>
    cir.return 
  } 
```
As one can see, first three operations depend on the function type. Once
`foo` is lowered, we need to fix `GetGlobalOp`:
otherwise the code will fail with the verification error since actual
`foo` type (lowered) differs from the one currently expected by the
`GetGlobalOp`.
First idea would just rewrite only the `GetGlobalOp` and insert a
bitcast after, so both `AllocaOp` and `StoreOp` would work witth proper
types.

Once the code will be more complex, we will need to take care about
possible use cases, e.g. if we use arrays, we will need to track array
accesses to it as well in order to insert this bitcast every time the
array element is needed.

One workaround I can think of: we fix the `GetGlobalOp` type and cast
from the lowered type to the initial, and cast back before the actual
call happens - but it doesn't sound as a good and clean approach (from
my point of view, of course).

So I suggest to use type converter and rewrite any operation that may
deal with function pointers and make sure it has a proper type, and we
don't have any unlowered function type in the program after the calling
convention lowering pass.

#### Implementation
I added lowering for `AllocaOp`, `GetGlobalOp`, and split the lowering
for `FuncOp` (former `CallConvLoweringPattern`) and lower `CallOp`
separately.
Frankly speaking, I tried to implement a pattern for each operation, but
for some reasons the tests are not passed for
windows and macOs in this case - something weird happens inside
`applyPatternsAndFold` function. I suspect it's due to two different
rewriters used - one in the `LoweringModule` and one in the mentioned
function.
So I decided to follow the same approach as it's done for the
`LoweringPrepare` pass and don't involve this complex rewriting
framework.

Next I will add a type converter for the struct type, patterns for
`ConstantOp` (for const arrays and `GlobalViewAttr`)
In the end of the day we'll have (at least I hope so) a clean CIR code
without any bitcasts for function pointers.

cc @sitio-couto  @bcardosolopes
  • Loading branch information
gitoleg authored and lanza committed Nov 4, 2024
1 parent 673338c commit 985e849
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 50 deletions.
126 changes: 76 additions & 50 deletions clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TargetLowering/LowerModule.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/MissingFeatures.h"
Expand All @@ -23,50 +23,93 @@
namespace mlir {
namespace cir {

//===----------------------------------------------------------------------===//
// Rewrite Patterns
//===----------------------------------------------------------------------===//

struct CallConvLoweringPattern : public OpRewritePattern<FuncOp> {
using OpRewritePattern<FuncOp>::OpRewritePattern;
FuncType getFuncPointerTy(mlir::Type typ) {
if (auto ptr = dyn_cast<PointerType>(typ))
return dyn_cast<FuncType>(ptr.getPointee());
return {};
}

LogicalResult matchAndRewrite(FuncOp op,
PatternRewriter &rewriter) const final {
llvm::TimeTraceScope scope("Call Conv Lowering Pass", op.getSymName().str());
bool isFuncPointerTy(mlir::Type typ) { return (bool)getFuncPointerTy(typ); }

const auto module = op->getParentOfType<mlir::ModuleOp>();
struct CallConvLowering {

auto modOp = op->getParentOfType<ModuleOp>();
std::unique_ptr<LowerModule> lowerModule =
createLowerModule(modOp, rewriter);
CallConvLowering(ModuleOp module)
: rewriter(module.getContext()),
lowerModule(createLowerModule(module, rewriter)) {}

// Rewrite function calls before definitions. This should be done before
// lowering the definition.
void lower(FuncOp op) {
// Fail the pass on unimplemented function users
const auto module = op->getParentOfType<mlir::ModuleOp>();
auto calls = op.getSymbolUses(module);
if (calls.has_value()) {
for (auto call : calls.value()) {
// FIXME(cir): Function pointers are ignored.
if (isa<GetGlobalOp>(call.getUser())) {
if (auto g = dyn_cast<GetGlobalOp>(call.getUser()))
rewriteGetGlobalOp(g);
else if (auto c = dyn_cast<CallOp>(call.getUser()))
lowerDirectCallOp(c, op);
else {
cir_cconv_assert_or_abort(!::cir::MissingFeatures::ABIFuncPtr(),
"NYI");
continue;
}

auto callOp = dyn_cast_or_null<CallOp>(call.getUser());
if (!callOp)
cir_cconv_unreachable("NYI empty callOp");
if (lowerModule->rewriteFunctionCall(callOp, op).failed())
return failure();
}
}

// TODO(cir): Instead of re-emmiting every load and store, bitcast arguments
// and return values to their ABI-specific counterparts when possible.
if (lowerModule->rewriteFunctionDefinition(op).failed())
return failure();
op.walk([&](CallOp c) {
if (c.isIndirect())
lowerIndirectCallOp(c);
});

return success();
lowerModule->rewriteFunctionDefinition(op);
}

private:
FuncType convert(FuncType t) {
auto &typs = lowerModule->getTypes();
return typs.getFunctionType(typs.arrangeFreeFunctionType(t));
}

mlir::Type convert(mlir::Type t) {
if (auto fTy = getFuncPointerTy(t))
return PointerType::get(rewriter.getContext(), convert(fTy));
return t;
}

void bitcast(Value src, Type newTy) {
if (src.getType() != newTy) {
auto cast =
rewriter.create<CastOp>(src.getLoc(), newTy, CastKind::bitcast, src);
rewriter.replaceAllUsesExcept(src, cast, cast);
}
}

void rewriteGetGlobalOp(GetGlobalOp op) {
auto resTy = op.getResult().getType();
if (isFuncPointerTy(resTy)) {
rewriter.setInsertionPoint(op);
auto newOp = rewriter.replaceOpWithNewOp<GetGlobalOp>(op, convert(resTy),
op.getName());
rewriter.setInsertionPointAfter(newOp);
bitcast(newOp, resTy);
}
}

void lowerDirectCallOp(CallOp op, FuncOp callee) {
lowerModule->rewriteFunctionCall(op, callee);
}

void lowerIndirectCallOp(CallOp op) {
cir_cconv_assert(op.isIndirect());

rewriter.setInsertionPoint(op);
auto typ = op.getIndirectCall().getType();
if (isFuncPointerTy(typ)) {
cir_cconv_unreachable("Indirect calls NYI");
}
}

private:
mlir::PatternRewriter rewriter;
std::unique_ptr<LowerModule> lowerModule;
};

//===----------------------------------------------------------------------===//
Expand All @@ -81,27 +124,10 @@ struct CallConvLoweringPass
StringRef getArgument() const override { return "cir-call-conv-lowering"; };
};

void populateCallConvLoweringPassPatterns(RewritePatternSet &patterns) {
patterns.add<CallConvLoweringPattern>(patterns.getContext());
}

void CallConvLoweringPass::runOnOperation() {

// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateCallConvLoweringPassPatterns(patterns);

// Collect operations to be considered by the pass.
SmallVector<Operation *, 16> ops;
getOperation()->walk([&](FuncOp op) { ops.push_back(op); });

// Configure rewrite to ignore new ops created during the pass.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;

// Apply patterns.
if (failed(applyOpPatternsAndFold(ops, std::move(patterns), config)))
signalPassFailure();
auto module = dyn_cast<ModuleOp>(getOperation());
CallConvLowering cc(module);
module.walk([&](FuncOp op) { cc.lower(op); });
}

} // namespace cir
Expand Down
18 changes: 18 additions & 0 deletions clang/test/CIR/CallConvLowering/x86_64/fptrs.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir-flat -fclangir-call-conv-lowering %s -o - | FileCheck %s

typedef struct {
int a;
} S;

typedef int (*myfptr)(S);

int foo(S s) { return 42 + s.a; }

// CHECK: cir.func {{.*@bar}}
// CHECK: %[[#V0:]] = cir.alloca !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>, ["a", init]
// CHECK: %[[#V1:]] = cir.get_global @foo : !cir.ptr<!cir.func<!s32i (!s32i)>>
// CHECK: %[[#V2:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.func<!s32i (!s32i)>>), !cir.ptr<!cir.func<!s32i (!ty_S)>>
// CHECK: cir.store %[[#V2]], %[[#V0]] : !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>
void bar() {
myfptr a = foo;
}

0 comments on commit 985e849

Please sign in to comment.