Skip to content

Commit

Permalink
no pattern rw
Browse files Browse the repository at this point in the history
  • Loading branch information
gitoleg committed Oct 23, 2024
1 parent eab0441 commit 11547ee
Showing 1 changed file with 114 additions and 15 deletions.
129 changes: 114 additions & 15 deletions clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,100 @@ FuncType lowerFuncType(LowerModule &mod, FuncType ftyp) {
// Rewrite Patterns
//===----------------------------------------------------------------------===//

struct CCLowering {

CCLowering(LowerModule& mod,
mlir::PatternRewriter& rw,
mlir::TypeConverter& tyConv)
: lowerModule(mod)
, rewriter(rw)
, typeConverter(tyConv)
{}

void lower(Operation* op) {

rewriter.setInsertionPoint(op);
if (auto fun = dyn_cast<FuncOp>(op))
lowerFuncOp(fun);
else if (auto al = dyn_cast<AllocaOp>(op))
lowerAllocaOp(al);
else if (auto glob = dyn_cast<GetGlobalOp>(op))
lowerGetGlobalOp(glob);
else if (auto call = dyn_cast<CallOp>(op))
lowerCallOp(call);
}

private:

void lowerFuncOp(FuncOp op) {
// const auto module = op->getParentOfType<mlir::ModuleOp>();
// auto calls = op.getSymbolUses(module);
// if (calls.has_value()) {
// for (auto call : calls.value()) {
// // FIXME(cir): Function pointers are ignored except the next cases
// if (!isa<GetGlobalOp, CallOp>(call.getUser())) {
// cir_cconv_assert_or_abort(!::cir::MissingFeatures::ABIFuncPtr(),
// "NYI");
// continue;
// }

// if (auto callOp = dyn_cast_or_null<CallOp>(call.getUser()))
// lowerModule.rewriteFunctionCall(callOp, op);


// }
// }
lowerModule.rewriteFunctionDefinition(op);
}

void lowerAllocaOp(AllocaOp op) {
auto eltTy = typeConverter.convertType(op.getAllocaType());
if (op.getAllocaType() != eltTy)
rewriter.replaceOpWithNewOp<AllocaOp>(
op, typeConverter.convertType(op.getResult().getType()),
eltTy, op.getName(), op.getAlignmentAttr(), op.getDynAllocSize());
}

void lowerGetGlobalOp(GetGlobalOp op) {
auto resTy = op.getResult().getType();
if (auto ptrTy = dyn_cast<PointerType>(resTy)) {
if (isa<FuncType>(ptrTy.getPointee())) {
rewriter.replaceOpWithNewOp<GetGlobalOp>(
op, typeConverter.convertType(resTy), op.getName());
}
}
}

void lowerCallOp(CallOp op) {
auto mod = op->getParentOfType<ModuleOp>();
if (auto callee = op.getCallee()) {
if (auto fun = findFun(mod, *callee)) {
lowerModule.rewriteFunctionCall(op, fun);
}
}

}

FuncOp findFun(mlir::ModuleOp mod, llvm::StringRef name) {
FuncOp fun;
mod->walk([&](FuncOp f) {
if (f.getName() == name) {
fun = f;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return fun;
}

private:

LowerModule &lowerModule;
mlir::PatternRewriter& rewriter;
mlir::TypeConverter& typeConverter;
};


class CCFuncOpLowering : public mlir::OpRewritePattern<FuncOp> {
using OpRewritePattern<FuncOp>::OpRewritePattern;
LowerModule &lowerModule;
Expand Down Expand Up @@ -170,24 +264,29 @@ void CallConvLoweringPass::runOnOperation() {
mlir::TypeConverter converter;
initTypeConverter(converter, *lowerModule.get());

// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateCallConvLoweringPassPatterns(converter, *lowerModule.get(), patterns);

// Collect operations to be considered by the pass.
SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<AllocaOp, FuncOp, GetGlobalOp>(op))
ops.push_back(op);
CCLowering low(*lowerModule.get(), rewriter, converter);
getOperation()->walk([&](Operation* op) {
low.lower(op);
});

// Configure rewrite to ignore new ops created during the pass.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
// // Collect rewrite patterns.
// RewritePatternSet patterns(&getContext());
// populateCallConvLoweringPassPatterns(converter, *lowerModule.get(), patterns);

// // Collect operations to be considered by the pass.
// SmallVector<Operation *, 16> ops;
// getOperation()->walk([&](Operation *op) {
// if (isa<AllocaOp, FuncOp, GetGlobalOp>(op))
// ops.push_back(op);
// });

// // Configure rewrite to ignore new ops created during the pass.
// GreedyRewriteConfig config;
// config.strictMode = GreedyRewriteStrictness::ExistingOps;

// Apply patterns.
if (failed(applyOpPatternsAndFold(ops, std::move(patterns), config)))
signalPassFailure();
// // Apply patterns.
// if (failed(applyOpPatternsAndFold(ops, std::move(patterns), config)))
// signalPassFailure();
}

} // namespace cir
Expand Down

0 comments on commit 11547ee

Please sign in to comment.