Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][TargetLowering][NFC] Refactor LowerModule to a unified global state #759

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- LowerModuleRegistry.h - LowerModule singleton registry ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines the registry of LowerModule so that it can be easily
// accessed from other libraries.
//
//===----------------------------------------------------------------------===//

#ifndef CLANG_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_LOWERMODULEREGISTRY_H
#define CLANG_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_LOWERMODULEREGISTRY_H

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace cir {

class LowerModule;

/// Registry for the LowerModule, enabling easy access to the LowerModule from
/// various libraries.
class LowerModuleRegistry {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the naming here. Advice welcome.

std::unique_ptr<LowerModule> lowerModule;
std::optional<PatternRewriter> rewriter;

public:
/// Initialize the LowerModuleRegistry with the given module and an internal
/// rewriter.
void initializeWithModule(ModuleOp module);

/// Check if the LowerModuleRegistry has been initialized.
bool isInitialized() { return lowerModule != nullptr; }

/// Get the reference to already-initialized LowerModule.
LowerModule &get() {
assert(isInitialized() && "LowerModuleRegistry not initialized");
return *lowerModule;
}

/// Get the LowerModuleRegistry singleton.
static LowerModuleRegistry &instance();
};

} // namespace cir
} // namespace mlir

#endif // CLANG_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_LOWERMODULEREGISTRY_H
17 changes: 12 additions & 5 deletions clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Transforms/TargetLowering/LowerModuleRegistry.h"

#define GEN_PASS_DEF_CALLCONVLOWERING
#include "clang/CIR/Dialect/Passes.h.inc"
Expand All @@ -36,23 +37,22 @@ struct CallConvLoweringPattern : public OpRewritePattern<FuncOp> {
return op.emitError("function has no AST information");

auto modOp = op->getParentOfType<ModuleOp>();
std::unique_ptr<LowerModule> lowerModule =
createLowerModule(modOp, rewriter);
LowerModule &lowerModule = LowerModuleRegistry::instance().get();

// Rewrite function calls before definitions. This should be done before
// lowering the definition.
auto calls = op.getSymbolUses(module);
if (calls.has_value()) {
for (auto call : calls.value()) {
auto callOp = cast<CallOp>(call.getUser());
if (lowerModule->rewriteFunctionCall(callOp, op).failed())
if (lowerModule.rewriteFunctionCall(callOp, op).failed())
return failure();
}
}

// TODO(cir): Instead of re-emmiting every load and store, bitcast arguments
// and return values to their ABI-specific counterparts when possible.
if (lowerModule->rewriteFunctionDefinition(op).failed())
if (lowerModule.rewriteFunctionDefinition(op).failed())
return failure();

return success();
Expand All @@ -76,14 +76,21 @@ void populateCallConvLoweringPassPatterns(RewritePatternSet &patterns) {
}

void CallConvLoweringPass::runOnOperation() {
auto module = cast<ModuleOp>(getOperation());

// Initialize LowerModule here to avoid create it for each function.
if (auto &registry = LowerModuleRegistry::instance();
!registry.isInitialized()) {
registry.initializeWithModule(module);
}

// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateCallConvLoweringPassPatterns(patterns);

// Collect operations to be considered by the pass.
SmallVector<Operation *, 16> ops;
getOperation()->walk([&](FuncOp op) { ops.push_back(op); });
module->walk([&](FuncOp op) { ops.push_back(op); });

// Configure rewrite to ignore new ops created during the pass.
GreedyRewriteConfig config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_clang_library(TargetLowering
LowerCall.cpp
LowerFunction.cpp
LowerModule.cpp
LowerModuleRegistry.cpp
LowerTypes.cpp
RecordLayoutBuilder.cpp
TargetInfo.cpp
Expand Down
27 changes: 0 additions & 27 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
//
//===----------------------------------------------------------------------===//

// FIXME(cir): This header file is not exposed to the public API, but can be
// reused by CIR ABI lowering since it holds target-specific information.
#include "../../../../Basic/Targets.h"
#include "clang/Basic/LangOptions.h"
#include "clang/Basic/TargetOptions.h"

Expand Down Expand Up @@ -221,29 +218,5 @@ LogicalResult LowerModule::rewriteFunctionCall(CallOp callOp, FuncOp funcOp) {
return success();
}

// TODO: not to create it every time
std::unique_ptr<LowerModule> createLowerModule(ModuleOp module,
PatternRewriter &rewriter) {
// Fetch the LLVM data layout string.
auto dataLayoutStr = cast<StringAttr>(
module->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));

// Fetch target information.
llvm::Triple triple(
cast<StringAttr>(module->getAttr("cir.triple")).getValue());
clang::TargetOptions targetOptions;
targetOptions.Triple = triple.str();
auto targetInfo = clang::targets::AllocateTarget(triple, targetOptions);

// FIXME(cir): This just uses the default language options. We need to account
// for custom options.
// Create context.
assert(!::cir::MissingFeatures::langOpts());
clang::LangOptions langOpts;

return std::make_unique<LowerModule>(langOpts, module, dataLayoutStr,
std::move(targetInfo), rewriter);
}

} // namespace cir
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ class LowerModule {
LogicalResult rewriteFunctionCall(CallOp callOp, FuncOp funcOp);
};

std::unique_ptr<LowerModule> createLowerModule(ModuleOp module,
PatternRewriter &rewriter);

} // namespace cir
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- LowerModuleRegistry.cpp - LowerModule singleton registry -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the registry of LowerModule so that it can be easily
// accessed from other libraries.
//
//===----------------------------------------------------------------------===//

