diff --git a/lib/nnc/_ccv_nnc_symbolic_graph.h b/lib/nnc/_ccv_nnc_symbolic_graph.h index ad5d76835..819121dcd 100644 --- a/lib/nnc/_ccv_nnc_symbolic_graph.h +++ b/lib/nnc/_ccv_nnc_symbolic_graph.h @@ -241,14 +241,16 @@ inline static int ccv_nnc_exec_dep_hop(const ccv_nnc_exec_dep_t exec_dep, const return -1; } -inline static int ccv_nnc_exec_dep_check(const ccv_nnc_exec_dep_t exec_dep, const int d, const int dd) +inline static int ccv_nnc_exec_dep_check(const ccv_nnc_exec_dep_t exec_dep, const int d, ccv_sparse_matrix_vector_t* const vector, const int dd) { // Check if dd is d's ancestor. const int dd_chain_id = exec_dep.chain_ids[dd]; const int dd_chain_pos = exec_dep.chain_pos[dd]; if (exec_dep.chain_ids[d] == dd_chain_id) return exec_dep.chain_pos[d] > dd_chain_pos; - const ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(exec_dep.deps, d, dd_chain_id); + if (vector == (ccv_sparse_matrix_vector_t*)1) // Special sentinel value to say that we don't have vector found. + return 0; + const ccv_numeric_data_t cell = vector ? ccv_get_sparse_matrix_cell_from_vector(exec_dep.deps, vector, dd_chain_id) : ccv_get_sparse_matrix_cell(exec_dep.deps, d, dd_chain_id); // Check if the chain pos is greater than or equal to dd_chain_pos. If it is, it is an ancestor. if (cell.i32 && cell.i32[0] > 0) return cell.i32[0] >= dd_chain_pos; diff --git a/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c b/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c index 617ab64e7..6213068d2 100644 --- a/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c +++ b/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c @@ -74,6 +74,137 @@ static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int) KHASH_SET_INIT_INT(ccv_cnnp_tensor_symbol_set) +ccv_nnc_exec_dep_t _ccv_nnc_exec_dep_new(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_visit_t* const visit, const int exec_rnum, uint32_t* const maskbit) +{ + int* chain_ids = ccmalloc(sizeof(int) * exec_rnum * 2); + int* chain_pos = chain_ids + exec_rnum; + int* buf = (int*)ccmalloc(sizeof(int) * exec_rnum); + int* reversed_depth = buf; + const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); + int i, j; + // Go reverse order to generate the distance from sink. + ccv_nnc_graph_visit_for_reversed(visit, exec_symbol_info, node, idx, term) { + if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f)))) + continue; + chain_ids[idx] = -1; + if (!node->outgoings || node->outgoings->rnum == 0) + { + reversed_depth[idx] = 0; + continue; + } + int depth = -1; + for (i = 0; i < node->outgoings->rnum; i++) + { + const int outgoing = *(int*)ccv_array_get(node->outgoings, i); + if (outgoing >= exec_rnum) + continue; + depth = ccv_max(depth, reversed_depth[outgoing]); + } + reversed_depth[idx] = depth + 1; + } ccv_nnc_graph_visit_endfor + // Go in order to generate chain ids (if there are multiple exits, we use the reverse depth to break the tie). + // Note that we cannot use depth so-far because then multiple exit nodes are equally good to "inherit" the chain selection. + int chain_count = 0; + ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) { + if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f)))) + continue; + int chain_id = chain_ids[idx]; + if (chain_ids[idx] < 0) + { + chain_id = chain_count; + chain_ids[idx] = chain_id; + chain_pos[idx] = 1; // The first one in this chain. 1-based index because in sparse matrix, 0 is the default value. + chain_count += 1; + } + if (!node->outgoings || node->outgoings->rnum == 0) + continue; + int depth = -1; + int next_idx = -1; + for (i = 0; i < node->outgoings->rnum; i++) + { + const int outgoing = *(int*)ccv_array_get(node->outgoings, i); + if (outgoing >= exec_rnum) + continue; + if (chain_ids[outgoing] < 0 && reversed_depth[outgoing] > depth) + depth = reversed_depth[outgoing], next_idx = outgoing; + } + if (next_idx >= 0) + { + chain_ids[next_idx] = chain_id; + assert(reversed_depth[idx] - depth >= 1); + chain_pos[next_idx] = chain_pos[idx] + (reversed_depth[idx] - depth); + } + } ccv_nnc_graph_visit_endfor + if (exec_rnum < chain_count * 2) // Be more conservative on RAM usage. + buf = ccrealloc(buf, sizeof(int) * chain_count * 2); + ccv_sparse_matrix_t* deps = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, chain_count, CCV_32S | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0); + // It logs which pos on that chain we depend on. We can simply compare that with the chain_pos for a node to know if they are ancestors. +#define for_block(x, val) \ + do { \ + if (((int32_t*)val)[0] > 0) \ + { \ + buf[buf_size * 2] = x; \ + buf[buf_size * 2 + 1] = ((int32_t*)val)[0]; \ + ++buf_size; \ + } \ + } while (0) + int buf_size; + ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) { + if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f)))) + continue; + buf_size = 0; /* save all its parent deps to this buffer */ + ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, idx); + if (vector) + CCV_SPARSE_VECTOR_FOREACH(deps, vector, for_block); + if (!node->outgoings) + continue; + const int chain_id = chain_ids[idx]; + const int pos = chain_pos[idx]; + for (i = 0; i < node->outgoings->rnum; i++) + { + const int outgoing = *(int*)ccv_array_get(node->outgoings, i); + if (outgoing >= exec_rnum) + continue; + const int outgoing_chain_id = chain_ids[outgoing]; + if (outgoing_chain_id != chain_id) + { + ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(deps, outgoing, chain_id); + /* If not found, set, if the current node is the destination node, no need + * set itself as parent of subsequent nodes because its terminal nature. */ + if (!cell.i32 || cell.i32[0] == 0 || cell.i32[0] < pos) + ccv_set_sparse_matrix_cell(deps, outgoing, chain_id, &pos); + } + if (buf_size > 0) + { + ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, outgoing); + for (j = 0; j < buf_size; j++) /* set with all idx's dependencies as well */ + { + if (outgoing_chain_id == buf[j * 2]) // We don't need to add as dependency for the same chain. + continue; + if (!vector) + { + ccv_set_sparse_matrix_cell(deps, outgoing, buf[j * 2], &buf[j * 2 + 1]); + vector = ccv_get_sparse_matrix_vector(deps, outgoing); + continue; + } + ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2]); + /* If not found, set. Otherwise, set to the latest one only if it is later. */ + if (!cell.i32 || cell.i32[0] == 0 || cell.i32[0] <= buf[j * 2 + 1]) + ccv_set_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2], &buf[j * 2 + 1]); + } + } + } + } ccv_nnc_graph_visit_endfor +#undef for_block + ccfree(buf); + ccv_nnc_exec_dep_t exec_dep = { + .chain_ids = chain_ids, + .chain_pos = chain_pos, + .deps = deps + }; + return exec_dep; +} + void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph) { ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints; @@ -120,7 +251,6 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c ccv_array_t* const parameter_trainables = ccv_array_new(sizeof(int), 0, 0); ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0); ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0); - ccv_array_t* const buf = ccv_array_new(sizeof(int), 0, 0); int max_output_size = 0; for (i = 0; i < gradient_checkpoints->rnum; i++) { @@ -560,45 +690,13 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c } } // Find parents to visited_backward_execs, and use that as the starting point of all newly added graph_exec_symbols. Use the visited backward execs as the source, use all its parents as destination, go through with graph visit. - ccv_sparse_matrix_t* const exec_dep = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, graph->exec_symbol_info->rnum, CCV_8U | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0); -#define for_block(x, val) \ - do { \ - if (((uint8_t*)val)[0] != 0) \ - ccv_array_push(buf, &x); \ - } while (0) - const uint8_t one = 1; - // Now go from outputs to inputs, unmark visited ones. - ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { - if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f))) - { - ccv_array_clear(buf); - ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx); - if (vector) - CCV_SPARSE_VECTOR_FOREACH(exec_dep, vector, for_block); - if (node->outgoings && node->outgoings->rnum > 0) - { - ccv_array_t* const outgoings = node->outgoings; - for (k = 0; k < outgoings->rnum; k++) - { - const int outgoing_d = *(int*)ccv_array_get(outgoings, k); - if (outgoing_d >= exec_rnum) - continue; - int l; - // We cannot avoid the ones that visited, because these may not contain all the deps. - ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, idx, &one); - for (l = 0; l < buf->rnum; l++) - ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, *(int*)ccv_array_get(buf, l), &one); - } - } - } - } ccv_nnc_graph_visit_endfor + ccv_nnc_exec_dep_t exec_dep = _ccv_nnc_exec_dep_new(graph, visit, exec_rnum, maskbit); // Now go from outputs to inputs, unmark visited ones. ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { if (idx < exec_rnum) maskbit[idx >> 5] &= ~(1u << (idx & 0x1f)); } ccv_nnc_graph_visit_endfor ccv_nnc_graph_visit_free(visit); -#undef for_block // Go through visited backward execs, remove the ones that has no dependency on any replaced backward execs. for (j = 0; j < visited_backward_execs->rnum;) { @@ -608,17 +706,15 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c ++j; continue; } - ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx); + ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx); + if (!vector) + vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find. int flag = 0; -#define for_block(x, val) \ - do { \ - if (((uint8_t*)val)[0] != 0) \ - if (ccv_array_contain_int(replaced_backward_execs, x)) \ - flag = 1; \ - } while (0) - if (vector) - CCV_SPARSE_VECTOR_FOREACH(exec_dep, vector, for_block); -#undef for_block + for (k = 0; !flag && k < replaced_backward_execs->rnum; k++) + { + const int d = *(int*)ccv_array_get(replaced_backward_execs, k); + flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d); + } if (!flag) { if (j < visited_backward_execs->rnum - 1) @@ -632,17 +728,15 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c for (j = 0; j < replaced_backward_execs->rnum; j++) { const int idx = *(int*)ccv_array_get(replaced_backward_execs, j); - ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx); + ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx); + if (!vector) + vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find. int flag = 0; -#define for_block(x, val) \ - do { \ - if (((uint8_t*)val)[0] != 0) \ - if (ccv_array_contain_int(visited_backward_execs, x)) \ - flag = 1; \ - } while (0) - if (vector) - CCV_SPARSE_VECTOR_FOREACH(exec_dep, vector, for_block); -#undef for_block + for (k = 0; !flag && k < visited_backward_execs->rnum; k++) + { + const int d = *(int*)ccv_array_get(visited_backward_execs, k); + flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d); + } // If this one has no parents that is within the visited_backward_execs, it is a good place for us to add all its parents as dependency for input_execs. if (!flag) { @@ -662,7 +756,7 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c } } } - ccv_matrix_free(exec_dep); + ccv_nnc_exec_dep_free(exec_dep); // Go through all exec, free ones that doesn't have output used. // Reuse this array because it is not useful any more. ccv_array_t* forward_pass_inputs = visited_backward_execs; @@ -784,7 +878,6 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols); kh_destroy(ccv_cnnp_tensor_symbol_set, parameters_or_internals); ccfree(max_outputs); - ccv_array_free(buf); ccv_array_free(newly_used_outputs); ccv_array_free(parameters); ccv_array_free(parameter_ids); diff --git a/lib/nnc/ccv_nnc_symbolic_graph_chain_decomposition.c b/lib/nnc/ccv_nnc_symbolic_graph_chain_decomposition.c index 3772a9641..e6f3c6720 100644 --- a/lib/nnc/ccv_nnc_symbolic_graph_chain_decomposition.c +++ b/lib/nnc/ccv_nnc_symbolic_graph_chain_decomposition.c @@ -80,6 +80,8 @@ ccv_nnc_exec_dep_t ccv_nnc_exec_dep_new(const ccv_nnc_symbolic_graph_t* const gr } while (0) int buf_size; ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) { + if (node->flags & CCV_NNC_GRAPH_EXEC_DEAD) + continue; buf_size = 0; /* save all its parent deps to this buffer */ ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, idx); if (vector) diff --git a/lib/nnc/ccv_nnc_symbolic_graph_memory_reduction.c b/lib/nnc/ccv_nnc_symbolic_graph_memory_reduction.c index 52232a665..ad487bdbb 100644 --- a/lib/nnc/ccv_nnc_symbolic_graph_memory_reduction.c +++ b/lib/nnc/ccv_nnc_symbolic_graph_memory_reduction.c @@ -163,7 +163,7 @@ void ccv_nnc_symbolic_graph_memory_reduction(ccv_nnc_symbolic_graph_t* const gra continue; } // Check dependencies, if there is a dependency from y node to dd, dd cannot be source. - const int checked = ccv_nnc_exec_dep_check(exec_deps, dd, ddd); + const int checked = ccv_nnc_exec_dep_check(exec_deps, dd, 0, ddd); if (checked) flag = 1; }