diff --git a/lib/nnc/_ccv_cnnp_model.h b/lib/nnc/_ccv_cnnp_model.h index d789ad161..f149be5f9 100644 --- a/lib/nnc/_ccv_cnnp_model.h +++ b/lib/nnc/_ccv_cnnp_model.h @@ -291,6 +291,7 @@ static inline ccv_nnc_tensor_symbol_t ccv_cnnp_parameter_from_indice(ccv_cnnp_mo } typedef struct { + int record; ccv_array_t* tensor_symbols; void* old_tensor_symbol_new_hook_context; ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook; @@ -301,7 +302,8 @@ typedef struct { static void _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name) { ccv_cnnp_model_gradient_checkpoint_build_context_t* const build_context = (ccv_cnnp_model_gradient_checkpoint_build_context_t*)context; - ccv_array_push(build_context->tensor_symbols, &symbol); + if (build_context->record) + ccv_array_push(build_context->tensor_symbols, &symbol); if (build_context->old_tensor_symbol_new_hook) build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name); } @@ -309,7 +311,8 @@ static void _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_new_hook(void* con static void _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name) { ccv_cnnp_model_gradient_checkpoint_build_context_t* const build_context = (ccv_cnnp_model_gradient_checkpoint_build_context_t*)context; - ccv_array_push(build_context->tensor_symbols, &symbol); + if (build_context->record) + ccv_array_push(build_context->tensor_symbols, &symbol); if (build_context->old_tensor_symbol_alias_new_hook) build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); } @@ -323,13 +326,14 @@ static inline void ccv_cnnp_model_build(ccv_cnnp_model_t* const self, ccv_nnc_sy build_data->is_trainable = self->is_trainable; if (self->name && self->name[0] != '\0') ccv_cnnp_model_push(self, build_data->model_sequence); - if (self->gradient_checkpointing && !build_data->is_gradient_checkpointing) + if (self->gradient_checkpointing == 1 && !build_data->is_gradient_checkpointing) { build_data->is_gradient_checkpointing = 1; // Prepare to record gradient checkpoint. We will log the build function, inputs, what are the tensors / graph execs we created. if (!build_data->gradient_checkpoints) build_data->gradient_checkpoints = ccv_array_new(sizeof(ccv_cnnp_model_gradient_checkpoint_t), 0, 0); ccv_cnnp_model_gradient_checkpoint_build_context_t build_context = { + .record = 1, .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), }; build_context.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_tensor_symbol_new_hook, &build_context, &build_context.old_tensor_symbol_new_hook); @@ -362,6 +366,28 @@ static inline void ccv_cnnp_model_build(ccv_cnnp_model_t* const self, ccv_nnc_sy ccv_array_push(build_data->gradient_checkpoints, &checkpoint); build_data->is_gradient_checkpointing = 0; } else { + // If we want to disable gradient checkpointing for this model, we simply not log any new tensors created here so there is no mapping for these. + int old_record; + ccv_cnnp_model_gradient_checkpoint_build_context_t* build_context = 0; + if (build_data->is_gradient_checkpointing) + { + if (self->gradient_checkpointing == -1) + { + ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook; + build_context = ccv_nnc_tensor_symbol_new_hook(graph, 0, 0, &old_tensor_symbol_new_hook); + // Set back the build_context. + ccv_nnc_tensor_symbol_new_hook(graph, old_tensor_symbol_new_hook, build_context, 0); + old_record = build_context->record; + build_context->record = 0; + } else if (self->gradient_checkpointing == 1) { // Force to turn on gradient checkpointing if it is inside a gradient checkpointing = -1. + ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook; + build_context = ccv_nnc_tensor_symbol_new_hook(graph, 0, 0, &old_tensor_symbol_new_hook); + // Set back the build_context. + ccv_nnc_tensor_symbol_new_hook(graph, old_tensor_symbol_new_hook, build_context, 0); + old_record = build_context->record; + build_context->record = 1; + } + } // No push checkpoint, easy. if (outputs && output_size) { @@ -370,6 +396,8 @@ static inline void ccv_cnnp_model_build(ccv_cnnp_model_t* const self, ccv_nnc_sy memcpy(self->outputs, outputs, sizeof(ccv_nnc_tensor_symbol_t) * output_size); } else self->isa->build(self, graph, inputs, input_size, self->outputs, self->output_size); + if (build_context) // Restore previous state even if our gradient checkpointing controlled whether to turn on recording or not. + build_context->record = old_record; } // Skip if there is none. This helps to load parameters to a different model when only changes non-parameterized settings (add reshapes, permutations etc). // If it is named, we have to push too. diff --git a/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c b/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c index 20351e998..e753a0328 100644 --- a/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c +++ b/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c @@ -36,30 +36,31 @@ typedef struct { } ccv_nnc_graph_exec_symbol_reverse_t; typedef struct { - ccv_array_t* tensor_symbols; - void* old_tensor_symbol_new_hook_context; - ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook; - void* old_tensor_symbol_alias_new_hook_context; - ccv_nnc_tensor_symbol_alias_new_hook_f old_tensor_symbol_alias_new_hook; + ccv_cnnp_model_gradient_checkpoint_build_context_t tensor_context; ccv_array_t* graph_exec_symbols; ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook; void* old_graph_exec_symbol_new_hook_context; + ccv_array_t* all_tensor_symbols; } ccv_cnnp_gradient_checkpoint_build_t; static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name) { ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; - ccv_array_push(build_context->tensor_symbols, &symbol); - if (build_context->old_tensor_symbol_new_hook) - build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name); + if (build_context->tensor_context.record) + ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol); + ccv_array_push(build_context->all_tensor_symbols, &symbol); + if (build_context->tensor_context.old_tensor_symbol_new_hook) + build_context->tensor_context.old_tensor_symbol_new_hook(build_context->tensor_context.old_tensor_symbol_new_hook_context, symbol, info, name); } static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name) { ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; - ccv_array_push(build_context->tensor_symbols, &symbol); - if (build_context->old_tensor_symbol_alias_new_hook) - build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); + if (build_context->tensor_context.record) + ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol); + ccv_array_push(build_context->all_tensor_symbols, &symbol); + if (build_context->tensor_context.old_tensor_symbol_alias_new_hook) + build_context->tensor_context.old_tensor_symbol_alias_new_hook(build_context->tensor_context.old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); } static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name) @@ -289,11 +290,15 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c ccv_array_add_unique_int(visited_backward_execs, ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, j))->d); #undef visitor ccv_cnnp_gradient_checkpoint_build_t build = { - .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), + .tensor_context = { + .record = 1, + .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), + }, .graph_exec_symbols = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0), + .all_tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), }; - build.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.old_tensor_symbol_new_hook); - build.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.old_tensor_symbol_alias_new_hook); + build.tensor_context.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.tensor_context.old_tensor_symbol_new_hook); + build.tensor_context.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.tensor_context.old_tensor_symbol_alias_new_hook); build.old_graph_exec_symbol_new_hook_context = ccv_nnc_graph_exec_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook, &build, &build.old_graph_exec_symbol_new_hook); ccv_array_clear(parameters); ccv_array_clear(parameter_ids); @@ -335,8 +340,8 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c kh_destroy(ccv_cnnp_model_name_bank, model_sequence.bank); if (model_sequence.sequences) ccv_array_free(model_sequence.sequences); - ccv_nnc_tensor_symbol_new_hook(graph, build.old_tensor_symbol_new_hook, build.old_tensor_symbol_new_hook_context, 0); - ccv_nnc_tensor_symbol_alias_new_hook(graph, build.old_tensor_symbol_alias_new_hook, build.old_tensor_symbol_alias_new_hook_context, 0); + ccv_nnc_tensor_symbol_new_hook(graph, build.tensor_context.old_tensor_symbol_new_hook, build.tensor_context.old_tensor_symbol_new_hook_context, 0); + ccv_nnc_tensor_symbol_alias_new_hook(graph, build.tensor_context.old_tensor_symbol_alias_new_hook, build.tensor_context.old_tensor_symbol_alias_new_hook_context, 0); ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0); for (j = 0; j < parameter_ids->rnum; j++) ccfree(*(char**)ccv_array_get(parameter_ids, j)); @@ -346,9 +351,9 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); // Reuse existing one. kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols); - for (j = 0; j < build.tensor_symbols->rnum; j++) + for (j = 0; j < build.tensor_context.tensor_symbols->rnum; j++) { - const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d; + const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d; if (idx < 0) continue; if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals)) @@ -442,10 +447,10 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c ccv_array_push(newly_input_execs, &symbol); } // Build a map between old tensor symbols and new tensor symbols. - assert(build.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum); + assert(build.tensor_context.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum); // Build a map to potentially map from old input to new input. kh_clear(ccv_cnnp_tensor_symbol_map, symbol_map); - for (j = 0, k = 0; j < build.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum;) + for (j = 0, k = 0; j < build.tensor_context.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum;) { const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d; if (from_d < 0) // This is removed, move to the next one. @@ -454,7 +459,7 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c ++k; continue; } - const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d; + const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d; assert(to_d >= 0); int from_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, from_d) != kh_end(parameters_or_internals); int to_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, to_d) != kh_end(parameters_or_internals); @@ -746,9 +751,9 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c } } // Free unused tensor symbols. - for (j = 0; j < build.tensor_symbols->rnum; j++) + for (j = 0; j < build.all_tensor_symbols->rnum; j++) { - const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j)); + const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.all_tensor_symbols, j)); if (ccv_array_contain_int(newly_used_outputs, symbol->d) || ccv_array_contain_int(forward_pass_inputs, symbol->d)) continue; if (tensor_symbol_info[symbol->d].alias_ref > 0) @@ -769,8 +774,9 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c ccv_nnc_graph_exec_symbol_set_flags(graph, *symbol, CCV_NNC_GRAPH_EXEC_DISABLE_OPT); } // Free these newly created execs and tensor symbols. - ccv_array_free(build.tensor_symbols); + ccv_array_free(build.tensor_context.tensor_symbols); ccv_array_free(build.graph_exec_symbols); + ccv_array_free(build.all_tensor_symbols); } kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map); kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);