From aad89fb9981984853a35ef5f78c2a242a98c9e05 Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Wed, 6 Mar 2024 15:40:34 +0400 Subject: [PATCH] Support dyn shapes in BatchNormDecomposition transformation (#23290) ### Details: Support dyn shapes in BatchNormDecomposition transformation ### Tickets: - *CVS-133609* --- .../batch_norm_decomposition.cpp | 41 ++++++---- .../batch_norm_decomposition_test.cpp | 75 +++++++++++++++++++ 2 files changed, 102 insertions(+), 14 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp index 2efac27b933dd2..31e53513b9c21e 100644 --- a/src/common/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp @@ -28,16 +28,16 @@ using namespace ov; ov::pass::BatchNormDecomposition::BatchNormDecomposition() { MATCHER_SCOPE(BatchNormDecomposition); - auto bn_1 = pattern::wrap_type({pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape()), + auto bn_1 = pattern::wrap_type({pattern::any_input(), + pattern::any_input(), pattern::any_input(pattern::has_static_rank()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape())}); + pattern::any_input(), + pattern::any_input()}); auto bn_5 = pattern::wrap_type({pattern::any_input(pattern::has_static_rank()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape())}); + pattern::any_input(), + pattern::any_input(), + pattern::any_input(), + pattern::any_input()}); auto bn = std::make_shared(OutputVector{bn_1, bn_5}); matcher_pass_callback callback = [this](ov::pass::pattern::Matcher& m) { @@ -83,9 +83,8 @@ ov::pass::BatchNormDecomposition::BatchNormDecomposition() { std::make_shared(gamma_div_scale, new_shape, true); std::shared_ptr beta_aligned = std::make_shared(m_beta, new_shape, true); std::shared_ptr mean_aligned = std::make_shared(m_mean, new_shape, true); - std::shared_ptr mean_negative = std::make_shared( - mean_aligned, - ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1})); + auto mul_const = ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1}); + std::shared_ptr mean_negative = std::make_shared(mean_aligned, mul_const); if (auto constant = ov::util::get_constant_from_source(beta_aligned)) beta_aligned = constant; @@ -103,9 +102,23 @@ ov::pass::BatchNormDecomposition::BatchNormDecomposition() { add->set_friendly_name(m_bn->get_friendly_name()); - copy_runtime_info( - m_bn, - {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned, beta_aligned, input_sub_mean, mul, add}); + copy_runtime_info(m_bn, + {scale_add, + scale, + gamma_div_scale, + gamma_div_scale_aligned, + beta_aligned, + input_sub_mean, + mul, + add, + mean_negative, + mean_aligned, + new_shape, + tail_shape, + tail_shape_rank, + one, + mul_const, + C_dim}); replace_node(m_bn, add); diff --git a/src/common/transformations/tests/op_conversions/batch_norm_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/batch_norm_decomposition_test.cpp index 595975082e6f60..51fc081353e158 100644 --- a/src/common/transformations/tests/op_conversions/batch_norm_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/batch_norm_decomposition_test.cpp @@ -18,6 +18,45 @@ using namespace ov; using namespace testing; +std::shared_ptr get_ref_model_with_dyn_shapes(ov::element::Type precision, const PartialShape& input_shape) { + auto input = std::make_shared(precision, input_shape); + auto gamma = std::make_shared(precision, PartialShape{-1}); + auto beta = std::make_shared(precision, PartialShape{-1}); + auto mean = std::make_shared(precision, PartialShape{-1}); + auto var = std::make_shared(precision, PartialShape{-1}); + // scale_add = variance + eps + auto scale_add = std::make_shared(var, ov::op::v0::Constant::create(precision, Shape{}, {0.001})); + // scale = sqrt(variance + eps) + auto scale = std::make_shared(scale_add); + // Divide `gamma` by `sqrt(variance + eps)` + auto gamma_div_scale = std::make_shared(gamma, scale); + + int64_t dims_to_add = input->get_partial_shape().rank().get_length() - 2; + const auto one = ov::op::v0::Constant::create(element::i64, Shape{1}, {1}); + const auto tail_shape_rank = ov::op::v0::Constant::create(element::i64, Shape{1}, {dims_to_add}); + const auto tail_shape = std::make_shared(one, tail_shape_rank); + const auto C_dim = std::make_shared(gamma); + // create new shape [1, C, 1, 1, ...] + const auto new_shape = std::make_shared(OutputVector{one, C_dim, tail_shape}, 0); + + std::shared_ptr gamma_div_scale_aligned = + std::make_shared(gamma_div_scale, new_shape, true); + std::shared_ptr beta_aligned = std::make_shared(beta, new_shape, true); + std::shared_ptr mean_aligned = std::make_shared(mean, new_shape, true); + std::shared_ptr mean_negative = std::make_shared( + mean_aligned, + ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1})); + + // input_sub_mean = input + mean * -1 + auto input_sub_mean = std::make_shared(input, mean_negative); + // Multiply `input - mean` and `gamma / sqrt(variance + eps)` + auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned); + // Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta` + auto add = std::make_shared(mul, beta_aligned); + + return std::make_shared(NodeVector{add}, ParameterVector{input, gamma, beta, mean, var}); +} + TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset1) { const PartialShape input_shape{-1, -1, -1, -1}; const auto precision = element::f32; @@ -74,6 +113,42 @@ TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset5) { } } +TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset1) { + const PartialShape input_shape{-1, -1, -1, -1}; + const auto precision = element::f32; + { + auto input = std::make_shared(precision, input_shape); + auto gamma = std::make_shared(precision, PartialShape{-1}); + auto beta = std::make_shared(precision, PartialShape{-1}); + auto mean = std::make_shared(precision, PartialShape{-1}); + auto var = std::make_shared(precision, PartialShape{-1}); + auto batch_norm = std::make_shared(input, gamma, beta, mean, var, 0.001); + + model = std::make_shared(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var}); + manager.register_pass(); + comparator.enable(FunctionsComparator::CONST_VALUES); + } + { model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); } +} + +TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset5) { + const PartialShape input_shape{-1, -1, -1, -1}; + const auto precision = element::f32; + { + auto input = std::make_shared(precision, input_shape); + auto gamma = std::make_shared(precision, PartialShape{-1}); + auto beta = std::make_shared(precision, PartialShape{-1}); + auto mean = std::make_shared(precision, PartialShape{-1}); + auto var = std::make_shared(precision, PartialShape{-1}); + auto batch_norm = std::make_shared(input, gamma, beta, mean, var, 0.001); + + model = std::make_shared(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var}); + manager.register_pass(); + comparator.enable(FunctionsComparator::CONST_VALUES); + } + { model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); } +} + TEST_F(TransformationTestsF, BatchNormDecompositionDynamicRank) { { auto input = std::make_shared(element::f32, PartialShape::dynamic());