Skip to content

Commit

Permalink
Use new exec_dep logic for gradient checkpointing.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 3, 2024
1 parent 787623a commit ce93a48
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 59 deletions.
6 changes: 4 additions & 2 deletions lib/nnc/_ccv_nnc_symbolic_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,16 @@ inline static int ccv_nnc_exec_dep_hop(const ccv_nnc_exec_dep_t exec_dep, const
return -1;
}

inline static int ccv_nnc_exec_dep_check(const ccv_nnc_exec_dep_t exec_dep, const int d, const int dd)
inline static int ccv_nnc_exec_dep_check(const ccv_nnc_exec_dep_t exec_dep, const int d, ccv_sparse_matrix_vector_t* const vector, const int dd)
{
// Check if dd is d's ancestor.
const int dd_chain_id = exec_dep.chain_ids[dd];
const int dd_chain_pos = exec_dep.chain_pos[dd];
if (exec_dep.chain_ids[d] == dd_chain_id)
return exec_dep.chain_pos[d] > dd_chain_pos;
const ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(exec_dep.deps, d, dd_chain_id);
if (vector == (ccv_sparse_matrix_vector_t*)1) // Special sentinel value to say that we don't have vector found.
return 0;
const ccv_numeric_data_t cell = vector ? ccv_get_sparse_matrix_cell_from_vector(exec_dep.deps, vector, dd_chain_id) : ccv_get_sparse_matrix_cell(exec_dep.deps, d, dd_chain_id);
// Check if the chain pos is greater than or equal to dd_chain_pos. If it is, it is an ancestor.
if (cell.i32 && cell.i32[0] > 0)
return cell.i32[0] >= dd_chain_pos;
Expand Down
205 changes: 149 additions & 56 deletions lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,137 @@ static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void*
KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int)
KHASH_SET_INIT_INT(ccv_cnnp_tensor_symbol_set)

