Skip to content

Commit

Permalink
[CIR][CodeGen][LowerToLLVM] Set calling convention for call ops (llvm…
Browse files Browse the repository at this point in the history
…#836)

This PR implements the CIRGen and Lowering part of calling convention
attribute of `cir.call`-like operations. Here we have **4 kinds of
operations**: (direct or indirect) x (`call` or `try_call`).

According to our need and feasibility of constructing a test case, this
PR includes:

* For CIRGen, only direct `call`. Until now, the only extra calling
conventions are SPIR ones, which cannot be set from source code manually
using attributes. Meanwhile, OpenCL C *does not allow* function pointers
or exceptions, therefore the only case remaining is direct call.
* For Lowering, direct and indirect `call`, but not any `try_call`.
Although it's possible to write all 4 kinds of calls with calling
convention in ClangIR assembly, exceptions is quite hard to write and
read. I prefer source-code-level test for it when it's available in the
future. For example, possibly C++ `thiscall` with exceptions.
* Extra: the verification of calling convention consistency for direct
`call` and direct `try_call`.

All unsupported cases are guarded by assertions or MLIR diags.
  • Loading branch information
seven-mile authored and smeenai committed Oct 9, 2024
1 parent a9a16d4 commit ead58dc
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 15 deletions.
19 changes: 12 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
mlir::cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
mlir::cir::FuncOp directFuncOp,
SmallVectorImpl<mlir::Value> &CIRCallArgs,
mlir::Operation *InvokeDest,
mlir::Operation *InvokeDest, mlir::cir::CallingConv callingConv,
mlir::cir::ExtraFuncAttributesAttr extraFnAttrs) {
auto &builder = CGF.getBuilder();

Expand All @@ -468,6 +468,8 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
}

mlir::cir::CallOp tryCallOp;
// TODO(cir): Set calling convention for `cir.try_call`.
assert(callingConv == mlir::cir::CallingConv::C && "NYI");
if (indirectFuncTy) {
tryCallOp = builder.createIndirectTryCallOp(callLoc, indirectFuncVal,
indirectFuncTy, CIRCallArgs);
Expand All @@ -484,12 +486,15 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
}

assert(builder.getInsertionBlock() && "expected valid basic block");
if (indirectFuncTy)
if (indirectFuncTy) {
// TODO(cir): Set calling convention for indirect calls.
assert(callingConv == mlir::cir::CallingConv::C && "NYI");
return builder.createIndirectCallOp(
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs,
mlir::cir::CallingConv::C, extraFnAttrs);
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs,
mlir::cir::CallingConv::C, extraFnAttrs);
}
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs, callingConv,
extraFnAttrs);
}

RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
Expand Down Expand Up @@ -765,9 +770,9 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
auto extraFnAttrs = mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), Attrs.getDictionary(builder.getContext()));

mlir::cir::CIRCallOpInterface callLikeOp =
buildCallLikeOp(*this, callLoc, indirectFuncTy, indirectFuncVal,
directFuncOp, CIRCallArgs, InvokeDest, extraFnAttrs);
mlir::cir::CIRCallOpInterface callLikeOp = buildCallLikeOp(
*this, callLoc, indirectFuncTy, indirectFuncVal, directFuncOp,
CIRCallArgs, InvokeDest, callingConv, extraFnAttrs);

if (E)
callLikeOp->setAttr(
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2704,6 +2704,12 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
<< op->getOperand(i).getType() << " for operand number " << i;
}

// Calling convention must match.
if (callIf.getCallingConv() != fn.getCallingConv())
return op->emitOpError("calling convention mismatch: expected ")
<< stringifyCallingConv(fn.getCallingConv()) << ", but provided "
<< stringifyCallingConv(callIf.getCallingConv());

