From 94be2dd65e404b59bb05e4e64246ff770641ffcd Mon Sep 17 00:00:00 2001 From: Vedant Paranjape <22630228+VedantParanjape@users.noreply.github.com> Date: Sun, 27 Aug 2023 02:12:56 +0530 Subject: [PATCH] WIP: Add infra to construct a while loop using the loop_info analysis --- include/blocks/basic_blocks.h | 1 + include/blocks/loops.h | 4 + src/blocks/basic_blocks.cpp | 8 ++ src/blocks/dominance.cpp | 2 +- src/blocks/loops.cpp | 211 +++++++++++++++++++++++++++++++- src/builder/builder_context.cpp | 23 +++- 6 files changed, 244 insertions(+), 5 deletions(-) diff --git a/include/blocks/basic_blocks.h b/include/blocks/basic_blocks.h index 828f2e7..937ccd2 100644 --- a/include/blocks/basic_blocks.h +++ b/include/blocks/basic_blocks.h @@ -15,6 +15,7 @@ class basic_block { block::expr::Ptr branch_expr; block::stmt::Ptr parent; unsigned int ast_index; + unsigned int ast_depth; unsigned int id; std::string name; }; diff --git a/include/blocks/loops.h b/include/blocks/loops.h index c3ec74e..1242e7c 100644 --- a/include/blocks/loops.h +++ b/include/blocks/loops.h @@ -23,6 +23,7 @@ class loop { stmt::Ptr entry_stmt; } loop_bounds; + unsigned int loop_id; basic_block::cfg_block blocks; std::unordered_set blocks_id_map; std::shared_ptr parent_loop; @@ -37,6 +38,8 @@ class loop_info { analyze(); } std::shared_ptr allocate_loop(std::shared_ptr header); + block::stmt_block::Ptr convert_to_ast(block::stmt_block::Ptr ast); + std::map> postorder_loops_map; std::vector> loops; std::vector> top_level_loops; @@ -44,6 +47,7 @@ class loop_info { basic_block::cfg_block parent_ast; dominator_analysis dta; std::map> bb_loop_map; + void postorder_dfs_helper(std::vector &postorder_loops_map, std::vector &visited_loops, int id); // discover loops during traversal of the abstract syntax tree void analyze(); }; diff --git a/src/blocks/basic_blocks.cpp b/src/blocks/basic_blocks.cpp index a9c8aeb..1261179 100644 --- a/src/blocks/basic_blocks.cpp +++ b/src/blocks/basic_blocks.cpp @@ -14,6 +14,7 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) { auto bb = std::make_shared(std::to_string(basic_block_count)); bb->parent = st; bb->ast_index = ast_index_counter++; + bb->ast_depth = 0; work_list.push_back(bb); basic_block_count++; } @@ -40,6 +41,7 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) { stmt_block_list.push_back(std::make_shared(std::to_string(basic_block_count++))); stmt_block_list.back()->parent = st; stmt_block_list.back()->ast_index = ast_index_counter++; + stmt_block_list.back()->ast_depth = bb->ast_depth + 1; } // set the basic block successors @@ -77,6 +79,8 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) { auto exit_bb = std::make_shared("exit" + std::to_string(basic_block_count)); // assign it a empty stmt_block as parent exit_bb->parent = std::make_shared(); + // set the ast depth of the basic block + exit_bb->ast_depth = bb->ast_depth; // check if this is the last block, if yes the successor will be empty if (bb->successor.size()) { // set the successor to the block that if_stmt successor pointer to earlier @@ -94,6 +98,8 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) { auto then_bb = std::make_shared(std::to_string(++basic_block_count)); // set the parent of this block as the then stmts then_bb->parent = if_stmt_->then_stmt; + // set the ast depth of the basic block + then_bb->ast_depth = bb->ast_depth; // set the successor of this block to be the exit block then_bb->successor.push_back(exit_bb); // set the successor of the original if_stmt block to be this then block @@ -106,6 +112,8 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) { auto else_bb = std::make_shared(std::to_string(++basic_block_count)); // set the parent of this block as the else stmts else_bb->parent = if_stmt_->else_stmt; + // set the ast depth of the basic block + else_bb->ast_depth = bb->ast_depth; // set the successor of this block to be the exit block else_bb->successor.push_back(exit_bb); // set the successor of the orignal if_stmt block to be this else block diff --git a/src/blocks/dominance.cpp b/src/blocks/dominance.cpp index ef37cf2..72726ad 100644 --- a/src/blocks/dominance.cpp +++ b/src/blocks/dominance.cpp @@ -16,7 +16,7 @@ dominator_analysis::dominator_analysis(basic_block::cfg_block &cfg) : cfg_(cfg) void dominator_analysis::postorder_idom_helper(std::vector &visited, int id) { for (int idom_id: idom_map[id]) { - std::cerr << idom_id << "\n"; + // std::cerr << idom_id << "\n"; if (idom_id != -1 && !visited[idom_id]) { visited[idom_id] = true; postorder_idom_helper(visited, idom_id); diff --git a/src/blocks/loops.cpp b/src/blocks/loops.cpp index 27adcf6..28cba63 100644 --- a/src/blocks/loops.cpp +++ b/src/blocks/loops.cpp @@ -10,6 +10,16 @@ std::shared_ptr loop_info::allocate_loop(std::shared_ptr head return loops.back(); } +void loop_info::postorder_dfs_helper(std::vector &postorder_loops_map, std::vector &visited_loops, int id) { + for (auto subloop: loops[id]->subloops) { + if (!visited_loops[subloop->loop_id]) { + visited_loops[subloop->loop_id] = true; + postorder_dfs_helper(postorder_loops_map, visited_loops, subloop->loop_id); + postorder_loops_map.push_back(subloop->loop_id); + } + } +} + void loop_info::analyze() { std::vector idom = dta.get_idom(); @@ -126,4 +136,203 @@ void loop_info::analyze() { } } } -} \ No newline at end of file + + // Assign id to the loops + for (unsigned int i = 0; i < loops.size(); i++) { + loops[i]->loop_id = i; + } + + // build a loop tree + std::vector visited_loops(loops.size()); + visited_loops.assign(visited_loops.size(), false); + for (auto loop: top_level_loops) { + std::vector postorder_loop_tree; + visited_loops[loop->loop_id] = true; + + postorder_dfs_helper(postorder_loop_tree, visited_loops, loop->loop_id); + postorder_loop_tree.push_back(loop->loop_id); + postorder_loops_map[loop->loop_id] = postorder_loop_tree; + } +} + +static stmt::Ptr get_loop_block(std::shared_ptr loop_header, block::stmt_block::Ptr ast) { + block::stmt::Ptr current_ast = to(ast); + std::vector current_block = to(current_ast)->stmts; + // unsigned int ast_index = loop_header->ast_index; + std::deque worklist; + std::map ast_parent_map; + + for (auto stmt: current_block) { + ast_parent_map[stmt] = current_ast; + } + worklist.insert(worklist.end(), current_block.begin(), current_block.end()); + + while (worklist.size()) { + stmt::Ptr worklist_top = worklist.front(); + worklist.pop_front(); + + if (isa(worklist_top)) { + stmt_block::Ptr wl_stmt_block = to(worklist_top); + for (auto stmt: wl_stmt_block->stmts) { + ast_parent_map[stmt] = worklist_top; + } + worklist.insert(worklist.end(), wl_stmt_block->stmts.begin(), wl_stmt_block->stmts.end()); + } + else if (isa(worklist_top)) { + if_stmt::Ptr wl_if_stmt = to(worklist_top); + + if (to(wl_if_stmt->then_stmt)->stmts.size() != 0) { + stmt_block::Ptr wl_if_then_stmt = to(wl_if_stmt->then_stmt); + for (auto stmt: wl_if_then_stmt->stmts) { + ast_parent_map[stmt] = worklist_top; + } + worklist.insert(worklist.end(), wl_if_then_stmt->stmts.begin(), wl_if_then_stmt->stmts.end()); + } + if (to(wl_if_stmt->else_stmt)->stmts.size() != 0) { + stmt_block::Ptr wl_if_else_stmt = to(wl_if_stmt->else_stmt); + for (auto stmt: wl_if_else_stmt->stmts) { + ast_parent_map[stmt] = worklist_top; + } + worklist.insert(worklist.end(), wl_if_else_stmt->stmts.begin(), wl_if_else_stmt->stmts.end()); + } + } + else if (isa(worklist_top)) { + label_stmt::Ptr wl_label_stmt = to(worklist_top); + if (worklist_top == loop_header->parent) + return ast_parent_map[worklist_top]; + } + else if (isa(worklist_top)) { + goto_stmt::Ptr wl_goto_stmt = to(worklist_top); + if (worklist_top == loop_header->parent) + return ast_parent_map[worklist_top]; + } + } + + return nullptr; +} + +static void replace_loop_latches(std::shared_ptr loop, block::stmt_block::Ptr ast) { + for (auto latch : loop->loop_latch_blocks) { + stmt::Ptr loop_latch_ast = get_loop_block(latch, ast); + if (isa(loop_latch_ast)) { + std::vector &temp_loop_ast = to(loop_latch_ast)->stmts; + std::replace(temp_loop_ast.begin(), temp_loop_ast.end(), temp_loop_ast[latch->ast_index], to(std::make_shared())); + } + else if (isa(loop_latch_ast)) { + stmt_block::Ptr if_then_block = to(to(loop_latch_ast)->then_stmt); + stmt_block::Ptr if_else_block = to(to(loop_latch_ast)->else_stmt); + + if (if_then_block->stmts.size() && if_then_block->stmts[latch->ast_index] == latch->parent) { + std::replace(if_then_block->stmts.begin(), if_then_block->stmts.end(), if_then_block->stmts[latch->ast_index], to(std::make_shared())); + } + else if (if_else_block->stmts.size() && if_else_block->stmts[latch->ast_index] == latch->parent) { + std::replace(if_else_block->stmts.begin(), if_else_block->stmts.end(), if_else_block->stmts[latch->ast_index], to(std::make_shared())); + } + } + } +} + +block::stmt_block::Ptr loop_info::convert_to_ast(block::stmt_block::Ptr ast) { + for (auto loop_map: postorder_loops_map) { + // std::cerr << "== top level loop tree ==\n"; + for (auto postorder: loop_map.second) { + // std::cerr << postorder <<"\n"; + block::stmt::Ptr loop_header_ast = get_loop_block(loops[postorder]->header_block, ast); + + while_stmt::Ptr while_block = std::make_shared(); + while_block->cond = std::make_shared(); + to(while_block->cond)->value = 1; + while_block->body = std::make_shared(); + + if (isa(loop_header_ast)) { + unsigned int ast_index = loops[postorder]->header_block->ast_index; + if (to(loop_header_ast)->stmts[ast_index] == loops[postorder]->header_block->parent) { + stmt_block::Ptr then_block = to(to(to(loop_header_ast)->stmts[ast_index + 1])->then_stmt); + stmt_block::Ptr else_block = to(to(to(loop_header_ast)->stmts[ast_index + 1])->else_stmt); + + // if (isa(then_block->stmts.back())) { + // then_block->stmts.pop_back(); + // then_block->stmts.push_back(std::make_shared()); + // } + replace_loop_latches(loops[postorder], ast); + + else_block->stmts.push_back(std::make_shared()); + to(while_block->body)->stmts.push_back(to(loop_header_ast)->stmts[ast_index + 1]); + // while_block->cond = to(to(loop_header_ast)->stmts[ast_index + 1])->cond; + // while_block->dump(std::cerr, 0); + // std::cerr << "found loop header in stmt block\n"; + + // if block to be replaced with while block + std::vector &temp_ast = to(loop_header_ast)->stmts; + std::replace(temp_ast.begin(), temp_ast.end(), temp_ast[ast_index + 1], to(while_block)); + temp_ast.erase(temp_ast.begin() + ast_index); + } + else { + // std::cerr << "not found loop header in stmt block\n"; + } + } + else if (isa(loop_header_ast)) { + unsigned int ast_index = loops[postorder]->header_block->ast_index; + stmt_block::Ptr if_then_block = to(to(loop_header_ast)->then_stmt); + stmt_block::Ptr if_else_block = to(to(loop_header_ast)->else_stmt); + + if (if_then_block->stmts.size() != 0) { + if (if_then_block->stmts[ast_index] == loops[postorder]->header_block->parent) { + stmt_block::Ptr then_block = to(to(if_then_block->stmts[ast_index + 1])->then_stmt); + stmt_block::Ptr else_block = to(to(if_then_block->stmts[ast_index + 1])->else_stmt); + + replace_loop_latches(loops[postorder], ast); + + else_block->stmts.push_back(std::make_shared()); + to(while_block->body)->stmts.push_back(if_then_block->stmts[ast_index + 1]); + // while_block->cond = to(loop_header_ast)->cond; + + // while_block->dump(std::cerr, 0); + // std::cerr << "found loop header in if-then stmt\n"; + + // if block to be replaced with while block + std::vector &temp_ast = if_then_block->stmts; + std::replace(temp_ast.begin(), temp_ast.end(), temp_ast[ast_index + 1], to(while_block)); + temp_ast.erase(temp_ast.begin() + ast_index); + } + else { + // loop_header_ast->dump(std::cerr, 0); + // std::cerr << "not found loop header in if-then stmt\n"; + } + } + else if (if_else_block->stmts.size() != 0) { + if (if_else_block->stmts[ast_index] == loops[postorder]->header_block->parent) { + stmt_block::Ptr then_block = to(to(if_else_block->stmts[ast_index + 1])->then_stmt); + stmt_block::Ptr else_block = to(to(if_else_block->stmts[ast_index + 1])->else_stmt); + + replace_loop_latches(loops[postorder], ast); + + else_block->stmts.push_back(std::make_shared()); + to(while_block->body)->stmts.push_back(if_else_block->stmts[ast_index + 1]); + // while_block->cond = to(loop_header_ast)->cond; + + // while_block->dump(std::cerr, 0); + // std::cerr << "found loop header in if-else stmt\n"; + + // if block to be replaced with while block + std::vector &temp_ast = if_else_block->stmts; + std::replace(temp_ast.begin(), temp_ast.end(), temp_ast[ast_index + 1], to(while_block)); + temp_ast.erase(temp_ast.begin() + ast_index); + } + else { + // loop_header_ast->dump(std::cerr, 0); + // std::cerr << "not found loop header in if-else stmt\n"; + } + } + } + else { + // std::cerr << "loop header not found\n"; + } + // insert into AST - std::replace + // set the ast to loop depth + 1 + // loops[loop_tree.first]->header_block->ast_index + } + } + + return ast; +} diff --git a/src/builder/builder_context.cpp b/src/builder/builder_context.cpp index 3d0de3b..269ccdf 100644 --- a/src/builder/builder_context.cpp +++ b/src/builder/builder_context.cpp @@ -300,6 +300,7 @@ block::stmt::Ptr builder_context::extract_ast_from_function_impl(void) { for (auto pred: bb->predecessor) { std::cerr << pred->name << ", "; } + std::cerr << bb->ast_depth; std::cerr << "\n"; if (bb->branch_expr) { std::cerr << " "; @@ -386,14 +387,30 @@ block::stmt::Ptr builder_context::extract_ast_from_function_impl(void) { for (auto subl: loop->subloops) std::cerr << "(loop header: " << subl->header_block->id << ") "; std::cerr << "\n"; } + + std::cerr << "++++++ top level loops ++++++ \n"; + for (auto top_level_loop: LI.top_level_loops) std::cerr << "(loop header: " << top_level_loop->header_block->id << ") "; + std::cerr << "\n"; + + std::cerr << "++++++ preorder loops tree ++++++ \n"; + for (auto loop_tree: LI.postorder_loops_map) { + std::cerr << "loop tree root: (loop header: " << LI.loops[loop_tree.first]->header_block->id << ")\n"; + std::cerr << "postorder: "; + for (auto node: loop_tree.second) std::cerr << node << " "; + std::cerr << "\n"; + } std::cerr << "++++++ loop info ++++++ \n"; + std::cerr << "++++++ convert to ast ++++++ \n"; + LI.convert_to_ast(block::to(ast)); + std::cerr << "++++++ convert to ast ++++++ \n"; + if (feature_unstructured) return ast; - block::loop_finder finder; - finder.ast = ast; - ast->accept(&finder); + // block::loop_finder finder; + // finder.ast = ast; + // ast->accept(&finder); block::for_loop_finder for_finder; for_finder.ast = ast;