Skip to content

Commit

Permalink
[onert/python] Allow calling compilation by infer session
Browse files Browse the repository at this point in the history
This commit allows calling compilation(prepare) by infer session.
  - Introduce recreating internal session(nnfw_session) into BaseSession
  - Introduce compile function that can recreate internal session
  - Modify __getattr__ of BaseSession strictly

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani committed Jan 7, 2025
1 parent d978911 commit cbc046f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
21 changes: 19 additions & 2 deletions runtime/onert/api/python/package/common/basesession.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BaseSession:
"""
Base class providing common functionality for inference and training sessions.
"""
def __init__(self, backend_session):
def __init__(self, backend_session=None):
"""
Initialize the BaseSession with a backend session.
Args:
Expand All @@ -33,7 +33,24 @@ def __getattr__(self, name):
Returns:
The attribute or method from the bound NNFW_SESSION instance.
"""
return getattr(self.session, name)
if name in self.__dict__:
# First, try to get the attribute from the instance's own dictionary
return self.__dict__[name]
elif hasattr(self.session, name):
# If not found, delegate to the session object
return getattr(self.session, name)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'")

def _recreate_session(self, backend_session):
"""
Protected method to recreate the session.
Subclasses can override this method to provide custom session recreation logic.
"""
if self.session is not None:
del self.session # Clean up the existing session
self.session = backend_session

def set_inputs(self, size, inputs_array=[]):
"""
Expand Down
25 changes: 23 additions & 2 deletions runtime/onert/api/python/package/infer/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,35 @@ class session(BaseSession):
"""
Class for inference using nnfw_session.
"""
def __init__(self, nnpackage_path, backends="cpu"):
def __init__(self, nnpackage_path: str = None, backends: str = "cpu"):
"""
Initialize the inference session.
Args:
nnpackage_path (str): Path to the nnpackage file or directory.
backends (str): Backends to use, default is "cpu".
"""
super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
if nnpackage_path is not None:
super().__init__(
libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
self.session.prepare()
self.set_outputs(self.session.output_size())
else:
super().__init__()

def compile(self, nnpackage_path: str, backends: str = "cpu"):
"""
Prepare the session by recreating it with new parameters.
Args:
nnpackage_path (str): Path to the nnpackage file or directory. Defaults to the existing path.
backends (str): Backends to use. Defaults to the existing backends.
"""
# Update parameters if provided
if nnpackage_path is None:
raise ValueError("nnpackage_path must not be None.")
# Recreate the session with updated parameters
self._recreate_session(
libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
# Prepare the new session
self.session.prepare()
self.set_outputs(self.session.output_size())

Expand Down

0 comments on commit cbc046f

Please sign in to comment.