ccv_nnc_exec_dep_t _ccv_nnc_exec_dep_new(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_visit_t* const visit, const int exec_rnum, uint32_t* const maskbit)
{
int* chain_ids = ccmalloc(sizeof(int) * exec_rnum * 2);
int* chain_pos = chain_ids + exec_rnum;
int* buf = (int*)ccmalloc(sizeof(int) * exec_rnum);
int* reversed_depth = buf;
const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
int i, j;
// Go reverse order to generate the distance from sink.
ccv_nnc_graph_visit_for_reversed(visit, exec_symbol_info, node, idx, term) {
if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f))))
continue;
chain_ids[idx] = -1;
if (!node->outgoings || node->outgoings->rnum == 0)
{
reversed_depth[idx] = 0;
continue;
}
int depth = -1;
for (i = 0; i < node->outgoings->rnum; i++)
{
const int outgoing = *(int*)ccv_array_get(node->outgoings, i);
if (outgoing >= exec_rnum)
continue;
depth = ccv_max(depth, reversed_depth[outgoing]);
}
reversed_depth[idx] = depth + 1;
} ccv_nnc_graph_visit_endfor
// Go in order to generate chain ids (if there are multiple exits, we use the reverse depth to break the tie).
// Note that we cannot use depth so-far because then multiple exit nodes are equally good to "inherit" the chain selection.
int chain_count = 0;
ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) {
if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f))))
continue;
int chain_id = chain_ids[idx];
if (chain_ids[idx] < 0)
{
chain_id = chain_count;
chain_ids[idx] = chain_id;
chain_pos[idx] = 1; // The first one in this chain. 1-based index because in sparse matrix, 0 is the default value.
chain_count += 1;
}
if (!node->outgoings || node->outgoings->rnum == 0)
continue;
int depth = -1;
int next_idx = -1;
for (i = 0; i < node->outgoings->rnum; i++)
{
const int outgoing = *(int*)ccv_array_get(node->outgoings, i);
if (outgoing >= exec_rnum)
continue;
if (chain_ids[outgoing] < 0 && reversed_depth[outgoing] > depth)
depth = reversed_depth[outgoing], next_idx = outgoing;
}
if (next_idx >= 0)
{
chain_ids[next_idx] = chain_id;
assert(reversed_depth[idx] - depth >= 1);
chain_pos[next_idx] = chain_pos[idx] + (reversed_depth[idx] - depth);
}
} ccv_nnc_graph_visit_endfor
if (exec_rnum < chain_count * 2) // Be more conservative on RAM usage.
buf = ccrealloc(buf, sizeof(int) * chain_count * 2);
ccv_sparse_matrix_t* deps = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, chain_count, CCV_32S | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0);
// It logs which pos on that chain we depend on. We can simply compare that with the chain_pos for a node to know if they are ancestors.
#define for_block(x, val) \
do { \
if (((int32_t*)val)[0] > 0) \
{ \
buf[buf_size * 2] = x; \
buf[buf_size * 2 + 1] = ((int32_t*)val)[0]; \
++buf_size; \
} \
} while (0)
int buf_size;
ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) {
if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f))))
continue;
buf_size = 0; /* save all its parent deps to this buffer */
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, idx);
if (vector)
CCV_SPARSE_VECTOR_FOREACH(deps, vector, for_block);
if (!node->outgoings)
continue;
const int chain_id = chain_ids[idx];
const int pos = chain_pos[idx];
for (i = 0; i < node->outgoings->rnum; i++)
{
const int outgoing = *(int*)ccv_array_get(node->outgoings, i);
if (outgoing >= exec_rnum)
continue;
const int outgoing_chain_id = chain_ids[outgoing];
if (outgoing_chain_id != chain_id)
{
ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(deps, outgoing, chain_id);
/* If not found, set, if the current node is the destination node, no need
* set itself as parent of subsequent nodes because its terminal nature. */
if (!cell.i32 || cell.i32[0] == 0 || cell.i32[0] < pos)
ccv_set_sparse_matrix_cell(deps, outgoing, chain_id, &pos);
}
if (buf_size > 0)
{
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, outgoing);
for (j = 0; j < buf_size; j++) /* set with all idx's dependencies as well */
{
if (outgoing_chain_id == buf[j * 2]) // We don't need to add as dependency for the same chain.
continue;
if (!vector)
{
ccv_set_sparse_matrix_cell(deps, outgoing, buf[j * 2], &buf[j * 2 + 1]);
vector = ccv_get_sparse_matrix_vector(deps, outgoing);
continue;
}
ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2]);
/* If not found, set. Otherwise, set to the latest one only if it is later. */
if (!cell.i32 || cell.i32[0] == 0 || cell.i32[0] <= buf[j * 2 + 1])
ccv_set_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2], &buf[j * 2 + 1]);
}
}
}
} ccv_nnc_graph_visit_endfor
#undef for_block
ccfree(buf);
ccv_nnc_exec_dep_t exec_dep = {
.chain_ids = chain_ids,
.chain_pos = chain_pos,
.deps = deps
};
return exec_dep;
}

