-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert-micro] Introduce training configure tool (#13593)
This pr supports training configure tool in onert-micro. ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
- Loading branch information
1 parent
aa86e3d
commit 1a40ee3
Showing
14 changed files
with
1,638 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
message(STATUS "START Training Config Tool") | ||
|
||
add_definitions(-DOM_MEMORY_ESTIMATE) | ||
|
||
set(TRAIN_CONFIG_TOOL_SRC | ||
TrainingConfigureTool.cpp | ||
src/SparseBackpropagationHandler.cpp | ||
src/TensorRankSparseBackpropagationHandler.cpp | ||
src/TrainingConfigureFileHandler.cpp | ||
src/TrainingDriverHandler.cpp | ||
src/SparseBackpropagationHelper.cpp) | ||
|
||
add_executable(train_config_tool ${TRAIN_CONFIG_TOOL_SRC}) | ||
|
||
# This variable is needed to separate standalone interpreter libraries from the libraries used in tool | ||
set(CUSTOM_OM_SUFFIX "_train_config_tool") | ||
add_subdirectory(${NNAS_PROJECT_SOURCE_DIR}/onert-micro/onert-micro ${CMAKE_CURRENT_BINARY_DIR}/onert-micro) | ||
|
||
target_include_directories(train_config_tool PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/onert_micro/include") | ||
target_include_directories(train_config_tool PUBLIC "include") | ||
target_link_libraries(train_config_tool PUBLIC onert_micro_interpreter) | ||
target_include_directories(train_config_tool PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/onert_micro/include") | ||
target_link_libraries(train_config_tool PUBLIC onert_micro_training_interpreter) | ||
|
||
install(TARGETS train_config_tool DESTINATION bin) | ||
|
||
message(STATUS "DONE Training Config Tool") |
143 changes: 143 additions & 0 deletions
143
onert-micro/training-configure-tool/TrainingConfigureTool.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "include/SparseBackpropagationHandler.h" | ||
#include "include/TensorRankSparseBackpropagationHandler.h" | ||
|
||
#include "TrainingDriverHandler.h" | ||
|
||
#include <iostream> | ||
|
||
int entry(int argc, char **argv) | ||
{ | ||
if (argc != 9 and argc != 10) | ||
{ | ||
std::cerr << "Two variant of usage with and without wof file: " << argv[0] | ||
<< " <path/to/circle/model> " | ||
" optional(<path/to/wof/file>) <path/to/save/train/config/result> " | ||
"<path/to/input/train_data> " | ||
"<path/to/input/target_train_data> " | ||
"<path/to/input/test_data> " | ||
"<path/to/input/target_test_data>" | ||
"num_of_train_smpl " | ||
"num_of_test_smpl\n"; | ||
return EXIT_FAILURE; | ||
} | ||
|
||
training_configure_tool::TrainData train_data; | ||
|
||
if (argc == 10) | ||
{ | ||
train_data.circle_model_path = argv[1]; | ||
train_data.wof_file_path = argv[2]; | ||
train_data.output_tool_file_path = argv[3]; | ||
train_data.input_input_train_data_path = argv[4]; | ||
train_data.input_target_train_data_path = argv[5]; | ||
train_data.input_input_test_data_path = argv[6]; | ||
train_data.input_target_test_data_path = argv[7]; | ||
train_data.num_train_data_samples = atoi(argv[8]); | ||
train_data.num_test_data_samples = atoi(argv[9]); | ||
} | ||
else if (argc == 9) | ||
{ | ||
train_data.circle_model_path = argv[1]; | ||
train_data.output_tool_file_path = argv[2]; | ||
train_data.input_input_train_data_path = argv[3]; | ||
train_data.input_target_train_data_path = argv[4]; | ||
train_data.input_input_test_data_path = argv[5]; | ||
train_data.input_target_test_data_path = argv[6]; | ||
train_data.num_train_data_samples = atoi(argv[7]); | ||
train_data.num_test_data_samples = atoi(argv[8]); | ||
} | ||
else | ||
{ | ||
throw std::runtime_error("Unknown commands number\n"); | ||
} | ||
|
||
// Configure training mode | ||
onert_micro::OMConfig config; | ||
|
||
// Set user defined training settings | ||
const uint32_t training_epochs = 25; | ||
const float lambda = 0.001f; | ||
const uint32_t BATCH_SIZE = 64; | ||
const uint32_t num_train_layers = 0; | ||
const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY; | ||
const onert_micro::OMTrainOptimizer train_optimizer = onert_micro::ADAM; | ||
const float beta = 0.9; | ||
const float beta_squares = 0.999; | ||
const float epsilon = 1e-07; | ||
|
||
config.train_mode = true; | ||
{ | ||
onert_micro::OMTrainingContext train_context; | ||
train_context.batch_size = BATCH_SIZE; | ||
train_context.num_of_train_layers = num_train_layers; | ||
train_context.learning_rate = lambda; | ||
train_context.loss = loss; | ||
train_context.optimizer = train_optimizer; | ||
train_context.beta = beta; | ||
train_context.beta_squares = beta_squares; | ||
train_context.epsilon = epsilon; | ||
train_context.epochs = training_epochs; | ||
|
||
config.training_context = train_context; | ||
} | ||
|
||
train_data.metrics_to_check_best_config = onert_micro::CROSS_ENTROPY_METRICS; | ||
train_data.memory_above_restriction = 300000; | ||
train_data.acceptable_diff = 0.02; | ||
// Find sparse backpropagation best configure | ||
std::unordered_set<uint16_t> best_trainable_op_indexes; | ||
training_configure_tool::findBestTrainableOpIndexes(config, train_data, | ||
best_trainable_op_indexes); | ||
|
||
// Find the best train tensors ranks | ||
training_configure_tool::TrainConfigFileData config_result; | ||
auto res = training_configure_tool::findBestSparseBackpropagationTensorsRanks( | ||
config, train_data, best_trainable_op_indexes, config_result.trainable_op_indexes_with_ranks); | ||
|
||
// Save result into file | ||
assert(!config_result.trainable_op_indexes_with_ranks.empty()); | ||
training_configure_tool::createResultFile(config_result, train_data.output_tool_file_path); | ||
|
||
return EXIT_SUCCESS; | ||
} | ||
|
||
int entry(int argc, char **argv); | ||
|
||
#ifdef NDEBUG | ||
int main(int argc, char **argv) | ||
{ | ||
try | ||
{ | ||
return entry(argc, argv); | ||
} | ||
catch (const std::exception &e) | ||
{ | ||
std::cerr << "ERROR: " << e.what() << std::endl; | ||
} | ||
|
||
return 255; | ||
} | ||
#else // NDEBUG | ||
int main(int argc, char **argv) | ||
{ | ||
// NOTE main does not catch internal exceptions for debug build to make it easy to | ||
// check the stacktrace with a debugger | ||
return entry(argc, argv); | ||
} | ||
#endif // !NDEBUG |
39 changes: 39 additions & 0 deletions
39
onert-micro/training-configure-tool/include/SparseBackpropagationHandler.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef ONERT_MICRO_TRAINING_CONFIG_TOOL_SPARSE_BACKPROPAGATION_HANDLER | ||
#define ONERT_MICRO_TRAINING_CONFIG_TOOL_SPARSE_BACKPROPAGATION_HANDLER | ||
|
||
#include "OMStatus.h" | ||
#include "OMConfig.h" | ||
#include "TrainConfigData.h" | ||
#include "TrainingConfigureFileHandler.h" | ||
|
||
#include <vector> | ||
|
||
namespace training_configure_tool | ||
{ | ||
|
||
/* | ||
* Method to find the most trainable (which gets the best metric result) operators indexes. | ||
*/ | ||
onert_micro::OMStatus | ||
findBestTrainableOpIndexes(onert_micro::OMConfig &config, TrainData &train_data, | ||
std::unordered_set<uint16_t> &best_trainable_op_indexes); | ||
|
||
} // namespace training_configure_tool | ||
|
||
#endif // ONERT_MICRO_TRAINING_CONFIG_TOOL_SPARSE_BACKPROPAGATION_HANDLER |
61 changes: 61 additions & 0 deletions
61
onert-micro/training-configure-tool/include/SparseBackpropagationHelper.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef ONERT_MICRO_TRAINING_CONFIG_TOOL_SPARSE_BACKPROPAGATION_HELPER | ||
#define ONERT_MICRO_TRAINING_CONFIG_TOOL_SPARSE_BACKPROPAGATION_HELPER | ||
|
||
#include "OMStatus.h" | ||
#include "OMConfig.h" | ||
#include "TrainConfigData.h" | ||
|
||
#include <vector> | ||
#include <unordered_set> | ||
|
||
namespace training_configure_tool | ||
{ | ||
|
||
// Find is left train result is better then right in terms of metric result and memory consumptions. | ||
// acceptable_diff - acceptable difference in metric values in order to select the best result in | ||
// memory. | ||
bool cmpTrainResults(const training_configure_tool::TrainResult &left, | ||
const training_configure_tool::TrainResult &right, | ||
const float acceptable_diff); | ||
|
||
// To find all trainable ops indexes in the model - initial_train_op_indexes | ||
std::unordered_set<uint16_t> findAllTrainableOps(const char *circle_model_path); | ||
|
||
// To generate all possible sets from initial_train_op_indexes | ||
std::vector<std::unordered_set<uint16_t>> | ||
generateAllPossibleOpIndexesSets(const std::unordered_set<uint16_t> &initial_train_op_indexes); | ||
|
||
// Remove operations indexes sets with peak memory footprint greater then given restriction: | ||
// 1 - Run train interpreter with all this sets with single train sample and single test sample | ||
// to obtain approximately peak memory footprint for each set. | ||
// 2 - Cut according to max peak memory. | ||
std::vector<std::unordered_set<uint16_t>> selectOpIndexesSetsAccordingToMemoryRestriction( | ||
const std::vector<std::unordered_set<uint16_t>> &op_indexes_sets, onert_micro::OMConfig config, | ||
training_configure_tool::TrainData train_data); | ||
|
||
// Find All combinations with ranks for current selected op indexes. | ||
// Return vector of all possible combinations of train rank for every op. | ||
std::vector<std::unordered_map<uint16_t, OpTrainableRank>> | ||
findAllTensorsRanksCombinations(const std::unordered_set<uint16_t> &selected_op_indexes, | ||
onert_micro::OMConfig config, | ||
training_configure_tool::TrainData train_data); | ||
|
||
} // namespace training_configure_tool | ||
|
||
#endif // ONERT_MICRO_TRAINING_CONFIG_TOOL_SPARSE_BACKPROPAGATION_HELPER |
44 changes: 44 additions & 0 deletions
44
onert-micro/training-configure-tool/include/TensorRankSparseBackpropagationHandler.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef ONERT_MICRO_TRAINING_CONFIG_TOOL_TENSOR_RANK_SPARSE_BACKPROPAGATION_HANDLER | ||
#define ONERT_MICRO_TRAINING_CONFIG_TOOL_TENSOR_RANK_SPARSE_BACKPROPAGATION_HANDLER | ||
|
||
#include "OMStatus.h" | ||
#include "OMConfig.h" | ||
#include "TrainConfigData.h" | ||
#include "TrainingConfigureFileHandler.h" | ||
|
||
#include <vector> | ||
#include <unordered_map> | ||
|
||
namespace training_configure_tool | ||
{ | ||
|
||
/* | ||
* Method to find the most trainable (which gets the best metric result and less peak memory) train | ||
* ranks for every operation in selected operators indexes. Note: Train rank - this is an indicator | ||
* of how much data of the current operation we will train (for example, the entire operation, only | ||
* the bias, only the upper half, and so on) | ||
*/ | ||
onert_micro::OMStatus findBestSparseBackpropagationTensorsRanks( | ||
onert_micro::OMConfig &config, TrainData &train_data, | ||
const std::unordered_set<uint16_t> &selected_op_indexes, | ||
std::unordered_map<uint16_t, OpTrainableRank> &best_train_ranks); | ||
|
||
} // namespace training_configure_tool | ||
|
||
#endif // ONERT_MICRO_TRAINING_CONFIG_TOOL_TENSOR_RANK_SPARSE_BACKPROPAGATION_HANDLER |
Oops, something went wrong.