Skip to content

Commit

Permalink
[onert/frontend] Add circle training info loader (#12400)
Browse files Browse the repository at this point in the history
This PR adds circle training information loader which creates a TrainingInfo from a buffer based on the flatbuffer schema.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
  • Loading branch information
zetwhite authored Jan 5, 2024
1 parent 7ad4079 commit f165eaf
Show file tree
Hide file tree
Showing 5 changed files with 1,051 additions and 0 deletions.
15 changes: 15 additions & 0 deletions runtime/onert/frontend/circle_traininfo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
if(NOT BUILD_CIRCLE_LOADER)
return()
endif()

nnfw_find_package(FlatBuffers REQUIRED)

set(TRAININFO_SOURCES src/traininfo_loader.cc)

add_library(traininfo_loader STATIC ${TRAININFO_SOURCES})
set_target_properties(traininfo_loader PROPERTIES POSITION_INDEPENDENT_CODE ON)

target_include_directories(traininfo_loader PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)

target_link_libraries(traininfo_loader PRIVATE onert_core)
target_link_libraries(traininfo_loader PRIVATE flatbuffers::flatbuffers)
119 changes: 119 additions & 0 deletions runtime/onert/frontend/circle_traininfo/circle_traininfo.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// This file is from https://github.sec.samsung.net/one-project/circle-x/blob/91fc9550d914a3c62bf41e62e007e1a2f5f37b59/spec/circle_training.fbs

// Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

namespace circle;

// file_identifier for circle training.
// The last two characters is identifier version.
// This identifier number should be same if it does not break compatibility.
file_identifier "CTR1";

// Revision History
//
// Version Major.Minor
//
// Major version is schema version.
// We keep schema version if it is compatible.
// Minor version is for human communication.
// It will not be stored in circle_training.
//
// Note: The scheme version is bumped up as it gets fields
// while identifier version is not changed unless compatibility is broken.
//
// Version 0.1: Initial version

// File extension of any written files.
file_extension "circletr";

// --------
// Optimizer: It defines fields about optimizer for training.
//
// It uses the well-known names for optimizer and its parameters.
// If something is not clear, please refer to
// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers
// --------

enum Optimizer : byte {
SGD = 0,
ADAM = 1,
}

union OptimizerOptions {
SGDOptions,
AdamOptions,
}

table SGDOptions {
learning_rate:float;
}

table AdamOptions {
learning_rate:float;
beta_1:float;
beta_2:float;
epsilon:float;
}

// --------
// Loss Function: It is about loss function used during training.
//
// It uses the well-known names for loss function and its parameters.
// If something is not clear, please refer to
// https://www.tensorflow.org/api_docs/python/tf/keras/losses
// --------

enum LossFn : byte {
SPARSE_CATEGORICAL_CROSSENTROPY = 0,
CATEGORICAL_CROSSENTROPY = 1,
MEAN_SQUARED_ERROR = 2,
}

union LossFnOptions {
SparseCategoricalCrossentropyOptions,
CategoricalCrossentropyOptions,
MeanSquaredErrorOptions,
}

table SparseCategoricalCrossentropyOptions {
from_logits: bool;
}

table CategoricalCrossentropyOptions {
from_logits: bool;
}

table MeanSquaredErrorOptions {
}

// --------
// Model Metadata
//
// --------

table ModelTraining {
// Version of the schema.
version:uint;
// For training
optimizer: Optimizer;
optimizer_opt: OptimizerOptions;
lossfn: LossFn;
lossfn_opt: LossFnOptions;
epochs: int;
batch_size: int;
}

root_type ModelTraining;
38 changes: 38 additions & 0 deletions runtime/onert/frontend/circle_traininfo/include/traininfo_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __CIRCLE_TRAININFO_LOADER_H__
#define __CIRCLE_TRAININFO_LOADER_H__

#include "ir/train/TrainingInfo.h"
#include "ir/Model.h"

namespace onert
{
namespace train
{
namespace traininfo_loader
{

static constexpr char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING";

std::unique_ptr<ir::train::TrainingInfo> loadTrainingInfo(const uint8_t *buffer, const size_t size);

} // namespace traininfo_loader
} // namespace train
} // namespace onert

#endif // __CIRCLE_TRAININFO_LOADER_H__
Loading

0 comments on commit f165eaf

Please sign in to comment.