|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from ..native import libnnfw_api_pybind |
| 4 | + |
| 5 | + |
| 6 | +def num_elems(tensor_info): |
| 7 | + """Get the total number of elements in nnfw_tensorinfo.dims.""" |
| 8 | + n = 1 |
| 9 | + for x in range(tensor_info.rank): |
| 10 | + n *= tensor_info.dims[x] |
| 11 | + return n |
| 12 | + |
| 13 | + |
| 14 | +class BaseSession: |
| 15 | + """ |
| 16 | + Base class providing common functionality for inference and training sessions. |
| 17 | + """ |
| 18 | + def __init__(self, backend_session): |
| 19 | + """ |
| 20 | + Initialize the BaseSession with a backend session. |
| 21 | + Args: |
| 22 | + backend_session: A backend-specific session object (e.g., nnfw_session). |
| 23 | + """ |
| 24 | + self.session = backend_session |
| 25 | + self.inputs = [] |
| 26 | + self.outputs = [] |
| 27 | + |
| 28 | + def __getattr__(self, name): |
| 29 | + """ |
| 30 | + Delegate attribute access to the bound NNFW_SESSION instance. |
| 31 | + Args: |
| 32 | + name (str): The name of the attribute or method to access. |
| 33 | + Returns: |
| 34 | + The attribute or method from the bound NNFW_SESSION instance. |
| 35 | + """ |
| 36 | + return getattr(self.session, name) |
| 37 | + |
| 38 | + def set_inputs(self, size, inputs_array=[]): |
| 39 | + """ |
| 40 | + Set the input tensors for the session. |
| 41 | + Args: |
| 42 | + size (int): Number of input tensors. |
| 43 | + inputs_array (list): List of numpy arrays for the input data. |
| 44 | + """ |
| 45 | + for i in range(size): |
| 46 | + input_tensorinfo = self.session.input_tensorinfo(i) |
| 47 | + |
| 48 | + if len(inputs_array) > i: |
| 49 | + input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype) |
| 50 | + else: |
| 51 | + print( |
| 52 | + f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0." |
| 53 | + ) |
| 54 | + input_array = np.zeros((num_elems(input_tensorinfo)), |
| 55 | + dtype=input_tensorinfo.dtype) |
| 56 | + |
| 57 | + self.session.set_input(i, input_array) |
| 58 | + self.inputs.append(input_array) |
| 59 | + |
| 60 | + def set_outputs(self, size): |
| 61 | + """ |
| 62 | + Set the output tensors for the session. |
| 63 | + Args: |
| 64 | + size (int): Number of output tensors. |
| 65 | + """ |
| 66 | + for i in range(size): |
| 67 | + output_tensorinfo = self.session.output_tensorinfo(i) |
| 68 | + output_array = np.zeros((num_elems(output_tensorinfo)), |
| 69 | + dtype=output_tensorinfo.dtype) |
| 70 | + self.session.set_output(i, output_array) |
| 71 | + self.outputs.append(output_array) |
| 72 | + |
| 73 | + |
| 74 | +def tensorinfo(): |
| 75 | + return libnnfw_api_pybind.infer.nnfw_tensorinfo() |
0 commit comments