Skip to content

Commit

Permalink
[TF FE] Fix translators for multiple output operations (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#20787)

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Oct 31, 2023
1 parent 48c9598 commit 8d6f56d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
9 changes: 7 additions & 2 deletions src/frontends/tensorflow/src/op/sparse_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) {
NamedOutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) {
// Currently, the translation for SparseReshape is possible only if new shape value is the same as the input shape
// value or it is different just by one dynamic dimension of the new shape that can be replace with the
// corresponding static dimension of the input shape.
Expand Down Expand Up @@ -67,7 +67,12 @@ OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeCon
"This case with SparseReshape is not possible to translate to OpenVINO opset. The number "
"of dynamic shapes in new shape must be 1 at most.");
*/
return {input_indices, input_shape};
auto output_indices = input_indices;
auto output_shape = input_shape;
set_out_name(node.get_name() + ":0", output_indices);
set_out_name(node.get_name() + ":1", output_shape);

return {{"output_indices", output_indices}, {"output_shape", output_shape}};
}

NamedOutputVector translate_sparse_fill_empty_rows_op(const ov::frontend::tensorflow::NodeContext& node) {
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TF_OP_CONVERTER(translate_queue_dequeue_many_op);
TF_OP_CONVERTER(translate_readvariable_op);
TF_OP_CONVERTER(translate_restorev2_op);
TF_OP_CONVERTER_NAMED(translate_sparse_fill_empty_rows_op);
TF_OP_CONVERTER(translate_sparse_reshape_op);
TF_OP_CONVERTER_NAMED(translate_sparse_reshape_op);
TF_OP_CONVERTER(translate_sparse_segment_sum_op);
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
TF_OP_CONVERTER(translate_stringjoin_op);
Expand Down Expand Up @@ -216,7 +216,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"MaxPool", CreatorFunction(translate_max_pool_op)},
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
{"MaxPool3D", CreatorFunction(translate_max_pool_op)},
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_op)},
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_with_argmax)},
{"Merge", CreatorFunction(translate_merge_op)},
{"MirrorPad", CreatorFunction(translate_mirror_pad_op)},
{"MutableHashTable", CreatorFunction(translate_hash_table_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ OP_CONVERTER(translate_lrn_op);
OP_CONVERTER(translate_mat_mul_op);
OP_CONVERTER(translate_matrix_diag_op);
OP_CONVERTER(translate_max_pool_op);
OP_CONVERTER_NAMED(translate_max_pool_with_argmax);
OP_CONVERTER(translate_mirror_pad_op);
OP_CONVERTER_NAMED(translate_non_max_suppression_op);
OP_CONVERTER(translate_parallel_dynamic_stitch_op);
Expand Down
7 changes: 3 additions & 4 deletions src/frontends/tensorflow_common/src/op/max_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ OutputVector translate_max_pool_v2(const NodeContext& node) {
return translate_max_pool_util(node, 2, ksize_vector, strides_vector);
}

OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
NamedOutputVector translate_max_pool_with_argmax(const NodeContext& node) {
// MaxPoolWithArgmax has just one input. ksize and strides are attributes
TENSORFLOW_OP_VALIDATION(node,
node.get_input_size() > 0,
Expand Down Expand Up @@ -199,8 +199,9 @@ OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
convert_nchw_to_nhwc(true, output_indices, 4);
}

set_out_name(node_name + ":0", max_pool);
set_out_name(node_name + ":1", output_indices);
return {max_pool, output_indices};
return {{"output", max_pool}, {"argmax", output_indices}};
}

OutputVector translate_max_pool_op(const NodeContext& node) {
Expand All @@ -210,8 +211,6 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
return translate_max_pool_v2(node);
} else if (node.get_op_type() == "MaxPool3D") {
return translate_max_pool(node, 3);
} else if (node.get_op_type() == "MaxPoolWithArgmax") {
return translate_max_pool_with_argmax(node);
} else {
TENSORFLOW_OP_VALIDATION(node, false, "Only MaxPool2D, MaxPoolV2 and MaxPool3D are supported.");
}
Expand Down

0 comments on commit 8d6f56d

Please sign in to comment.