Skip to content

Commit

Permalink
[CIR] Add Case Op Kind Range (#650)
Browse files Browse the repository at this point in the history
Make lowering result of case range smart.

Resolve #632
  • Loading branch information
wenpen authored Jun 5, 2024
1 parent c23ea3d commit 4200ad0
Show file tree
Hide file tree
Showing 6 changed files with 385 additions and 62 deletions.
8 changes: 7 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1468,11 +1468,12 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
def CaseOpKind_DT : I32EnumAttrCase<"Default", 1, "default">;
def CaseOpKind_EQ : I32EnumAttrCase<"Equal", 2, "equal">;
def CaseOpKind_AO : I32EnumAttrCase<"Anyof", 3, "anyof">;
def CaseOpKind_RG : I32EnumAttrCase<"Range", 4, "range">;

def CaseOpKind : I32EnumAttr<
"CaseOpKind",
"case kind",
[CaseOpKind_DT, CaseOpKind_EQ, CaseOpKind_AO]> {
[CaseOpKind_DT, CaseOpKind_EQ, CaseOpKind_AO, CaseOpKind_RG]> {
let cppNamespace = "::mlir::cir";
}

Expand Down Expand Up @@ -1510,6 +1511,7 @@ def SwitchOp : CIR_Op<"switch",
condition.
- `anyof, [constant-list]`: equals to any of the values in a subsequent
following list.
- `range, [lower-bound, upper-bound]`: the condition is within the closed interval.
- `default`: any other value.

Each case region must be explicitly terminated.
Expand All @@ -1526,6 +1528,10 @@ def SwitchOp : CIR_Op<"switch",
...
cir.return ...
}
case (range, [10, 15]) {
...
cir.yield break
},
case (default) {
...
cir.yield fallthrough
Expand Down
46 changes: 32 additions & 14 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,37 +608,55 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
const CaseStmt *
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
auto *ctxt = builder.getContext();

const CaseStmt *caseStmt = &S;
const CaseStmt *lastCase = &S;
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;

int caseAttrCount = 0;

// Fold cascading cases whenever possible to simplify codegen a bit.
while (caseStmt) {
lastCase = caseStmt;

auto startVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
auto endVal = startVal;
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());

if (auto *rhs = caseStmt->getRHS()) {
endVal = rhs->EvaluateKnownConstInt(getContext());
}
for (auto intVal = startVal; intVal <= endVal; ++intVal) {
auto endVal = rhs->EvaluateKnownConstInt(getContext());
SmallVector<mlir::Attribute, 4> rangeCaseAttr = {
mlir::cir::IntAttr::get(condType, intVal),
mlir::cir::IntAttr::get(condType, endVal)};
auto caseAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(rangeCaseAttr),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Range));
caseAttrs.push_back(caseAttr);
++caseAttrCount;
} else {
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
}

caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
}

assert(!caseEltValueListAttr.empty() && "empty case value NYI");

auto *ctxt = builder.getContext();
if (!caseEltValueListAttr.empty()) {
auto caseOpKind = caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal;
auto caseAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt, caseOpKind));
caseAttrs.push_back(caseAttr);
++caseAttrCount;
}

auto caseAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal));
assert(caseAttrCount > 0 && "there should be at least one valid case attr");

caseAttrs.push_back(caseAttr);
for (int i = 1; i < caseAttrCount; ++i) {
// If there are multiple case attributes, we need to create a new region
auto *region = currLexScope->createSwitchRegion();
auto *block = builder.createBlock(region);
}