// Void function must not return any results.
if (fnType.isVoid() && op->getNumResults() != 0)
return op->emitOpError("callee returns void but call has results");
Expand Down
29 changes: 21 additions & 8 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,18 +875,24 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
mlir::Block *landingPadBlock = nullptr) {
llvm::SmallVector<mlir::Type, 8> llvmResults;
auto cirResults = op->getResultTypes();
auto callIf = cast<mlir::cir::CIRCallOpInterface>(op);

if (converter->convertTypes(cirResults, llvmResults).failed())
return mlir::failure();

auto cconv = convertCallingConv(callIf.getCallingConv());

if (calleeAttr) { // direct call
if (landingPadBlock)
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
if (landingPadBlock) {
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
op, llvmResults, calleeAttr, callOperands, continueBlock,
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
else
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, llvmResults,
calleeAttr, callOperands);
newOp.setCConv(cconv);
} else {
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, llvmResults, calleeAttr, callOperands);
newOp.setCConv(cconv);
}
} else { // indirect call
assert(op->getOperands().size() &&
"operands list must no be empty for the indirect call");
Expand All @@ -899,14 +905,17 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
if (landingPadBlock) {
auto llvmFnTy =
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp));
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
op, llvmFnTy, mlir::FlatSymbolRefAttr{}, callOperands, continueBlock,
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
newOp.setCConv(cconv);
} else {
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op,
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp)),
callOperands);
newOp.setCConv(cconv);
}
}
return mlir::success();
}
Expand All @@ -932,6 +941,10 @@ class CIRTryCallLowering
mlir::LogicalResult
matchAndRewrite(mlir::cir::TryCallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (op.getCallingConv() != mlir::cir::CallingConv::C) {
return op.emitError(
"non-C calling convention is not implemented for try_call");
}
return rewriteToCallOrInvoke(
op.getOperation(), adaptor.getOperands(), rewriter, getTypeConverter(),
op.getCalleeAttr(), op.getCont(), op.getLandingPad());
Expand Down
4 changes: 4 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/spir-calling-conv.cl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ kernel void bar(global int *A);
// LLVM-DAG: define{{.*}} spir_kernel void @foo(
kernel void foo(global int *A) {
int id = get_dummy_id(0);
// CIR: %{{[0-9]+}} = cir.call @get_dummy_id(%2) : (!s32i) -> !s32i cc(spir_function)
// LLVM: %{{[a-z0-9_]+}} = call spir_func i32 @get_dummy_id(
A[id] = id;
bar(A);
// CIR: cir.call @bar(%8) : (!cir.ptr<!s32i, addrspace(offload_global)>) -> () cc(spir_kernel)
// LLVM: call spir_kernel void @bar(ptr addrspace(1)
}
15 changes: 15 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,22 @@ module {
!s32i = !cir.int<s, 32>

module {
cir.func @subroutine() cc(spir_function) {
cir.return
}

cir.func @call_conv_match() {
// expected-error@+1 {{'cir.call' op calling convention mismatch: expected spir_function, but provided spir_kernel}}
cir.call @subroutine(): () -> !cir.void cc(spir_kernel)
cir.return
}
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @test_bitcast_addrspace() {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["tmp"] {alignment = 4 : i64}
// expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}}
Expand Down
22 changes: 22 additions & 0 deletions clang/test/CIR/Lowering/call-op-call-conv.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: cir-translate -cir-to-llvmir %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM

!s32i = !cir.int<s, 32>
!fnptr = !cir.ptr<!cir.func<!s32i (!s32i)>>

module {
cir.func private @my_add(%a: !s32i, %b: !s32i) -> !s32i cc(spir_function)

cir.func @ind(%fnptr: !fnptr, %a : !s32i) {
%1 = cir.call %fnptr(%a) : (!fnptr, !s32i) -> !s32i cc(spir_kernel)
// LLVM: %{{[0-9]+}} = call spir_kernel i32 %{{[0-9]+}}(i32 %{{[0-9]+}})

%2 = cir.call %fnptr(%a) : (!fnptr, !s32i) -> !s32i cc(spir_function)
// LLVM: %{{[0-9]+}} = call spir_func i32 %{{[0-9]+}}(i32 %{{[0-9]+}})

%3 = cir.call @my_add(%1, %2) : (!s32i, !s32i) -> !s32i cc(spir_function)
// LLVM: %{{[0-9]+}} = call spir_func i32 @my_add(i32 %{{[0-9]+}}, i32 %{{[0-9]+}})

cir.return
}
}

0 comments on commit ead58dc

Please sign in to comment.