Skip to content

Commit

Permalink
Dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Nov 12, 2024
1 parent 74ff9bb commit f79c865
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
32 changes: 32 additions & 0 deletions python/tvm/runtime/executor/aot_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(self, module):
self._get_input_index = module["get_input_index"]
self._get_num_inputs = module["get_num_inputs"]
self._get_input_name = module["get_input_name"]
self._get_output_index = module["get_output_index"]
self._get_output_info = module["get_output_info"]

def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -199,3 +201,33 @@ def get_input_info(self):
dtype_dict[input_name] = input_tensor.dtype

return shape_dict, dtype_dict

def get_output_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the graph.
Returns
-------
shape_dict : Map
Shape dictionary - {output_name: tuple}.
dtype_dict : Map
dtype dictionary - {output_name: dtype}.
"""
output_info = self._get_output_info()
assert "shape" in output_info
shape_dict = output_info["shape"]
assert "dtype" in output_info
dtype_dict = output_info["dtype"]

return shape_dict, dtype_dict

def get_output_index(self, name):
"""Get outputs index via output name.
Parameters
----------
name : str
The output key name
Returns
-------
index: int
The output index. -1 will be returned if the given output name is not found.
"""
return self._get_output_index(name)
6 changes: 5 additions & 1 deletion src/runtime/crt/aot_executor_module/aot_executor_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = {
&TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper)
&TVMAotExecutorModule_NotImplemented, // share_params (do not implement)
&TVMAotExecutorModule_GetInputName, // get_input_name
&TVMAotExecutorModule_NotImplemented, // get_output_index
&TVMAotExecutorModule_NotImplemented, // get_output_info
};

static const TVMFuncRegistry aot_executor_registry = {
Expand All @@ -223,7 +225,9 @@ static const TVMFuncRegistry aot_executor_registry = {
"run\0"
"set_input\0"
"share_params\0"
"get_input_name\0",
"get_input_name\0"
"get_output_index\0"
"get_output_info\0",
aot_executor_registry_funcs};

tvm_crt_error_t TVMAotExecutorModule_Register() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ static const TVMBackendPackedCFunc graph_executor_registry_funcs[] = {
&TVMGraphExecutorModule_Run,
&TVMGraphExecutorModule_SetInput,
&TVMGraphExecutorModule_NotImplemented, // share_params
&TVMGraphExecutorModule_NotImplemented, // get_output_index
&TVMGraphExecutorModule_NotImplemented, // get_output_info
};

static const TVMFuncRegistry graph_executor_registry = {
Expand All @@ -247,7 +249,9 @@ static const TVMFuncRegistry graph_executor_registry = {
"load_params\0"
"run\0"
"set_input\0"
"share_params\0",
"share_params\0"
"get_output_index\0"
"get_output_info\0",
graph_executor_registry_funcs};

tvm_crt_error_t TVMGraphExecutorModule_Register() {
Expand Down

0 comments on commit f79c865

Please sign in to comment.