Skip to content

Commit

Permalink
Add scalar case if "then" and "else" subgraph output are scalar. (#28444
Browse files Browse the repository at this point in the history
)

### Details:
- *if "then" and "else" subgraph output shape are scalar, the if node
output shape should be static and a scalar*
 - *...*

### Tickets:
 - *[28235](#28235
  • Loading branch information
mangguo321 authored Jan 16, 2025
1 parent 2b442b4 commit ed470e7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/core/src/op/if.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const
return ov::PartialShape::dynamic();
}
if (then_rank.get_length() != else_rank.get_length()) {
auto is_one_element = [](const ov::PartialShape& pshape) {
return pshape.size() == 0 || (pshape.is_static() && pshape[0].get_length() == 1);
};
// Union of scalar and 1D case
if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) {
return ov::PartialShape::dynamic(1);
return (is_one_element(then_pshape) && is_one_element(else_pshape)) ? ov::PartialShape{1}
: ov::PartialShape::dynamic(1);
} else {
return ov::PartialShape::dynamic();
}
}
ov::PartialShape new_dims;

ov::PartialShape new_dims;
// If ranges are equal each dimension of then_body output is union with each dimension of
// else_body
for (auto then_it = then_pshape.cbegin(), else_it = else_pshape.cbegin(); then_it != then_pshape.cend();
Expand Down
30 changes: 30 additions & 0 deletions src/core/tests/type_prop/if.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,36 @@ TEST(type_prop, if_scalar_and_1d_static_union) {
EXPECT_EQ(sh, out_shape);
}

TEST(type_prop, if_output_one_element) {
// That which we iterate over
auto X = make_shared<ov::op::v0::Parameter>(element::f32, Shape{});
auto Y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1});
auto cond = make_shared<ov::op::v0::Parameter>(element::boolean, Shape{});

// Body parameters
auto Xt = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto Ye = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
// Body
auto then_op = std::make_shared<op::v1::Add>(Xt, Xt);
auto then_body_res = make_shared<ov::op::v0::Result>(then_op);
auto then_body = make_shared<ov::Model>(OutputVector{then_body_res}, ParameterVector{Xt});

auto else_op = std::make_shared<op::v1::Maximum>(Ye, Ye);
auto else_body_res = make_shared<ov::op::v0::Result>(else_op);
auto else_body = make_shared<ov::Model>(OutputVector{else_body_res}, ParameterVector{Ye});

auto if_op = make_shared<op::v8::If>(cond);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, nullptr);
if_op->set_input(Y, nullptr, Ye);
auto res = if_op->set_output(then_body_res, else_body_res);
auto result0 = make_shared<ov::op::v0::Result>(res);
PartialShape out_shape{1};
auto sh = result0->get_output_partial_shape(0);
EXPECT_EQ(sh, out_shape);
}

TEST(type_prop, if_element_type_dynamic) {
// That which we iterate over
auto X = make_shared<ov::op::v0::Parameter>(element::f16, Shape{32, 40, 10});
Expand Down

0 comments on commit ed470e7

Please sign in to comment.