Skip to content

Commit acc4d36

Browse files
committed
update mha_tokenization
1 parent 4422a81 commit acc4d36

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

src/common/snippets/src/pass/mha_tokenization.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
458458
/* ================================ */
459459

460460
/* ====== Subgraph creation ======= */
461-
/*
461+
462462
ov::OutputVector body_inputs, subgraph_inputs;
463463
ov::ParameterVector body_parameters;
464464
ov::ResultVector body_results;
@@ -471,7 +471,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
471471
const auto constant = ov::as_type_ptr<ov::op::v0::Constant>(parent);
472472
if (constant && (ov::shape_size(input.get_shape()) == 1 ||
473473
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
474-
constant_input_should_be_inside_body(node))) {
474+
ov::snippets::utils::constant_input_should_be_inside_body(node))) {
475475
// If Constant has one consumer - target node, we add Constant to body_inputs
476476
// If Constant has several consumers, we should check that all these consumers are inside Subgraph body
477477
// and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we should
@@ -525,7 +525,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
525525
OPENVINO_THROW("body results and node results size mismatch during subgraph collapse");
526526
}
527527

528-
auto body = op::create_body(last_node->get_friendly_name(), body_results, body_parameters);
528+
auto body = ov::snippets::utils::create_body(last_node->get_friendly_name(), body_results, body_parameters);
529529
auto subgraph = std::make_shared<op::Subgraph>(subgraph_inputs, body);
530530
// Copy runtime info from last node to subgraph - to copy topological order
531531
copy_runtime_info(last_node, subgraph);
@@ -536,7 +536,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
536536
target_input.replace_source_output(subgraph->output(i));
537537
}
538538
}
539-
op::update_out_tensor_name(subgraph);
539+
ov::snippets::utils::update_out_tensor_name(subgraph);
540540

541541
subgraph->validate_and_infer_types();
542542

@@ -545,9 +545,6 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
545545
act_body->get_parameters()[i]->set_friendly_name(body_parameters[i]->get_friendly_name());
546546
}
547547
subgraph->get_rt_info()["originalLayersNames"] = fused_names;
548-
*/
549-
550-
auto subgraph = utils::wrap_nodes_as_subgraph(ordered_ops);
551548
subgraph->set_virtual_port_count(hidden_virtual_ports_count);
552549

553550
// mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization

0 commit comments

Comments
 (0)