-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TF FE] Support Case operation #28027
base: master
Are you sure you want to change the base?
[TF FE] Support Case operation #28027
Conversation
@@ -140,6 +142,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() { | |||
{"Asinh", CreatorFunction(translate_unary_op<v3::Asinh>)}, | |||
{"Atan", CreatorFunction(translate_unary_op<v0::Atan>)}, | |||
{"Atanh", CreatorFunction(translate_unary_op<v3::Atanh>)}, | |||
{"Case", CreatorFunction(translate_case_op)}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add layer tests for this Operation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rkazants It would be greatly appreciated if you could guide me on the appropriate folder where I should add the layer tests for this operation. As I am new to open-source contributions, your guidance would be very helpful in ensuring I follow the correct structure and conventions for adding the test case for this operation.
…://github.com/shubdas9902/openvino into feature/issue-20534-add-case-op-support-tf-fe
@rkazants I added the layer tests what should I do next? |
bump @rkazants |
#ifndef CASE_HPP | ||
#define CASE_HPP | ||
|
||
#include "openvino/frontend/tensorflow/node_context.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace tensorflow { | ||
namespace op { | ||
|
||
OutputVector translate_case_op(const ov::frontend::tensorflow::NodeContext& node); | ||
|
||
} // namespace op | ||
} // namespace tensorflow | ||
} // namespace frontend | ||
} // namespace ov | ||
|
||
#endif // CASE_HPP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed header file. Please check other operation translators for which we don't have header
@@ -0,0 +1,72 @@ | |||
#include "openvino/frontend/tensorflow/node_context.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add copyright as for other src files.
@@ -0,0 +1,72 @@ | |||
#include "openvino/frontend/tensorflow/node_context.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Validate the operation type | ||
auto op_type = node.get_op_type(); | ||
TENSORFLOW_OP_VALIDATION(node, op_type == "Case", | ||
"Internal error: incorrect usage of translate_case_op."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use default_op_checks
instead
"[TensorFlow Frontend] Case operation must have at least one branch."); | ||
|
||
// The first input is the condition for selecting the branch | ||
auto cond = node.get_input(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us rename it to branch_index
|
||
if (current_model) { | ||
if_op->set_else_body(current_model); | ||
} else { | ||
// Default empty else body | ||
auto placeholder_model = std::make_shared<Model>(OutputVector{}, ParameterVector{}); | ||
if_op->set_else_body(placeholder_model); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see that you properly set input parameters for then
and else
bodies. please check implementation of If
translators. The same question for outputs of bodies.
Also, you need to have different conditions for each nested If
. It is like a sort of branch_index == i
where i is an index of branch.
|
||
def branch_fn_2(): | ||
return tf.multiply(input_data, tf.constant(2.0, dtype=tf.float32)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us have 4 branches
test_data_basic = [ | ||
dict(input_shape=[1, 2], cond=True), | ||
dict(input_shape=[3, 3], cond=False), | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us check for four branch index values "0, 1, 2, 3"
Details:
case.hpp
to define the translation function for the Case operation.translate_case_op
function incase.cpp
to handle the conversion of the TensorFlow Case operation to OpenVINO.TF_OP_CONVERTER
macro in the operation converter framework.Tickets: