Skip to content

Commit 765d7b7

Browse files
authored
[onert] Add wrapping CAPIs for training (#14524)
This commit adds wrapping CAPIs for training - Add the include guard for nnfw_api_wrapper.h - Introduce the namespace onert::api::python - Wrap CAPIs for training ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
1 parent d539375 commit 765d7b7

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

runtime/onert/api/python/include/nnfw_api_wrapper.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,22 @@
1414
* limitations under the License.
1515
*/
1616

17+
#ifndef __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__
18+
#define __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__
19+
1720
#include "nnfw.h"
21+
#include "nnfw_experimental.h"
1822

1923
#include <pybind11/stl.h>
2024
#include <pybind11/numpy.h>
2125

26+
namespace onert
27+
{
28+
namespace api
29+
{
30+
namespace python
31+
{
32+
2233
namespace py = pybind11;
2334

2435
/**
@@ -159,4 +170,62 @@ class NNFW_SESSION
159170
void set_output_layout(uint32_t index, const char *layout);
160171
tensorinfo input_tensorinfo(uint32_t index);
161172
tensorinfo output_tensorinfo(uint32_t index);
173+
174+
//////////////////////////////////////////////
175+
// Experimental APIs for training
176+
//////////////////////////////////////////////
177+
nnfw_train_info train_get_traininfo();
178+
void train_set_traininfo(const nnfw_train_info *info);
179+
180+
template <typename T> void train_set_input(uint32_t index, py::array_t<T> &buffer)
181+
{
182+
nnfw_tensorinfo tensor_info;
183+
nnfw_input_tensorinfo(this->session, index, &tensor_info);
184+
185+
py::buffer_info buf_info = buffer.request();
186+
const auto buf_shape = buf_info.shape;
187+
assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0);
188+
tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0));
189+
190+
ensure_status(nnfw_train_set_input(this->session, index, buffer.request().ptr, &tensor_info));
191+
}
192+
template <typename T> void train_set_expected(uint32_t index, py::array_t<T> &buffer)
193+
{
194+
nnfw_tensorinfo tensor_info;
195+
nnfw_output_tensorinfo(this->session, index, &tensor_info);
196+
197+
py::buffer_info buf_info = buffer.request();
198+
const auto buf_shape = buf_info.shape;
199+
assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0);
200+
tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0));
201+
202+
ensure_status(
203+
nnfw_train_set_expected(this->session, index, buffer.request().ptr, &tensor_info));
204+
}
205+
template <typename T> void train_set_output(uint32_t index, py::array_t<T> &buffer)
206+
{
207+
nnfw_tensorinfo tensor_info;
208+
nnfw_output_tensorinfo(this->session, index, &tensor_info);
209+
NNFW_TYPE type = tensor_info.dtype;
210+
uint32_t output_elements = num_elems(&tensor_info);
211+
size_t length = sizeof(T) * output_elements;
212+
213+
ensure_status(nnfw_train_set_output(session, index, type, buffer.request().ptr, length));
214+
}
215+
216+
void train_prepare();
217+
void train(bool update_weights);
218+
float train_get_loss(uint32_t index);
219+
220+
void train_export_circle(const py::str &path);
221+
void train_import_checkpoint(const py::str &path);
222+
void train_export_checkpoint(const py::str &path);
223+
224+
// TODO Add other apis
162225
};
226+
227+
} // namespace python
228+
} // namespace api
229+
} // namespace onert
230+
231+
#endif // __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__

runtime/onert/api/python/src/nnfw_api_wrapper.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818

1919
#include <iostream>
2020

21+
namespace onert
22+
{
23+
namespace api
24+
{
25+
namespace python
26+
{
27+
28+
namespace py = pybind11;
29+
2130
void ensure_status(NNFW_STATUS status)
2231
{
2332
switch (status)
@@ -243,3 +252,54 @@ tensorinfo NNFW_SESSION::output_tensorinfo(uint32_t index)
243252
}
244253
return ti;
245254
}
255+
256+
//////////////////////////////////////////////
257+
// Experimental APIs for training
258+
//////////////////////////////////////////////
259+
nnfw_train_info NNFW_SESSION::train_get_traininfo()
260+
{
261+
nnfw_train_info train_info = nnfw_train_info();
262+
ensure_status(nnfw_train_get_traininfo(session, &train_info));
263+
return train_info;
264+
}
265+
266+
void NNFW_SESSION::train_set_traininfo(const nnfw_train_info *info)
267+
{
268+
ensure_status(nnfw_train_set_traininfo(session, info));
269+
}
270+
271+
void NNFW_SESSION::train_prepare() { ensure_status(nnfw_train_prepare(session)); }
272+
273+
void NNFW_SESSION::train(bool update_weights)
274+
{
275+
ensure_status(nnfw_train(session, update_weights));
276+
}
277+
278+
float NNFW_SESSION::train_get_loss(uint32_t index)
279+
{
280+
float loss = 0.f;
281+
ensure_status(nnfw_train_get_loss(session, index, &loss));
282+
return loss;
283+
}
284+
285+
void NNFW_SESSION::train_export_circle(const py::str &path)
286+
{
287+
const char *c_str_path = path.cast<std::string>().c_str();
288+
ensure_status(nnfw_train_export_circle(session, c_str_path));
289+
}
290+
291+
void NNFW_SESSION::train_import_checkpoint(const py::str &path)
292+
{
293+
const char *c_str_path = path.cast<std::string>().c_str();
294+
ensure_status(nnfw_train_import_checkpoint(session, c_str_path));
295+
}
296+
297+
void NNFW_SESSION::train_export_checkpoint(const py::str &path)
298+
{
299+
const char *c_str_path = path.cast<std::string>().c_str();
300+
ensure_status(nnfw_train_export_checkpoint(session, c_str_path));
301+
}
302+
303+
} // namespace python
304+
} // namespace api
305+
} // namespace onert

runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
namespace py = pybind11;
2020

21+
using namespace onert::api::python;
22+
2123
PYBIND11_MODULE(libnnfw_api_pybind, m)
2224
{
2325
m.doc() = "nnfw python plugin";

0 commit comments

Comments
 (0)