From ded49387b022a754461aded73dcb22d70f36dc6e Mon Sep 17 00:00:00 2001 From: Abdulrahman Adel Ibrahim <57441828+Abdulrahman-Adel@users.noreply.github.com> Date: Sat, 30 Dec 2023 22:29:48 +0200 Subject: [PATCH] Support ConjugateTranspose for TF (#21586) * conj transpose * Update * Fix Mistakes * Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp Co-authored-by: Anastasiia Pnevskaia * Update src/frontends/tensorflow/src/op_table.cpp Co-authored-by: Roman Kazantsev * """Test the conjugate transpose layer.""" * u * Resolve Issues * Revert "conj transpose" This reverts commit 27833f348a691c4c243ecd38120c36460402c959. * Update ConjTranspose * Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp * update clang-format | fix test errors * Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp * fix test error * Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp --------- Co-authored-by: Anastasiia Pnevskaia Co-authored-by: Roman Kazantsev --- src/frontends/tensorflow/src/op_table.cpp | 1 + .../include/common_op_table.hpp | 1 + .../src/op/conj_transpose.cpp | 65 ++++++++++ .../test_tf_ConjugateTranspose.py | 122 ++++++++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 src/frontends/tensorflow_common/src/op/conj_transpose.cpp create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_ConjugateTranspose.py diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 864619b57c0c63..062d29cf8b06ac 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -157,6 +157,7 @@ const std::map get_supported_ops() { {"ClipByValue", CreatorFunction(translate_clip_by_value_op)}, {"Complex", CreatorFunction(translate_complex_op)}, {"ComplexAbs", CreatorFunction(translate_complex_abs_op)}, + {"ConjugateTranspose", CreatorFunction(translate_conj_transpose_op)}, {"Concat", CreatorFunction(translate_concat_op)}, {"ConcatV2", CreatorFunction(translate_concat_op)}, {"Const", CreatorFunction(translate_const_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index 95e79f001ca28d..c2f756b4aecc62 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -50,6 +50,7 @@ OP_CONVERTER(translate_clip_by_value_op); OP_CONVERTER(translate_complex_op); OP_CONVERTER(translate_complex_abs_op); OP_CONVERTER(translate_concat_op); +OP_CONVERTER(translate_conj_transpose_op); OP_CONVERTER(translate_const_op); OP_CONVERTER(translate_conv_2d_op); OP_CONVERTER(translate_conv_2d_backprop_input_op); diff --git a/src/frontends/tensorflow_common/src/op/conj_transpose.cpp b/src/frontends/tensorflow_common/src/op/conj_transpose.cpp new file mode 100644 index 00000000000000..8b4a51bdcbc828 --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/conj_transpose.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/negative.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/transpose.hpp" + +using namespace std; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_conj_transpose_op(const NodeContext& node) { + default_op_checks(node, 2, {"ConjugateTranspose"}, true); + + auto x = node.get_input(0); + auto perm = node.get_input(1); + + auto complex_type_mark = as_type_ptr(x.get_node_shared_ptr()); + if (complex_type_mark) { + element::Type complex_part_type = complex_type_mark->get_complex_part_type(); + auto x = complex_type_mark->input_value(0); + + auto real_index = make_shared(element::i32, Shape{1}, 0); + auto imag_index = make_shared(element::i32, Shape{1}, 1); + + auto gather_axis = make_shared(element::i32, Shape{1}, -1); + + auto real = make_shared(x, real_index, gather_axis)->output(0); + auto imag = make_shared(x, imag_index, gather_axis)->output(0); + + imag = make_shared(imag); + + auto conj_tensor = make_shared(OutputVector{real, imag}, -1)->output(0); + + OutputVector concat_inputs; + concat_inputs.push_back(perm); + concat_inputs.push_back(make_shared(perm, perm.get_element_type())); + + auto concat = make_shared(concat_inputs, 0); + auto conj_transpose = make_shared(conj_tensor, concat); + + set_node_name(node.get_name(), conj_transpose); + auto complex_transpose = make_shared(conj_transpose, complex_part_type); + return {complex_transpose->output(0)}; + } + + auto conj_transpose = make_shared(x, perm); + set_node_name(node.get_name(), conj_transpose); + return {conj_transpose}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ConjugateTranspose.py b/tests/layer_tests/tensorflow_tests/test_tf_ConjugateTranspose.py new file mode 100644 index 00000000000000..a0f78096c02d50 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_ConjugateTranspose.py @@ -0,0 +1,122 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + +# Testing operation ConjugateTranspose +# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/ConjugateTranspose + + +class TestComplexConjugateTranspose(CommonTFLayerTest): + + def _prepare_input(self, inputs_info): + + rng = np.random.default_rng() + assert 'real_part' in inputs_info + real_part_shape = inputs_info['real_part'] + assert 'imag_part' in inputs_info + imag_part_shape = inputs_info['imag_part'] + + inputs_data = {} + inputs_data['real_part'] = 4 * rng.random(real_part_shape).astype(np.float32) - 2 + inputs_data['imag_part'] = 4 * rng.random(imag_part_shape).astype(np.float32) - 2 + + return inputs_data + + def create_complex_conjugate_transpose_net(self, input_shape, perm): + """ + TensorFlow net IR net + + Placeholder->ConjugateTranspose => Placeholder->Transpose->Conjugate->Transpose + """ + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + real_part = tf.compat.v1.placeholder(np.float32, input_shape, 'real_part') + imag_part = tf.compat.v1.placeholder(np.float32, input_shape, 'imag_part') + + complex_input = tf.raw_ops.Complex(real=real_part, imag=imag_part) + + conj_tranpose = tf.raw_ops.ConjugateTranspose(x=complex_input, perm=perm, name = "Operation") + real = tf.raw_ops.Real(input=conj_tranpose) + img = tf.raw_ops.Imag(input=conj_tranpose) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + + test_data = [ + (dict(input_shape=[1, 2], perm=[1, 0])), + (dict(input_shape=[1, 2, 3], perm=[2, 1, 0])), + (dict(input_shape=[1, 2, 3, 4], perm=[0, 3, 2, 1])), + (dict(input_shape=[1, 2, 3, 4, 5, 6], perm=[0, 2, 1, 3, 4, 5])), + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_conjugate_transpose(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_complex_conjugate_transpose_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestConjugateTranspose(CommonTFLayerTest): + + def _prepare_input(self, inputs_info): + + assert 'input' in inputs_info + input_shape = inputs_info['input'] + + inputs_data = {} + inputs_data['input'] = np.random.default_rng().random(input_shape).astype(np.float32) + + return inputs_data + + def create_conjugate_transpose_net(self, input_shape, perm): + """ + TensorFlow net IR net + + Placeholder->ConjugateTranspose => Placeholder->Transpose->Conjugate->Transpose + """ + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + input = tf.compat.v1.placeholder(np.float32, input_shape, 'input') + + tf.raw_ops.ConjugateTranspose(x=input, perm=perm, name = "Operation") + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data = [ + (dict(input_shape=[1, 2], perm=[1, 0])), + (dict(input_shape=[1, 2, 3], perm=[2, 1, 0])), + (dict(input_shape=[1, 2, 3, 4], perm=[0, 3, 2, 1])), + (dict(input_shape=[1, 2, 3, 4, 5, 6], perm=[0, 2, 1, 3, 4, 5])), + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_conjugate_transpose(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_conjugate_transpose_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) \ No newline at end of file