Skip to content

Commit

Permalink
[PyOV] Allow single inputs in form of lists of simple types (#21734)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Iwaszkiewicz authored Dec 18, 2023
1 parent 0e496fa commit c59498b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 13 deletions.
12 changes: 8 additions & 4 deletions src/bindings/python/src/openvino/runtime/ie_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def infer(
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
* scalar inputs (i.e. `np.float_`/`str`/`bytes`/`int`/`float`)
* lists of simple data types (i.e. `str`/`bytes`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Expand Down Expand Up @@ -192,7 +193,8 @@ def start_async(
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
* scalar inputs (i.e. `np.float_`/`str`/`bytes`/`int`/`float`)
* lists of simple data types (i.e. `str`/`bytes`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Expand Down Expand Up @@ -346,7 +348,8 @@ def __call__(
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
* scalar inputs (i.e. `np.float_`/`str`/`bytes`/`int`/`float`)
* lists of simple data types (i.e. `str`/`bytes`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Expand Down Expand Up @@ -464,7 +467,8 @@ def start_async(
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
* scalar inputs (i.e. `np.float_`/`str`/`bytes`/`int`/`float`)
* lists of simple data types (i.e. `str`/`bytes`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
ValidKeys = Union[str, int, ConstOutput]


def is_list_simple_type(input_list: list) -> bool:
for sublist in input_list:
if isinstance(sublist, list):
for element in sublist:
if not isinstance(element, (str, float, int, bytes)):
return False
else:
if not isinstance(sublist, (str, float, int, bytes)):
return False
return True


def get_request_tensor(
request: _InferRequestWrapper,
key: Optional[ValidKeys] = None,
Expand Down Expand Up @@ -198,17 +210,28 @@ def create_shared(


@create_shared.register(dict)
@create_shared.register(list)
@create_shared.register(tuple)
@create_shared.register(OVDict)
def _(
inputs: ContainerTypes,
inputs: Union[dict, tuple, OVDict],
request: _InferRequestWrapper,
) -> dict:
request._inputs_data = normalize_arrays(inputs, is_shared=True)
return {k: value_to_tensor(v, request=request, is_shared=True, key=k) for k, v in request._inputs_data.items()}


# Special override to perform list-related dispatch
@create_shared.register(list)
def _(
inputs: list,
request: _InferRequestWrapper,
) -> dict:
# If list is passed to single input model and consists only of simple types
# i.e. str/bytes/float/int, wrap around it and pass into the dispatcher.
request._inputs_data = normalize_arrays([inputs] if request._is_single_input() and is_list_simple_type(inputs) else inputs, is_shared=True)
return {k: value_to_tensor(v, request=request, is_shared=True, key=k) for k, v in request._inputs_data.items()}


@create_shared.register(np.ndarray)
def _(
inputs: np.ndarray,
Expand Down Expand Up @@ -348,16 +371,26 @@ def create_copied(


@create_copied.register(dict)
@create_copied.register(list)
@create_copied.register(tuple)
@create_copied.register(OVDict)
def _(
inputs: ContainerTypes,
inputs: Union[dict, tuple, OVDict],
request: _InferRequestWrapper,
) -> dict:
return update_inputs(normalize_arrays(inputs, is_shared=False), request)


# Special override to perform list-related dispatch
@create_copied.register(list)
def _(
inputs: list,
request: _InferRequestWrapper,
) -> dict:
# If list is passed to single input model and consists only of simple types
# i.e. str/bytes/float/int, wrap around it and pass into the dispatcher.
return update_inputs(normalize_arrays([inputs] if request._is_single_input() and is_list_simple_type(inputs) else inputs, is_shared=False), request)


@create_copied.register(np.ndarray)
def _(
inputs: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(self, other: InferRequestBase) -> None:
self._inputs_data = None
super().__init__(other)

def _is_single_input(self) -> bool:
return len(self.input_tensors) == 1


class OVDict(Mapping):
"""Custom OpenVINO dictionary with inference results.
Expand Down
10 changes: 5 additions & 5 deletions src/bindings/python/tests/test_utils/test_data_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ def test_string_array_dispatcher(device, input_data, data_type, input_shape, is_
@pytest.mark.parametrize(
("input_data", "input_shape"),
[
([["śćżóąę", "data_dispatcher_test"]], [2]),
([[b"abcdef", b"data_dispatcher_test"]], [2]),
([[bytes("abc", encoding="utf-8"), bytes("zzzz", encoding="utf-8")]], [2]),
([[["śćżóąę", "data_dispatcher_test"]]], [1, 2]),
([[["śćżóąę"], ["data_dispatcher_test"]]], [2, 1]),
(["śćżóąę", "data_dispatcher_test"], [2]),
([b"abcdef", b"data_dispatcher_test"], [2]),
([bytes("abc", encoding="utf-8"), bytes("zzzz", encoding="utf-8")], [2]),
([["śćżóąę", "data_dispatcher_test"]], [1, 2]),
([["śćżóąę"], ["data_dispatcher_test"]], [2, 1]),
],
)
@pytest.mark.parametrize("data_type", [Type.string, str, bytes, np.str_, np.bytes_])
Expand Down

0 comments on commit c59498b

Please sign in to comment.