Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Support Case operation #28027

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/frontends/tensorflow/src/op/case.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "openvino/frontend/tensorflow/node_context.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add copyright as for other src files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "openvino/frontend/tensorflow/translate_session.hpp"
#include "openvino/opsets/opset11.hpp"
#include "openvino/util/log.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

using namespace ov::opsets;

OutputVector translate_case_op(const NodeContext& node) {
// Validate the operation type
auto op_type = node.get_op_type();
TENSORFLOW_OP_VALIDATION(node, op_type == "Case",
"Internal error: incorrect usage of translate_case_op.");
Comment on lines +14 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use default_op_checks instead


// Retrieve the number of branches and inputs
auto num_branches = node.get_attribute<int>("branches");
TENSORFLOW_OP_VALIDATION(node, num_branches > 0,
"[TensorFlow Frontend] Case operation must have at least one branch.");

// The first input is the condition for selecting the branch
auto cond = node.get_input(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us rename it to branch_index


// Create a list to store sub-graphs for the branches
std::vector<std::shared_ptr<Model>> branch_graphs;
for (int i = 0; i < num_branches; ++i) {
std::string branch_name = "branch_" + std::to_string(i);
auto branch_body = node.get_attribute<std::string>(branch_name);

// Ensure that the branch model is correctly loaded
auto branch_model = node.get_translate_session()->get_body_ov_model(branch_body, node.get_inputs());
TENSORFLOW_OP_VALIDATION(node, branch_model,
"[TensorFlow Frontend] Failed to retrieve body graph for branch: " + branch_name);
branch_graphs.push_back(branch_model);
}

// Create the nested If operation to represent the Case operation
std::shared_ptr<Model> current_model = nullptr;
for (int i = num_branches - 1; i >= 0; --i) {
auto if_op = std::make_shared<If>(cond);
if_op->set_then_body(branch_graphs[i]);

if (current_model) {
if_op->set_else_body(current_model);
} else {
// Default empty else body
auto placeholder_model = std::make_shared<Model>(OutputVector{}, ParameterVector{});
if_op->set_else_body(placeholder_model);
}
Comment on lines +45 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see that you properly set input parameters for then and else bodies. please check implementation of If translators. The same question for outputs of bodies.

Also, you need to have different conditions for each nested If. It is like a sort of branch_index == i where i is an index of branch.


current_model = if_op->get_body_model();
}

// Set the outputs and names
auto outputs = current_model->get_results();
OutputVector ov_outputs;
for (size_t i = 0; i < outputs.size(); ++i) {
auto tensor = outputs[i]->output(0).get_tensor();
tensor.set_names({node.get_name() + ":" + std::to_string(i)});
ov_outputs.push_back(outputs[i]->output(0));
}

return ov_outputs;
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
18 changes: 18 additions & 0 deletions src/frontends/tensorflow/src/op/case.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef CASE_HPP
#define CASE_HPP

#include "openvino/frontend/tensorflow/node_context.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_case_op(const ov::frontend::tensorflow::NodeContext& node);

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

#endif // CASE_HPP
Comment on lines +1 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed header file. Please check other operation translators for which we don't have header

3 changes: 3 additions & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/op/bitwise_or.hpp"
#include "openvino/op/bitwise_right_shift.hpp"
#include "openvino/op/bitwise_xor.hpp"
#include "case.hpp"
#include "openvino/op/ceiling.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/cosh.hpp"
Expand Down Expand Up @@ -87,6 +88,7 @@ TF_OP_CONVERTER(translate_assignvariable_op);
TF_OP_CONVERTER(translate_add_variable_op);
TF_OP_CONVERTER(translate_sub_variable_op);
TF_OP_CONVERTER(translate_block_lstm_op);
TF_OP_CONVERTER(translate_case_op);
TF_OP_CONVERTER(translate_enter_op);
TF_OP_CONVERTER(translate_exit_op);
TF_OP_CONVERTER(translate_fifo_queue_op);
Expand Down Expand Up @@ -140,6 +142,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Asinh", CreatorFunction(translate_unary_op<v3::Asinh>)},
{"Atan", CreatorFunction(translate_unary_op<v0::Atan>)},
{"Atanh", CreatorFunction(translate_unary_op<v3::Atanh>)},
{"Case", CreatorFunction(translate_case_op)},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add layer tests for this Operation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkazants It would be greatly appreciated if you could guide me on the appropriate folder where I should add the layer tests for this operation. As I am new to open-source contributions, your guidance would be very helpful in ensuring I follow the correct structure and conventions for adding the test case for this operation.

{"Ceil", CreatorFunction(translate_unary_op<v0::Ceiling>)},
{"Cos", CreatorFunction(translate_unary_op<v0::Cos>)},
{"Cosh", CreatorFunction(translate_unary_op<v0::Cosh>)},
Expand Down
71 changes: 71 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Case_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest


class TestCaseOp(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
"""
Prepares input data based on the given input shapes and data types.
"""
assert 'cond' in inputs_info
assert 'input_data' in inputs_info
inputs_data = {
'cond': np.array(inputs_info['cond'], dtype=np.bool_),
'input_data': np.array(inputs_info['input_data'], dtype=np.float32)
}
return inputs_data

def create_case_net(self, input_shape, cond_value):
"""
Creates a TensorFlow model with a Case operation.

Args:
input_shape: Shape of the input tensor.
cond_value: The condition value to select the branch.

Returns:
TensorFlow graph definition and None.
"""
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
# Inputs
cond = tf.compat.v1.placeholder(dtype=tf.bool, shape=(), name="cond")
input_data = tf.compat.v1.placeholder(dtype=tf.float32, shape=input_shape, name="input_data")

# Define branch functions
def branch_fn_1():
return tf.add(input_data, tf.constant(1.0, dtype=tf.float32))

def branch_fn_2():
return tf.multiply(input_data, tf.constant(2.0, dtype=tf.float32))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us have 4 branches

branches_fn = [branch_fn_1, branch_fn_2]

# Create Case operation
case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32)
tf.identity(case_op, name="output")

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

# Test parameters
test_data_basic = [
dict(input_shape=[1, 2], cond=True),
dict(input_shape=[3, 3], cond=False),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us check for four branch index values "0, 1, 2, 3"


@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_case_op(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
"""
Executes the test for the Case operation.
"""
self._test(*self.create_case_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
Loading