-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert/python] Introduce BaseSession (#14527)
This commit introduces the class BaseSession. - Inroduce the class BaseSession to group common functionalities - Modify setup.py to make Python files easier to find. ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
7 changed files
with
133 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
__all__ = ['infer'] | ||
# Define the public API of the onert package | ||
__all__ = ["infer", "tensorinfo"] | ||
|
||
# Import and expose the infer module's functionalities | ||
from . import infer | ||
|
||
# Import and expose tensorinfo | ||
from .common import tensorinfo as tensorinfo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .basesession import BaseSession, tensorinfo | ||
|
||
__all__ = ["BaseSession", "tensorinfo"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import numpy as np | ||
|
||
from ..native import libnnfw_api_pybind | ||
|
||
|
||
def num_elems(tensor_info): | ||
"""Get the total number of elements in nnfw_tensorinfo.dims.""" | ||
n = 1 | ||
for x in range(tensor_info.rank): | ||
n *= tensor_info.dims[x] | ||
return n | ||
|
||
|
||
class BaseSession: | ||
""" | ||
Base class providing common functionality for inference and training sessions. | ||
""" | ||
def __init__(self, backend_session): | ||
""" | ||
Initialize the BaseSession with a backend session. | ||
Args: | ||
backend_session: A backend-specific session object (e.g., nnfw_session). | ||
""" | ||
self.session = backend_session | ||
self.inputs = [] | ||
self.outputs = [] | ||
|
||
def __getattr__(self, name): | ||
""" | ||
Delegate attribute access to the bound NNFW_SESSION instance. | ||
Args: | ||
name (str): The name of the attribute or method to access. | ||
Returns: | ||
The attribute or method from the bound NNFW_SESSION instance. | ||
""" | ||
return getattr(self.session, name) | ||
|
||
def set_inputs(self, size, inputs_array=[]): | ||
""" | ||
Set the input tensors for the session. | ||
Args: | ||
size (int): Number of input tensors. | ||
inputs_array (list): List of numpy arrays for the input data. | ||
""" | ||
for i in range(size): | ||
input_tensorinfo = self.session.input_tensorinfo(i) | ||
|
||
if len(inputs_array) > i: | ||
input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype) | ||
else: | ||
print( | ||
f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0." | ||
) | ||
input_array = np.zeros((num_elems(input_tensorinfo)), | ||
dtype=input_tensorinfo.dtype) | ||
|
||
self.session.set_input(i, input_array) | ||
self.inputs.append(input_array) | ||
|
||
def set_outputs(self, size): | ||
""" | ||
Set the output tensors for the session. | ||
Args: | ||
size (int): Number of output tensors. | ||
""" | ||
for i in range(size): | ||
output_tensorinfo = self.session.output_tensorinfo(i) | ||
output_array = np.zeros((num_elems(output_tensorinfo)), | ||
dtype=output_tensorinfo.dtype) | ||
self.session.set_output(i, output_array) | ||
self.outputs.append(output_array) | ||
|
||
|
||
def tensorinfo(): | ||
return libnnfw_api_pybind.infer.nnfw_tensorinfo() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .session import session | ||
|
||
__all__ = ["session"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from ..native import libnnfw_api_pybind | ||
from ..common.basesession import BaseSession | ||
|
||
|
||
class session(BaseSession): | ||
""" | ||
Class for inference using nnfw_session. | ||
""" | ||
def __init__(self, nnpackage_path, backends="cpu"): | ||
""" | ||
Initialize the inference session. | ||
Args: | ||
nnpackage_path (str): Path to the nnpackage file or directory. | ||
backends (str): Backends to use, default is "cpu". | ||
""" | ||
super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends)) | ||
self.session.prepare() | ||
self.set_outputs(self.session.output_size()) | ||
|
||
def inference(self): | ||
""" | ||
Perform model and get outputs | ||
Returns: | ||
list: Outputs from the model. | ||
""" | ||
self.session.run() | ||
return self.outputs |