From 5112c2f561aea16890e4491954372c6d95242c62 Mon Sep 17 00:00:00 2001 From: mangguo321 Date: Wed, 15 Jan 2025 02:11:28 +0100 Subject: [PATCH 1/4] Add scalar case if "then" and "else" subgraph outpt are scalar. --- src/core/src/op/if.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index d8ac09096e714f..58e5ebe2cd8b41 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -33,15 +33,21 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const return ov::PartialShape::dynamic(); } if (then_rank.get_length() != else_rank.get_length()) { + auto isScalar = [](const ov::PartialShape& pshape) { + return ((pshape.rank() == 0) || (pshape.rank() == 1 && pshape.get_shape()[0] == 1)); + }; // Union of scalar and 1D case if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) { + if (isScalar(then_pshape) && isScalar(else_pshape)) { + return ov::PartialShape{1}; + } return ov::PartialShape::dynamic(1); } else { return ov::PartialShape::dynamic(); } } - ov::PartialShape new_dims; + ov::PartialShape new_dims; // If ranges are equal each dimension of then_body output is union with each dimension of // else_body for (auto then_it = then_pshape.cbegin(), else_it = else_pshape.cbegin(); then_it != then_pshape.cend(); From ccd38ba172531685993d037b30ece7367b318764 Mon Sep 17 00:00:00 2001 From: mangguo321 Date: Wed, 15 Jan 2025 15:29:41 +0100 Subject: [PATCH 2/4] Fix review comments --- src/core/src/op/if.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index 58e5ebe2cd8b41..17a1a991e1f803 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -33,15 +33,13 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const return ov::PartialShape::dynamic(); } if (then_rank.get_length() != else_rank.get_length()) { - auto isScalar = [](const ov::PartialShape& pshape) { - return ((pshape.rank() == 0) || (pshape.rank() == 1 && pshape.get_shape()[0] == 1)); + auto is_one_element = [](const ov::PartialShape& pshape) { + return pshape.size() == 0 || pshape[0].get_max_length() == 1; }; // Union of scalar and 1D case if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) { - if (isScalar(then_pshape) && isScalar(else_pshape)) { - return ov::PartialShape{1}; - } - return ov::PartialShape::dynamic(1); + return (is_one_element(then_pshape) && is_one_element(else_pshape)) ? + ov::PartialShape{1} : ov::PartialShape::dynamic(1); } else { return ov::PartialShape::dynamic(); } From 8ae608d3fe0b7b51d1ec1d46631e71d4ca5d859e Mon Sep 17 00:00:00 2001 From: mangguo321 Date: Wed, 15 Jan 2025 15:42:25 +0100 Subject: [PATCH 3/4] Add test case --- src/core/tests/type_prop/if.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index d893b34472ad2b..12b27f0ad00965 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -340,6 +340,36 @@ TEST(type_prop, if_scalar_and_1d_static_union) { EXPECT_EQ(sh, out_shape); } +TEST(type_prop, if_output_one_element) { + // That which we iterate over + auto X = make_shared(element::f32, Shape{}); + auto Y = make_shared(element::f32, Shape{1}); + auto cond = make_shared(element::boolean, Shape{}); + + // Body parameters + auto Xt = make_shared(element::f32, PartialShape::dynamic()); + auto Ye = make_shared(element::f32, PartialShape::dynamic()); + // Body + auto then_op = std::make_shared(Xt, Xt); + auto then_body_res = make_shared(then_op); + auto then_body = make_shared(OutputVector{then_body_res}, ParameterVector{Xt}); + + auto else_op = std::make_shared(Ye, Ye); + auto else_body_res = make_shared(else_op); + auto else_body = make_shared(OutputVector{else_body_res}, ParameterVector{Ye}); + + auto if_op = make_shared(cond); + if_op->set_then_body(then_body); + if_op->set_else_body(else_body); + if_op->set_input(X, Xt, nullptr); + if_op->set_input(Y, nullptr, Ye); + auto res = if_op->set_output(then_body_res, else_body_res); + auto result0 = make_shared(res); + PartialShape out_shape{PartialShape{1}}; + auto sh = result0->get_output_partial_shape(0); + EXPECT_EQ(sh, out_shape); +} + TEST(type_prop, if_element_type_dynamic) { // That which we iterate over auto X = make_shared(element::f16, Shape{32, 40, 10}); From afcc6fee8c6a5962a59a665d3dc20efda5696a8b Mon Sep 17 00:00:00 2001 From: mangguo321 Date: Thu, 16 Jan 2025 03:32:08 +0100 Subject: [PATCH 4/4] Apply review comments --- src/core/src/op/if.cpp | 6 +++--- src/core/tests/type_prop/if.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index 17a1a991e1f803..ba90cb4431e191 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -34,12 +34,12 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const } if (then_rank.get_length() != else_rank.get_length()) { auto is_one_element = [](const ov::PartialShape& pshape) { - return pshape.size() == 0 || pshape[0].get_max_length() == 1; + return pshape.size() == 0 || (pshape.is_static() && pshape[0].get_length() == 1); }; // Union of scalar and 1D case if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) { - return (is_one_element(then_pshape) && is_one_element(else_pshape)) ? - ov::PartialShape{1} : ov::PartialShape::dynamic(1); + return (is_one_element(then_pshape) && is_one_element(else_pshape)) ? ov::PartialShape{1} + : ov::PartialShape::dynamic(1); } else { return ov::PartialShape::dynamic(); } diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index 12b27f0ad00965..0b97bd05cfa5bc 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -365,7 +365,7 @@ TEST(type_prop, if_output_one_element) { if_op->set_input(Y, nullptr, Ye); auto res = if_op->set_output(then_body_res, else_body_res); auto result0 = make_shared(res); - PartialShape out_shape{PartialShape{1}}; + PartialShape out_shape{1}; auto sh = result0->get_output_partial_shape(0); EXPECT_EQ(sh, out_shape); }