Skip to content

Commit

Permalink
Allow to disable gradient checkpointing inside a gradient checkpointi…
Browse files Browse the repository at this point in the history
…ng node.
  • Loading branch information
liuliu committed Oct 21, 2024
1 parent b361b30 commit 20d998d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 27 deletions.
34 changes: 31 additions & 3 deletions lib/nnc/_ccv_cnnp_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ static inline ccv_nnc_tensor_symbol_t ccv_cnnp_parameter_from_indice(ccv_cnnp_mo
}

typedef struct {
int record;
ccv_array_t* tensor_symbols;
void* old_tensor_symbol_new_hook_context;
ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook;
Expand All @@ -301,15 +302,17 @@ typedef struct {
static void _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name)
{
ccv_cnnp_model_gradient_checkpoint_build_context_t* const build_context = (ccv_cnnp_model_gradient_checkpoint_build_context_t*)context;
ccv_array_push(build_context->tensor_symbols, &symbol);
if (build_context->record)
ccv_array_push(build_context->tensor_symbols, &symbol);
if (build_context->old_tensor_symbol_new_hook)
build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name);
}

static void _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name)
{
ccv_cnnp_model_gradient_checkpoint_build_context_t* const build_context = (ccv_cnnp_model_gradient_checkpoint_build_context_t*)context;
ccv_array_push(build_context->tensor_symbols, &symbol);
if (build_context->record)
ccv_array_push(build_context->tensor_symbols, &symbol);
if (build_context->old_tensor_symbol_alias_new_hook)
build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name);
}
Expand All @@ -323,13 +326,14 @@ static inline void ccv_cnnp_model_build(ccv_cnnp_model_t* const self, ccv_nnc_sy
build_data->is_trainable = self->is_trainable;
if (self->name && self->name[0] != '\0')
ccv_cnnp_model_push(self, build_data->model_sequence);
if (self->gradient_checkpointing && !build_data->is_gradient_checkpointing)
if (self->gradient_checkpointing == 1 && !build_data->is_gradient_checkpointing)
{
build_data->is_gradient_checkpointing = 1;
// Prepare to record gradient checkpoint. We will log the build function, inputs, what are the tensors / graph execs we created.
if (!build_data->gradient_checkpoints)
build_data->gradient_checkpoints = ccv_array_new(sizeof(ccv_cnnp_model_gradient_checkpoint_t), 0, 0);
ccv_cnnp_model_gradient_checkpoint_build_context_t build_context = {
.record = 1,
.tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
};
build_context.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_new_hook, &build_context, &build_context.old_tensor_symbol_new_hook);
Expand Down Expand Up @@ -362,6 +366,28 @@ static inline void ccv_cnnp_model_build(ccv_cnnp_model_t* const self, ccv_nnc_sy
ccv_array_push(build_data->gradient_checkpoints, &checkpoint);
build_data->is_gradient_checkpointing = 0;
} else {
// If we want to disable gradient checkpointing for this model, we simply not log any new tensors created here so there is no mapping for these.
int old_record;
ccv_cnnp_model_gradient_checkpoint_build_context_t* build_context = 0;
if (build_data->is_gradient_checkpointing)
{
if (self->gradient_checkpointing == -1)
{
ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook;
build_context = ccv_nnc_tensor_symbol_new_hook(graph, 0, 0, &old_tensor_symbol_new_hook);
// Set back the build_context.
ccv_nnc_tensor_symbol_new_hook(graph, old_tensor_symbol_new_hook, build_context, 0);
old_record = build_context->record;
build_context->record = 0;
} else if (self->gradient_checkpointing == 1) { // Force to turn on gradient checkpointing if it is inside a gradient checkpointing = -1.
ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook;
build_context = ccv_nnc_tensor_symbol_new_hook(graph, 0, 0, &old_tensor_symbol_new_hook);
// Set back the build_context.
ccv_nnc_tensor_symbol_new_hook(graph, old_tensor_symbol_new_hook, build_context, 0);
old_record = build_context->record;
build_context->record = 1;
}
}
// No push checkpoint, easy.
if (outputs && output_size)
{
Expand All @@ -370,6 +396,8 @@ static inline void ccv_cnnp_model_build(ccv_cnnp_model_t* const self, ccv_nnc_sy
memcpy(self->outputs, outputs, sizeof(ccv_nnc_tensor_symbol_t) * output_size);
} else
self->isa->build(self, graph, inputs, input_size, self->outputs, self->output_size);
if (build_context) // Restore previous state even if our gradient checkpointing controlled whether to turn on recording or not.
build_context->record = old_record;
}
// Skip if there is none. This helps to load parameters to a different model when only changes non-parameterized settings (add reshapes, permutations etc).
// If it is named, we have to push too.
Expand Down
54 changes: 30 additions & 24 deletions lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,31 @@ typedef struct {
} ccv_nnc_graph_exec_symbol_reverse_t;

typedef struct {
ccv_array_t* tensor_symbols;
void* old_tensor_symbol_new_hook_context;
ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook;
void* old_tensor_symbol_alias_new_hook_context;
ccv_nnc_tensor_symbol_alias_new_hook_f old_tensor_symbol_alias_new_hook;
ccv_cnnp_model_gradient_checkpoint_build_context_t tensor_context;
ccv_array_t* graph_exec_symbols;
ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook;
void* old_graph_exec_symbol_new_hook_context;
ccv_array_t* all_tensor_symbols;
} ccv_cnnp_gradient_checkpoint_build_t;

static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name)
{
ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context;
ccv_array_push(build_context->tensor_symbols, &symbol);
if (build_context->old_tensor_symbol_new_hook)
build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name);
if (build_context->tensor_context.record)
ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol);
ccv_array_push(build_context->all_tensor_symbols, &symbol);
if (build_context->tensor_context.old_tensor_symbol_new_hook)
build_context->tensor_context.old_tensor_symbol_new_hook(build_context->tensor_context.old_tensor_symbol_new_hook_context, symbol, info, name);
}

