Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 6, 2024
1 parent 75f1abb commit a44e199
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
assert param.annotation == tp.Tensor

assert signature.return_annotation == Sequence[tp.Tensor]
assert signature.return_annotation == tp.Tensor

def test_signature_multiple_return_values(self, multiple_return_executable):
signature = inspect.signature(multiple_return_executable)
Expand Down
10 changes: 2 additions & 8 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, executable, arg_names):
for name in self._arg_names:
params.append(inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Tensor))

return_annotation = Tensor if self._executable_signature.get_num_output_args() == 1 else Sequence[Tensor]
return_annotation = Tensor if self._executable_signature.get_num_results() == 1 else Sequence[Tensor]

self.__signature__ = inspect.Signature(params, return_annotation=return_annotation)

Expand Down Expand Up @@ -227,15 +227,9 @@ def add(a, b):
print(compiled_add.get_output_info())
"""
num_input_args = self._executable_signature.get_num_input_args()
num_output_args = self._executable_signature.get_num_output_args()
num_results = self._executable_signature.get_num_results()

assert not (num_output_args and num_results), "Cannot have both output arguments and results"

if num_output_args:
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
else:
return [self._get_result_info(idx) for idx in range(num_results)]
return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down

0 comments on commit a44e199

Please sign in to comment.