Skip to content

Commit

Permalink
Fix Or pattern behavior (#27721)
Browse files Browse the repository at this point in the history
### Details:
- Fixed Or pattern
- Added new unit tests for Or and Optional patterns

Or pattern have to point to the node from the selected branch.
It also affects Optional pattern behavior as it uses Or pattern inside.



### Tickets:
 - CVS-157939
  • Loading branch information
itikhono authored Nov 26, 2024
1 parent 611796c commit 90f64e0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/core/src/pattern/op/or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ bool ov::pass::pattern::op::Or::match_value(Matcher* matcher,
auto saved = matcher->start_match();
if (matcher->match_value(input_value, graph_value)) {
auto& pattern_map = matcher->get_pattern_value_map();
pattern_map[input_value.get_node_shared_ptr()] = graph_value;
pattern_map[shared_from_this()] = graph_value;
return saved.finish(true);
}
}
Expand Down
57 changes: 57 additions & 0 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,63 @@ TEST(pattern, optional_match_node_with_single_input) {
}
}

TEST(pattern, or_pattern_points_the_selected_branch) {
using namespace ov::op;
using namespace ov::pass::pattern;

// Graph:
auto model_param = make_shared<v0::Parameter>();
auto model_sigmoid = make_shared<v0::Sigmoid>(model_param);

// Pattern:
auto option_1 = wrap_type<v0::Parameter>();
auto option_2 = wrap_type<v0::Sigmoid>();
auto or_pattern = std::make_shared<pattern::op::Or>(ov::OutputVector{option_1, option_2});

// Test:
TestMatcher matcher;
EXPECT_TRUE(matcher.match(or_pattern, model_sigmoid));

auto pattern_val_mp = matcher.get_pattern_value_map();
EXPECT_EQ(pattern_val_mp.count(or_pattern), 1);

// we expect that Or pattern points to the first node of the selected branch
EXPECT_NE(ov::as_type<v0::Sigmoid>(pattern_val_mp.at(or_pattern).get_node()), nullptr);
}

TEST(pattern, multiple_optionals_in_row) {
using namespace ov::op;
using namespace ov::pass::pattern;

// Graph:
Shape shape{1, 2, 3};
auto model_input_0 = make_shared<v0::Parameter>(element::f32, shape);
auto model_sigmoid = make_shared<v0::Sigmoid>(model_input_0);

// Pattern:
auto in = wrap_type<v0::Parameter>();
auto pattern_convert = optional<v0::Convert>(in);
auto pattern_relu = optional<v0::Relu>(pattern_convert);
auto pattern_sigmoid = wrap_type<v0::Sigmoid>({pattern_relu});

// Test:
TestMatcher matcher;
EXPECT_TRUE(matcher.match(pattern_sigmoid, model_sigmoid));

auto pattern_val_mp = matcher.get_pattern_value_map();

EXPECT_EQ(pattern_val_mp.count(in), 1);
EXPECT_NE(ov::as_type<v0::Parameter>(pattern_val_mp.at(in).get_node()), nullptr);

// as Convert and Relu ops are not present in the graph, so we expect the optional nodes
// do not point to the graph nodes, in other words, the optional nodes are not in the pattern map.
EXPECT_EQ(pattern_val_mp.count(pattern_convert), 0);
EXPECT_EQ(pattern_val_mp.count(pattern_relu), 0);

EXPECT_EQ(pattern_val_mp.count(pattern_sigmoid), 1);
EXPECT_NE(ov::as_type<v0::Sigmoid>(pattern_val_mp.at(pattern_sigmoid).get_node()), nullptr);
}

// match optional nodes with multi input where order in not important
TEST(pattern, optional_match_cumulative_node_with_multi_input) {
Shape shape{1, 2, 3};
Expand Down

0 comments on commit 90f64e0

Please sign in to comment.