Skip to content

Commit

Permalink
Add model size reduction information to ComparisonResult.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679319590
  • Loading branch information
rewu93 authored and copybara-github committed Oct 7, 2024
1 parent a869b4c commit 5e3c285
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 54 deletions.
43 changes: 30 additions & 13 deletions ai_edge_quantizer/model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,17 @@ class ComparisonResult:
comparison_results: A dictionary of signature key and its comparison result.
"""

def __init__(self):
def __init__(self, reference_model: bytes, target_model: bytes):
"""Initialize the ComparisonResult object.
Args:
reference_model: Model which will be used as the reference.
target_model: Target model which will be compared against the reference.
We expect target_model and reference_model to have the same graph
structure.
"""
self._reference_model = reference_model
self._target_model = target_model
self._comparison_results: dict[str, SingleSignatureComparisonResult] = {}

def get_signature_comparison_result(
Expand All @@ -85,15 +95,13 @@ def available_signature_keys(self) -> list[str]:

def add_new_signature_results(
self,
source_model: Union[str, bytearray],
error_metric: str,
comparison_result: dict[str, float],
signature_key: str = _DEFAULT_SIGNATURE_KEY,
):
"""Add a new signature result to the comparison result.
Args:
source_model: The model to be validated.
error_metric: The name of the error metric used for comparison.
comparison_result: A dictionary of tensor name and its value.
signature_key: The model signature that the comparison_result belongs to.
Expand All @@ -107,20 +115,25 @@ def add_new_signature_results(
result = {key: float(value) for key, value in comparison_result.items()}

input_tensor_results = {}
for name in utils.get_input_tensor_names(source_model, signature_key):
for name in utils.get_input_tensor_names(
self._reference_model, signature_key
):
input_tensor_results[name] = result.pop(name)

output_tensor_results = {}
for name in utils.get_output_tensor_names(source_model, signature_key):
for name in utils.get_output_tensor_names(
self._reference_model, signature_key
):
output_tensor_results[name] = result.pop(name)

constant_tensor_results = {}
# Only get constant tensors from the main subgraph of the signature.
subgraph_index = utils.get_signature_main_subgraph_index(
utils.create_tfl_interpreter(source_model), signature_key
utils.create_tfl_interpreter(self._reference_model),
signature_key,
)
for name in utils.get_constant_tensor_names(
source_model,
self._reference_model,
subgraph_index,
):
constant_tensor_results[name] = result.pop(name)
Expand Down Expand Up @@ -157,7 +170,12 @@ def save(self, save_folder: str, model_name: str) -> None:
Raises:
RuntimeError: If no quantized model is available.
"""
result = {}
reduced_model_size = len(self._reference_model) - len(self._target_model)
reduction_ratio = reduced_model_size / len(self._reference_model) * 100
result = {
'reduced_size_bytes': reduced_model_size,
'reduced_size_percentage': reduction_ratio,
}
for signature, comparison_result in self._comparison_results.items():
result[str(signature)] = {
'error_metric': comparison_result.error_metric,
Expand Down Expand Up @@ -186,7 +204,7 @@ def save(self, save_folder: str, model_name: str) -> None:


def _setup_validation_interpreter(
model: Union[str, bytearray],
model: bytes,
signature_input: dict[str, Any],
signature_key: Optional[str],
use_reference_kernel: bool,
Expand Down Expand Up @@ -224,8 +242,8 @@ def _setup_validation_interpreter(

# TODO: b/330797129 - Enable multi-threaded evaluation.
def compare_model(
reference_model: Union[str, bytes],
target_model: Union[str, bytes],
reference_model: bytes,
target_model: bytes,
test_data: dict[str, Iterable[dict[str, Any]]],
error_metric: str,
compare_fn: Callable[[Any, Any], float],
Expand Down Expand Up @@ -254,7 +272,7 @@ def compare_model(
Returns:
A ComparisonResult object.
"""
model_comparion_result = ComparisonResult()
model_comparion_result = ComparisonResult(reference_model, target_model)
for signature_key, signature_inputs in test_data.items():
comparison_results = {}
for signature_input in signature_inputs:
Expand Down Expand Up @@ -296,7 +314,6 @@ def compare_model(
for tensor_name in comparison_results:
agregated_results[tensor_name] = np.mean(comparison_results[tensor_name])
model_comparion_result.add_new_signature_results(
reference_model,
error_metric,
agregated_results,
signature_key,
Expand Down
98 changes: 64 additions & 34 deletions ai_edge_quantizer/model_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow.python.platform import googletest
from ai_edge_quantizer import model_validator
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_quantizer.utils import validation_utils

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
Expand All @@ -34,6 +35,16 @@ def setUp(self):
self.test_model_path = os.path.join(
TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
)
self.test_model = tfl_flatbuffer_utils.get_model_buffer(
self.test_model_path
)
self.test_quantized_model_path = os.path.join(
TEST_DATA_PREFIX_PATH,
'tests/models/two_signatures_a8w8.tflite',
)
self.test_quantized_model = tfl_flatbuffer_utils.get_model_buffer(
self.test_quantized_model_path
)
self.test_data = {
'add': {'add_x:0': 1e-3, 'Add/y': 0.25, 'PartitionedCall:0': 1e-3},
'multiply': {
Expand All @@ -43,22 +54,23 @@ def setUp(self):
},
}
self.test_dir = self.create_tempdir()
self.comparison_result = model_validator.ComparisonResult(
self.test_model, self.test_quantized_model
)

def test_add_new_signature_results_succeeds(self):
comparison_result = model_validator.ComparisonResult()
for signature_key, test_result in self.test_data.items():
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
test_result,
signature_key,
)
self.assertLen(
comparison_result.available_signature_keys(), len(self.test_data)
self.comparison_result.available_signature_keys(), len(self.test_data)
)

for signature_key in self.test_data:
signature_result = comparison_result.get_signature_comparison_result(
signature_result = self.comparison_result.get_signature_comparison_result(
signature_key
)
input_tensors = signature_result.input_tensors
Expand All @@ -72,9 +84,7 @@ def test_add_new_signature_results_succeeds(self):
self.assertEmpty(intermediate_tensors)

def test_add_new_signature_results_fails_same_signature_key(self):
comparison_result = model_validator.ComparisonResult()
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
self.test_data['add'],
'add',
Expand All @@ -83,8 +93,7 @@ def test_add_new_signature_results_fails_same_signature_key(self):
with self.assertRaisesWithPredicateMatch(
ValueError, lambda err: error_message in str(err)
):
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
self.test_data['add'],
'add',
Expand All @@ -93,9 +102,7 @@ def test_add_new_signature_results_fails_same_signature_key(self):
def test_get_signature_comparison_result_fails_with_invalid_signature_key(
self,
):
comparison_result = model_validator.ComparisonResult()
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
self.test_data['add'],
'add',
Expand All @@ -104,18 +111,16 @@ def test_get_signature_comparison_result_fails_with_invalid_signature_key(
with self.assertRaisesWithPredicateMatch(
ValueError, lambda err: error_message in str(err)
):
comparison_result.get_signature_comparison_result('multiply')
self.comparison_result.get_signature_comparison_result('multiply')

def test_get_all_tensor_results_succeeds(self):
comparison_result = model_validator.ComparisonResult()
for signature_key, test_result in self.test_data.items():
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
test_result,
signature_key,
)
all_tensor_results = comparison_result.get_all_tensor_results()
all_tensor_results = self.comparison_result.get_all_tensor_results()
self.assertLen(all_tensor_results, 6)
self.assertIn('add_x:0', all_tensor_results)
self.assertIn('Add/y', all_tensor_results)
Expand All @@ -125,21 +130,34 @@ def test_get_all_tensor_results_succeeds(self):
self.assertIn('PartitionedCall_1:0', all_tensor_results)

def test_save_comparison_result_succeeds(self):
comparison_result = model_validator.ComparisonResult()
for signature_key, test_result in self.test_data.items():
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
test_result,
signature_key,
)
model_name = 'test_model'
comparison_result.save(self.test_dir.full_path, model_name)
self.comparison_result.save(self.test_dir.full_path, model_name)
test_json_path = os.path.join(
self.test_dir.full_path, model_name + '_comparison_result.json'
)
with open(test_json_path) as json_file:
json_dict = json.load(json_file)

# Check model size stats.
self.assertIn('reduced_size_bytes', json_dict)
self.assertEqual(
json_dict['reduced_size_bytes'],
len(self.test_model) - len(self.test_quantized_model),
)
self.assertIn('reduced_size_percentage', json_dict)
self.assertEqual(
json_dict['reduced_size_percentage'],
(len(self.test_model) - len(self.test_quantized_model))
/ len(self.test_model)
* 100,
)

for signature_key in self.test_data:
self.assertIn(signature_key, json_dict)
signature_result = json_dict[signature_key]
Expand Down Expand Up @@ -169,6 +187,12 @@ def setUp(self):
TEST_DATA_PREFIX_PATH,
'tests/models/single_fc_bias_sub_channel_weight_only_sym_weight.tflite',
)
self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
self.reference_model_path
)
self.target_model = tfl_flatbuffer_utils.get_model_buffer(
self.target_model_path
)
self.signature_key = 'serving_default' # single signature.
self.test_data = test_utils.create_random_normal_input_data(
self.reference_model_path
Expand All @@ -178,8 +202,8 @@ def setUp(self):
def test_model_validator_compare(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -205,8 +229,8 @@ def test_model_validator_compare(self):
def test_create_json_for_model_explorer(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -224,8 +248,8 @@ def test_create_json_for_model_explorer(self):
def test_create_json_for_model_explorer_no_thresholds(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -249,6 +273,12 @@ def setUp(self):
TEST_DATA_PREFIX_PATH,
'tests/models/two_signatures_a8w8.tflite',
)
self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
self.reference_model_path
)
self.target_model = tfl_flatbuffer_utils.get_model_buffer(
self.target_model_path
)
self.test_data = {
'add': [{'x': np.array([2.0]).astype(np.float32)}],
'multiply': [{'x': np.array([1.0]).astype(np.float32)}],
Expand All @@ -258,8 +288,8 @@ def setUp(self):
def test_model_validator_compare_succeeds(self):
error_metric = 'mean_squared_difference'
result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand Down Expand Up @@ -288,8 +318,8 @@ def test_model_validator_compare_succeeds(self):
def test_create_json_for_model_explorer(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -308,8 +338,8 @@ def test_create_json_for_model_explorer(self):
def test_create_json_for_model_explorer_no_thresholds(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand Down
Loading

0 comments on commit 5e3c285

Please sign in to comment.