Skip to content

Commit

Permalink
Migrate lite interpreter usage to LiteRT in oss python build for ai-e…
Browse files Browse the repository at this point in the history
…dge-quantizer.

PiperOrigin-RevId: 679315082
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 26, 2024
1 parent b6d8c4e commit 485fb02
Show file tree
Hide file tree
Showing 18 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import

_TFLOpName = qtyping.TFLOperationName
_QuantTransformation = qtyping.QuantTransformation
Expand Down
5 changes: 4 additions & 1 deletion ai_edge_quantizer/examples/mnist/quantize_toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ai_edge_quantizer import quantizer
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_interpreter_utils
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import

_OpExecutionMode = qtyping.OpExecutionMode
_OpName = qtyping.TFLOperationName
Expand Down Expand Up @@ -177,7 +178,9 @@ def _read_img(img_path: str) -> np.ndarray:
data = data / 255.0
return data.reshape((1, 28, 28, 1))

tflite_interpreter = tf.lite.Interpreter(model_content=quantized_tflite)
tflite_interpreter = tfl_interpreter.Interpreter(
model_content=quantized_tflite
)
tflite_interpreter.allocate_tensors()
data = _read_img(image_path)
tfl_interpreter_utils.invoke_interpreter_once(tflite_interpreter, [data])
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/model_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ai_edge_quantizer import transformation_instruction_generator
from ai_edge_quantizer import transformation_performer
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import


Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformation_instruction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Optional
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


# When a tensor has no producer, we'll assign -1 to the producer field
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformation_performer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ai_edge_quantizer.transformations import quant_insert
from ai_edge_quantizer.transformations import quantize_tensor
from ai_edge_quantizer.transformations import transformation_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


class TransformationPerformer:
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/dequant_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import quantize_tensor
from ai_edge_quantizer.transformations import transformation_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


def insert_dequant(
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/dequant_insert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ai_edge_quantizer.transformations import transformation_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/emulated_subchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import quantize_tensor
from ai_edge_quantizer.transformations import transformation_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


def emulated_subchannel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ai_edge_quantizer.transformations import transformation_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/quant_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import quantize_tensor
from ai_edge_quantizer.transformations import transformation_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


def insert_quant(
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/quant_insert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ai_edge_quantizer.transformations import transformation_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/quantize_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import transformation_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


# TODO: b/335014051 - support distinguishing INT, FLOAT & UINT, BFLOAT
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/quantize_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ai_edge_quantizer.transformations import transformation_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/transformation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np

from ai_edge_quantizer import qtyping
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ai_edge_quantizer.transformations import transformation_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from tensorflow.lite.python import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/utils/tfl_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np

from ai_edge_quantizer import qtyping
from tensorflow.lite.python import schema_py_generated # pylint:disable=g-direct-tensorflow-import
from ai_edge_litert import schema_py_generated # pylint:disable=g-direct-tensorflow-import
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import

Expand Down
16 changes: 8 additions & 8 deletions ai_edge_quantizer/utils/tfl_interpreter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from typing import Any, Optional, Union

import numpy as np
import tensorflow as tf

from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import


def create_tfl_interpreter(
tflite_model: Union[str, bytearray],
allocate_tensors: bool = True,
use_reference_kernel: bool = False,
) -> tf.lite.Interpreter:
) -> tfl.Interpreter:
"""Creates a TFLite interpreter from a model file.
Args:
Expand All @@ -47,12 +47,12 @@ def create_tfl_interpreter(
else:
tflite_model = bytes(tflite_model)
if use_reference_kernel:
op_resolver = tf.lite.experimental.OpResolverType.BUILTIN_REF
op_resolver = tfl.OpResolverType.BUILTIN_REF
else:
op_resolver = (
tf.lite.experimental.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
)
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = tfl.Interpreter(
model_content=tflite_model,
experimental_op_resolver_type=op_resolver,
experimental_preserve_all_tensors=True,
Expand All @@ -76,7 +76,7 @@ def is_tensor_quantized(tensor_detail: dict[str, Any]) -> bool:


def invoke_interpreter_signature(
tflite_interpreter: tf.lite.Interpreter,
tflite_interpreter: tfl.Interpreter,
signature_input_data: dict[str, Any],
signature_key: Optional[str] = None,
quantize_input: bool = True,
Expand Down Expand Up @@ -108,7 +108,7 @@ def invoke_interpreter_signature(


def invoke_interpreter_once(
tflite_interpreter: tf.lite.Interpreter,
tflite_interpreter: tfl.Interpreter,
input_data_list: list[Any],
quantize_input: bool = True,
):
Expand Down Expand Up @@ -295,7 +295,7 @@ def get_constant_tensor_names(


def get_signature_main_subgraph_index(
tflite_interpreter: tf.lite.Interpreter,
tflite_interpreter: tfl.Interpreter,
signature_key: Optional[str] = None,
) -> int:
"""Gets the main subgraph index of a signature.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
immutabledict
numpy
tf-nightly>=2.18.0.dev20240624
ai-edge-litert-nightly

0 comments on commit 485fb02

Please sign in to comment.