return lastCase;
}
Expand Down
14 changes: 10 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,8 @@ parseSwitchOp(OpAsmParser &parser,
// 2. Get the value (next in list)

// These needs to be in sync with CIROps.td
if (parser.parseOptionalKeyword(&attrStr, {"default", "equal", "anyof"})) {
if (parser.parseOptionalKeyword(&attrStr,
{"default", "equal", "anyof", "range"})) {
::mlir::StringAttr attrVal;
::mlir::OptionalParseResult parseResult = parser.parseOptionalAttribute(
attrVal, parser.getBuilder().getNoneType(), "kind", attrStorage);
Expand All @@ -1016,8 +1017,9 @@ parseSwitchOp(OpAsmParser &parser,

if (attrStr.empty()) {
return parser.emitError(
loc, "expected string or keyword containing one of the following "
"enum values for attribute 'kind' [default, equal, anyof]");
loc,
"expected string or keyword containing one of the following "
"enum values for attribute 'kind' [default, equal, anyof, range]");
}

auto attrOptional = ::mlir::cir::symbolizeCaseOpKind(attrStr.str());
Expand All @@ -1042,6 +1044,7 @@ parseSwitchOp(OpAsmParser &parser,
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(intCondType, val));
break;
}
case cir::CaseOpKind::Range:
case cir::CaseOpKind::Anyof: {
if (parser.parseComma().failed())
return mlir::failure();
Expand Down Expand Up @@ -1129,7 +1132,7 @@ void printSwitchOp(OpAsmPrinter &p, SwitchOp op,
auto attr = casesAttr[idx].cast<CaseAttr>();
auto kind = attr.getKind().getValue();
assert((kind == CaseOpKind::Default || kind == CaseOpKind::Equal ||
kind == CaseOpKind::Anyof) &&
kind == CaseOpKind::Anyof || kind == CaseOpKind::Range) &&
"unknown case");

// Case kind
Expand All @@ -1144,6 +1147,9 @@ void printSwitchOp(OpAsmPrinter &p, SwitchOp op,
(intAttrTy.isSigned() ? p << intAttr.getSInt() : p << intAttr.getUInt());
break;
}
case cir::CaseOpKind::Range:
assert(attr.getValue().size() == 2 && "range must have two values");
// The print format of the range is the same as anyof
case cir::CaseOpKind::Anyof: {
p << ", [";
llvm::interleaveComma(attr.getValue(), p, [&](const Attribute &a) {
Expand Down
87 changes: 85 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,43 @@ class CIRSwitchOpFlattening
destination);
}

// Return the new defaultDestination block.
Block *condBrToRangeDestination(mlir::cir::SwitchOp op,
mlir::PatternRewriter &rewriter,
mlir::Block *rangeDestination,
mlir::Block *defaultDestination,
APInt lowerBound, APInt upperBound) const {
assert(lowerBound.sle(upperBound) && "Invalid range");
auto resBlock = rewriter.createBlock(defaultDestination);
auto sIntType = mlir::cir::IntType::get(op.getContext(), 32, true);
auto uIntType = mlir::cir::IntType::get(op.getContext(), 32, false);

auto rangeLength = rewriter.create<mlir::cir::ConstantOp>(
op.getLoc(), sIntType,
mlir::cir::IntAttr::get(op.getContext(), sIntType,
upperBound - lowerBound));

auto lowerBoundValue = rewriter.create<mlir::cir::ConstantOp>(
op.getLoc(), sIntType,
mlir::cir::IntAttr::get(op.getContext(), sIntType, lowerBound));
auto diffValue = rewriter.create<mlir::cir::BinOp>(
op.getLoc(), sIntType, mlir::cir::BinOpKind::Sub, op.getCondition(),
lowerBoundValue);

// Use unsigned comparison to check if the condition is in the range.
auto uDiffValue = rewriter.create<mlir::cir::CastOp>(
op.getLoc(), uIntType, CastKind::integral, diffValue);
auto uRangeLength = rewriter.create<mlir::cir::CastOp>(
op.getLoc(), uIntType, CastKind::integral, rangeLength);

auto cmpResult = rewriter.create<mlir::cir::CmpOp>(
op.getLoc(), mlir::cir::BoolType::get(op.getContext()),
mlir::cir::CmpOpKind::le, uDiffValue, uRangeLength);
rewriter.create<mlir::cir::BrCondOp>(op.getLoc(), cmpResult,
rangeDestination, defaultDestination);
return resBlock;
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::SwitchOp op,
mlir::PatternRewriter &rewriter) const override {
Expand All @@ -279,6 +316,10 @@ class CIRSwitchOpFlattening
llvm::SmallVector<mlir::Block *, 8> caseDestinations;
llvm::SmallVector<mlir::ValueRange, 8> caseOperands;

llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
llvm::SmallVector<mlir::Block *> rangeDestinations;
llvm::SmallVector<mlir::ValueRange> rangeOperands;

// Initialize default case as optional.
mlir::Block *defaultDestination = exitBlock;
mlir::ValueRange defaultOperands = exitBlock->getArguments();
Expand All @@ -292,16 +333,31 @@ class CIRSwitchOpFlattening
auto caseAttr = op.getCases()->getValue()[i].cast<mlir::cir::CaseAttr>();

// Found default case: save destination and operands.
if (caseAttr.getKind().getValue() == mlir::cir::CaseOpKind::Default) {
switch (caseAttr.getKind().getValue()) {
case mlir::cir::CaseOpKind::Default:
defaultDestination = &region.front();
defaultOperands = region.getArguments();
} else {
break;
case mlir::cir::CaseOpKind::Range:
assert(caseAttr.getValue().size() == 2 &&
"Case range should have 2 case value");
rangeValues.push_back(
{caseAttr.getValue()[0].cast<mlir::cir::IntAttr>().getValue(),
caseAttr.getValue()[1].cast<mlir::cir::IntAttr>().getValue()});
rangeDestinations.push_back(&region.front());
rangeOperands.push_back(region.getArguments());
break;
case mlir::cir::CaseOpKind::Anyof:
case mlir::cir::CaseOpKind::Equal:
// AnyOf cases kind can have multiple values, hence the loop below.
for (auto &value : caseAttr.getValue()) {
caseValues.push_back(value.cast<mlir::cir::IntAttr>().getValue());
caseOperands.push_back(region.getArguments());
caseDestinations.push_back(&region.front());
}
break;
default:
llvm_unreachable("unsupported case kind");
}

// Previous case is a fallthrough: branch it to this case.
Expand Down Expand Up @@ -336,6 +392,33 @@ class CIRSwitchOpFlattening
fallthroughYieldOp = nullptr;
}

for (size_t index = 0; index < rangeValues.size(); ++index) {
auto lowerBound = rangeValues[index].first;
auto upperBound = rangeValues[index].second;

// The case range is unreachable, skip it.
if (lowerBound.sgt(upperBound))
continue;

// If range is small, add multiple switch instruction cases.
// This magical number is from the original CGStmt code.
constexpr int kSmallRangeThreshold = 64;
if ((upperBound - lowerBound)
.ult(llvm::APInt(32, kSmallRangeThreshold))) {
for (auto iValue = lowerBound; iValue.sle(upperBound); iValue++) {
caseValues.push_back(iValue);
caseOperands.push_back(rangeOperands[index]);
caseDestinations.push_back(rangeDestinations[index]);
}
continue;
}

defaultDestination =
condBrToRangeDestination(op, rewriter, rangeDestinations[index],
defaultDestination, lowerBound, upperBound);
defaultOperands = rangeOperands[index];
}

// Set switch op to branch to the newly created blocks.
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::SwitchFlatOp>(
Expand Down
Loading

0 comments on commit 4200ad0

Please sign in to comment.