From 742cf254bb303402783b38f2e84b902fa36dfd75 Mon Sep 17 00:00:00 2001 From: Mang Guo Date: Thu, 16 Jan 2025 13:59:18 +0800 Subject: [PATCH] Add scalar case if "then" and "else" subgraph output are scalar. (#28444) ### Details: - *if "then" and "else" subgraph output shape are scalar, the if node output shape should be static and a scalar* - *...* ### Tickets: - *[28235](https://github.com/openvinotoolkit/openvino/issues/28235)* --- src/core/src/op/if.cpp | 8 ++++++-- src/core/tests/type_prop/if.cpp | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index 5b43210774c311..7e3a0515e77ad5 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -33,15 +33,19 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const return ov::PartialShape::dynamic(); } if (then_rank.get_length() != else_rank.get_length()) { + auto is_one_element = [](const ov::PartialShape& pshape) { + return pshape.size() == 0 || (pshape.is_static() && pshape[0].get_length() == 1); + }; // Union of scalar and 1D case if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) { - return ov::PartialShape::dynamic(1); + return (is_one_element(then_pshape) && is_one_element(else_pshape)) ? ov::PartialShape{1} + : ov::PartialShape::dynamic(1); } else { return ov::PartialShape::dynamic(); } } - ov::PartialShape new_dims; + ov::PartialShape new_dims; // If ranges are equal each dimension of then_body output is union with each dimension of // else_body for (auto then_it = then_pshape.cbegin(), else_it = else_pshape.cbegin(); then_it != then_pshape.cend(); diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index 6a66e13a66cb39..855971bd3358e8 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -340,6 +340,36 @@ TEST(type_prop, if_scalar_and_1d_static_union) { EXPECT_EQ(sh, out_shape); } +TEST(type_prop, if_output_one_element) { + // That which we iterate over + auto X = make_shared(element::f32, Shape{}); + auto Y = make_shared(element::f32, Shape{1}); + auto cond = make_shared(element::boolean, Shape{}); + + // Body parameters + auto Xt = make_shared(element::f32, PartialShape::dynamic()); + auto Ye = make_shared(element::f32, PartialShape::dynamic()); + // Body + auto then_op = std::make_shared(Xt, Xt); + auto then_body_res = make_shared(then_op); + auto then_body = make_shared(OutputVector{then_body_res}, ParameterVector{Xt}); + + auto else_op = std::make_shared(Ye, Ye); + auto else_body_res = make_shared(else_op); + auto else_body = make_shared(OutputVector{else_body_res}, ParameterVector{Ye}); + + auto if_op = make_shared(cond); + if_op->set_then_body(then_body); + if_op->set_else_body(else_body); + if_op->set_input(X, Xt, nullptr); + if_op->set_input(Y, nullptr, Ye); + auto res = if_op->set_output(then_body_res, else_body_res); + auto result0 = make_shared(res); + PartialShape out_shape{1}; + auto sh = result0->get_output_partial_shape(0); + EXPECT_EQ(sh, out_shape); +} + TEST(type_prop, if_element_type_dynamic) { // That which we iterate over auto X = make_shared(element::f16, Shape{32, 40, 10});