Skip to content

Commit

Permalink
[CIR][Transform] Add constant load elimination pass
Browse files Browse the repository at this point in the history
This patch tries to give a simple initial implementation for eliminating
redundant loads of constant objects, an idea originally posted by OfekShilon.

Specifically, this patch consists of two parts:

* It adds a new unit attribute `const` to the `cir.alloca` operation.  Presence
  of this attribute indicates that the alloca-ed object is declared `const` in
  the input source program.  CIRGen is updated accordingly to start emitting
  this new attribute.

* It adds a new pass to the CIR optimization pipeline.  This new pass runs on
  function level, and identifies and eliminates all redundant loads of a
  constant alloca-ed object.
  • Loading branch information
Lancern committed Sep 27, 2024
1 parent 52323c1 commit 0b45aea
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 7 deletions.
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ def AllocaOp : CIR_Op<"alloca", [
cases, the first use contains the initialization (a cir.store, a cir.call
to a ctor, etc).

The presence of the `const` attribute indicates that the local variable is
declared with C/C++ `const` keyword.

The `dynAllocSize` specifies the size to dynamically allocate on the stack
and ignores the allocation size based on the original type. This is useful
when handling VLAs and is omitted when declaring regular local variables.
Expand All @@ -492,6 +495,7 @@ def AllocaOp : CIR_Op<"alloca", [
TypeAttr:$allocaType,
StrAttr:$name,
UnitAttr:$init,
UnitAttr:$constant,
ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$alignment,
OptionalAttr<ASTVarDeclInterface>:$ast
);
Expand Down Expand Up @@ -529,6 +533,7 @@ def AllocaOp : CIR_Op<"alloca", [
($dynAllocSize^ `:` type($dynAllocSize) `,`)?
`[` $name
(`,` `init` $init^)?
(`,` `const` $constant^)?
`]`
(`ast` $ast^)? attr-dict
}];
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ std::unique_ptr<Pass> createLifetimeCheckPass(ArrayRef<StringRef> remark,
clang::ASTContext *astCtx);
std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createConstLoadEliminationPass();
std::unique_ptr<Pass> createDropASTPass();
std::unique_ptr<Pass> createSCFPreparePass();
std::unique_ptr<Pass> createLoweringPreparePass();
Expand Down
18 changes: 18 additions & 0 deletions clang/include/clang/CIR/Dialect/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ def CIRSimplify : Pass<"cir-simplify"> {
let dependentDialects = ["cir::CIRDialect"];
}

def ConstLoadElimination : Pass<"cir-const-load-elimination"> {
let summary = "Eliminate redundant loads of constant objects";
let description = [{
This pass eliminates those redundant loads that load object known to be
constant.

The value of an object declared with `const` cannot change during the
object's whole lifetime. Thus multiple loads of a `const` object can be
merged into a single load when the result load dominates all the original
loads.

This pass is a function pass and it processes a single function within a
single run.
}];
let constructor = "mlir::createConstLoadEliminationPass()";
let dependentDialects = ["cir::CIRDialect"];
}

