|
14 | 14 | * limitations under the License.
|
15 | 15 | */
|
16 | 16 |
|
| 17 | +#ifndef __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ |
| 18 | +#define __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ |
| 19 | + |
17 | 20 | #include "nnfw.h"
|
| 21 | +#include "nnfw_experimental.h" |
18 | 22 |
|
19 | 23 | #include <pybind11/stl.h>
|
20 | 24 | #include <pybind11/numpy.h>
|
21 | 25 |
|
| 26 | +namespace onert |
| 27 | +{ |
| 28 | +namespace api |
| 29 | +{ |
| 30 | +namespace python |
| 31 | +{ |
| 32 | + |
22 | 33 | namespace py = pybind11;
|
23 | 34 |
|
24 | 35 | /**
|
@@ -159,4 +170,62 @@ class NNFW_SESSION
|
159 | 170 | void set_output_layout(uint32_t index, const char *layout);
|
160 | 171 | tensorinfo input_tensorinfo(uint32_t index);
|
161 | 172 | tensorinfo output_tensorinfo(uint32_t index);
|
| 173 | + |
| 174 | + ////////////////////////////////////////////// |
| 175 | + // Experimental APIs for training |
| 176 | + ////////////////////////////////////////////// |
| 177 | + nnfw_train_info train_get_traininfo(); |
| 178 | + void train_set_traininfo(const nnfw_train_info *info); |
| 179 | + |
| 180 | + template <typename T> void train_set_input(uint32_t index, py::array_t<T> &buffer) |
| 181 | + { |
| 182 | + nnfw_tensorinfo tensor_info; |
| 183 | + nnfw_input_tensorinfo(this->session, index, &tensor_info); |
| 184 | + |
| 185 | + py::buffer_info buf_info = buffer.request(); |
| 186 | + const auto buf_shape = buf_info.shape; |
| 187 | + assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0); |
| 188 | + tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0)); |
| 189 | + |
| 190 | + ensure_status(nnfw_train_set_input(this->session, index, buffer.request().ptr, &tensor_info)); |
| 191 | + } |
| 192 | + template <typename T> void train_set_expected(uint32_t index, py::array_t<T> &buffer) |
| 193 | + { |
| 194 | + nnfw_tensorinfo tensor_info; |
| 195 | + nnfw_output_tensorinfo(this->session, index, &tensor_info); |
| 196 | + |
| 197 | + py::buffer_info buf_info = buffer.request(); |
| 198 | + const auto buf_shape = buf_info.shape; |
| 199 | + assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0); |
| 200 | + tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0)); |
| 201 | + |
| 202 | + ensure_status( |
| 203 | + nnfw_train_set_expected(this->session, index, buffer.request().ptr, &tensor_info)); |
| 204 | + } |
| 205 | + template <typename T> void train_set_output(uint32_t index, py::array_t<T> &buffer) |
| 206 | + { |
| 207 | + nnfw_tensorinfo tensor_info; |
| 208 | + nnfw_output_tensorinfo(this->session, index, &tensor_info); |
| 209 | + NNFW_TYPE type = tensor_info.dtype; |
| 210 | + uint32_t output_elements = num_elems(&tensor_info); |
| 211 | + size_t length = sizeof(T) * output_elements; |
| 212 | + |
| 213 | + ensure_status(nnfw_train_set_output(session, index, type, buffer.request().ptr, length)); |
| 214 | + } |
| 215 | + |
| 216 | + void train_prepare(); |
| 217 | + void train(bool update_weights); |
| 218 | + float train_get_loss(uint32_t index); |
| 219 | + |
| 220 | + void train_export_circle(const py::str &path); |
| 221 | + void train_import_checkpoint(const py::str &path); |
| 222 | + void train_export_checkpoint(const py::str &path); |
| 223 | + |
| 224 | + // TODO Add other apis |
162 | 225 | };
|
| 226 | + |
| 227 | +} // namespace python |
| 228 | +} // namespace api |
| 229 | +} // namespace onert |
| 230 | + |
| 231 | +#endif // __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ |
0 commit comments