Skip to content

Commit

Permalink
[TRANSFORMATIONS] Add ConvertLike to MVN decomposition
Browse files Browse the repository at this point in the history
It may happen that a MultiSubGraph body may contain an MVN node that is
usually decomposed into simpler operations.

If during ConvertPrecision transformation the inputs of the MultiSubGraph body are
converted to a different precision than the original model and
MultiSubGraph had (as it's done in the GPU pipeline f32 -> f16), the
desired precision for all the nodes appeared during decomposition will be
propagated through. However this will not apply to Constants resulting into
the error of mixed-precision for binary element-wise operations.

Fix it by adding a ConvertLike for the Constant that will be folded
and posses the correct desired precision.

- Ticket:
* CVS-158631

Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake committed Dec 12, 2024
1 parent 45bf77b commit 4d4d400
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/power.hpp"
Expand Down Expand Up @@ -48,7 +49,8 @@ ov::pass::MVN6Decomposition::MVN6Decomposition() {
} else {
// (x - ReduceMean(x, axes)) ^ 2
auto sqr_const = ov::op::v0::Constant::create(data.get_element_type(), ov::Shape{1}, {2});
auto sqr = std::make_shared<ov::op::v1::Power>(mean_normalization, sqr_const);
auto sqr_const_conv_like = std::make_shared<ov::op::v1::ConvertLike>(sqr_const, mean_normalization);
auto sqr = std::make_shared<ov::op::v1::Power>(mean_normalization, sqr_const_conv_like);
// ReduceMean((x - ReduceMean(x, axes)) ^ 2)
auto mean2 = std::make_shared<ov::op::v1::ReduceMean>(sqr, axes, true);

Expand All @@ -59,26 +61,31 @@ ov::pass::MVN6Decomposition::MVN6Decomposition() {
std::shared_ptr<ov::op::v1::Add> eps_add;
std::shared_ptr<ov::op::v0::Sqrt> sqrt;
std::shared_ptr<ov::op::v1::Divide> div;
std::shared_ptr<ov::op::v1::ConvertLike> eps_node_conv_like;

if (eps_mode == op::MVNEpsMode::INSIDE_SQRT) {
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
eps_add = std::make_shared<ov::op::v1::Add>(mean2, eps_node);
eps_node_conv_like = std::make_shared<ov::op::v1::ConvertLike>(eps_node, mean2);
eps_add = std::make_shared<ov::op::v1::Add>(mean2, eps_node_conv_like);
sqrt = std::make_shared<ov::op::v0::Sqrt>(eps_add);
// (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
div = std::make_shared<ov::op::v1::Divide>(mean_normalization, sqrt);
} else if (eps_mode == op::MVNEpsMode::OUTSIDE_SQRT) {
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps
sqrt = std::make_shared<ov::op::v0::Sqrt>(mean2);
eps_add = std::make_shared<ov::op::v1::Add>(sqrt, eps_node);
eps_node_conv_like = std::make_shared<ov::op::v1::ConvertLike>(eps_node, mean2);
eps_add = std::make_shared<ov::op::v1::Add>(sqrt, eps_node_conv_like);
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
div = std::make_shared<ov::op::v1::Divide>(mean_normalization, eps_add);
} else {
return false;
}

div->set_friendly_name(mvn_node->get_friendly_name());
ov::copy_runtime_info(mvn_node, {mean, mean_normalization, sqr, mean2, eps_node, eps_add, sqrt, div});
ov::copy_runtime_info(mvn_node, {mean, mean_normalization, sqr, mean2, eps_node, eps_add, sqrt, div, sqr_const_conv_like, eps_node_conv_like});
ov::replace_node(mvn_node, div);

std::cout << "MVN6Decomposition: Power node: " << sqr->get_name() << std::endl;
}
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ TEST_F(TransformationTestsF, MVN6Decomposition_Inside_Sqrt) {
auto mean = std::make_shared<opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<opset6::Subtract>(input0, mean);

auto sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2});
auto sqr = std::make_shared<opset6::Power>(mean_normalization, sqr_const);
std::shared_ptr<ov::Node> sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2});
auto sqr_const_conv_like = std::make_shared<opset6::ConvertLike>(sqr_const, mean_normalization);
auto sqr = std::make_shared<opset6::Power>(mean_normalization, sqr_const_conv_like);
auto mean2 = std::make_shared<opset6::ReduceMean>(sqr, axes_const, true);

auto eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5});
std::shared_ptr<ov::Node> eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5});
auto eps_node_conv_like = std::make_shared<opset6::ConvertLike>(eps_node, mean2);

auto eps_add = std::make_shared<opset6::Add>(mean2, eps_node);
auto eps_add = std::make_shared<opset6::Add>(mean2, eps_node_conv_like);
auto sqrt = std::make_shared<opset6::Sqrt>(eps_add);
auto div = std::make_shared<opset6::Divide>(mean_normalization, sqrt);

Expand All @@ -87,14 +89,16 @@ TEST_F(TransformationTestsF, MVN6Decomposition_Outside_Sqrt) {
auto mean = std::make_shared<opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<opset6::Subtract>(input0, mean);

auto sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2});
auto sqr = std::make_shared<opset6::Power>(mean_normalization, sqr_const);
std::shared_ptr<ov::Node> sqr_const = opset6::Constant::create(element::f32, Shape{1}, {2});
auto sqr_const_conv_like = std::make_shared<opset6::ConvertLike>(sqr_const, mean_normalization);
auto sqr = std::make_shared<opset6::Power>(mean_normalization, sqr_const_conv_like);
auto mean2 = std::make_shared<opset6::ReduceMean>(sqr, axes_const, true);

auto eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5});
std::shared_ptr<ov::Node> eps_node = opset6::Constant::create(element::f32, Shape{1}, {1e-5});
auto eps_node_conv_like = std::make_shared<ov::op::v1::ConvertLike>(eps_node, mean2);

auto sqrt = std::make_shared<opset6::Sqrt>(mean2);
auto eps_add = std::make_shared<opset6::Add>(sqrt, eps_node);
auto eps_add = std::make_shared<opset6::Add>(sqrt, eps_node_conv_like);
auto div = std::make_shared<opset6::Divide>(mean_normalization, eps_add);

model_ref = std::make_shared<ov::Model>(NodeVector{div}, ParameterVector{input0});
Expand Down

0 comments on commit 4d4d400

Please sign in to comment.