def LifetimeCheck : Pass<"cir-lifetime-check"> {
let summary = "Check lifetime safety and generate diagnostics";
let description = [{
Expand Down
14 changes: 8 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,11 @@ mlir::LogicalResult CIRGenFunction::declare(const Decl *var, QualType ty,
assert(!symbolTable.count(var) && "not supposed to be available just yet");

addr = buildAlloca(namedVar->getName(), ty, loc, alignment);
if (isParam) {
auto allocaOp = cast<mlir::cir::AllocaOp>(addr.getDefiningOp());
auto allocaOp = cast<mlir::cir::AllocaOp>(addr.getDefiningOp());
if (isParam)
allocaOp.setInitAttr(mlir::UnitAttr::get(builder.getContext()));
}
if (ty.isConstQualified())
allocaOp.setConstantAttr(mlir::UnitAttr::get(builder.getContext()));

symbolTable.insert(var, addr);
return mlir::success();
Expand All @@ -324,10 +325,11 @@ mlir::LogicalResult CIRGenFunction::declare(Address addr, const Decl *var,
assert(!symbolTable.count(var) && "not supposed to be available just yet");

addrVal = addr.getPointer();
if (isParam) {
auto allocaOp = cast<mlir::cir::AllocaOp>(addrVal.getDefiningOp());
auto allocaOp = cast<mlir::cir::AllocaOp>(addrVal.getDefiningOp());
if (isParam)
allocaOp.setInitAttr(mlir::UnitAttr::get(builder.getContext()));
}
if (ty.isConstQualified())
allocaOp.setConstantAttr(mlir::UnitAttr::get(builder.getContext()));

symbolTable.insert(var, addrVal);
return mlir::success();
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ mlir::LogicalResult runCIRToCIRPasses(
pm.addPass(std::move(libOpPass));
}

if (enableCIRSimplify)
if (enableCIRSimplify) {
pm.addPass(mlir::createCIRSimplifyPass());
pm.addPass(mlir::createConstLoadEliminationPass());
}

pm.addPass(mlir::createLoweringPreparePass(&astCtx));

Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_clang_library(MLIRCIRTransforms
LoweringPrepare.cpp
CIRCanonicalize.cpp
CIRSimplify.cpp
ConstLoadElimination.cpp
DropAST.cpp
IdiomRecognizer.cpp
LibOpt.cpp
Expand Down
116 changes: 116 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/ConstLoadElimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===- ConstLoadElimination.cpp - performs redundant load elimination -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

using namespace mlir;
using namespace cir;

namespace {

void processConstAlloca(DominanceInfo &dom, AllocaOp alloca) {
assert(alloca.getConstant() && "must be a constant alloca");

// First find out all loads and stores to the alloca-ed object.
SmallVector<LoadOp> allLoads;
SmallVector<StoreOp> allStores;
for (Operation *user : alloca->getUsers()) {
if (auto load = dyn_cast<LoadOp>(user))
allLoads.push_back(load);
else if (auto store = dyn_cast<StoreOp>(user))
allStores.push_back(store);
}

// For each non-volatile load:
// - If there is a load operation that properly dominates it, replace the
// load with that dominator load. This process is "recursive": if load A
// dominates load B and load B dominates load C, we should eventually
// replace load C with load A.
// - If there is a store operation that dominates it, replace the load with
// the stored value.

// Record the "immediate dominator" load of a load. During the process if we
// find a store dominates the load, replace that load directly.
DenseMap<LoadOp, LoadOp> idomLoad;
for (LoadOp &load : allLoads) {
// Try to replace the load with a previous store directly.
// Note that volatile loads are not candidates for elimination.
if (!load.getIsVolatile()) {
for (StoreOp store : allStores) {
if (dom.dominates(store, load)) {
load.replaceAllUsesWith(store.getValue());
load.erase();
load = nullptr;
break;
}
}
if (!load)
continue;
}

// No store dominates the load. Find the "immediate dominator" load for the
// load.
for (LoadOp domLoad : allLoads) {
if (dom.properlyDominates(domLoad.getOperation(), load)) {
idomLoad[load] = domLoad;
break;
}
}
}

// Try to replace load with previous loads.
for (LoadOp load : allLoads) {
if (!load) {
// Already replaced by a store.
continue;
}

// Volatile loads are not candidates for elimination.
if (load.getIsVolatile())
continue;

// Follow the "immediate dominator" link to find the load for replacement.
LoadOp target = load;
while (idomLoad.contains(target))
target = idomLoad[target];

if (load != target) {
load->replaceAllUsesWith(target);
load->erase();
}
}
}

void processFunc(mlir::cir::FuncOp func) {
SmallVector<AllocaOp> constAllocaList;
func->walk([&](AllocaOp alloca) {
if (alloca.getConstant())
constAllocaList.push_back(alloca);
});

DominanceInfo dom;
for (AllocaOp alloca : constAllocaList)
processConstAlloca(dom, alloca);
}

struct ConstLoadEliminationPass
: public ConstLoadEliminationBase<ConstLoadEliminationPass> {
using ConstLoadEliminationBase::ConstLoadEliminationBase;

void runOnOperation() override { getOperation()->walk(processFunc); }
};

} // namespace

std::unique_ptr<Pass> mlir::createConstLoadEliminationPass() {
return std::make_unique<ConstLoadEliminationPass>();
}
33 changes: 33 additions & 0 deletions clang/test/CIR/CodeGen/const-alloca.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

int produce_int();

void local_const_int() {
const int x = produce_int();
}

// CHECK-LABEL: @_Z15local_const_intv
// CHECK: %{{.+}} = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init, const]
// CHECK: }

void param_const_int(const int x) {}

// CHECK-LABEL: @_Z15param_const_inti
// CHECK: %{{.+}} = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init, const]
// CHECK: }

struct Foo {
int a;
int b;
};

Foo produce_foo();

void local_const_struct() {
const Foo x = produce_foo();
}

// CHECK-LABEL: @_Z18local_const_structv
// CHECK: %{{.+}} = cir.alloca !ty_Foo, !cir.ptr<!ty_Foo>, ["x", init, const]
// CHECK: }
65 changes: 65 additions & 0 deletions clang/test/CIR/Transforms/const-load-elimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -O1 -fclangir -emit-cir %s -o %t.cir
// FileCheck --input-file=%t.cir %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -O1 -fclangir -fclangir-mem2reg -emit-cir %s -o %t.cir
// FileCheck --input-file=%t.cir %s --check-prefix=MEM2REG

int produce_int();
void blackbox(const int &);
void blackbox(const volatile int &);

int load_local_const_int() {
const int x = produce_int();
int a = x;
blackbox(x);
int b = x;
return a + b;
}

// CHECK-LABEL: @_Z20load_local_const_intv
// CHECK: %[[#x_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init, const] {alignment = 4 : i64}
// CHECK-NEXT: %[[#a_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
// CHECK-NEXT: %[[#b_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
// CHECK-NEXT: %[[#init:]] = cir.call @_Z11produce_intv() : () -> !s32i
// CHECK-NEXT: cir.store %[[#init]], %[[#x_slot]] : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.store %[[#init]], %[[#a_slot]] : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.call @_Z8blackboxRKi(%[[#x_slot]]) : (!cir.ptr<!s32i>) -> ()
// CHECK-NEXT: cir.store %[[#init]], %[[#b_slot]] : !s32i, !cir.ptr<!s32i>
// CHECK: }

// MEM2REG-LABEL: @_Z20load_local_const_intv
// MEM2REG-NEXT: %[[#x_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init, const] {alignment = 4 : i64}
// MEM2REG-NEXT: %[[#init:]] = cir.call @_Z11produce_intv() : () -> !s32i
// MEM2REG-NEXT: cir.store %[[#init]], %[[#x_slot]] : !s32i, !cir.ptr<!s32i>
// MEM2REG-NEXT: cir.call @_Z8blackboxRKi(%[[#x_slot]]) : (!cir.ptr<!s32i>) -> ()
// MEM2REG-NEXT: %{{.+}} = cir.binop(add, %[[#init]], %[[#init]]) nsw : !s32i
// MEM2REG: }

int load_volatile_local_const_int() {
const volatile int x = produce_int();
int a = x;
blackbox(x);
int b = x;
return a + b;
}

// CHECKLABEL: @_Z29load_volatile_local_const_intv
// CHECK: %[[#x_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init, const] {alignment = 4 : i64}
// CHECK-NEXT: %[[#a_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
// CHECK-NEXT: %[[#b_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
// CHECK-NEXT: %[[#init:]] = cir.call @_Z11produce_intv() : () -> !s32i
// CHECK-NEXT: cir.store volatile %[[#init]], %[[#x_slot]] : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: %[[#reload_1:]] = cir.load volatile %[[#x_slot]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: cir.store %[[#reload_1]], %[[#a_slot]] : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.call @_Z8blackboxRVKi(%[[#x_slot]]) : (!cir.ptr<!s32i>) -> ()
// CHECK-NEXT: %[[#reload_2:]] = cir.load volatile %[[#x_slot]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: cir.store %[[#reload_2]], %[[#b_slot]] : !s32i, !cir.ptr<!s32i>
// CHECK: }

// MEM2REG-LABEL: @_Z29load_volatile_local_const_intv
// MEM2REG-NEXT: %[[#x_slot:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init, const] {alignment = 4 : i64}
// MEM2REG-NEXT: %[[#init:]] = cir.call @_Z11produce_intv() : () -> !s32i
// MEM2REG-NEXT: cir.store volatile %[[#init]], %[[#x_slot]] : !s32i, !cir.ptr<!s32i>
// MEM2REG-NEXT: %{{.+}} = cir.load volatile %[[#x_slot]] : !cir.ptr<!s32i>, !s32i
// MEM2REG-NEXT: cir.call @_Z8blackboxRVKi(%[[#x_slot]]) : (!cir.ptr<!s32i>) -> ()
// MEM2REG-NEXT: %{{.+}} = cir.load volatile %[[#x_slot]] : !cir.ptr<!s32i>, !s32i
// MEM2REG: }

0 comments on commit 0b45aea

Please sign in to comment.