From 8d6f56dd1214253e14e55b035e4e22970ff26565 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 31 Oct 2023 17:22:09 +0400 Subject: [PATCH] [TF FE] Fix translators for multiple output operations (#20787) Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/op/sparse_reshape.cpp | 9 +++++++-- src/frontends/tensorflow/src/op_table.cpp | 4 ++-- .../tensorflow_common/include/common_op_table.hpp | 1 + src/frontends/tensorflow_common/src/op/max_pool.cpp | 7 +++---- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/frontends/tensorflow/src/op/sparse_reshape.cpp b/src/frontends/tensorflow/src/op/sparse_reshape.cpp index 1def5f4192f155..0d2e9de81e68c2 100644 --- a/src/frontends/tensorflow/src/op/sparse_reshape.cpp +++ b/src/frontends/tensorflow/src/op/sparse_reshape.cpp @@ -16,7 +16,7 @@ namespace ov { namespace frontend { namespace tensorflow { namespace op { -OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) { +NamedOutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) { // Currently, the translation for SparseReshape is possible only if new shape value is the same as the input shape // value or it is different just by one dynamic dimension of the new shape that can be replace with the // corresponding static dimension of the input shape. @@ -67,7 +67,12 @@ OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeCon "This case with SparseReshape is not possible to translate to OpenVINO opset. The number " "of dynamic shapes in new shape must be 1 at most."); */ - return {input_indices, input_shape}; + auto output_indices = input_indices; + auto output_shape = input_shape; + set_out_name(node.get_name() + ":0", output_indices); + set_out_name(node.get_name() + ":1", output_shape); + + return {{"output_indices", output_indices}, {"output_shape", output_shape}}; } NamedOutputVector translate_sparse_fill_empty_rows_op(const ov::frontend::tensorflow::NodeContext& node) { diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 4926ac159cecbb..e5f25dad31270a 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -43,7 +43,7 @@ TF_OP_CONVERTER(translate_queue_dequeue_many_op); TF_OP_CONVERTER(translate_readvariable_op); TF_OP_CONVERTER(translate_restorev2_op); TF_OP_CONVERTER_NAMED(translate_sparse_fill_empty_rows_op); -TF_OP_CONVERTER(translate_sparse_reshape_op); +TF_OP_CONVERTER_NAMED(translate_sparse_reshape_op); TF_OP_CONVERTER(translate_sparse_segment_sum_op); TF_OP_CONVERTER(translate_staticregexfullmatch_op); TF_OP_CONVERTER(translate_stringjoin_op); @@ -216,7 +216,7 @@ const std::map get_supported_ops() { {"MaxPool", CreatorFunction(translate_max_pool_op)}, {"MaxPoolV2", CreatorFunction(translate_max_pool_op)}, {"MaxPool3D", CreatorFunction(translate_max_pool_op)}, - {"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_op)}, + {"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_with_argmax)}, {"Merge", CreatorFunction(translate_merge_op)}, {"MirrorPad", CreatorFunction(translate_mirror_pad_op)}, {"MutableHashTable", CreatorFunction(translate_hash_table_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index 3601a07f6c45d0..6befa470761a45 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -93,6 +93,7 @@ OP_CONVERTER(translate_lrn_op); OP_CONVERTER(translate_mat_mul_op); OP_CONVERTER(translate_matrix_diag_op); OP_CONVERTER(translate_max_pool_op); +OP_CONVERTER_NAMED(translate_max_pool_with_argmax); OP_CONVERTER(translate_mirror_pad_op); OP_CONVERTER_NAMED(translate_non_max_suppression_op); OP_CONVERTER(translate_parallel_dynamic_stitch_op); diff --git a/src/frontends/tensorflow_common/src/op/max_pool.cpp b/src/frontends/tensorflow_common/src/op/max_pool.cpp index d64ac1a17fbafe..c693f1e7533554 100644 --- a/src/frontends/tensorflow_common/src/op/max_pool.cpp +++ b/src/frontends/tensorflow_common/src/op/max_pool.cpp @@ -128,7 +128,7 @@ OutputVector translate_max_pool_v2(const NodeContext& node) { return translate_max_pool_util(node, 2, ksize_vector, strides_vector); } -OutputVector translate_max_pool_with_argmax(const NodeContext& node) { +NamedOutputVector translate_max_pool_with_argmax(const NodeContext& node) { // MaxPoolWithArgmax has just one input. ksize and strides are attributes TENSORFLOW_OP_VALIDATION(node, node.get_input_size() > 0, @@ -199,8 +199,9 @@ OutputVector translate_max_pool_with_argmax(const NodeContext& node) { convert_nchw_to_nhwc(true, output_indices, 4); } + set_out_name(node_name + ":0", max_pool); set_out_name(node_name + ":1", output_indices); - return {max_pool, output_indices}; + return {{"output", max_pool}, {"argmax", output_indices}}; } OutputVector translate_max_pool_op(const NodeContext& node) { @@ -210,8 +211,6 @@ OutputVector translate_max_pool_op(const NodeContext& node) { return translate_max_pool_v2(node); } else if (node.get_op_type() == "MaxPool3D") { return translate_max_pool(node, 3); - } else if (node.get_op_type() == "MaxPoolWithArgmax") { - return translate_max_pool_with_argmax(node); } else { TENSORFLOW_OP_VALIDATION(node, false, "Only MaxPool2D, MaxPoolV2 and MaxPool3D are supported."); }