Skip to content

Commit 80e94f4

Browse files
authored
[onert/python] Introduce BaseSession (#14527)
This commit introduces the class BaseSession. - Inroduce the class BaseSession to group common functionalities - Modify setup.py to make Python files easier to find. ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
1 parent 765d7b7 commit 80e94f4

File tree

7 files changed

+133
-66
lines changed

7 files changed

+133
-66
lines changed

infra/nnfw/python/setup.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,23 @@
5252

5353
# copy *py files to package_directory
5454
PY_DIR = os.path.join(THIS_FILE_DIR, '../../../runtime/onert/api/python/package')
55-
for py_file in os.listdir(PY_DIR):
56-
if py_file.endswith(".py"):
57-
src_path = os.path.join(PY_DIR, py_file)
58-
dest_path = os.path.join(THIS_FILE_DIR, package_directory)
59-
shutil.copy(src_path, dest_path)
60-
print(f"Copied '{src_path}' to '{dest_path}'")
55+
for root, dirs, files in os.walk(PY_DIR):
56+
# Calculate the relative path from the source directory
57+
rel_path = os.path.relpath(root, PY_DIR)
58+
dest_dir = os.path.join(THIS_FILE_DIR, package_directory)
59+
dest_sub_dir = os.path.join(dest_dir, rel_path)
60+
print(f"dest_sub_dir '{dest_sub_dir}'")
61+
62+
# Ensure the corresponding destination subdirectory exists
63+
os.makedirs(dest_sub_dir, exist_ok=True)
64+
65+
# Copy only .py files
66+
for py_file in files:
67+
if py_file.endswith(".py"):
68+
src_path = os.path.join(root, py_file)
69+
# dest_path = os.path.join(THIS_FILE_DIR, package_directory)
70+
shutil.copy(src_path, dest_sub_dir)
71+
print(f"Copied '{src_path}' to '{dest_sub_dir}'")
6172

6273
# remove architecture directory
6374
if os.path.exists(package_directory):
@@ -142,6 +153,6 @@ def get_directories():
142153
url='https://github.com/Samsung/ONE',
143154
license='Apache-2.0, MIT, BSD-2-Clause, BSD-3-Clause, Mozilla Public License 2.0',
144155
has_ext_modules=lambda: True,
145-
packages=[package_directory],
156+
packages=find_packages(),
146157
package_data={package_directory: so_list},
147158
install_requires=['numpy >= 1.19'])
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1-
__all__ = ['infer']
1+
# Define the public API of the onert package
2+
__all__ = ["infer", "tensorinfo"]
3+
4+
# Import and expose the infer module's functionalities
25
from . import infer
6+
7+
# Import and expose tensorinfo
8+
from .common import tensorinfo as tensorinfo
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .basesession import BaseSession, tensorinfo
2+
3+
__all__ = ["BaseSession", "tensorinfo"]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
3+
from ..native import libnnfw_api_pybind
4+
5+
6+
def num_elems(tensor_info):
7+
"""Get the total number of elements in nnfw_tensorinfo.dims."""
8+
n = 1
9+
for x in range(tensor_info.rank):
10+
n *= tensor_info.dims[x]
11+
return n
12+
13+
14+
class BaseSession:
15+
"""
16+
Base class providing common functionality for inference and training sessions.
17+
"""
18+
def __init__(self, backend_session):
19+
"""
20+
Initialize the BaseSession with a backend session.
21+
Args:
22+
backend_session: A backend-specific session object (e.g., nnfw_session).
23+
"""
24+
self.session = backend_session
25+
self.inputs = []
26+
self.outputs = []
27+
28+
def __getattr__(self, name):
29+
"""
30+
Delegate attribute access to the bound NNFW_SESSION instance.
31+
Args:
32+
name (str): The name of the attribute or method to access.
33+
Returns:
34+
The attribute or method from the bound NNFW_SESSION instance.
35+
"""
36+
return getattr(self.session, name)
37+
38+
def set_inputs(self, size, inputs_array=[]):
39+
"""
40+
Set the input tensors for the session.
41+
Args:
42+
size (int): Number of input tensors.
43+
inputs_array (list): List of numpy arrays for the input data.
44+
"""
45+
for i in range(size):
46+
input_tensorinfo = self.session.input_tensorinfo(i)
47+
48+
if len(inputs_array) > i:
49+
input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype)
50+
else:
51+
print(
52+
f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0."
53+
)
54+
input_array = np.zeros((num_elems(input_tensorinfo)),
55+
dtype=input_tensorinfo.dtype)
56+
57+
self.session.set_input(i, input_array)
58+
self.inputs.append(input_array)
59+
60+
def set_outputs(self, size):
61+
"""
62+
Set the output tensors for the session.
63+
Args:
64+
size (int): Number of output tensors.
65+
"""
66+
for i in range(size):
67+
output_tensorinfo = self.session.output_tensorinfo(i)
68+
output_array = np.zeros((num_elems(output_tensorinfo)),
69+
dtype=output_tensorinfo.dtype)
70+
self.session.set_output(i, output_array)
71+
self.outputs.append(output_array)
72+
73+
74+
def tensorinfo():
75+
return libnnfw_api_pybind.infer.nnfw_tensorinfo()

runtime/onert/api/python/package/infer.py

Lines changed: 0 additions & 58 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .session import session
2+
3+
__all__ = ["session"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from ..native import libnnfw_api_pybind
2+
from ..common.basesession import BaseSession
3+
4+
5+
class session(BaseSession):
6+
"""
7+
Class for inference using nnfw_session.
8+
"""
9+
def __init__(self, nnpackage_path, backends="cpu"):
10+
"""
11+
Initialize the inference session.
12+
Args:
13+
nnpackage_path (str): Path to the nnpackage file or directory.
14+
backends (str): Backends to use, default is "cpu".
15+
"""
16+
super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
17+
self.session.prepare()
18+
self.set_outputs(self.session.output_size())
19+
20+
def inference(self):
21+
"""
22+
Perform model and get outputs
23+
Returns:
24+
list: Outputs from the model.
25+
"""
26+
self.session.run()
27+
return self.outputs

0 commit comments

Comments
 (0)