void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph)
{
ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints;
Expand Down Expand Up @@ -120,7 +251,6 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
ccv_array_t* const parameter_trainables = ccv_array_new(sizeof(int), 0, 0);
ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0);
ccv_array_t* const buf = ccv_array_new(sizeof(int), 0, 0);
int max_output_size = 0;
for (i = 0; i < gradient_checkpoints->rnum; i++)
{
Expand Down Expand Up @@ -560,45 +690,13 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
}
}
// Find parents to visited_backward_execs, and use that as the starting point of all newly added graph_exec_symbols. Use the visited backward execs as the source, use all its parents as destination, go through with graph visit.
ccv_sparse_matrix_t* const exec_dep = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, graph->exec_symbol_info->rnum, CCV_8U | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0);
#define for_block(x, val) \
do { \
if (((uint8_t*)val)[0] != 0) \
ccv_array_push(buf, &x); \
} while (0)
const uint8_t one = 1;
// Now go from outputs to inputs, unmark visited ones.
ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f)))
{
ccv_array_clear(buf);
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx);
if (vector)
CCV_SPARSE_VECTOR_FOREACH(exec_dep, vector, for_block);
if (node->outgoings && node->outgoings->rnum > 0)
{
ccv_array_t* const outgoings = node->outgoings;
for (k = 0; k < outgoings->rnum; k++)
{
const int outgoing_d = *(int*)ccv_array_get(outgoings, k);
if (outgoing_d >= exec_rnum)
continue;
int l;
// We cannot avoid the ones that visited, because these may not contain all the deps.
ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, idx, &one);
for (l = 0; l < buf->rnum; l++)
ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, *(int*)ccv_array_get(buf, l), &one);
}
}
}
} ccv_nnc_graph_visit_endfor
ccv_nnc_exec_dep_t exec_dep = _ccv_nnc_exec_dep_new(graph, visit, exec_rnum, maskbit);
// Now go from outputs to inputs, unmark visited ones.
ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
if (idx < exec_rnum)
maskbit[idx >> 5] &= ~(1u << (idx & 0x1f));
} ccv_nnc_graph_visit_endfor
ccv_nnc_graph_visit_free(visit);
#undef for_block
// Go through visited backward execs, remove the ones that has no dependency on any replaced backward execs.
for (j = 0; j < visited_backward_execs->rnum;)
{
Expand All @@ -608,17 +706,15 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
++j;
continue;
}
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx);
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx);
if (!vector)
vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find.
int flag = 0;
#define for_block(x, val) \
do { \
if (((uint8_t*)val)[0] != 0) \
if (ccv_array_contain_int(replaced_backward_execs, x)) \
flag = 1; \
} while (0)
if (vector)
CCV_SPARSE_VECTOR_FOREACH(exec_dep, vector, for_block);
#undef for_block
for (k = 0; !flag && k < replaced_backward_execs->rnum; k++)
{
const int d = *(int*)ccv_array_get(replaced_backward_execs, k);
flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d);
}
if (!flag)
{
if (j < visited_backward_execs->rnum - 1)
Expand All @@ -632,17 +728,15 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
for (j = 0; j < replaced_backward_execs->rnum; j++)
{
const int idx = *(int*)ccv_array_get(replaced_backward_execs, j);
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx);
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx);
if (!vector)
vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find.
int flag = 0;
#define for_block(x, val) \
do { \
if (((uint8_t*)val)[0] != 0) \
if (ccv_array_contain_int(visited_backward_execs, x)) \
flag = 1; \
} while (0)
if (vector)
CCV_SPARSE_VECTOR_FOREACH(exec_dep, vector, for_block);
#undef for_block
for (k = 0; !flag && k < visited_backward_execs->rnum; k++)
{
const int d = *(int*)ccv_array_get(visited_backward_execs, k);
flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d);
}
// If this one has no parents that is within the visited_backward_execs, it is a good place for us to add all its parents as dependency for input_execs.
if (!flag)
{
Expand All @@ -662,7 +756,7 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
}
}
}
ccv_matrix_free(exec_dep);
ccv_nnc_exec_dep_free(exec_dep);
// Go through all exec, free ones that doesn't have output used.
// Reuse this array because it is not useful any more.
ccv_array_t* forward_pass_inputs = visited_backward_execs;
Expand Down Expand Up @@ -784,7 +878,6 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
kh_destroy(ccv_cnnp_tensor_symbol_set, parameters_or_internals);
ccfree(max_outputs);
ccv_array_free(buf);
ccv_array_free(newly_used_outputs);
ccv_array_free(parameters);
ccv_array_free(parameter_ids);
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/ccv_nnc_symbolic_graph_chain_decomposition.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ ccv_nnc_exec_dep_t ccv_nnc_exec_dep_new(const ccv_nnc_symbolic_graph_t* const gr
} while (0)
int buf_size;
ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) {
if (node->flags & CCV_NNC_GRAPH_EXEC_DEAD)
continue;
buf_size = 0; /* save all its parent deps to this buffer */
ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, idx);
if (vector)
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/ccv_nnc_symbolic_graph_memory_reduction.c
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void ccv_nnc_symbolic_graph_memory_reduction(ccv_nnc_symbolic_graph_t* const gra
continue;
}
// Check dependencies, if there is a dependency from y node to dd, dd cannot be source.
const int checked = ccv_nnc_exec_dep_check(exec_deps, dd, ddd);
const int checked = ccv_nnc_exec_dep_check(exec_deps, dd, 0, ddd);
if (checked)
flag = 1;
}
Expand Down

0 comments on commit ce93a48

Please sign in to comment.