Skip to content

Commit

Permalink
[TF FE]: Support complex tensors for ScatterNd operation (openvinotoo…
Browse files Browse the repository at this point in the history
…lkit#26821)

### Details:
 - Support complex tensors for ScatterNd operation + tests

### Tickets:
 - [None](openvinotoolkit#23240)

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
hub-bla and rkazants authored Sep 28, 2024
1 parent 8f8344a commit b383112
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/frontends/tensorflow_common/src/op/scatter_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
//

#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "utils.hpp"

Expand All @@ -15,10 +17,29 @@ namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_scatter_nd_op(const NodeContext& node) {
default_op_checks(node, 3, {"ScatterNd", "SCATTER_ND"});
default_op_checks(node, 3, {"ScatterNd", "SCATTER_ND"}, true);
auto input_indices = node.get_input(0);
auto updates = node.get_input(1);
auto shape = node.get_input(2);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(updates.get_node_shared_ptr());

if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
updates = complex_type_mark->input_value(0);

auto new_dim = create_same_type_const<int32_t>(shape, vector<int32_t>{2}, Shape{1});
auto new_shape = make_shared<v0::Concat>(OutputVector{shape, new_dim}, -1);

auto const_zero = create_same_type_const<int32_t>(updates, vector<int32_t>{0}, Shape{1});
auto broadcast = make_shared<v3::Broadcast>(const_zero, new_shape);

auto complex_scatter_nd = make_shared<v3::ScatterNDUpdate>(broadcast, input_indices, updates);

set_node_name(node.get_name(), complex_scatter_nd);
auto complex_result = make_shared<ComplexTypeMark>(complex_scatter_nd, complex_part_type);

return {complex_result};
}

auto input_data = create_same_type_const<int32_t>(updates, vector<int32_t>{0}, Shape{1});
auto broadcast = make_shared<v3::Broadcast>(input_data, shape);
Expand Down
53 changes: 53 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import platform

import numpy as np
import pytest
from common.tf_layer_test_class import CommonTFLayerTest

rng = np.random.default_rng(475912)

class TestTFScatterND(CommonTFLayerTest):
def create_tf_scatternd_placeholder_const_net(self, x_shape, indices, updates, ir_version,
Expand Down Expand Up @@ -73,3 +75,54 @@ def test_tf_scatter_nd(self, params, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend=use_legacy_frontend),
ie_device, precision, temp_dir=temp_dir, ir_version=ir_version,
use_legacy_frontend=use_legacy_frontend, **params)

class TestTFScatterNDComplex(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'param_real:0' in inputs_info, "Test error: inputs_info must contain `param_real`"
assert 'param_imag:0' in inputs_info, "Test error: inputs_info must contain `param_imag`"
updates_shape = inputs_info['param_real:0']
inputs_data = {}
inputs_data['param_real:0'] = rng.integers(-10, 10, updates_shape).astype(np.float32)
inputs_data['param_imag:0'] = rng.integers(-10, 10, updates_shape).astype(np.float32)

return inputs_data

def create_tf_scatternd_complex_placeholder_const_net(self, x_shape, indices, updates_shape, indices_type,
ir_version, use_legacy_frontend):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
param_real = tf.compat.v1.placeholder(tf.float32, updates_shape, 'param_real')
param_imag = tf.compat.v1.placeholder(tf.float32, updates_shape, 'param_imag')

tf_indices = tf.constant(indices, dtype=indices_type)
tf_shape = tf.constant(x_shape, dtype=indices_type)

complex = tf.raw_ops.Complex(real=param_real, imag=param_imag)

result = tf.scatter_nd(indices=tf_indices, updates=complex, shape=tf_shape, name="Operation")
tf.raw_ops.Real(input=result)
tf.raw_ops.Imag(input=result)

tf.compat.v1.global_variables_initializer()

tf_net = sess.graph_def

return tf_net, None

test_data = [
dict(x_shape=[8], indices=[[4], [3], [1], [7]], updates_shape=[4], indices_type=np.int32),
dict(x_shape=[10], indices=[[0], [2], [4], [6], [8]], updates_shape=[5], indices_type=np.int64),
dict(x_shape=[5, 5], indices=[[0, 0], [1, 1], [2, 2], [3, 3]], updates_shape=[4], indices_type=np.int64),
dict(x_shape=[3, 3, 3], indices=[[0, 0, 0], [1, 1, 1], [2, 2, 2]], updates_shape=[3], indices_type=np.int32),
]

@pytest.mark.parametrize("params", test_data)
@pytest.mark.nightly
@pytest.mark.precommit
def test_tf_scatter_nd_complex(self, params, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_tf_scatternd_complex_placeholder_const_net(**params, ir_version=ir_version,
use_legacy_frontend=use_legacy_frontend),
ie_device, precision, temp_dir=temp_dir, ir_version=ir_version,
use_legacy_frontend=use_legacy_frontend, **params)

0 comments on commit b383112

Please sign in to comment.