Skip to content

Commit

Permalink
[POT] Update for the Results quantization & tests (openvinotoolkit#8324)
Browse files Browse the repository at this point in the history
* Update rule for multi-results quantization

(cherry picked from commit d94fe7d)

* Update tests with lstm cases

(cherry picked from commit 16fe385)

* Update lstm reference

* Fix pylint issue
  • Loading branch information
nikita-malininn authored Nov 8, 2021
1 parent 763859d commit 8cb7824
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,12 @@ def _custom_broadcast(arrays_list):
def create_renamed_layers_mapping(model, stats_layout):
changed_names_map = {}
for layer_name in stats_layout:
node = get_node_by_name(model, layer_name)
node_name = layer_name
port_id = None
if isinstance(layer_name, tuple):
node_name, port_id = layer_name
node = get_node_by_name(model, node_name)
if node is not None and 'orig_node_name' in node:
changed_names_map[node.name] = node['orig_node_name']
name_change_to = node['orig_node_name'] if port_id is None else (node['orig_node_name'], port_id)
changed_names_map[layer_name] = name_change_to
return changed_names_map
14 changes: 11 additions & 3 deletions tools/pot/openvino/tools/pot/graph/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,20 @@ def rename_fqs_in_the_end(self, graph: Graph):
def change_names(_, match):
fq_node = match['fq']
input_node = get_node_input(fq_node, 0)
new_fq_name = copy(input_node.name)
if 'orig_node_name' in input_node:
new_fq_name = copy(input_node['orig_node_name'])

input_node_outputs = get_all_node_outputs(input_node)
if all([op.type == 'FakeQuantize' for op in input_node_outputs]):
new_fq_name += '.{}'.format(fq_node.in_port(0).get_source().idx)

fq_node['orig_fq_name'] = copy(fq_node.name)
fq_node.name = copy(input_node.name)
fq_node.name = copy(new_fq_name)

input_node['orig_node_name'] = copy(input_node.name)
input_node.name = '{original_name}/pre_fq_input'.format(original_name=input_node.name)
if 'orig_node_name' not in input_node:
input_node['orig_node_name'] = copy(input_node.name)
input_node.name = '{original_name}/pre_fq_input'.format(original_name=input_node.name)

pattern = get_fq_result_pattern()
apply_pattern(
Expand Down
3 changes: 3 additions & 0 deletions tools/pot/tests/data/models/lstm_example/lstm_example.json
Git LFS file not shown
3 changes: 3 additions & 0 deletions tools/pot/tests/data/models/lstm_example/lstm_example.onnx
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
56 changes: 26 additions & 30 deletions tools/pot/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
GNA_CONFIG_PATH = HARDWARE_CONFIG_PATH / 'gna.json'

TEST_MODELS = [
('mobilenetv2_example', 'pytorch'),
('resnet_example', 'pytorch'),
('googlenet_example', 'pytorch'),
('mobilenetv2_ssd_example', 'pytorch'),
('densenet121_example', 'pytorch'),
('multiple_out_ports_net', 'tf'), # multiple output ports in node case check,
# ('rm_nnet4a', 'kaldi')
('mobilenetv2_example', 'pytorch', 'ANY'),
('resnet_example', 'pytorch', 'ANY'),
('googlenet_example', 'pytorch', 'ANY'),
('mobilenetv2_ssd_example', 'pytorch', 'ANY'),
('densenet121_example', 'pytorch', 'ANY'),
('multiple_out_ports_net', 'tf', 'ANY'),
('lstm_example', 'pytorch', 'GNA'),
('multiple_outputs_net_example', 'dldt', 'GNA')
]

CASCADE_MAP = Dict({
Expand All @@ -37,15 +38,16 @@


@pytest.mark.parametrize(
'model_name, model_framework', TEST_MODELS,
'model_name, model_framework, target_device', TEST_MODELS,
ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MODELS])
def test_build_quantization_graph(tmp_path, models, model_name, model_framework):
def test_build_quantization_graph(tmp_path, models, model_name, model_framework, target_device):
model = models.get(model_name, model_framework, tmp_path)
model = load_model(model.model_params)
model = load_model(model.model_params, target_device=target_device)

hardware_config = HardwareConfig.from_json(CPU_CONFIG_PATH.as_posix())
if model_framework == 'kaldi':
if target_device == 'GNA':
hardware_config = HardwareConfig.from_json(GNA_CONFIG_PATH.as_posix())
else:
hardware_config = HardwareConfig.from_json(CPU_CONFIG_PATH.as_posix())

quantization_model = GraphTransformer(hardware_config).insert_fake_quantize(model)

Expand Down Expand Up @@ -246,27 +248,21 @@ def test_multibranch_propagation_without_fq_moving():


MODELS_WITH_LSTM = [
# ('rm_lstm4f', 'kaldi', {
# 'prev_memory_output69':
# ['next_lstm_output108', 'lstmprojectedstreams/Shape', 'input_fullyconnected/WithoutBiases'],
# 'prev_memory_state82':
# ['state_filtered_tahn100', 'clamp_scaleshift101/Mul_', 'next_lstm_state98'],
# 'prev_memory_output':
# ['next_lstm_output', 'affinetransform/WithoutBiases'],
# 'prev_memory_state':
# ['state_filtered_tahn', 'clamp_scaleshift/Mul_', 'next_lstm_state']
# })
('lstm_example', 'pytorch', {
'LSTM_15/TensorIterator/22/variable_1':
['Assign_298'],
'LSTM_15/TensorIterator/24/variable_2':
['Assign_305'],
'LSTM_19/TensorIterator/22/variable_1':
['Assign_327'],
'LSTM_19/TensorIterator/24/variable_2':
['Assign_334']
})
]


@pytest.fixture(scope='module', params=MODELS_WITH_LSTM,
ids=['{}_{}'.format(m[0], m[1]) for m in MODELS_WITH_LSTM])
def _params(request):
return request.param


def test_lstm_ends(_params, tmp_path, models):
model_name, model_framework, lstm_ends_ref = _params
def test_lstm_ends(tmp_path, models):
model_name, model_framework, lstm_ends_ref = MODELS_WITH_LSTM[0]
model = models.get(model_name, model_framework, tmp_path)
model = load_model(model.model_params)
read_values = get_nodes_by_type(model, ['ReadValue'])
Expand Down

0 comments on commit 8cb7824

Please sign in to comment.