diff --git a/runtime/onert/frontend/circle_traininfo/CMakeLists.txt b/runtime/onert/frontend/circle_traininfo/CMakeLists.txt new file mode 100644 index 00000000000..3af8566b9a4 --- /dev/null +++ b/runtime/onert/frontend/circle_traininfo/CMakeLists.txt @@ -0,0 +1,15 @@ +if(NOT BUILD_CIRCLE_LOADER) + return() +endif() + +nnfw_find_package(FlatBuffers REQUIRED) + +set(TRAININFO_SOURCES src/traininfo_loader.cc) + +add_library(traininfo_loader STATIC ${TRAININFO_SOURCES}) +set_target_properties(traininfo_loader PROPERTIES POSITION_INDEPENDENT_CODE ON) + +target_include_directories(traininfo_loader PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) + +target_link_libraries(traininfo_loader PRIVATE onert_core) +target_link_libraries(traininfo_loader PRIVATE flatbuffers::flatbuffers) diff --git a/runtime/onert/frontend/circle_traininfo/circle_traininfo.fbs b/runtime/onert/frontend/circle_traininfo/circle_traininfo.fbs new file mode 100644 index 00000000000..9e919776e57 --- /dev/null +++ b/runtime/onert/frontend/circle_traininfo/circle_traininfo.fbs @@ -0,0 +1,119 @@ +// This file is from https://github.sec.samsung.net/one-project/circle-x/blob/91fc9550d914a3c62bf41e62e007e1a2f5f37b59/spec/circle_training.fbs + +// Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace circle; + +// file_identifier for circle training. +// The last two characters is identifier version. +// This identifier number should be same if it does not break compatibility. +file_identifier "CTR1"; + +// Revision History +// +// Version Major.Minor +// +// Major version is schema version. +// We keep schema version if it is compatible. +// Minor version is for human communication. +// It will not be stored in circle_training. +// +// Note: The scheme version is bumped up as it gets fields +// while identifier version is not changed unless compatibility is broken. +// +// Version 0.1: Initial version + +// File extension of any written files. +file_extension "circletr"; + +// -------- +// Optimizer: It defines fields about optimizer for training. +// +// It uses the well-known names for optimizer and its parameters. +// If something is not clear, please refer to +// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers +// -------- + +enum Optimizer : byte { + SGD = 0, + ADAM = 1, +} + +union OptimizerOptions { + SGDOptions, + AdamOptions, +} + +table SGDOptions { + learning_rate:float; +} + +table AdamOptions { + learning_rate:float; + beta_1:float; + beta_2:float; + epsilon:float; +} + +// -------- +// Loss Function: It is about loss function used during training. +// +// It uses the well-known names for loss function and its parameters. +// If something is not clear, please refer to +// https://www.tensorflow.org/api_docs/python/tf/keras/losses +// -------- + +enum LossFn : byte { + SPARSE_CATEGORICAL_CROSSENTROPY = 0, + CATEGORICAL_CROSSENTROPY = 1, + MEAN_SQUARED_ERROR = 2, +} + +union LossFnOptions { + SparseCategoricalCrossentropyOptions, + CategoricalCrossentropyOptions, + MeanSquaredErrorOptions, +} + +table SparseCategoricalCrossentropyOptions { + from_logits: bool; +} + +table CategoricalCrossentropyOptions { + from_logits: bool; +} + +table MeanSquaredErrorOptions { +} + +// -------- +// Model Metadata +// +// -------- + +table ModelTraining { + // Version of the schema. + version:uint; + // For training + optimizer: Optimizer; + optimizer_opt: OptimizerOptions; + lossfn: LossFn; + lossfn_opt: LossFnOptions; + epochs: int; + batch_size: int; +} + +root_type ModelTraining; diff --git a/runtime/onert/frontend/circle_traininfo/include/traininfo_loader.h b/runtime/onert/frontend/circle_traininfo/include/traininfo_loader.h new file mode 100644 index 00000000000..88e0d010adc --- /dev/null +++ b/runtime/onert/frontend/circle_traininfo/include/traininfo_loader.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_TRAININFO_LOADER_H__ +#define __CIRCLE_TRAININFO_LOADER_H__ + +#include "ir/train/TrainingInfo.h" +#include "ir/Model.h" + +namespace onert +{ +namespace train +{ +namespace traininfo_loader +{ + +static constexpr char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING"; + +std::unique_ptr loadTrainingInfo(const uint8_t *buffer, const size_t size); + +} // namespace traininfo_loader +} // namespace train +} // namespace onert + +#endif // __CIRCLE_TRAININFO_LOADER_H__ diff --git a/runtime/onert/frontend/circle_traininfo/src/circle_traininfo_generated.h b/runtime/onert/frontend/circle_traininfo/src/circle_traininfo_generated.h new file mode 100644 index 00000000000..8da056c453d --- /dev/null +++ b/runtime/onert/frontend/circle_traininfo/src/circle_traininfo_generated.h @@ -0,0 +1,771 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// automatically generated by the FlatBuffers compiler, do not modify +// command: externals/FLATBUFFER-2.0/build/flatc -c ../circle_traininfo.fbs + +#ifndef FLATBUFFERS_GENERATED_CIRCLETRAINING_CIRCLE_H_ +#define FLATBUFFERS_GENERATED_CIRCLETRAINING_CIRCLE_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace circle +{ + +struct SGDOptions; +struct SGDOptionsBuilder; + +struct AdamOptions; +struct AdamOptionsBuilder; + +struct SparseCategoricalCrossentropyOptions; +struct SparseCategoricalCrossentropyOptionsBuilder; + +struct CategoricalCrossentropyOptions; +struct CategoricalCrossentropyOptionsBuilder; + +struct MeanSquaredErrorOptions; +struct MeanSquaredErrorOptionsBuilder; + +struct ModelTraining; +struct ModelTrainingBuilder; + +enum Optimizer : int8_t +{ + Optimizer_SGD = 0, + Optimizer_ADAM = 1, + Optimizer_MIN = Optimizer_SGD, + Optimizer_MAX = Optimizer_ADAM +}; + +inline const Optimizer (&EnumValuesOptimizer())[2] +{ + static const Optimizer values[] = {Optimizer_SGD, Optimizer_ADAM}; + return values; +} + +inline const char *const *EnumNamesOptimizer() +{ + static const char *const names[3] = {"SGD", "ADAM", nullptr}; + return names; +} + +inline const char *EnumNameOptimizer(Optimizer e) +{ + if (flatbuffers::IsOutRange(e, Optimizer_SGD, Optimizer_ADAM)) + return ""; + const size_t index = static_cast(e); + return EnumNamesOptimizer()[index]; +} + +enum OptimizerOptions : uint8_t +{ + OptimizerOptions_NONE = 0, + OptimizerOptions_SGDOptions = 1, + OptimizerOptions_AdamOptions = 2, + OptimizerOptions_MIN = OptimizerOptions_NONE, + OptimizerOptions_MAX = OptimizerOptions_AdamOptions +}; + +inline const OptimizerOptions (&EnumValuesOptimizerOptions())[3] +{ + static const OptimizerOptions values[] = {OptimizerOptions_NONE, OptimizerOptions_SGDOptions, + OptimizerOptions_AdamOptions}; + return values; +} + +inline const char *const *EnumNamesOptimizerOptions() +{ + static const char *const names[4] = {"NONE", "SGDOptions", "AdamOptions", nullptr}; + return names; +} + +inline const char *EnumNameOptimizerOptions(OptimizerOptions e) +{ + if (flatbuffers::IsOutRange(e, OptimizerOptions_NONE, OptimizerOptions_AdamOptions)) + return ""; + const size_t index = static_cast(e); + return EnumNamesOptimizerOptions()[index]; +} + +template struct OptimizerOptionsTraits +{ + static const OptimizerOptions enum_value = OptimizerOptions_NONE; +}; + +template <> struct OptimizerOptionsTraits +{ + static const OptimizerOptions enum_value = OptimizerOptions_SGDOptions; +}; + +template <> struct OptimizerOptionsTraits +{ + static const OptimizerOptions enum_value = OptimizerOptions_AdamOptions; +}; + +bool VerifyOptimizerOptions(flatbuffers::Verifier &verifier, const void *obj, + OptimizerOptions type); +bool VerifyOptimizerOptionsVector(flatbuffers::Verifier &verifier, + const flatbuffers::Vector> *values, + const flatbuffers::Vector *types); + +enum LossFn : int8_t +{ + LossFn_SPARSE_CATEGORICAL_CROSSENTROPY = 0, + LossFn_CATEGORICAL_CROSSENTROPY = 1, + LossFn_MEAN_SQUARED_ERROR = 2, + LossFn_MIN = LossFn_SPARSE_CATEGORICAL_CROSSENTROPY, + LossFn_MAX = LossFn_MEAN_SQUARED_ERROR +}; + +inline const LossFn (&EnumValuesLossFn())[3] +{ + static const LossFn values[] = {LossFn_SPARSE_CATEGORICAL_CROSSENTROPY, + LossFn_CATEGORICAL_CROSSENTROPY, LossFn_MEAN_SQUARED_ERROR}; + return values; +} + +inline const char *const *EnumNamesLossFn() +{ + static const char *const names[4] = {"SPARSE_CATEGORICAL_CROSSENTROPY", + "CATEGORICAL_CROSSENTROPY", "MEAN_SQUARED_ERROR", nullptr}; + return names; +} + +inline const char *EnumNameLossFn(LossFn e) +{ + if (flatbuffers::IsOutRange(e, LossFn_SPARSE_CATEGORICAL_CROSSENTROPY, LossFn_MEAN_SQUARED_ERROR)) + return ""; + const size_t index = static_cast(e); + return EnumNamesLossFn()[index]; +} + +enum LossFnOptions : uint8_t +{ + LossFnOptions_NONE = 0, + LossFnOptions_SparseCategoricalCrossentropyOptions = 1, + LossFnOptions_CategoricalCrossentropyOptions = 2, + LossFnOptions_MeanSquaredErrorOptions = 3, + LossFnOptions_MIN = LossFnOptions_NONE, + LossFnOptions_MAX = LossFnOptions_MeanSquaredErrorOptions +}; + +inline const LossFnOptions (&EnumValuesLossFnOptions())[4] +{ + static const LossFnOptions values[] = { + LossFnOptions_NONE, LossFnOptions_SparseCategoricalCrossentropyOptions, + LossFnOptions_CategoricalCrossentropyOptions, LossFnOptions_MeanSquaredErrorOptions}; + return values; +} + +inline const char *const *EnumNamesLossFnOptions() +{ + static const char *const names[5] = {"NONE", "SparseCategoricalCrossentropyOptions", + "CategoricalCrossentropyOptions", "MeanSquaredErrorOptions", + nullptr}; + return names; +} + +inline const char *EnumNameLossFnOptions(LossFnOptions e) +{ + if (flatbuffers::IsOutRange(e, LossFnOptions_NONE, LossFnOptions_MeanSquaredErrorOptions)) + return ""; + const size_t index = static_cast(e); + return EnumNamesLossFnOptions()[index]; +} + +template struct LossFnOptionsTraits +{ + static const LossFnOptions enum_value = LossFnOptions_NONE; +}; + +template <> struct LossFnOptionsTraits +{ + static const LossFnOptions enum_value = LossFnOptions_SparseCategoricalCrossentropyOptions; +}; + +template <> struct LossFnOptionsTraits +{ + static const LossFnOptions enum_value = LossFnOptions_CategoricalCrossentropyOptions; +}; + +template <> struct LossFnOptionsTraits +{ + static const LossFnOptions enum_value = LossFnOptions_MeanSquaredErrorOptions; +}; + +bool VerifyLossFnOptions(flatbuffers::Verifier &verifier, const void *obj, LossFnOptions type); +bool VerifyLossFnOptionsVector(flatbuffers::Verifier &verifier, + const flatbuffers::Vector> *values, + const flatbuffers::Vector *types); + +struct SGDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef SGDOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_LEARNING_RATE = 4 + }; + float learning_rate() const { return GetField(VT_LEARNING_RATE, 0.0f); } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_LEARNING_RATE) && + verifier.EndTable(); + } +}; + +struct SGDOptionsBuilder +{ + typedef SGDOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_learning_rate(float learning_rate) + { + fbb_.AddElement(SGDOptions::VT_LEARNING_RATE, learning_rate, 0.0f); + } + explicit SGDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSGDOptions(flatbuffers::FlatBufferBuilder &_fbb, + float learning_rate = 0.0f) +{ + SGDOptionsBuilder builder_(_fbb); + builder_.add_learning_rate(learning_rate); + return builder_.Finish(); +} + +struct AdamOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef AdamOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_LEARNING_RATE = 4, + VT_BETA_1 = 6, + VT_BETA_2 = 8, + VT_EPSILON = 10 + }; + float learning_rate() const { return GetField(VT_LEARNING_RATE, 0.0f); } + float beta_1() const { return GetField(VT_BETA_1, 0.0f); } + float beta_2() const { return GetField(VT_BETA_2, 0.0f); } + float epsilon() const { return GetField(VT_EPSILON, 0.0f); } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_LEARNING_RATE) && + VerifyField(verifier, VT_BETA_1) && VerifyField(verifier, VT_BETA_2) && + VerifyField(verifier, VT_EPSILON) && verifier.EndTable(); + } +}; + +struct AdamOptionsBuilder +{ + typedef AdamOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_learning_rate(float learning_rate) + { + fbb_.AddElement(AdamOptions::VT_LEARNING_RATE, learning_rate, 0.0f); + } + void add_beta_1(float beta_1) { fbb_.AddElement(AdamOptions::VT_BETA_1, beta_1, 0.0f); } + void add_beta_2(float beta_2) { fbb_.AddElement(AdamOptions::VT_BETA_2, beta_2, 0.0f); } + void add_epsilon(float epsilon) + { + fbb_.AddElement(AdamOptions::VT_EPSILON, epsilon, 0.0f); + } + explicit AdamOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAdamOptions(flatbuffers::FlatBufferBuilder &_fbb, + float learning_rate = 0.0f, + float beta_1 = 0.0f, float beta_2 = 0.0f, + float epsilon = 0.0f) +{ + AdamOptionsBuilder builder_(_fbb); + builder_.add_epsilon(epsilon); + builder_.add_beta_2(beta_2); + builder_.add_beta_1(beta_1); + builder_.add_learning_rate(learning_rate); + return builder_.Finish(); +} + +struct SparseCategoricalCrossentropyOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef SparseCategoricalCrossentropyOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FROM_LOGITS = 4 + }; + bool from_logits() const { return GetField(VT_FROM_LOGITS, 0) != 0; } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_FROM_LOGITS) && + verifier.EndTable(); + } +}; + +struct SparseCategoricalCrossentropyOptionsBuilder +{ + typedef SparseCategoricalCrossentropyOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_from_logits(bool from_logits) + { + fbb_.AddElement(SparseCategoricalCrossentropyOptions::VT_FROM_LOGITS, + static_cast(from_logits), 0); + } + explicit SparseCategoricalCrossentropyOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateSparseCategoricalCrossentropyOptions(flatbuffers::FlatBufferBuilder &_fbb, + bool from_logits = false) +{ + SparseCategoricalCrossentropyOptionsBuilder builder_(_fbb); + builder_.add_from_logits(from_logits); + return builder_.Finish(); +} + +struct CategoricalCrossentropyOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef CategoricalCrossentropyOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_FROM_LOGITS = 4 + }; + bool from_logits() const { return GetField(VT_FROM_LOGITS, 0) != 0; } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_FROM_LOGITS) && + verifier.EndTable(); + } +}; + +struct CategoricalCrossentropyOptionsBuilder +{ + typedef CategoricalCrossentropyOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_from_logits(bool from_logits) + { + fbb_.AddElement(CategoricalCrossentropyOptions::VT_FROM_LOGITS, + static_cast(from_logits), 0); + } + explicit CategoricalCrossentropyOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateCategoricalCrossentropyOptions(flatbuffers::FlatBufferBuilder &_fbb, bool from_logits = false) +{ + CategoricalCrossentropyOptionsBuilder builder_(_fbb); + builder_.add_from_logits(from_logits); + return builder_.Finish(); +} + +struct MeanSquaredErrorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef MeanSquaredErrorOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct MeanSquaredErrorOptionsBuilder +{ + typedef MeanSquaredErrorOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit MeanSquaredErrorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateMeanSquaredErrorOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + MeanSquaredErrorOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ModelTraining FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef ModelTrainingBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_VERSION = 4, + VT_OPTIMIZER = 6, + VT_OPTIMIZER_OPT_TYPE = 8, + VT_OPTIMIZER_OPT = 10, + VT_LOSSFN = 12, + VT_LOSSFN_OPT_TYPE = 14, + VT_LOSSFN_OPT = 16, + VT_EPOCHS = 18, + VT_BATCH_SIZE = 20 + }; + uint32_t version() const { return GetField(VT_VERSION, 0); } + circle::Optimizer optimizer() const + { + return static_cast(GetField(VT_OPTIMIZER, 0)); + } + circle::OptimizerOptions optimizer_opt_type() const + { + return static_cast(GetField(VT_OPTIMIZER_OPT_TYPE, 0)); + } + const void *optimizer_opt() const { return GetPointer(VT_OPTIMIZER_OPT); } + template const T *optimizer_opt_as() const; + const circle::SGDOptions *optimizer_opt_as_SGDOptions() const + { + return optimizer_opt_type() == circle::OptimizerOptions_SGDOptions + ? static_cast(optimizer_opt()) + : nullptr; + } + const circle::AdamOptions *optimizer_opt_as_AdamOptions() const + { + return optimizer_opt_type() == circle::OptimizerOptions_AdamOptions + ? static_cast(optimizer_opt()) + : nullptr; + } + circle::LossFn lossfn() const + { + return static_cast(GetField(VT_LOSSFN, 0)); + } + circle::LossFnOptions lossfn_opt_type() const + { + return static_cast(GetField(VT_LOSSFN_OPT_TYPE, 0)); + } + const void *lossfn_opt() const { return GetPointer(VT_LOSSFN_OPT); } + template const T *lossfn_opt_as() const; + const circle::SparseCategoricalCrossentropyOptions * + lossfn_opt_as_SparseCategoricalCrossentropyOptions() const + { + return lossfn_opt_type() == circle::LossFnOptions_SparseCategoricalCrossentropyOptions + ? static_cast(lossfn_opt()) + : nullptr; + } + const circle::CategoricalCrossentropyOptions *lossfn_opt_as_CategoricalCrossentropyOptions() const + { + return lossfn_opt_type() == circle::LossFnOptions_CategoricalCrossentropyOptions + ? static_cast(lossfn_opt()) + : nullptr; + } + const circle::MeanSquaredErrorOptions *lossfn_opt_as_MeanSquaredErrorOptions() const + { + return lossfn_opt_type() == circle::LossFnOptions_MeanSquaredErrorOptions + ? static_cast(lossfn_opt()) + : nullptr; + } + int32_t epochs() const { return GetField(VT_EPOCHS, 0); } + int32_t batch_size() const { return GetField(VT_BATCH_SIZE, 0); } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_OPTIMIZER) && + VerifyField(verifier, VT_OPTIMIZER_OPT_TYPE) && + VerifyOffset(verifier, VT_OPTIMIZER_OPT) && + VerifyOptimizerOptions(verifier, optimizer_opt(), optimizer_opt_type()) && + VerifyField(verifier, VT_LOSSFN) && + VerifyField(verifier, VT_LOSSFN_OPT_TYPE) && + VerifyOffset(verifier, VT_LOSSFN_OPT) && + VerifyLossFnOptions(verifier, lossfn_opt(), lossfn_opt_type()) && + VerifyField(verifier, VT_EPOCHS) && + VerifyField(verifier, VT_BATCH_SIZE) && verifier.EndTable(); + } +}; + +template <> +inline const circle::SGDOptions *ModelTraining::optimizer_opt_as() const +{ + return optimizer_opt_as_SGDOptions(); +} + +template <> +inline const circle::AdamOptions *ModelTraining::optimizer_opt_as() const +{ + return optimizer_opt_as_AdamOptions(); +} + +template <> +inline const circle::SparseCategoricalCrossentropyOptions * +ModelTraining::lossfn_opt_as() const +{ + return lossfn_opt_as_SparseCategoricalCrossentropyOptions(); +} + +template <> +inline const circle::CategoricalCrossentropyOptions * +ModelTraining::lossfn_opt_as() const +{ + return lossfn_opt_as_CategoricalCrossentropyOptions(); +} + +template <> +inline const circle::MeanSquaredErrorOptions * +ModelTraining::lossfn_opt_as() const +{ + return lossfn_opt_as_MeanSquaredErrorOptions(); +} + +struct ModelTrainingBuilder +{ + typedef ModelTraining Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_version(uint32_t version) + { + fbb_.AddElement(ModelTraining::VT_VERSION, version, 0); + } + void add_optimizer(circle::Optimizer optimizer) + { + fbb_.AddElement(ModelTraining::VT_OPTIMIZER, static_cast(optimizer), 0); + } + void add_optimizer_opt_type(circle::OptimizerOptions optimizer_opt_type) + { + fbb_.AddElement(ModelTraining::VT_OPTIMIZER_OPT_TYPE, + static_cast(optimizer_opt_type), 0); + } + void add_optimizer_opt(flatbuffers::Offset optimizer_opt) + { + fbb_.AddOffset(ModelTraining::VT_OPTIMIZER_OPT, optimizer_opt); + } + void add_lossfn(circle::LossFn lossfn) + { + fbb_.AddElement(ModelTraining::VT_LOSSFN, static_cast(lossfn), 0); + } + void add_lossfn_opt_type(circle::LossFnOptions lossfn_opt_type) + { + fbb_.AddElement(ModelTraining::VT_LOSSFN_OPT_TYPE, + static_cast(lossfn_opt_type), 0); + } + void add_lossfn_opt(flatbuffers::Offset lossfn_opt) + { + fbb_.AddOffset(ModelTraining::VT_LOSSFN_OPT, lossfn_opt); + } + void add_epochs(int32_t epochs) { fbb_.AddElement(ModelTraining::VT_EPOCHS, epochs, 0); } + void add_batch_size(int32_t batch_size) + { + fbb_.AddElement(ModelTraining::VT_BATCH_SIZE, batch_size, 0); + } + explicit ModelTrainingBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateModelTraining(flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, + circle::Optimizer optimizer = circle::Optimizer_SGD, + circle::OptimizerOptions optimizer_opt_type = circle::OptimizerOptions_NONE, + flatbuffers::Offset optimizer_opt = 0, + circle::LossFn lossfn = circle::LossFn_SPARSE_CATEGORICAL_CROSSENTROPY, + circle::LossFnOptions lossfn_opt_type = circle::LossFnOptions_NONE, + flatbuffers::Offset lossfn_opt = 0, int32_t epochs = 0, + int32_t batch_size = 0) +{ + ModelTrainingBuilder builder_(_fbb); + builder_.add_batch_size(batch_size); + builder_.add_epochs(epochs); + builder_.add_lossfn_opt(lossfn_opt); + builder_.add_optimizer_opt(optimizer_opt); + builder_.add_version(version); + builder_.add_lossfn_opt_type(lossfn_opt_type); + builder_.add_lossfn(lossfn); + builder_.add_optimizer_opt_type(optimizer_opt_type); + builder_.add_optimizer(optimizer); + return builder_.Finish(); +} + +inline bool VerifyOptimizerOptions(flatbuffers::Verifier &verifier, const void *obj, + OptimizerOptions type) +{ + switch (type) + { + case OptimizerOptions_NONE: + { + return true; + } + case OptimizerOptions_SGDOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case OptimizerOptions_AdamOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: + return true; + } +} + +inline bool +VerifyOptimizerOptionsVector(flatbuffers::Verifier &verifier, + const flatbuffers::Vector> *values, + const flatbuffers::Vector *types) +{ + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) + { + if (!VerifyOptimizerOptions(verifier, values->Get(i), types->GetEnum(i))) + { + return false; + } + } + return true; +} + +inline bool VerifyLossFnOptions(flatbuffers::Verifier &verifier, const void *obj, + LossFnOptions type) +{ + switch (type) + { + case LossFnOptions_NONE: + { + return true; + } + case LossFnOptions_SparseCategoricalCrossentropyOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case LossFnOptions_CategoricalCrossentropyOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case LossFnOptions_MeanSquaredErrorOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: + return true; + } +} + +inline bool VerifyLossFnOptionsVector(flatbuffers::Verifier &verifier, + const flatbuffers::Vector> *values, + const flatbuffers::Vector *types) +{ + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) + { + if (!VerifyLossFnOptions(verifier, values->Get(i), types->GetEnum(i))) + { + return false; + } + } + return true; +} + +inline const circle::ModelTraining *GetModelTraining(const void *buf) +{ + return flatbuffers::GetRoot(buf); +} + +inline const circle::ModelTraining *GetSizePrefixedModelTraining(const void *buf) +{ + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char *ModelTrainingIdentifier() { return "CTR1"; } + +inline bool ModelTrainingBufferHasIdentifier(const void *buf) +{ + return flatbuffers::BufferHasIdentifier(buf, ModelTrainingIdentifier()); +} + +inline bool VerifyModelTrainingBuffer(flatbuffers::Verifier &verifier) +{ + return verifier.VerifyBuffer(ModelTrainingIdentifier()); +} + +inline bool VerifySizePrefixedModelTrainingBuffer(flatbuffers::Verifier &verifier) +{ + return verifier.VerifySizePrefixedBuffer(ModelTrainingIdentifier()); +} + +inline const char *ModelTrainingExtension() { return "circletr"; } + +inline void FinishModelTrainingBuffer(flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) +{ + fbb.Finish(root, ModelTrainingIdentifier()); +} + +inline void FinishSizePrefixedModelTrainingBuffer(flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) +{ + fbb.FinishSizePrefixed(root, ModelTrainingIdentifier()); +} + +} // namespace circle + +#endif // FLATBUFFERS_GENERATED_CIRCLETRAINING_CIRCLE_H_ diff --git a/runtime/onert/frontend/circle_traininfo/src/traininfo_loader.cc b/runtime/onert/frontend/circle_traininfo/src/traininfo_loader.cc new file mode 100644 index 00000000000..45a6c45ef25 --- /dev/null +++ b/runtime/onert/frontend/circle_traininfo/src/traininfo_loader.cc @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "traininfo_loader.h" +#include "circle_traininfo_generated.h" +#include "flatbuffers/flatbuffers.h" + +namespace onert +{ +namespace train +{ +namespace traininfo_loader +{ + +namespace +{ + +ir::train::OptimizerInfo loadOptimizerInfo(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + // fill ir_opt from cirlce_opt + ir::train::OptimizerInfo ir_opt; + const circle::Optimizer circle_opt = circle_model->optimizer(); + + switch (circle_opt) + { + case circle::Optimizer_SGD: + ir_opt.optim_code = ir::train::OptimizerCode::SGD; + ir_opt.learning_rate = circle_model->optimizer_opt_as_SGDOptions()->learning_rate(); + break; + case circle::Optimizer_ADAM: + ir_opt.optim_code = ir::train::OptimizerCode::Adam; + ir_opt.learning_rate = circle_model->optimizer_opt_as_AdamOptions()->learning_rate(); + break; + default: + throw std::runtime_error("unknown optimzer"); + } + return ir_opt; +} + +ir::train::LossInfo loadLossInfo(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + // fill ir_loss from circle_loss + ir::train::LossInfo ir_loss; + const circle::LossFn circle_loss = circle_model->lossfn(); + + switch (circle_loss) + { + case circle::LossFn::LossFn_CATEGORICAL_CROSSENTROPY: + ir_loss.loss_code = ir::train::LossCode::CategoricalCrossentropy; + break; + case circle::LossFn::LossFn_MEAN_SQUARED_ERROR: + ir_loss.loss_code = ir::train::LossCode::MeanSquaredError; + break; + case circle::LossFn::LossFn_SPARSE_CATEGORICAL_CROSSENTROPY: + // TODO enable this conversion after core support sparse_categorial_crossentropy + throw std::runtime_error{"'sparse_categorical_crossentropy' is not supported yet"}; + default: + throw std::runtime_error{"unknown loss function"}; + } + + // TODO update circle schema to support loss reduction type + return ir_loss; +} +} // namespace + +std::unique_ptr loadTrainingInfo(const uint8_t *buffer, const size_t size) +{ + assert(buffer != nullptr); + + flatbuffers::Verifier v(buffer, size); + bool verified = circle::VerifyModelTrainingBuffer(v); + if (not verified) + throw std::runtime_error{"TrainingInfo buffer is not accessible"}; + + const circle::ModelTraining *circle_model = + circle::GetModelTraining(static_cast(buffer)); + + assert(circle_model != nullptr); + + auto tinfo = std::make_unique(); + { + tinfo->setBatchSize(circle_model->batch_size()); + tinfo->setOptimizerInfo(loadOptimizerInfo(circle_model)); + tinfo->setLossInfo(loadLossInfo(circle_model)); + } + return tinfo; +} + +} // namespace traininfo_loader +} // namespace train +} // namespace onert