Skip to content

Commit

Permalink
[onert/python] bind training APIs (#14532)
Browse files Browse the repository at this point in the history
This commit binds training APIs.
  - Bind trainings APIs related to session
  - Bind trainings APIs related to traininfo

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jan 10, 2025
1 parent b5bb382 commit d86db1e
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 0 deletions.
3 changes: 3 additions & 0 deletions runtime/onert/api/python/include/nnfw_session_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace python
// Declare binding common functions
void bind_nnfw_session(pybind11::module_ &m);

// Declare binding experimental functions
void bind_experimental_nnfw_session(pybind11::module_ &m);

} // namespace python
} // namespace api
} // namespace onert
Expand Down
34 changes: 34 additions & 0 deletions runtime/onert/api/python/include/nnfw_traininfo_bindings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__
#define __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

// Declare binding train enums
void bind_nnfw_train_enums(py::module_ &m);

// Declare binding loss info
void bind_nnfw_loss_info(py::module_ &m);

// Declare binding train info
void bind_nnfw_train_info(py::module_ &m);

#endif // __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__
17 changes: 17 additions & 0 deletions runtime/onert/api/python/src/bindings/nnfw_api_wrapper_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "nnfw_session_bindings.h"
#include "nnfw_tensorinfo_bindings.h"
#include "nnfw_traininfo_bindings.h"

using namespace onert::api::python;

Expand All @@ -33,6 +34,22 @@ PYBIND11_MODULE(libnnfw_api_pybind, m)
auto infer = m.def_submodule("infer", "Inference submodule");
infer.attr("nnfw_session") = m.attr("nnfw_session");

// Bind experimental `NNFW_SESSION` class
auto experimental = m.def_submodule("experimental", "Experimental submodule");
experimental.attr("nnfw_session") = m.attr("nnfw_session");
bind_experimental_nnfw_session(experimental);

// Bind common `tensorinfo` class
bind_tensorinfo(m);

m.doc() = "NNFW Python Bindings for Training";

// Bind training enums
bind_nnfw_train_enums(m);

// Bind training nnfw_loss_info
bind_nnfw_loss_info(m);

// Bind_train_info
bind_nnfw_train_info(m);
}
41 changes: 41 additions & 0 deletions runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,47 @@ void bind_nnfw_session(py::module_ &m)
"\ttensorinfo: Tensor info (shape, type, etc)");
}

// Bind the `NNFW_SESSION` class with experimental APIs
void bind_experimental_nnfw_session(py::module_ &m)
{
// Add experimental APIs for the `NNFW_SESSION` class
m.attr("nnfw_session")
.cast<py::class_<NNFW_SESSION>>()
.def("train_get_traininfo", &NNFW_SESSION::train_get_traininfo,
"Retrieve training information for the model.")
.def("train_set_traininfo", &NNFW_SESSION::train_set_traininfo, py::arg("info"),
"Set training information for the model.")
.def("train_prepare", &NNFW_SESSION::train_prepare, "Prepare for training")
.def("train", &NNFW_SESSION::train, py::arg("update_weights") = true,
"Run a training step, optionally updating weights.")
.def("train_get_loss", &NNFW_SESSION::train_get_loss, py::arg("index"),
"Retrieve the training loss for a specific index.")
.def("train_set_input", &NNFW_SESSION::train_set_input<float>, py::arg("index"),
py::arg("buffer"), "Set training input tensor for the given index (float).")
.def("train_set_input", &NNFW_SESSION::train_set_input<int>, py::arg("index"),
py::arg("buffer"), "Set training input tensor for the given index (int).")
.def("train_set_input", &NNFW_SESSION::train_set_input<uint8_t>, py::arg("index"),
py::arg("buffer"), "Set training input tensor for the given index (uint8).")
.def("train_set_expected", &NNFW_SESSION::train_set_expected<float>, py::arg("index"),
py::arg("buffer"), "Set expected output tensor for the given index (float).")
.def("train_set_expected", &NNFW_SESSION::train_set_expected<int>, py::arg("index"),
py::arg("buffer"), "Set expected output tensor for the given index (int).")
.def("train_set_expected", &NNFW_SESSION::train_set_expected<uint8_t>, py::arg("index"),
py::arg("buffer"), "Set expected output tensor for the given index (uint8).")
.def("train_set_output", &NNFW_SESSION::train_set_output<float>, py::arg("index"),
py::arg("buffer"), "Set output tensor for the given index (float).")
.def("train_set_output", &NNFW_SESSION::train_set_output<int>, py::arg("index"),
py::arg("buffer"), "Set output tensor for the given index (int).")
.def("train_set_output", &NNFW_SESSION::train_set_output<uint8_t>, py::arg("index"),
py::arg("buffer"), "Set output tensor for the given index (uint8).")
.def("train_export_circle", &NNFW_SESSION::train_export_circle, py::arg("path"),
"Export the trained model to a circle file.")
.def("train_import_checkpoint", &NNFW_SESSION::train_import_checkpoint, py::arg("path"),
"Import a training checkpoint from a file.")
.def("train_export_checkpoint", &NNFW_SESSION::train_export_checkpoint, py::arg("path"),
"Export the training checkpoint to a file.");
}

} // namespace python
} // namespace api
} // namespace onert
73 changes: 73 additions & 0 deletions runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "nnfw_traininfo_bindings.h"