// FIXME(cir): This header file is not exposed to the public API, but can be
// reused by CIR ABI lowering since it holds target-specific information.
#include "clang/CIR/Dialect/Transforms/TargetLowering/LowerModuleRegistry.h"
#include "../../../../Basic/Targets.h"
#include "LowerModule.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

namespace mlir {
namespace cir {

void LowerModuleRegistry::initializeWithModule(ModuleOp module) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrated from the removed function createLowerModule.

assert(!isInitialized() && "LowerModuleRegistry already initialized");
// Create a new rewriter.
rewriter.emplace(module->getContext());
// Fetch the LLVM data layout string.
auto dataLayoutStr = cast<StringAttr>(
module->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));

// Fetch target information.
llvm::Triple triple(
cast<StringAttr>(module->getAttr("cir.triple")).getValue());
clang::TargetOptions targetOptions;
targetOptions.Triple = triple.str();
auto targetInfo = clang::targets::AllocateTarget(triple, targetOptions);

// FIXME(cir): This just uses the default language options. We need to account
// for custom options.
// Create context.
assert(!::cir::MissingFeatures::langOpts());
clang::LangOptions langOpts;

lowerModule = std::make_unique<LowerModule>(langOpts, module, dataLayoutStr,
std::move(targetInfo), *rewriter);
}

LowerModuleRegistry &LowerModuleRegistry::instance() {
static llvm::ManagedStatic<LowerModuleRegistry> lowerModuleRegistry;
return *lowerModuleRegistry;
}

} // namespace cir
} // namespace mlir
2 changes: 2 additions & 0 deletions clang/lib/CIR/FrontendAction/CIRGenAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "clang/CIR/CIRGenerator.h"
#include "clang/CIR/CIRToCIRPasses.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Transforms/TargetLowering/LowerModuleRegistry.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/Passes.h"
#include "clang/CodeGen/BackendUtil.h"
Expand Down Expand Up @@ -166,6 +167,7 @@ class CIRGenConsumer : public clang::ASTConsumer {

auto mlirMod = gen->getModule();
auto mlirCtx = gen->takeContext();
mlir::cir::LowerModuleRegistry::instance().initializeWithModule(mlirMod);

auto setupCIRPipelineAndExecute = [&] {
// Sanitize passes options. MLIR uses spaces between pass options
Expand Down
27 changes: 15 additions & 12 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/Dialect/Transforms/TargetLowering/LowerModuleRegistry.h"
#include "clang/CIR/LoweringHelpers.h"
#include "clang/CIR/MissingFeatures.h"
#include "clang/CIR/Passes.h"
Expand Down Expand Up @@ -3482,15 +3483,19 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,

namespace {

std::unique_ptr<mlir::cir::LowerModule>
prepareLowerModule(mlir::ModuleOp module) {
mlir::PatternRewriter rewriter{module->getContext()};
// If the triple is not present, e.g. CIR modules parsed from text, we
// cannot init LowerModule properly.
assert(!::cir::MissingFeatures::makeTripleAlwaysPresent());
if (!module->hasAttr("cir.triple"))
return {};
return mlir::cir::createLowerModule(module, rewriter);
mlir::cir::LowerModule *prepareLowerModule(mlir::ModuleOp module) {
auto &registry = mlir::cir::LowerModuleRegistry::instance();
if (!registry.isInitialized()) {
// If the triple is not present, e.g. CIR modules parsed from text, we
// cannot init LowerModule properly.
assert(!::cir::MissingFeatures::makeTripleAlwaysPresent());
if (!module->hasAttr("cir.triple"))
return nullptr;

registry.initializeWithModule(module);
}

return &registry.get();
}

// FIXME: change the type of lowerModule to `LowerModule &` to have better
Expand Down Expand Up @@ -3742,9 +3747,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
auto module = getOperation();
mlir::DataLayout dataLayout(module);
mlir::LLVMTypeConverter converter(&getContext());
std::unique_ptr<mlir::cir::LowerModule> lowerModule =
prepareLowerModule(module);
prepareTypeConverter(converter, dataLayout, lowerModule.get());
prepareTypeConverter(converter, dataLayout, prepareLowerModule(module));

mlir::RewritePatternSet patterns(&getContext());

Expand Down
Loading