Skip to content

Commit

Permalink
[onert/python] Introduce BaseSession (#14527)
Browse files Browse the repository at this point in the history
This commit introduces the class BaseSession.
  - Inroduce the class BaseSession to group common functionalities
  - Modify setup.py to make Python files easier to find.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jan 6, 2025
1 parent 765d7b7 commit 80e94f4
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 66 deletions.
25 changes: 18 additions & 7 deletions infra/nnfw/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,23 @@

# copy *py files to package_directory
PY_DIR = os.path.join(THIS_FILE_DIR, '../../../runtime/onert/api/python/package')
for py_file in os.listdir(PY_DIR):
if py_file.endswith(".py"):
src_path = os.path.join(PY_DIR, py_file)
dest_path = os.path.join(THIS_FILE_DIR, package_directory)
shutil.copy(src_path, dest_path)
print(f"Copied '{src_path}' to '{dest_path}'")
for root, dirs, files in os.walk(PY_DIR):
# Calculate the relative path from the source directory
rel_path = os.path.relpath(root, PY_DIR)
dest_dir = os.path.join(THIS_FILE_DIR, package_directory)
dest_sub_dir = os.path.join(dest_dir, rel_path)
print(f"dest_sub_dir '{dest_sub_dir}'")

# Ensure the corresponding destination subdirectory exists
os.makedirs(dest_sub_dir, exist_ok=True)

# Copy only .py files
for py_file in files:
if py_file.endswith(".py"):
src_path = os.path.join(root, py_file)
# dest_path = os.path.join(THIS_FILE_DIR, package_directory)
shutil.copy(src_path, dest_sub_dir)
print(f"Copied '{src_path}' to '{dest_sub_dir}'")

# remove architecture directory
if os.path.exists(package_directory):
Expand Down Expand Up @@ -142,6 +153,6 @@ def get_directories():
url='https://github.com/Samsung/ONE',
license='Apache-2.0, MIT, BSD-2-Clause, BSD-3-Clause, Mozilla Public License 2.0',
has_ext_modules=lambda: True,
packages=[package_directory],
packages=find_packages(),
package_data={package_directory: so_list},
install_requires=['numpy >= 1.19'])
8 changes: 7 additions & 1 deletion runtime/onert/api/python/package/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
__all__ = ['infer']
# Define the public API of the onert package
__all__ = ["infer", "tensorinfo"]

# Import and expose the infer module's functionalities
from . import infer

# Import and expose tensorinfo
from .common import tensorinfo as tensorinfo
3 changes: 3 additions & 0 deletions runtime/onert/api/python/package/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .basesession import BaseSession, tensorinfo

__all__ = ["BaseSession", "tensorinfo"]
75 changes: 75 additions & 0 deletions runtime/onert/api/python/package/common/basesession.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np

from ..native import libnnfw_api_pybind


def num_elems(tensor_info):
"""Get the total number of elements in nnfw_tensorinfo.dims."""
n = 1
for x in range(tensor_info.rank):
n *= tensor_info.dims[x]
return n


class BaseSession:
"""
Base class providing common functionality for inference and training sessions.
"""
def __init__(self, backend_session):
"""
Initialize the BaseSession with a backend session.
Args:
backend_session: A backend-specific session object (e.g., nnfw_session).
"""
self.session = backend_session
self.inputs = []
self.outputs = []

def __getattr__(self, name):
"""
Delegate attribute access to the bound NNFW_SESSION instance.
Args:
name (str): The name of the attribute or method to access.
Returns:
The attribute or method from the bound NNFW_SESSION instance.
"""
return getattr(self.session, name)

def set_inputs(self, size, inputs_array=[]):
"""
Set the input tensors for the session.
Args:
size (int): Number of input tensors.
inputs_array (list): List of numpy arrays for the input data.
"""
for i in range(size):
input_tensorinfo = self.session.input_tensorinfo(i)

if len(inputs_array) > i:
input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype)
else:
print(
f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0."
)
input_array = np.zeros((num_elems(input_tensorinfo)),
dtype=input_tensorinfo.dtype)

self.session.set_input(i, input_array)
self.inputs.append(input_array)

def set_outputs(self, size):
"""
Set the output tensors for the session.
Args:
size (int): Number of output tensors.
"""
for i in range(size):
output_tensorinfo = self.session.output_tensorinfo(i)
output_array = np.zeros((num_elems(output_tensorinfo)),
dtype=output_tensorinfo.dtype)
self.session.set_output(i, output_array)
self.outputs.append(output_array)


def tensorinfo():
return libnnfw_api_pybind.infer.nnfw_tensorinfo()
58 changes: 0 additions & 58 deletions runtime/onert/api/python/package/infer.py

This file was deleted.

3 changes: 3 additions & 0 deletions runtime/onert/api/python/package/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .session import session

__all__ = ["session"]
27 changes: 27 additions & 0 deletions runtime/onert/api/python/package/infer/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ..native import libnnfw_api_pybind
from ..common.basesession import BaseSession


class session(BaseSession):
"""
Class for inference using nnfw_session.
"""
def __init__(self, nnpackage_path, backends="cpu"):
"""
Initialize the inference session.
Args:
nnpackage_path (str): Path to the nnpackage file or directory.
backends (str): Backends to use, default is "cpu".
"""
super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
self.session.prepare()
self.set_outputs(self.session.output_size())

def inference(self):
"""
Perform model and get outputs
Returns:
list: Outputs from the model.
"""
self.session.run()
return self.outputs

0 comments on commit 80e94f4

Please sign in to comment.