Skip to content

Commit

Permalink
[CIR][CodeGen] Handle the case of 'case' after label statement after …
Browse files Browse the repository at this point in the history
…'case' (llvm#879)

Motivation example:

```
extern "C" void action1();
extern "C" void action2();

extern "C" void case_follow_label(int v) {
  switch (v) {
    case 1:
    label:
    case 2:
      action1();
      break;
    default:
      action2();
      goto label;
  }
}
```

When we compile it, we will meet:

```
  case Stmt::CaseStmtClass:
  case Stmt::DefaultStmtClass:
    assert(0 &&
           "Should not get here, currently handled directly from SwitchStmt");
    break;
```

in `buildStmt`. The cause is clear. We call `buildStmt` when we build
the label stmt.

To solve this, I think we should be able to build case stmt in
buildStmt. But the new problem is, we need to pass the information like
caseAttr and condType. So I tried to add such informations in
CIRGenFunction as data member.
  • Loading branch information
ChuanqiXu9 authored and smeenai committed Oct 9, 2024
1 parent 183567b commit a073bbb
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 20 deletions.
15 changes: 9 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,13 @@ class CIRGenFunction : public CIRGenTypeCache {
// applies to. nullptr if there is no 'musttail' on the current statement.
const clang::CallExpr *MustTailCall = nullptr;

/// The attributes of cases collected during emitting the body of a switch
/// stmt.
llvm::SmallVector<llvm::SmallVector<mlir::Attribute, 4>, 2> caseAttrsStack;

/// The type of the condition for the emitting switch statement.
llvm::SmallVector<mlir::Type, 2> condTypeStack;

clang::ASTContext &getContext() const;

CIRGenBuilderTy &getBuilder() { return builder; }
Expand Down Expand Up @@ -1210,13 +1217,9 @@ class CIRGenFunction : public CIRGenTypeCache {
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);
mlir::LogicalResult buildSwitchCase(const clang::SwitchCase &S);

mlir::LogicalResult
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);
mlir::LogicalResult buildSwitchBody(const clang::Stmt *S);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down
34 changes: 20 additions & 14 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,

case Stmt::CaseStmtClass:
case Stmt::DefaultStmtClass:
assert(0 &&
"Should not get here, currently handled directly from SwitchStmt");
return buildSwitchCase(cast<SwitchCase>(*S));
break;

case Stmt::BreakStmtClass:
Expand Down Expand Up @@ -715,14 +714,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
return buildCaseDefaultCascade(&S, condType, caseAttrs);
}

mlir::LogicalResult
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
mlir::LogicalResult CIRGenFunction::buildSwitchCase(const SwitchCase &S) {
assert(!caseAttrsStack.empty() &&
"build switch case without seeting case attrs");
assert(!condTypeStack.empty() &&
"build switch case without specifying the type of the condition");

if (S.getStmtClass() == Stmt::CaseStmtClass)
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);
return buildCaseStmt(cast<CaseStmt>(S), condTypeStack.back(),
caseAttrsStack.back());

if (S.getStmtClass() == Stmt::DefaultStmtClass)
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);
return buildDefaultStmt(cast<DefaultStmt>(S), condTypeStack.back(),
caseAttrsStack.back());

llvm_unreachable("expect case or default stmt");
}
Expand Down Expand Up @@ -987,15 +991,13 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildSwitchBody(
const Stmt *S, mlir::Type condType,
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
mlir::LogicalResult CIRGenFunction::buildSwitchBody(const Stmt *S) {
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
mlir::Block *lastCaseBlock = nullptr;
auto res = mlir::success();
for (auto *c : compoundStmt->body()) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
res = buildSwitchCase(*switchCase, condType, caseAttrs);
res = buildSwitchCase(*switchCase);
lastCaseBlock = builder.getBlock();
} else if (lastCaseBlock) {
// This means it's a random stmt following up a case, just
Expand Down Expand Up @@ -1045,12 +1047,16 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
currLexScope->setAsSwitch();

llvm::SmallVector<mlir::Attribute, 4> caseAttrs;
caseAttrsStack.push_back({});
condTypeStack.push_back(condV.getType());

res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);
res = buildSwitchBody(S.getBody());

os.addRegions(currLexScope->getSwitchRegions());
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
os.addAttribute("cases", builder.getArrayAttr(caseAttrsStack.back()));

caseAttrsStack.pop_back();
condTypeStack.pop_back();
});

if (res.failed())
Expand Down
48 changes: 48 additions & 0 deletions clang/test/CIR/CodeGen/goto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,51 @@ extern "C" void multiple_non_case(int v) {
// NOFLAT: cir.label
// NOFLAT: cir.call @action2()
// NOFLAT: cir.break

extern "C" void case_follow_label(int v) {
switch (v) {
case 1:
label:
case 2:
action1();
break;
default:
action2();
goto label;
}
}

// NOFLAT: cir.func @case_follow_label
// NOFLAT: cir.switch
// NOFLAT: case (equal, 1)
// NOFLAT: cir.label "label"
// NOFLAT: cir.yield
// NOFLAT: case (equal, 2)
// NOFLAT: cir.call @action1()
// NOFLAT: cir.break
// NOFLAT: case (default)
// NOFLAT: cir.call @action2()
// NOFLAT: cir.goto "label"

extern "C" void default_follow_label(int v) {
switch (v) {
case 1:
case 2:
action1();
break;
label:
default:
action2();
goto label;
}
}

// NOFLAT: cir.func @default_follow_label
// NOFLAT: cir.switch
// NOFLAT: case (anyof, [1, 2] : !s32i)
// NOFLAT: cir.call @action1()
// NOFLAT: cir.break
// NOFLAT: cir.label "label"
// NOFLAT: case (default)
// NOFLAT: cir.call @action2()
// NOFLAT: cir.goto "label"

0 comments on commit a073bbb

Please sign in to comment.