diff --git a/src/frontends/tensorflow/src/op/case.cpp b/src/frontends/tensorflow/src/op/case.cpp new file mode 100644 index 00000000000000..42d13887efc085 --- /dev/null +++ b/src/frontends/tensorflow/src/op/case.cpp @@ -0,0 +1,72 @@ +#include "openvino/frontend/tensorflow/node_context.hpp" +#include "openvino/frontend/tensorflow/translate_session.hpp" +#include "openvino/opsets/opset11.hpp" +#include "openvino/util/log.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +using namespace ov::opsets; + +OutputVector translate_case_op(const NodeContext& node) { + // Validate the operation type + auto op_type = node.get_op_type(); + TENSORFLOW_OP_VALIDATION(node, op_type == "Case", + "Internal error: incorrect usage of translate_case_op."); + + // Retrieve the number of branches and inputs + auto num_branches = node.get_attribute("branches"); + TENSORFLOW_OP_VALIDATION(node, num_branches > 0, + "[TensorFlow Frontend] Case operation must have at least one branch."); + + // The first input is the condition for selecting the branch + auto cond = node.get_input(0); + + // Create a list to store sub-graphs for the branches + std::vector> branch_graphs; + for (int i = 0; i < num_branches; ++i) { + std::string branch_name = "branch_" + std::to_string(i); + auto branch_body = node.get_attribute(branch_name); + + // Ensure that the branch model is correctly loaded + auto branch_model = node.get_translate_session()->get_body_ov_model(branch_body, node.get_inputs()); + TENSORFLOW_OP_VALIDATION(node, branch_model, + "[TensorFlow Frontend] Failed to retrieve body graph for branch: " + branch_name); + branch_graphs.push_back(branch_model); + } + + // Create the nested If operation to represent the Case operation + std::shared_ptr current_model = nullptr; + for (int i = num_branches - 1; i >= 0; --i) { + auto if_op = std::make_shared(cond); + if_op->set_then_body(branch_graphs[i]); + + if (current_model) { + if_op->set_else_body(current_model); + } else { + // Default empty else body + auto placeholder_model = std::make_shared(OutputVector{}, ParameterVector{}); + if_op->set_else_body(placeholder_model); + } + + current_model = if_op->get_body_model(); + } + + // Set the outputs and names + auto outputs = current_model->get_results(); + OutputVector ov_outputs; + for (size_t i = 0; i < outputs.size(); ++i) { + auto tensor = outputs[i]->output(0).get_tensor(); + tensor.set_names({node.get_name() + ":" + std::to_string(i)}); + ov_outputs.push_back(outputs[i]->output(0)); + } + + return ov_outputs; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op/case.hpp b/src/frontends/tensorflow/src/op/case.hpp new file mode 100644 index 00000000000000..05bb076cb84b2b --- /dev/null +++ b/src/frontends/tensorflow/src/op/case.hpp @@ -0,0 +1,18 @@ +#ifndef CASE_HPP +#define CASE_HPP + +#include "openvino/frontend/tensorflow/node_context.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_case_op(const ov::frontend::tensorflow::NodeContext& node); + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov + +#endif // CASE_HPP diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 26b665c275bb48..924b49a9c62e9b 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -19,6 +19,7 @@ #include "openvino/op/bitwise_or.hpp" #include "openvino/op/bitwise_right_shift.hpp" #include "openvino/op/bitwise_xor.hpp" +#include "case.hpp" #include "openvino/op/ceiling.hpp" #include "openvino/op/cos.hpp" #include "openvino/op/cosh.hpp" @@ -87,6 +88,7 @@ TF_OP_CONVERTER(translate_assignvariable_op); TF_OP_CONVERTER(translate_add_variable_op); TF_OP_CONVERTER(translate_sub_variable_op); TF_OP_CONVERTER(translate_block_lstm_op); +TF_OP_CONVERTER(translate_case_op); TF_OP_CONVERTER(translate_enter_op); TF_OP_CONVERTER(translate_exit_op); TF_OP_CONVERTER(translate_fifo_queue_op); @@ -140,6 +142,7 @@ const std::map get_supported_ops() { {"Asinh", CreatorFunction(translate_unary_op)}, {"Atan", CreatorFunction(translate_unary_op)}, {"Atanh", CreatorFunction(translate_unary_op)}, + {"Case", CreatorFunction(translate_case_op)}, {"Ceil", CreatorFunction(translate_unary_op)}, {"Cos", CreatorFunction(translate_unary_op)}, {"Cosh", CreatorFunction(translate_unary_op)}, diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py new file mode 100644 index 00000000000000..12a7815247a26a --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestCaseOp(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + """ + Prepares input data based on the given input shapes and data types. + """ + assert 'cond' in inputs_info + assert 'input_data' in inputs_info + inputs_data = { + 'cond': np.array(inputs_info['cond'], dtype=np.bool_), + 'input_data': np.array(inputs_info['input_data'], dtype=np.float32) + } + return inputs_data + + def create_case_net(self, input_shape, cond_value): + """ + Creates a TensorFlow model with a Case operation. + + Args: + input_shape: Shape of the input tensor. + cond_value: The condition value to select the branch. + + Returns: + TensorFlow graph definition and None. + """ + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + # Inputs + cond = tf.compat.v1.placeholder(dtype=tf.bool, shape=(), name="cond") + input_data = tf.compat.v1.placeholder(dtype=tf.float32, shape=input_shape, name="input_data") + + # Define branch functions + def branch_fn_1(): + return tf.add(input_data, tf.constant(1.0, dtype=tf.float32)) + + def branch_fn_2(): + return tf.multiply(input_data, tf.constant(2.0, dtype=tf.float32)) + + branches_fn = [branch_fn_1, branch_fn_2] + + # Create Case operation + case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32) + tf.identity(case_op, name="output") + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + # Test parameters + test_data_basic = [ + dict(input_shape=[1, 2], cond=True), + dict(input_shape=[3, 3], cond=False), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + """ + Executes the test for the Case operation. + """ + self._test(*self.create_case_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)