From 5443caf4ad59e0fca0c24d82fe7b92d66b9b48c3 Mon Sep 17 00:00:00 2001 From: axp Date: Thu, 6 Jun 2024 05:26:26 +0800 Subject: [PATCH] [CIR] Add Case Op Kind Range (#650) Make lowering result of case range smart. Resolve #632 --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 8 +- clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 46 ++-- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 14 +- .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 87 ++++++- clang/test/CIR/CodeGen/switch-gnurange.cpp | 227 +++++++++++++++--- clang/test/CIR/Transforms/switch.cir | 65 ++++- 6 files changed, 385 insertions(+), 62 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 40e8911ed835..2ddedea4a8b8 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1468,11 +1468,12 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> { def CaseOpKind_DT : I32EnumAttrCase<"Default", 1, "default">; def CaseOpKind_EQ : I32EnumAttrCase<"Equal", 2, "equal">; def CaseOpKind_AO : I32EnumAttrCase<"Anyof", 3, "anyof">; +def CaseOpKind_RG : I32EnumAttrCase<"Range", 4, "range">; def CaseOpKind : I32EnumAttr< "CaseOpKind", "case kind", - [CaseOpKind_DT, CaseOpKind_EQ, CaseOpKind_AO]> { + [CaseOpKind_DT, CaseOpKind_EQ, CaseOpKind_AO, CaseOpKind_RG]> { let cppNamespace = "::mlir::cir"; } @@ -1510,6 +1511,7 @@ def SwitchOp : CIR_Op<"switch", condition. - `anyof, [constant-list]`: equals to any of the values in a subsequent following list. + - `range, [lower-bound, upper-bound]`: the condition is within the closed interval. - `default`: any other value. Each case region must be explicitly terminated. @@ -1526,6 +1528,10 @@ def SwitchOp : CIR_Op<"switch", ... cir.return ... } + case (range, [10, 15]) { + ... + cir.yield break + }, case (default) { ... cir.yield fallthrough diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 5df3b3586be7..93ab3ee06dea 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -608,37 +608,55 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) { const CaseStmt * CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType, SmallVector &caseAttrs) { + auto *ctxt = builder.getContext(); + const CaseStmt *caseStmt = &S; const CaseStmt *lastCase = &S; SmallVector caseEltValueListAttr; + int caseAttrCount = 0; + // Fold cascading cases whenever possible to simplify codegen a bit. while (caseStmt) { lastCase = caseStmt; - auto startVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext()); - auto endVal = startVal; + auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext()); + if (auto *rhs = caseStmt->getRHS()) { - endVal = rhs->EvaluateKnownConstInt(getContext()); - } - for (auto intVal = startVal; intVal <= endVal; ++intVal) { + auto endVal = rhs->EvaluateKnownConstInt(getContext()); + SmallVector rangeCaseAttr = { + mlir::cir::IntAttr::get(condType, intVal), + mlir::cir::IntAttr::get(condType, endVal)}; + auto caseAttr = mlir::cir::CaseAttr::get( + ctxt, builder.getArrayAttr(rangeCaseAttr), + CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Range)); + caseAttrs.push_back(caseAttr); + ++caseAttrCount; + } else { caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal)); } caseStmt = dyn_cast_or_null(caseStmt->getSubStmt()); } - assert(!caseEltValueListAttr.empty() && "empty case value NYI"); - - auto *ctxt = builder.getContext(); + if (!caseEltValueListAttr.empty()) { + auto caseOpKind = caseEltValueListAttr.size() > 1 + ? mlir::cir::CaseOpKind::Anyof + : mlir::cir::CaseOpKind::Equal; + auto caseAttr = mlir::cir::CaseAttr::get( + ctxt, builder.getArrayAttr(caseEltValueListAttr), + CaseOpKindAttr::get(ctxt, caseOpKind)); + caseAttrs.push_back(caseAttr); + ++caseAttrCount; + } - auto caseAttr = mlir::cir::CaseAttr::get( - ctxt, builder.getArrayAttr(caseEltValueListAttr), - CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1 - ? mlir::cir::CaseOpKind::Anyof - : mlir::cir::CaseOpKind::Equal)); + assert(caseAttrCount > 0 && "there should be at least one valid case attr"); - caseAttrs.push_back(caseAttr); + for (int i = 1; i < caseAttrCount; ++i) { + // If there are multiple case attributes, we need to create a new region + auto *region = currLexScope->createSwitchRegion(); + auto *block = builder.createBlock(region); + } return lastCase; } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index aee0d48f5246..a24af29da3b9 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1003,7 +1003,8 @@ parseSwitchOp(OpAsmParser &parser, // 2. Get the value (next in list) // These needs to be in sync with CIROps.td - if (parser.parseOptionalKeyword(&attrStr, {"default", "equal", "anyof"})) { + if (parser.parseOptionalKeyword(&attrStr, + {"default", "equal", "anyof", "range"})) { ::mlir::StringAttr attrVal; ::mlir::OptionalParseResult parseResult = parser.parseOptionalAttribute( attrVal, parser.getBuilder().getNoneType(), "kind", attrStorage); @@ -1016,8 +1017,9 @@ parseSwitchOp(OpAsmParser &parser, if (attrStr.empty()) { return parser.emitError( - loc, "expected string or keyword containing one of the following " - "enum values for attribute 'kind' [default, equal, anyof]"); + loc, + "expected string or keyword containing one of the following " + "enum values for attribute 'kind' [default, equal, anyof, range]"); } auto attrOptional = ::mlir::cir::symbolizeCaseOpKind(attrStr.str()); @@ -1042,6 +1044,7 @@ parseSwitchOp(OpAsmParser &parser, caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(intCondType, val)); break; } + case cir::CaseOpKind::Range: case cir::CaseOpKind::Anyof: { if (parser.parseComma().failed()) return mlir::failure(); @@ -1129,7 +1132,7 @@ void printSwitchOp(OpAsmPrinter &p, SwitchOp op, auto attr = casesAttr[idx].cast(); auto kind = attr.getKind().getValue(); assert((kind == CaseOpKind::Default || kind == CaseOpKind::Equal || - kind == CaseOpKind::Anyof) && + kind == CaseOpKind::Anyof || kind == CaseOpKind::Range) && "unknown case"); // Case kind @@ -1144,6 +1147,9 @@ void printSwitchOp(OpAsmPrinter &p, SwitchOp op, (intAttrTy.isSigned() ? p << intAttr.getSInt() : p << intAttr.getUInt()); break; } + case cir::CaseOpKind::Range: + assert(attr.getValue().size() == 2 && "range must have two values"); + // The print format of the range is the same as anyof case cir::CaseOpKind::Anyof: { p << ", ["; llvm::interleaveComma(attr.getValue(), p, [&](const Attribute &a) { diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index ce643e6735fa..470ff1dbff3f 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -259,6 +259,43 @@ class CIRSwitchOpFlattening destination); } + // Return the new defaultDestination block. + Block *condBrToRangeDestination(mlir::cir::SwitchOp op, + mlir::PatternRewriter &rewriter, + mlir::Block *rangeDestination, + mlir::Block *defaultDestination, + APInt lowerBound, APInt upperBound) const { + assert(lowerBound.sle(upperBound) && "Invalid range"); + auto resBlock = rewriter.createBlock(defaultDestination); + auto sIntType = mlir::cir::IntType::get(op.getContext(), 32, true); + auto uIntType = mlir::cir::IntType::get(op.getContext(), 32, false); + + auto rangeLength = rewriter.create( + op.getLoc(), sIntType, + mlir::cir::IntAttr::get(op.getContext(), sIntType, + upperBound - lowerBound)); + + auto lowerBoundValue = rewriter.create( + op.getLoc(), sIntType, + mlir::cir::IntAttr::get(op.getContext(), sIntType, lowerBound)); + auto diffValue = rewriter.create( + op.getLoc(), sIntType, mlir::cir::BinOpKind::Sub, op.getCondition(), + lowerBoundValue); + + // Use unsigned comparison to check if the condition is in the range. + auto uDiffValue = rewriter.create( + op.getLoc(), uIntType, CastKind::integral, diffValue); + auto uRangeLength = rewriter.create( + op.getLoc(), uIntType, CastKind::integral, rangeLength); + + auto cmpResult = rewriter.create( + op.getLoc(), mlir::cir::BoolType::get(op.getContext()), + mlir::cir::CmpOpKind::le, uDiffValue, uRangeLength); + rewriter.create(op.getLoc(), cmpResult, + rangeDestination, defaultDestination); + return resBlock; + } + mlir::LogicalResult matchAndRewrite(mlir::cir::SwitchOp op, mlir::PatternRewriter &rewriter) const override { @@ -279,6 +316,10 @@ class CIRSwitchOpFlattening llvm::SmallVector caseDestinations; llvm::SmallVector caseOperands; + llvm::SmallVector> rangeValues; + llvm::SmallVector rangeDestinations; + llvm::SmallVector rangeOperands; + // Initialize default case as optional. mlir::Block *defaultDestination = exitBlock; mlir::ValueRange defaultOperands = exitBlock->getArguments(); @@ -292,16 +333,31 @@ class CIRSwitchOpFlattening auto caseAttr = op.getCases()->getValue()[i].cast(); // Found default case: save destination and operands. - if (caseAttr.getKind().getValue() == mlir::cir::CaseOpKind::Default) { + switch (caseAttr.getKind().getValue()) { + case mlir::cir::CaseOpKind::Default: defaultDestination = ®ion.front(); defaultOperands = region.getArguments(); - } else { + break; + case mlir::cir::CaseOpKind::Range: + assert(caseAttr.getValue().size() == 2 && + "Case range should have 2 case value"); + rangeValues.push_back( + {caseAttr.getValue()[0].cast().getValue(), + caseAttr.getValue()[1].cast().getValue()}); + rangeDestinations.push_back(®ion.front()); + rangeOperands.push_back(region.getArguments()); + break; + case mlir::cir::CaseOpKind::Anyof: + case mlir::cir::CaseOpKind::Equal: // AnyOf cases kind can have multiple values, hence the loop below. for (auto &value : caseAttr.getValue()) { caseValues.push_back(value.cast().getValue()); caseOperands.push_back(region.getArguments()); caseDestinations.push_back(®ion.front()); } + break; + default: + llvm_unreachable("unsupported case kind"); } // Previous case is a fallthrough: branch it to this case. @@ -336,6 +392,33 @@ class CIRSwitchOpFlattening fallthroughYieldOp = nullptr; } + for (size_t index = 0; index < rangeValues.size(); ++index) { + auto lowerBound = rangeValues[index].first; + auto upperBound = rangeValues[index].second; + + // The case range is unreachable, skip it. + if (lowerBound.sgt(upperBound)) + continue; + + // If range is small, add multiple switch instruction cases. + // This magical number is from the original CGStmt code. + constexpr int kSmallRangeThreshold = 64; + if ((upperBound - lowerBound) + .ult(llvm::APInt(32, kSmallRangeThreshold))) { + for (auto iValue = lowerBound; iValue.sle(upperBound); iValue++) { + caseValues.push_back(iValue); + caseOperands.push_back(rangeOperands[index]); + caseDestinations.push_back(rangeDestinations[index]); + } + continue; + } + + defaultDestination = + condBrToRangeDestination(op, rewriter, rangeDestinations[index], + defaultDestination, lowerBound, upperBound); + defaultOperands = rangeOperands[index]; + } + // Set switch op to branch to the newly created blocks. rewriter.setInsertionPoint(op); rewriter.replaceOpWithNewOp( diff --git a/clang/test/CIR/CodeGen/switch-gnurange.cpp b/clang/test/CIR/CodeGen/switch-gnurange.cpp index 7fbd49ad704c..f48a32506252 100644 --- a/clang/test/CIR/CodeGen/switch-gnurange.cpp +++ b/clang/test/CIR/CodeGen/switch-gnurange.cpp @@ -22,10 +22,21 @@ int sw1(enum letter c) { // CIR: cir.func @_Z3sw16letter // CIR: cir.scope { // CIR: cir.switch -// CIR-NEXT: case (anyof, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] : !s32i) { +// CIR-NEXT: case (range, [0, 2] : !s32i) { +// CIR-NEXT: cir.yield +// CIR-NEXT: }, +// CIR-NEXT: case (range, [4, 5] : !s32i) { +// CIR-NEXT: cir.yield +// CIR-NEXT: }, +// CIR-NEXT: case (range, [6, 10] : !s32i) { +// CIR-NEXT: cir.yield +// CIR-NEXT: }, +// CIR-NEXT: case (equal, 3) { +// CIR-NEXT: cir.int<1> // CIR: cir.return // CIR-NEXT: }, // CIR-NEXT: case (default) { +// CIR-NEXT: cir.int<0> // CIR: cir.return // CIR-NEXT: } // CIR-NEXT: ] @@ -33,19 +44,25 @@ int sw1(enum letter c) { // LLVM: @_Z3sw16letter // LLVM: switch i32 %[[C:[0-9]+]], label %[[DEFAULT:[0-9]+]] [ -// LLVM-NEXT: i32 0, label %[[CASE:[0-9]+]] -// LLVM-NEXT: i32 1, label %[[CASE]] -// LLVM-NEXT: i32 2, label %[[CASE]] -// LLVM-NEXT: i32 3, label %[[CASE]] -// LLVM-NEXT: i32 4, label %[[CASE]] -// LLVM-NEXT: i32 5, label %[[CASE]] -// LLVM-NEXT: i32 6, label %[[CASE]] -// LLVM-NEXT: i32 7, label %[[CASE]] -// LLVM-NEXT: i32 8, label %[[CASE]] -// LLVM-NEXT: i32 9, label %[[CASE]] -// LLVM-NEXT: i32 10, label %[[CASE]] +// LLVM-NEXT: i32 3, label %[[CASE_3:[0-9]+]] +// LLVM-NEXT: i32 0, label %[[CASE_0_2:[0-9]+]] +// LLVM-NEXT: i32 1, label %[[CASE_0_2]] +// LLVM-NEXT: i32 2, label %[[CASE_0_2]] +// LLVM-NEXT: i32 4, label %[[CASE_4_5:[0-9]+]] +// LLVM-NEXT: i32 5, label %[[CASE_4_5]] +// LLVM-NEXT: i32 6, label %[[CASE_6_10:[0-9]+]] +// LLVM-NEXT: i32 7, label %[[CASE_6_10]] +// LLVM-NEXT: i32 8, label %[[CASE_6_10]] +// LLVM-NEXT: i32 9, label %[[CASE_6_10]] +// LLVM-NEXT: i32 10, label %[[CASE_6_10]] // LLVM-NEXT: ] -// LLVM: [[CASE]]: +// LLVM: [[CASE_0_2]]: +// LLVM: br label %[[CASE_4_5]] +// LLVM: [[CASE_4_5]]: +// LLVM: br label %[[CASE_6_10]] +// LLVM: [[CASE_6_10]]: +// LLVM: br label %[[CASE_3]] +// LLVM: [[CASE_3]]: // LLVM: store i32 1 // LLVM: ret // LLVM: [[DEFAULT]]: @@ -66,7 +83,7 @@ int sw2(enum letter c) { // CIR: cir.func @_Z3sw26letter // CIR: cir.scope { // CIR: cir.switch -// CIR-NEXT: case (anyof, [0, 1, 2] : !s32i) { +// CIR-NEXT: case (range, [0, 2] : !s32i) { // CIR: cir.return // CIR-NEXT: }, // CIR-NEXT: case (default) { @@ -109,19 +126,19 @@ void sw3(enum letter c) { // CIR: cir.func @_Z3sw36letter // CIR: cir.scope { // CIR: cir.switch -// CIR-NEXT: case (anyof, [0, 1, 2] : !s32i) { +// CIR-NEXT: case (range, [0, 2] : !s32i) { // CIR-NEXT: cir.int<1> // CIR: cir.break // CIR-NEXT: }, -// CIR-NEXT: case (anyof, [3, 4, 5] : !s32i) { +// CIR-NEXT: case (range, [3, 5] : !s32i) { // CIR-NEXT: cir.int<2> // CIR: cir.break // CIR-NEXT: }, -// CIR-NEXT: case (anyof, [6, 7, 8] : !s32i) { +// CIR-NEXT: case (range, [6, 8] : !s32i) { // CIR-NEXT: cir.int<3> // CIR: cir.break // CIR-NEXT: }, -// CIR-NEXT: case (anyof, [9, 10] : !s32i) { +// CIR-NEXT: case (range, [9, 10] : !s32i) { // CIR-NEXT: cir.int<4> // CIR: cir.break // CIR-NEXT: } @@ -155,7 +172,9 @@ void sw3(enum letter c) { // LLVM: store i32 4, ptr %[[X]] // LLVM: br label %[[EPILOG]] // LLVM: [[EPILOG]]: -// LLVM: ret void +// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]] +// LLVM: [[EPILOG_END]]: +// LLVM-NEXT: ret void void sw4(int x) { switch (x) { @@ -169,37 +188,165 @@ void sw4(int x) { // CIR: cir.func @_Z3sw4i // CIR: cir.scope { // CIR: cir.switch -// CIR-NEXT: case (anyof, [66, 67, 68, 69, {{[0-9, ]+}}, 230, 231, 232, 233] : !s32i) { +// CIR-NEXT: case (range, [66, 233] : !s32i) { // CIR-NEXT: cir.break // CIR-NEXT: }, -// CIR-NEXT: case (anyof, [-50, -49, -48, -47, {{[0-9, -]+}}, -1, 0, 1, {{[0-9, ]+}}, 47, 48, 49, 50] : !s32i) { +// CIR-NEXT: case (range, [-50, 50] : !s32i) { // CIR-NEXT: cir.break // CIR-NEXT: } // CIR-NEXT: ] // CIR-NEXT: } // LLVM: @_Z3sw4i +// LLVM: switch i32 %[[X:[0-9]+]], label %[[JUDGE_NEG50_50:[0-9]+]] [ +// LLVM-NEXT: ] +// LLVM: [[CASE_66_233:[0-9]+]]: +// LLVM-NEXT: br label %[[EPILOG:[0-9]+]] +// LLVM: [[CASE_NEG50_50:[0-9]+]]: +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[JUDGE_NEG50_50]]: +// LLVM-NEXT: %[[DIFF:[0-9]+]] = sub i32 %[[X]], -50 +// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 100 +// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_NEG50_50]], label %[[JUDGE_66_233:[0-9]+]] +// LLVM: [[JUDGE_66_233]]: +// LLVM-NEXT: %[[DIFF:[0-9]+]] = sub i32 %[[X]], 66 +// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 167 +// LLVM: br i1 %[[DIFF_CMP]], label %[[CASE_66_233]], label %[[EPILOG]] +// LLVM: [[EPILOG]]: +// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]] +// LLVM: [[EPILOG_END]]: +// LLVM-NEXT: ret void + +void sw5(int x) { + int y = 0; + switch (x) { + case 100 ... -100: + y = 1; + } +} + +// CIR: cir.func @_Z3sw5i +// CIR: cir.scope { +// CIR: cir.switch +// CIR-NEXT: case (range, [100, -100] : !s32i) { +// CIR-NEXT: cir.int<1> +// CIR: cir.yield +// CIR-NEXT: } +// CIR-NEXT: ] + +// LLVM: @_Z3sw5i +// LLVM: switch i32 %[[X:[0-9]+]], label %[[EPILOG:[0-9]+]] [ +// LLVM-NEXT: ] +// LLVM: [[CASE_100_NEG100:[0-9]+]]: +// LLVM-NEXT: store i32 1, ptr %[[Y:[0-9]+]] +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[EPILOG]]: +// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]] +// LLVM: [[EPILOG_END]]: +// LLVM-NEXT: ret void + +void sw6(int x) { + int y = 0; + switch (x) { + case -2147483648 ... 2147483647: + y = 1; + } +} + +// CIR: cir.func @_Z3sw6i +// CIR: cir.scope { +// CIR: cir.switch +// CIR-NEXT: case (range, [-2147483648, 2147483647] : !s32i) { +// CIR-NEXT: cir.int<1> +// CIR: cir.yield +// CIR-NEXT: } +// CIR-NEXT: ] + +// LLVM: @_Z3sw6i // LLVM: switch i32 %[[X:[0-9]+]], label %[[DEFAULT:[0-9]+]] [ -// LLVM-NEXT: i32 66, label %[[CASE_66_233:[0-9]+]] -// LLVM-NEXT: i32 67, label %[[CASE_66_233]] -// ... -// LLVM: i32 232, label %[[CASE_66_233]] -// LLVM-NEXT: i32 233, label %[[CASE_66_233]] -// LLVM-NEXT: i32 -50, label %[[CASE_NEG50_50:[0-9]+]] -// LLVM-NEXT: i32 -49, label %[[CASE_NEG50_50]] -// ... -// LLVM: i32 -1, label %[[CASE_NEG50_50]] -// LLVM-NEXT: i32 0, label %[[CASE_NEG50_50]] -// LLVM-NEXT: i32 1, label %[[CASE_NEG50_50]] -// ... -// LLVM: i32 49, label %[[CASE_NEG50_50]] -// LLVM-NEXT: i32 50, label %[[CASE_NEG50_50]] // LLVM-NEXT: ] -// LLVM: [[CASE_66_233]]: -// LLVM: br label %[[EPILOG:[0-9]+]] -// LLVM: [[CASE_NEG50_50]]: -// LLVM: br label %[[EPILOG]] +// LLVM: [[CASE_MIN_MAX:[0-9]+]]: +// LLVM-NEXT: store i32 1, ptr %[[Y:[0-9]+]] +// LLVM-NEXT: br label %[[EPILOG:[0-9]+]] +// LLVM: [[DEFAULT]]: +// LLVM-NEXT: %[[DIFF:[0-9]+]] = sub i32 %[[X]], -2147483648 +// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], -1 +// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_MIN_MAX]], label %[[EPILOG]] // LLVM: [[EPILOG]]: -// LLVM: ret void +// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]] +// LLVM: [[EPILOG_END]]: +// LLVM-NEXT: ret void + +void sw7(int x) { + switch(x) { + case 0: + break; + case 100 ... 200: + break; + case 1: + break; + case 300 ... 400: + break; + default: + break; + case 500 ... 600: + break; + } +} + +// CIR: cir.func @_Z3sw7i +// CIR: cir.scope { +// CIR: cir.switch +// CIR-NEXT: case (equal, 0) { +// CIR-NEXT: cir.break +// CIR-NEXT: }, +// CIR-NEXT: case (range, [100, 200] : !s32i) { +// CIR-NEXT: cir.break +// CIR-NEXT: }, +// CIR-NEXT: case (equal, 1) { +// CIR-NEXT: cir.break +// CIR-NEXT: }, +// CIR-NEXT: case (range, [300, 400] : !s32i) { +// CIR-NEXT: cir.break +// CIR-NEXT: }, +// CIR-NEXT: case (default) { +// CIR-NEXT: cir.break +// CIR-NEXT: }, +// CIR-NEXT: case (range, [500, 600] : !s32i) { +// CIR-NEXT: cir.break +// CIR-NEXT: } +// LLVM: @_Z3sw7i +// LLVM: switch i32 %[[X:[0-9]+]], label %[[JUDGE_RANGE_500_600:[0-9]+]] [ +// LLVM-NEXT: i32 0, label %[[CASE_0:[0-9]+]] +// LLVM-NEXT: i32 1, label %[[CASE_1:[0-9]+]] +// LLVM-NEXT: ] +// LLVM: [[CASE_0]]: +// LLVM-NEXT: br label %[[EPILOG:[0-9]+]] +// LLVM: [[CASE_100_200:[0-9]+]]: +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[CASE_1]]: +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[CASE_300_400:[0-9]+]]: +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[JUDGE_RANGE_500_600]]: +// LLVM-NEXT: %[[DIFF:[0-9]+]] = sub i32 %[[X]], 500 +// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 100 +// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_500_600:[0-9]+]], label %[[JUDGE_RANGE_300_400:[0-9]+]] +// LLVM: [[JUDGE_RANGE_300_400]]: +// LLVM-NEXT: %[[DIFF:[0-9]+]] = sub i32 %[[X]], 300 +// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 100 +// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_300_400]], label %[[JUDGE_RANGE_100_200:[0-9]+]] +// LLVM: [[JUDGE_RANGE_100_200]]: +// LLVM-NEXT: %[[DIFF:[0-9]+]] = sub i32 %[[X]], 100 +// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 100 +// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_100_200]], label %[[DEFAULT:[0-9]+]] +// LLVM: [[DEFAULT]]: +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[CASE_500_600]]: +// LLVM-NEXT: br label %[[EPILOG]] +// LLVM: [[EPILOG]]: +// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]] +// LLVM: [[EPILOG_END]]: +// LLVM-NEXT: ret void diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir index 1ea6dba49c98..177dfc98c8af 100644 --- a/clang/test/CIR/Transforms/switch.cir +++ b/clang/test/CIR/Transforms/switch.cir @@ -156,7 +156,7 @@ module { // CHECK: } - cir.func @shouldFlatNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i { + cir.func @shouldFlatNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i { %0 = cir.alloca !s32i, !cir.ptr, ["x", init] {alignment = 4 : i64} %1 = cir.alloca !s32i, !cir.ptr, ["y", init] {alignment = 4 : i64} %2 = cir.alloca !s32i, !cir.ptr, ["__retval"] {alignment = 4 : i64} @@ -204,4 +204,67 @@ module { // CHECK: cir.return %9 : !s32i // CHECK: } + + cir.func @flatCaseRange(%arg0: !s32i) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr, ["x", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, !cir.ptr, ["__retval"] {alignment = 4 : i64} + %2 = cir.alloca !s32i, !cir.ptr, ["y", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr + %3 = cir.const #cir.int<0> : !s32i + cir.store %3, %2 : !s32i, !cir.ptr + cir.scope { + %6 = cir.load %0 : !cir.ptr, !s32i + cir.switch (%6 : !s32i) [ + case (equal, -100) { + %7 = cir.const #cir.int<1> : !s32i + cir.store %7, %2 : !s32i, !cir.ptr + cir.break + }, + case (range, [1, 100] : !s32i) { + %7 = cir.const #cir.int<2> : !s32i + cir.store %7, %2 : !s32i, !cir.ptr + cir.break + }, + case (default) { + %7 = cir.const #cir.int<3> : !s32i + cir.store %7, %2 : !s32i, !cir.ptr + cir.break + } + ] + } + %4 = cir.load %2 : !cir.ptr, !s32i + cir.store %4, %1 : !s32i, !cir.ptr + %5 = cir.load %1 : !cir.ptr, !s32i + cir.return %5 : !s32i + } +// CHECK: cir.func @flatCaseRange(%arg0: !s32i) -> !s32i { +// CHECK: cir.switch.flat %[[X:[0-9]+]] : !s32i, ^[[JUDGE_RANGE:bb[0-9]+]] [ +// CHECK-NEXT: -100: ^[[CASE_EQUAL:bb[0-9]+]] +// CHECK-NEXT: ] +// CHECK-NEXT: ^[[CASE_EQUAL]]: +// CHECK-NEXT: cir.int<1> +// CHECK-NEXT: cir.store +// CHECK-NEXT: cir.br ^[[EPILOG:bb[0-9]+]] +// CHECK-NEXT: ^[[CASE_RANGE:bb[0-9]+]]: +// CHECK-NEXT: cir.int<2> +// CHECK-NEXT: cir.store +// CHECK-NEXT: cir.br ^[[EPILOG]] +// CHECK-NEXT: ^[[JUDGE_RANGE]]: +// CHECK-NEXT: %[[RANGE:[0-9]+]] = cir.const #cir.int<99> +// CHECK-NEXT: %[[LOWER_BOUND:[0-9]+]] = cir.const #cir.int<1> +// CHECK-NEXT: %[[DIFF:[0-9]+]] = cir.binop(sub, %[[X]], %[[LOWER_BOUND]]) +// CHECK-NEXT: %[[U_DIFF:[0-9]+]] = cir.cast(integral, %[[DIFF]] : !s32i), !u32i +// CHECK-NEXT: %[[U_RANGE:[0-9]+]] = cir.cast(integral, %[[RANGE]] : !s32i), !u32i +// CHECK-NEXT: %[[CMP_RESULT:[0-9]+]] = cir.cmp(le, %[[U_DIFF]], %[[U_RANGE]]) +// CHECK-NEXT: cir.brcond %[[CMP_RESULT]] ^[[CASE_RANGE]], ^[[CASE_DEFAULT:bb[0-9]+]] +// CHECK-NEXT: ^[[CASE_DEFAULT]]: +// CHECK-NEXT: cir.int<3> +// CHECK-NEXT: cir.store +// CHECK-NEXT: cir.br ^[[EPILOG]] +// CHECK-NEXT: ^[[EPILOG]]: +// CHECK-NEXT: cir.br ^[[EPILOG_END:bb[0-9]+]] +// CHECK-NEXT: ^[[EPILOG_END]]: +// CHECK: cir.return +// CHECK: } + }