static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name)
{
ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context;
ccv_array_push(build_context->tensor_symbols, &symbol);
if (build_context->old_tensor_symbol_alias_new_hook)
build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name);
if (build_context->tensor_context.record)
ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol);
ccv_array_push(build_context->all_tensor_symbols, &symbol);
if (build_context->tensor_context.old_tensor_symbol_alias_new_hook)
build_context->tensor_context.old_tensor_symbol_alias_new_hook(build_context->tensor_context.old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name);
}

static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name)
Expand Down Expand Up @@ -289,11 +290,15 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
ccv_array_add_unique_int(visited_backward_execs, ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, j))->d);
#undef visitor
ccv_cnnp_gradient_checkpoint_build_t build = {
.tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
.tensor_context = {
.record = 1,
.tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
},
.graph_exec_symbols = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0),
.all_tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
};
build.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.old_tensor_symbol_new_hook);
build.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.old_tensor_symbol_alias_new_hook);
build.tensor_context.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.tensor_context.old_tensor_symbol_new_hook);
build.tensor_context.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.tensor_context.old_tensor_symbol_alias_new_hook);
build.old_graph_exec_symbol_new_hook_context = ccv_nnc_graph_exec_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook, &build, &build.old_graph_exec_symbol_new_hook);
ccv_array_clear(parameters);
ccv_array_clear(parameter_ids);
Expand Down Expand Up @@ -335,8 +340,8 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
kh_destroy(ccv_cnnp_model_name_bank, model_sequence.bank);
if (model_sequence.sequences)
ccv_array_free(model_sequence.sequences);
ccv_nnc_tensor_symbol_new_hook(graph, build.old_tensor_symbol_new_hook, build.old_tensor_symbol_new_hook_context, 0);
ccv_nnc_tensor_symbol_alias_new_hook(graph, build.old_tensor_symbol_alias_new_hook, build.old_tensor_symbol_alias_new_hook_context, 0);
ccv_nnc_tensor_symbol_new_hook(graph, build.tensor_context.old_tensor_symbol_new_hook, build.tensor_context.old_tensor_symbol_new_hook_context, 0);
ccv_nnc_tensor_symbol_alias_new_hook(graph, build.tensor_context.old_tensor_symbol_alias_new_hook, build.tensor_context.old_tensor_symbol_alias_new_hook_context, 0);
ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0);
for (j = 0; j < parameter_ids->rnum; j++)
ccfree(*(char**)ccv_array_get(parameter_ids, j));
Expand All @@ -346,9 +351,9 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
// Reuse existing one.
kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
for (j = 0; j < build.tensor_symbols->rnum; j++)
for (j = 0; j < build.tensor_context.tensor_symbols->rnum; j++)
{
const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d;
const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d;
if (idx < 0)
continue;
if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals))
Expand Down Expand Up @@ -442,10 +447,10 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
ccv_array_push(newly_input_execs, &symbol);
}
// Build a map between old tensor symbols and new tensor symbols.
assert(build.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum);
assert(build.tensor_context.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum);
// Build a map to potentially map from old input to new input.
kh_clear(ccv_cnnp_tensor_symbol_map, symbol_map);
for (j = 0, k = 0; j < build.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum;)
for (j = 0, k = 0; j < build.tensor_context.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum;)
{
const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d;
if (from_d < 0) // This is removed, move to the next one.
Expand All @@ -454,7 +459,7 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
++k;
continue;
}
const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d;
const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d;
assert(to_d >= 0);
int from_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, from_d) != kh_end(parameters_or_internals);
int to_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, to_d) != kh_end(parameters_or_internals);
Expand Down Expand Up @@ -746,9 +751,9 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
}
}
// Free unused tensor symbols.
for (j = 0; j < build.tensor_symbols->rnum; j++)
for (j = 0; j < build.all_tensor_symbols->rnum; j++)
{
const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j));
const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.all_tensor_symbols, j));
if (ccv_array_contain_int(newly_used_outputs, symbol->d) || ccv_array_contain_int(forward_pass_inputs, symbol->d))
continue;
if (tensor_symbol_info[symbol->d].alias_ref > 0)
Expand All @@ -769,8 +774,9 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
ccv_nnc_graph_exec_symbol_set_flags(graph, *symbol, CCV_NNC_GRAPH_EXEC_DISABLE_OPT);
}
// Free these newly created execs and tensor symbols.
ccv_array_free(build.tensor_symbols);
ccv_array_free(build.tensor_context.tensor_symbols);
ccv_array_free(build.graph_exec_symbols);
ccv_array_free(build.all_tensor_symbols);
}
kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map);
kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
Expand Down

0 comments on commit 20d998d

Please sign in to comment.