Skip to content

Commit

Permalink
Implement human-maintainable policy for operators
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679313946
  • Loading branch information
v-dziuba authored and copybara-github committed Sep 27, 2024
1 parent b6d8c4e commit 585588d
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 10,326 deletions.
5 changes: 5 additions & 0 deletions ai_edge_quantizer/algorithm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
)


# Expose instance itself.
def get_algorithm_manager() -> algorithm_manager_api.AlgorithmManagerApi:
return _alg_manager_instance


# Quantization algorithms.
class AlgorithmName(str, enum.Enum):
NO_QUANTIZE = "no_quantize"
Expand Down
31 changes: 1 addition & 30 deletions ai_edge_quantizer/algorithm_manager_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,11 @@

"""The Python API for Algorithm Manager of Quantizer."""

import collections
from collections.abc import Callable
import dataclasses
import functools
import json
from typing import Any, Optional
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.utils import test_utils
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import


_DEFAULT_CONFIG_CHECK_POLICY_PATH = test_utils.get_path_to_datafile(
"policies/default_config_check_policy.json"
)


@dataclasses.dataclass
Expand Down Expand Up @@ -268,33 +259,13 @@ def _get_target_func(
return quantized_op_info[tfl_op_name].materialize_func
return None

def _load_config_check_policy(
self,
) -> qtyping.ConfigCheckPolicyDict:
"""Loads the config check policy for all algorithms."""

with gfile.Open(_DEFAULT_CONFIG_CHECK_POLICY_PATH) as json_file:
policy_content = json.load(json_file)

# Convert the config check policy content to a dict of OpQuantizationConfig.
policy = collections.OrderedDict()
for op_name, op_configs in policy_content.items():
policy[qtyping.TFLOperationName(op_name)] = [
qtyping.OpQuantizationConfig.from_dict(op_config)
for op_config in op_configs
]

return policy

# TODO: b/53780772 - Merge this function with
# register_op_quant_config_validation_func after full transition to new check
# mechanism.
def register_config_check_policy(
self,
algorithm_key: str,
config_check_policy: Optional[qtyping.ConfigCheckPolicyDict] = None,
config_check_policy: Optional[qtyping.ConfigCheckPolicyDict],
):
"""Registers a policy to check the op quantization config."""
if config_check_policy is None:
config_check_policy = self._load_config_check_policy()
self._config_check_policy_registry[algorithm_key] = config_check_policy
13 changes: 12 additions & 1 deletion ai_edge_quantizer/algorithm_manager_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,18 @@ def test_get_init_qsv_func(self):
def test_register_config_check_policy_succeeds(self):
self.assertEmpty(self._alg_manager._config_check_policy_registry)
test_algorithm_name = "test_algorithm"
self._alg_manager.register_config_check_policy(test_algorithm_name)
test_config_check_policy = qtyping.ConfigCheckPolicyDict({
_TFLOpName.FULLY_CONNECTED: {
qtyping.OpQuantizationConfig(
weight_tensor_config=qtyping.TensorQuantizationConfig(
num_bits=1
)
)
}
})
self._alg_manager.register_config_check_policy(
test_algorithm_name, test_config_check_policy
)
self.assertIn(
test_algorithm_name, self._alg_manager._config_check_policy_registry
)
Expand Down
Loading

0 comments on commit 585588d

Please sign in to comment.