Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model size reduction information to ComparisonResult. #150

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions ai_edge_quantizer/model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,17 @@ class ComparisonResult:
comparison_results: A dictionary of signature key and its comparison result.
"""

def __init__(self):
def __init__(self, reference_model: bytes, target_model: bytes):
"""Initialize the ComparisonResult object.

Args:
reference_model: Model which will be used as the reference.
target_model: Target model which will be compared against the reference.
We expect target_model and reference_model to have the same graph
structure.
"""
self._reference_model = reference_model
self._target_model = target_model
self._comparison_results: dict[str, SingleSignatureComparisonResult] = {}

def get_signature_comparison_result(
Expand All @@ -85,15 +95,13 @@ def available_signature_keys(self) -> list[str]:

def add_new_signature_results(
self,
source_model: Union[str, bytearray],
error_metric: str,
comparison_result: dict[str, float],
signature_key: str = _DEFAULT_SIGNATURE_KEY,
):
"""Add a new signature result to the comparison result.

Args:
source_model: The model to be validated.
error_metric: The name of the error metric used for comparison.
comparison_result: A dictionary of tensor name and its value.
signature_key: The model signature that the comparison_result belongs to.
Expand All @@ -107,20 +115,25 @@ def add_new_signature_results(
result = {key: float(value) for key, value in comparison_result.items()}

input_tensor_results = {}
for name in utils.get_input_tensor_names(source_model, signature_key):
for name in utils.get_input_tensor_names(
self._reference_model, signature_key
):
input_tensor_results[name] = result.pop(name)

output_tensor_results = {}
for name in utils.get_output_tensor_names(source_model, signature_key):
for name in utils.get_output_tensor_names(
self._reference_model, signature_key
):
output_tensor_results[name] = result.pop(name)

constant_tensor_results = {}
# Only get constant tensors from the main subgraph of the signature.
subgraph_index = utils.get_signature_main_subgraph_index(
utils.create_tfl_interpreter(source_model), signature_key
utils.create_tfl_interpreter(self._reference_model),
signature_key,
)
for name in utils.get_constant_tensor_names(
source_model,
self._reference_model,
subgraph_index,
):
constant_tensor_results[name] = result.pop(name)
Expand Down Expand Up @@ -157,7 +170,12 @@ def save(self, save_folder: str, model_name: str) -> None:
Raises:
RuntimeError: If no quantized model is available.
"""
result = {}
reduced_model_size = len(self._reference_model) - len(self._target_model)
reduction_ratio = reduced_model_size / len(self._reference_model) * 100
result = {
'reduced_size_bytes': reduced_model_size,
'reduced_size_percentage': reduction_ratio,
}
for signature, comparison_result in self._comparison_results.items():
result[str(signature)] = {
'error_metric': comparison_result.error_metric,
Expand Down Expand Up @@ -186,7 +204,7 @@ def save(self, save_folder: str, model_name: str) -> None:


def _setup_validation_interpreter(
model: Union[str, bytearray],
model: bytes,
signature_input: dict[str, Any],
signature_key: Optional[str],
use_reference_kernel: bool,
Expand Down Expand Up @@ -224,8 +242,8 @@ def _setup_validation_interpreter(

# TODO: b/330797129 - Enable multi-threaded evaluation.
def compare_model(
reference_model: Union[str, bytes],
target_model: Union[str, bytes],
reference_model: bytes,
target_model: bytes,
test_data: dict[str, Iterable[dict[str, Any]]],
error_metric: str,
compare_fn: Callable[[Any, Any], float],
Expand Down Expand Up @@ -254,7 +272,7 @@ def compare_model(
Returns:
A ComparisonResult object.
"""
model_comparion_result = ComparisonResult()
model_comparion_result = ComparisonResult(reference_model, target_model)
for signature_key, signature_inputs in test_data.items():
comparison_results = {}
for signature_input in signature_inputs:
Expand Down Expand Up @@ -296,7 +314,6 @@ def compare_model(
for tensor_name in comparison_results:
agregated_results[tensor_name] = np.mean(comparison_results[tensor_name])
model_comparion_result.add_new_signature_results(
reference_model,
error_metric,
agregated_results,
signature_key,
Expand Down
98 changes: 64 additions & 34 deletions ai_edge_quantizer/model_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow.python.platform import googletest
from ai_edge_quantizer import model_validator
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_quantizer.utils import validation_utils

TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
Expand All @@ -34,6 +35,16 @@ def setUp(self):
self.test_model_path = os.path.join(
TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
)
self.test_model = tfl_flatbuffer_utils.get_model_buffer(
self.test_model_path
)
self.test_quantized_model_path = os.path.join(
TEST_DATA_PREFIX_PATH,
'tests/models/two_signatures_a8w8.tflite',
)
self.test_quantized_model = tfl_flatbuffer_utils.get_model_buffer(
self.test_quantized_model_path
)
self.test_data = {
'add': {'add_x:0': 1e-3, 'Add/y': 0.25, 'PartitionedCall:0': 1e-3},
'multiply': {
Expand All @@ -43,22 +54,23 @@ def setUp(self):
},
}
self.test_dir = self.create_tempdir()
self.comparison_result = model_validator.ComparisonResult(
self.test_model, self.test_quantized_model
)

def test_add_new_signature_results_succeeds(self):
comparison_result = model_validator.ComparisonResult()
for signature_key, test_result in self.test_data.items():
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
test_result,
signature_key,
)
self.assertLen(
comparison_result.available_signature_keys(), len(self.test_data)
self.comparison_result.available_signature_keys(), len(self.test_data)
)

for signature_key in self.test_data:
signature_result = comparison_result.get_signature_comparison_result(
signature_result = self.comparison_result.get_signature_comparison_result(
signature_key
)
input_tensors = signature_result.input_tensors
Expand All @@ -72,9 +84,7 @@ def test_add_new_signature_results_succeeds(self):
self.assertEmpty(intermediate_tensors)

def test_add_new_signature_results_fails_same_signature_key(self):
comparison_result = model_validator.ComparisonResult()
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
self.test_data['add'],
'add',
Expand All @@ -83,8 +93,7 @@ def test_add_new_signature_results_fails_same_signature_key(self):
with self.assertRaisesWithPredicateMatch(
ValueError, lambda err: error_message in str(err)
):
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
self.test_data['add'],
'add',
Expand All @@ -93,9 +102,7 @@ def test_add_new_signature_results_fails_same_signature_key(self):
def test_get_signature_comparison_result_fails_with_invalid_signature_key(
self,
):
comparison_result = model_validator.ComparisonResult()
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
self.test_data['add'],
'add',
Expand All @@ -104,18 +111,16 @@ def test_get_signature_comparison_result_fails_with_invalid_signature_key(
with self.assertRaisesWithPredicateMatch(
ValueError, lambda err: error_message in str(err)
):
comparison_result.get_signature_comparison_result('multiply')
self.comparison_result.get_signature_comparison_result('multiply')

def test_get_all_tensor_results_succeeds(self):
comparison_result = model_validator.ComparisonResult()
for signature_key, test_result in self.test_data.items():
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
test_result,
signature_key,
)
all_tensor_results = comparison_result.get_all_tensor_results()
all_tensor_results = self.comparison_result.get_all_tensor_results()
self.assertLen(all_tensor_results, 6)
self.assertIn('add_x:0', all_tensor_results)
self.assertIn('Add/y', all_tensor_results)
Expand All @@ -125,21 +130,34 @@ def test_get_all_tensor_results_succeeds(self):
self.assertIn('PartitionedCall_1:0', all_tensor_results)

def test_save_comparison_result_succeeds(self):
comparison_result = model_validator.ComparisonResult()
for signature_key, test_result in self.test_data.items():
comparison_result.add_new_signature_results(
self.test_model_path,
self.comparison_result.add_new_signature_results(
'mean_squared_difference',
test_result,
signature_key,
)
model_name = 'test_model'
comparison_result.save(self.test_dir.full_path, model_name)
self.comparison_result.save(self.test_dir.full_path, model_name)
test_json_path = os.path.join(
self.test_dir.full_path, model_name + '_comparison_result.json'
)
with open(test_json_path) as json_file:
json_dict = json.load(json_file)

# Check model size stats.
self.assertIn('reduced_size_bytes', json_dict)
self.assertEqual(
json_dict['reduced_size_bytes'],
len(self.test_model) - len(self.test_quantized_model),
)
self.assertIn('reduced_size_percentage', json_dict)
self.assertEqual(
json_dict['reduced_size_percentage'],
(len(self.test_model) - len(self.test_quantized_model))
/ len(self.test_model)
* 100,
)

for signature_key in self.test_data:
self.assertIn(signature_key, json_dict)
signature_result = json_dict[signature_key]
Expand Down Expand Up @@ -169,6 +187,12 @@ def setUp(self):
TEST_DATA_PREFIX_PATH,
'tests/models/single_fc_bias_sub_channel_weight_only_sym_weight.tflite',
)
self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
self.reference_model_path
)
self.target_model = tfl_flatbuffer_utils.get_model_buffer(
self.target_model_path
)
self.signature_key = 'serving_default' # single signature.
self.test_data = test_utils.create_random_normal_input_data(
self.reference_model_path
Expand All @@ -178,8 +202,8 @@ def setUp(self):
def test_model_validator_compare(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -205,8 +229,8 @@ def test_model_validator_compare(self):
def test_create_json_for_model_explorer(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -224,8 +248,8 @@ def test_create_json_for_model_explorer(self):
def test_create_json_for_model_explorer_no_thresholds(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -249,6 +273,12 @@ def setUp(self):
TEST_DATA_PREFIX_PATH,
'tests/models/two_signatures_a8w8.tflite',
)
self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
self.reference_model_path
)
self.target_model = tfl_flatbuffer_utils.get_model_buffer(
self.target_model_path
)
self.test_data = {
'add': [{'x': np.array([2.0]).astype(np.float32)}],
'multiply': [{'x': np.array([1.0]).astype(np.float32)}],
Expand All @@ -258,8 +288,8 @@ def setUp(self):
def test_model_validator_compare_succeeds(self):
error_metric = 'mean_squared_difference'
result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand Down Expand Up @@ -288,8 +318,8 @@ def test_model_validator_compare_succeeds(self):
def test_create_json_for_model_explorer(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand All @@ -308,8 +338,8 @@ def test_create_json_for_model_explorer(self):
def test_create_json_for_model_explorer_no_thresholds(self):
error_metric = 'mean_squared_difference'
comparison_result = model_validator.compare_model(
self.reference_model_path,
self.target_model_path,
self.reference_model,
self.target_model,
self.test_data,
error_metric,
validation_utils.mean_squared_difference,
Expand Down
Loading
Loading