Skip to content

Commit fc31f55

Browse files
committed
[GPU] Moved RMSFusion higher in pipeline and added output type fuse
1 parent 4cd1512 commit fc31f55

File tree

5 files changed

+55
-16
lines changed

5 files changed

+55
-16
lines changed

src/common/transformations/include/ov_ops/rms.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
4343
m_epsilon = epsilon;
4444
}
4545

46+
void set_rms_output_type(const element::Type& output_type) {
47+
m_output_type = output_type;
48+
}
49+
4650
private:
4751
double m_epsilon{0};
4852
ov::element::Type m_output_type;

src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "openvino/op/reduce_mean.hpp"
1414
#include "openvino/op/sqrt.hpp"
1515
#include "openvino/pass/manager.hpp"
16+
#include "openvino/pass/pattern/op/or.hpp"
1617
#include "openvino/pass/pattern/op/wrap_type.hpp"
1718
#include "ov_ops/rms.hpp"
1819
#include "transformations/utils/utils.hpp"
@@ -57,11 +58,15 @@ RMSFusion::RMSFusion(bool force_tail_convert) {
5758
auto sqrt = wrap_type<ov::op::v0::Sqrt>({add_eps});
5859

5960
// 1/Sqrt(ReduceMean(x^2,axes)+eps)
60-
auto const_div = wrap_type<ov::op::v0::Constant>(constant_value(-1));
61-
auto div = wrap_type<ov::op::v1::Power>({sqrt, const_div});
61+
auto const_pow = wrap_type<ov::op::v0::Constant>(constant_value(-1));
62+
auto pow = wrap_type<ov::op::v1::Power>({sqrt, const_pow});
63+
64+
auto const_div = wrap_type<ov::op::v0::Constant>(constant_value(1));
65+
auto div = wrap_type<ov::op::v1::Divide>({const_div, sqrt});
66+
auto div_or_pow = std::make_shared<pattern::op::Or>(OutputVector{div, pow});
6267

6368
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps)
64-
auto mul1 = wrap_type<ov::op::v1::Multiply>({x, div});
69+
auto mul1 = wrap_type<ov::op::v1::Multiply>({x, div_or_pow});
6570

6671
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
6772
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));