#include "nnfw_api_wrapper.h"

namespace py = pybind11;

using namespace onert::api::python;

// Declare binding train enums
void bind_nnfw_train_enums(py::module_ &m)
{
// Bind NNFW_TRAIN_LOSS
py::enum_<NNFW_TRAIN_LOSS>(m, "loss", py::module_local())
.value("UNDEFINED", NNFW_TRAIN_LOSS_UNDEFINED)
.value("MEAN_SQUARED_ERROR", NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR)
.value("CATEGORICAL_CROSSENTROPY", NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY);

// Bind NNFW_TRAIN_LOSS_REDUCTION
py::enum_<NNFW_TRAIN_LOSS_REDUCTION>(m, "loss_reduction", py::module_local())
.value("UNDEFINED", NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED)
.value("SUM_OVER_BATCH_SIZE", NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE)
.value("SUM", NNFW_TRAIN_LOSS_REDUCTION_SUM);

// Bind NNFW_TRAIN_OPTIMIZER
py::enum_<NNFW_TRAIN_OPTIMIZER>(m, "optimizer", py::module_local())
.value("UNDEFINED", NNFW_TRAIN_OPTIMIZER_UNDEFINED)
.value("SGD", NNFW_TRAIN_OPTIMIZER_SGD)
.value("ADAM", NNFW_TRAIN_OPTIMIZER_ADAM);

// Bind NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES
py::enum_<NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES>(m, "trainable_ops", py::module_local())
.value("INCORRECT_STATE", NNFW_TRAIN_TRAINABLE_INCORRECT_STATE)
.value("ALL", NNFW_TRAIN_TRAINABLE_ALL)
.value("NONE", NNFW_TRAIN_TRAINABLE_NONE);
}

// Declare binding loss info
void bind_nnfw_loss_info(py::module_ &m)
{
py::class_<nnfw_loss_info>(m, "lossinfo", py::module_local())
.def(py::init<>()) // Default constructor
.def_readwrite("loss", &nnfw_loss_info::loss, "Loss type")
.def_readwrite("reduction_type", &nnfw_loss_info::reduction_type, "Reduction type");
}

// Declare binding train info
void bind_nnfw_train_info(py::module_ &m)
{
py::class_<nnfw_train_info>(m, "traininfo", py::module_local())
.def(py::init<>()) // Default constructor
.def_readwrite("learning_rate", &nnfw_train_info::learning_rate, "Learning rate")
.def_readwrite("batch_size", &nnfw_train_info::batch_size, "Batch size")
.def_readwrite("loss_info", &nnfw_train_info::loss_info, "Loss information")
.def_readwrite("opt", &nnfw_train_info::opt, "Optimizer type")
.def_readwrite("num_of_trainable_ops", &nnfw_train_info::num_of_trainable_ops,
"Number of trainable operations");
}

0 comments on commit d86db1e

Please sign in to comment.