diff --git a/src/common/transformations/src/transformations/op_conversions/mvn6_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/mvn6_decomposition.cpp index 9112f0602bc8ca..b1d491d5e3b7ae 100644 --- a/src/common/transformations/src/transformations/op_conversions/mvn6_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/mvn6_decomposition.cpp @@ -10,6 +10,7 @@ #include "openvino/core/rt_info.hpp" #include "openvino/op/add.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/mvn.hpp" #include "openvino/op/power.hpp" @@ -48,7 +49,8 @@ ov::pass::MVN6Decomposition::MVN6Decomposition() { } else { // (x - ReduceMean(x, axes)) ^ 2 auto sqr_const = ov::op::v0::Constant::create(data.get_element_type(), ov::Shape{1}, {2}); - auto sqr = std::make_shared(mean_normalization, sqr_const); + auto sqr_const_conv_like = std::make_shared(sqr_const, mean_normalization); + auto sqr = std::make_shared(mean_normalization, sqr_const_conv_like); // ReduceMean((x - ReduceMean(x, axes)) ^ 2) auto mean2 = std::make_shared(sqr, axes, true); @@ -59,17 +61,20 @@ ov::pass::MVN6Decomposition::MVN6Decomposition() { std::shared_ptr eps_add; std::shared_ptr sqrt; std::shared_ptr div; + std::shared_ptr eps_node_conv_like; if (eps_mode == op::MVNEpsMode::INSIDE_SQRT) { // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) - eps_add = std::make_shared(mean2, eps_node); + eps_node_conv_like = std::make_shared(eps_node, mean2); + eps_add = std::make_shared(mean2, eps_node_conv_like); sqrt = std::make_shared(eps_add); // (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) div = std::make_shared(mean_normalization, sqrt); } else if (eps_mode == op::MVNEpsMode::OUTSIDE_SQRT) { // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps sqrt = std::make_shared(mean2); - eps_add = std::make_shared(sqrt, eps_node); + eps_node_conv_like = std::make_shared(eps_node, mean2); + eps_add = std::make_shared(sqrt, eps_node_conv_like); // (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) div = std::make_shared(mean_normalization, eps_add); } else { @@ -77,8 +82,10 @@ ov::pass::MVN6Decomposition::MVN6Decomposition() { } div->set_friendly_name(mvn_node->get_friendly_name()); - ov::copy_runtime_info(mvn_node, {mean, mean_normalization, sqr, mean2, eps_node, eps_add, sqrt, div}); + ov::copy_runtime_info(mvn_node, {mean, mean_normalization, sqr, mean2, eps_node, eps_add, sqrt, div, sqr_const_conv_like, eps_node_conv_like}); ov::replace_node(mvn_node, div); + + std::cout << "MVN6Decomposition: Power node: " << sqr->get_name() << std::endl; } return true; }; diff --git a/src/common/transformations/tests/op_conversions/mvn6_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/mvn6_decomposition_test.cpp index e362b08b5780b2..7250fb75883d86 100644 --- a/src/common/transformations/tests/op_conversions/mvn6_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/mvn6_decomposition_test.cpp @@ -56,13 +56,15 @@ TEST_F(TransformationTestsF, MVN6Decomposition_Inside_Sqrt) { auto mean = std::make_shared(input0, axes_const, true); auto mean_normalization = std::make_shared(input0, mean); - auto sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2}); - auto sqr = std::make_shared(mean_normalization, sqr_const); + std::shared_ptr sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2}); + auto sqr_const_conv_like = std::make_shared(sqr_const, mean_normalization); + auto sqr = std::make_shared(mean_normalization, sqr_const_conv_like); auto mean2 = std::make_shared(sqr, axes_const, true); - auto eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5}); + std::shared_ptr eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5}); + auto eps_node_conv_like = std::make_shared(eps_node, mean2); - auto eps_add = std::make_shared(mean2, eps_node); + auto eps_add = std::make_shared(mean2, eps_node_conv_like); auto sqrt = std::make_shared(eps_add); auto div = std::make_shared(mean_normalization, sqrt); @@ -87,14 +89,16 @@ TEST_F(TransformationTestsF, MVN6Decomposition_Outside_Sqrt) { auto mean = std::make_shared(input0, axes_const, true); auto mean_normalization = std::make_shared(input0, mean); - auto sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2}); - auto sqr = std::make_shared(mean_normalization, sqr_const); + std::shared_ptr sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2}); + auto sqr_const_conv_like = std::make_shared(sqr_const, mean_normalization); + auto sqr = std::make_shared(mean_normalization, sqr_const_conv_like); auto mean2 = std::make_shared(sqr, axes_const, true); - auto eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5}); + std::shared_ptr eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5}); + auto eps_node_conv_like = std::make_shared(eps_node, mean2); auto sqrt = std::make_shared(mean2); - auto eps_add = std::make_shared(sqrt, eps_node); + auto eps_add = std::make_shared(sqrt, eps_node_conv_like); auto div = std::make_shared(mean_normalization, eps_add); model_ref = std::make_shared(NodeVector{div}, ParameterVector{input0});