src/common/transformations/src/transformations/convert_precision.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "openvino/pass/constant_folding.hpp"
1313
#include "openvino/pass/manager.hpp"
1414
#include "openvino/reference/convert.hpp"
15+
#include "ov_ops/rms.hpp"
1516
#include "ov_ops/type_relaxed.hpp"
1617
#include "transformations/fp16_compression/align_mixed_fp32_fp16_types.hpp"
1718
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
@@ -59,6 +60,7 @@ bool fuse_type_to_maxpool(const std::shared_ptr<ov::Node>& node, const precision
5960
bool fuse_type_to_nonzero(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
6061
bool fuse_type_to_bucketize(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
6162
bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
63+
bool fuse_type_to_rms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
6264

6365
bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
6466

@@ -465,7 +467,8 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
465467
{ov::op::v0::PriorBox::get_type_info_static(), fuse_type_to_prior_box<ov::op::v0::PriorBox>},
466468
{ov::op::v8::PriorBox::get_type_info_static(), fuse_type_to_prior_box<ov::op::v8::PriorBox>},
467469
{ov::op::v0::PriorBoxClustered::get_type_info_static(), fuse_type_to_prior_box<ov::op::v0::PriorBoxClustered>},
468-
{ov::op::v15::SearchSorted::get_type_info_static(), fuse_type_to_search_sorted_v15}};
470+
{ov::op::v15::SearchSorted::get_type_info_static(), fuse_type_to_search_sorted_v15},
471+
{ov::op::internal::RMS::get_type_info_static(), fuse_type_to_rms}};
469472

470473
for (const auto& it : m_additional_type_to_fuse_map) {
471474
type_to_fuse[it.first] = it.second;
@@ -858,6 +861,20 @@ bool fuse_type_to_nms_rotated(const std::shared_ptr<ov::Node>& node, const preci
858861
return res;
859862
}
860863

864+
bool fuse_type_to_rms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
865+
auto it = precisions.find(node->get_output_element_type(0));
866+
if (it == precisions.end())
867+
return false;
868+
const auto& to = it->second;
869+
if (auto rms = ov::as_type_ptr<ov::op::internal::RMS>(node)) {
870+
if (to.is_real()) {
871+
rms->set_rms_output_type(to);
872+
return true;
873+
}
874+
}
875+
return false;
876+
}
877+
861878
namespace {
862879

863880
bool update_type(size_t idx,

src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ TEST_F(TransformationTestsF, RMSNormFusionTest2) {
7777
model = std::make_shared<ov::Model>(ov::NodeVector{comp}, ov::ParameterVector{input});
7878
manager.register_pass<RMSFusion>();
7979
}
80+
{
81+
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{1, 2, 6});
82+
83+
auto rms_const = ov::opset10::Constant::create(ov::element::f32,
84+
ov::Shape{6},
85+
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
86+
auto rms = std::make_shared<ov::op::internal::RMS>(input, rms_const, 1e-5f, ov::element::f16);
87+
88+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
89+
}
90+
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
91+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
92+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
8093
}
8194

8295
TEST_F(TransformationTestsF, RMSNormFusionTest3) {
@@ -113,7 +126,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest4) {
113126
auto eps = ov::opset10::Constant::create(ov::element::f32, {}, {1e-5f});
114127
auto add_eps = std::make_shared<ov::opset10::Add>(mean, eps);
115128
auto sqrt = std::make_shared<ov::opset10::Sqrt>(add_eps);
116-
auto div_const = ov::opset10::Constant::create(ov::element::f32, {}, {1});
129+
auto div_const = ov::opset10::Constant::create(ov::element::f32, {}, {-1});
117130
auto div = std::make_shared<ov::opset10::Divide>(div_const, sqrt);
118131
auto mul1 = std::make_shared<ov::opset10::Multiply>(input, div);
119132
auto gamma = ov::opset10::Constant::create(ov::element::f32,

src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
#include "transformations/op_conversions/convert_broadcast3.hpp"
114114
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
115115
#include "transformations/op_conversions/convert_depth_to_space.hpp"
116+
#include "transformations/op_conversions/convert_divide.hpp"
116117
#include "transformations/op_conversions/convert_gather_0d.hpp"
117118
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
118119
#include "transformations/op_conversions/convert_gelu.hpp"
@@ -377,6 +378,16 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
377378
return !is_decompression_multiply(node, device_info.supports_immad);
378379
});
379380

381+
pass_config->set_callback<ov::pass::RMSFusion>([=](const_node_ptr& root) -> bool {
382+
if (!root->get_input_partial_shape(0).is_static()) {
383+
return false;
384+
}
385+
const auto& gamma_shape = root->get_input_partial_shape(0).to_shape();
386+
const int32_t vec_size = 8;
387+
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
388+
});
389+
manager.register_pass<ov::pass::RMSFusion>(false);
390+
380391
const bool keep_precision_sensitive_in_fp32_1 = true;
381392
const bool convert_input_output_precision = false;
382393
const bool store_original_precision_as_rt_attribute = true;
@@ -922,16 +933,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
922933

923934
manager.register_pass<ov::pass::ConvertGatherToGatherCompressed>();
924935
auto pass_config = manager.get_pass_config();
925-
pass_config->set_callback<ov::pass::RMSFusion>([=](const_node_ptr& root) -> bool {
926-
if (!root->get_input_node_ptr(0)->get_input_partial_shape(0).is_static()) {
927-
return false;
928-
}
929-
const auto& gamma_shape = root->get_input_node_ptr(0)->get_input_partial_shape(0).to_shape();
930-
const int32_t vec_size = 8;
931-
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
932-
});
933-
934-
manager.register_pass<ov::pass::RMSFusion>();
935936
manager.register_pass<ov::intel_gpu::KVCacheFusion>();
936937
manager.register_pass<ov::intel_gpu::FullyConnectedConvertFusion>();
937938
manager.register_pass<ov::intel_gpu::TransposeFusion>(device_info.supports_immad);
@@ -997,7 +998,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
997998
GPU_DEBUG_IF(cldnn::debug_configuration::get_instance()->verbose >= 1) {
998999
manager.register_pass<ov::intel_gpu::PrintModelStatistics>();
9991000
}
1000-
10011001
manager.run_passes(func);
10021002
}
10031003
}

0 commit comments

Comments
 (0)