-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TF FE] Support Case operation #28027
base: master
Are you sure you want to change the base?
Changes from all commits
ab5766a
ef01643
f71bcf9
a87341b
dcd5a16
ef61090
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#include "openvino/frontend/tensorflow/node_context.hpp" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
#include "openvino/frontend/tensorflow/translate_session.hpp" | ||
#include "openvino/opsets/opset11.hpp" | ||
#include "openvino/util/log.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace tensorflow { | ||
namespace op { | ||
|
||
using namespace ov::opsets; | ||
|
||
OutputVector translate_case_op(const NodeContext& node) { | ||
// Validate the operation type | ||
auto op_type = node.get_op_type(); | ||
TENSORFLOW_OP_VALIDATION(node, op_type == "Case", | ||
"Internal error: incorrect usage of translate_case_op."); | ||
Comment on lines
+14
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use |
||
|
||
// Retrieve the number of branches and inputs | ||
auto num_branches = node.get_attribute<int>("branches"); | ||
TENSORFLOW_OP_VALIDATION(node, num_branches > 0, | ||
"[TensorFlow Frontend] Case operation must have at least one branch."); | ||
|
||
// The first input is the condition for selecting the branch | ||
auto cond = node.get_input(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let us rename it to |
||
|
||
// Create a list to store sub-graphs for the branches | ||
std::vector<std::shared_ptr<Model>> branch_graphs; | ||
for (int i = 0; i < num_branches; ++i) { | ||
std::string branch_name = "branch_" + std::to_string(i); | ||
auto branch_body = node.get_attribute<std::string>(branch_name); | ||
|
||
// Ensure that the branch model is correctly loaded | ||
auto branch_model = node.get_translate_session()->get_body_ov_model(branch_body, node.get_inputs()); | ||
TENSORFLOW_OP_VALIDATION(node, branch_model, | ||
"[TensorFlow Frontend] Failed to retrieve body graph for branch: " + branch_name); | ||
branch_graphs.push_back(branch_model); | ||
} | ||
|
||
// Create the nested If operation to represent the Case operation | ||
std::shared_ptr<Model> current_model = nullptr; | ||
for (int i = num_branches - 1; i >= 0; --i) { | ||
auto if_op = std::make_shared<If>(cond); | ||
if_op->set_then_body(branch_graphs[i]); | ||
|
||
if (current_model) { | ||
if_op->set_else_body(current_model); | ||
} else { | ||
// Default empty else body | ||
auto placeholder_model = std::make_shared<Model>(OutputVector{}, ParameterVector{}); | ||
if_op->set_else_body(placeholder_model); | ||
} | ||
Comment on lines
+45
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see that you properly set input parameters for Also, you need to have different conditions for each nested |
||
|
||
current_model = if_op->get_body_model(); | ||
} | ||
|
||
// Set the outputs and names | ||
auto outputs = current_model->get_results(); | ||
OutputVector ov_outputs; | ||
for (size_t i = 0; i < outputs.size(); ++i) { | ||
auto tensor = outputs[i]->output(0).get_tensor(); | ||
tensor.set_names({node.get_name() + ":" + std::to_string(i)}); | ||
ov_outputs.push_back(outputs[i]->output(0)); | ||
} | ||
|
||
return ov_outputs; | ||
} | ||
|
||
} // namespace op | ||
} // namespace tensorflow | ||
} // namespace frontend | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef CASE_HPP | ||
#define CASE_HPP | ||
|
||
#include "openvino/frontend/tensorflow/node_context.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace tensorflow { | ||
namespace op { | ||
|
||
OutputVector translate_case_op(const ov::frontend::tensorflow::NodeContext& node); | ||
|
||
} // namespace op | ||
} // namespace tensorflow | ||
} // namespace frontend | ||
} // namespace ov | ||
|
||
#endif // CASE_HPP | ||
Comment on lines
+1
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed header file. Please check other operation translators for which we don't have header |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
#include "openvino/op/bitwise_or.hpp" | ||
#include "openvino/op/bitwise_right_shift.hpp" | ||
#include "openvino/op/bitwise_xor.hpp" | ||
#include "case.hpp" | ||
#include "openvino/op/ceiling.hpp" | ||
#include "openvino/op/cos.hpp" | ||
#include "openvino/op/cosh.hpp" | ||
|
@@ -87,6 +88,7 @@ TF_OP_CONVERTER(translate_assignvariable_op); | |
TF_OP_CONVERTER(translate_add_variable_op); | ||
TF_OP_CONVERTER(translate_sub_variable_op); | ||
TF_OP_CONVERTER(translate_block_lstm_op); | ||
TF_OP_CONVERTER(translate_case_op); | ||
TF_OP_CONVERTER(translate_enter_op); | ||
TF_OP_CONVERTER(translate_exit_op); | ||
TF_OP_CONVERTER(translate_fifo_queue_op); | ||
|
@@ -140,6 +142,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() { | |
{"Asinh", CreatorFunction(translate_unary_op<v3::Asinh>)}, | ||
{"Atan", CreatorFunction(translate_unary_op<v0::Atan>)}, | ||
{"Atanh", CreatorFunction(translate_unary_op<v3::Atanh>)}, | ||
{"Case", CreatorFunction(translate_case_op)}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add layer tests for this Operation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rkazants It would be greatly appreciated if you could guide me on the appropriate folder where I should add the layer tests for this operation. As I am new to open-source contributions, your guidance would be very helpful in ensuring I follow the correct structure and conventions for adding the test case for this operation. |
||
{"Ceil", CreatorFunction(translate_unary_op<v0::Ceiling>)}, | ||
{"Cos", CreatorFunction(translate_unary_op<v0::Cos>)}, | ||
{"Cosh", CreatorFunction(translate_unary_op<v0::Cosh>)}, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
from common.tf_layer_test_class import CommonTFLayerTest | ||
|
||
|
||
class TestCaseOp(CommonTFLayerTest): | ||
def _prepare_input(self, inputs_info): | ||
""" | ||
Prepares input data based on the given input shapes and data types. | ||
""" | ||
assert 'cond' in inputs_info | ||
assert 'input_data' in inputs_info | ||
inputs_data = { | ||
'cond': np.array(inputs_info['cond'], dtype=np.bool_), | ||
'input_data': np.array(inputs_info['input_data'], dtype=np.float32) | ||
} | ||
return inputs_data | ||
|
||
def create_case_net(self, input_shape, cond_value): | ||
""" | ||
Creates a TensorFlow model with a Case operation. | ||
|
||
Args: | ||
input_shape: Shape of the input tensor. | ||
cond_value: The condition value to select the branch. | ||
|
||
Returns: | ||
TensorFlow graph definition and None. | ||
""" | ||
tf.compat.v1.reset_default_graph() | ||
with tf.compat.v1.Session() as sess: | ||
# Inputs | ||
cond = tf.compat.v1.placeholder(dtype=tf.bool, shape=(), name="cond") | ||
input_data = tf.compat.v1.placeholder(dtype=tf.float32, shape=input_shape, name="input_data") | ||
|
||
# Define branch functions | ||
def branch_fn_1(): | ||
return tf.add(input_data, tf.constant(1.0, dtype=tf.float32)) | ||
|
||
def branch_fn_2(): | ||
return tf.multiply(input_data, tf.constant(2.0, dtype=tf.float32)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let us have 4 branches |
||
branches_fn = [branch_fn_1, branch_fn_2] | ||
|
||
# Create Case operation | ||
case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32) | ||
tf.identity(case_op, name="output") | ||
|
||
tf.compat.v1.global_variables_initializer() | ||
tf_net = sess.graph_def | ||
|
||
return tf_net, None | ||
|
||
# Test parameters | ||
test_data_basic = [ | ||
dict(input_shape=[1, 2], cond=True), | ||
dict(input_shape=[3, 3], cond=False), | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let us check for four branch index values "0, 1, 2, 3" |
||
|
||
@pytest.mark.parametrize("params", test_data_basic) | ||
@pytest.mark.precommit_tf_fe | ||
@pytest.mark.nightly | ||
def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, | ||
use_new_frontend, use_old_api): | ||
""" | ||
Executes the test for the Case operation. | ||
""" | ||
self._test(*self.create_case_net(**params), | ||
ie_device, precision, ir_version, temp_dir=temp_dir, | ||
use_new_frontend=use_new_frontend, use_old_api=use_old_api) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add copyright as for other src files.