Skip to content

Commit

Permalink
[JAX FE] Support square operation (#27978)
Browse files Browse the repository at this point in the history
**Details:** It appears since JAX 0.4.36

**Ticket:** 158994

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Dec 10, 2024
1 parent 0762993 commit be0ab30
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 3 deletions.
28 changes: 28 additions & 0 deletions src/frontends/jax/src/op/square.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/squeeze.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace jax {
namespace op {

using namespace ov::op;

OutputVector translate_square(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
auto const_two = create_same_type_const_scalar<int64_t>(x, 2);
return {std::make_shared<v1::Power>(x, const_two)};
};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ OP_CONVERTER(translate_reduce_window_sum);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_square);
OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_transpose);

Expand Down Expand Up @@ -92,6 +93,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"rsqrt", op::translate_rsqrt},
{"reshape", op::translate_reshape},
{"slice", op::translate_slice},
{"square", op::translate_square},
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
{"squeeze", op::translate_squeeze},
{"stop_gradient", op::skip_node},
Expand Down
6 changes: 3 additions & 3 deletions tests/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ pytest>=5.0,<8.4
pytest-dependency==0.5.1
pytest-html==4.1.1
pytest-timeout==2.3.1
jax<=0.4.35
jaxlib<=0.4.35
jax<=0.4.36
jaxlib<=0.4.36
kornia==0.7.0
networkx<=3.3
flax<=0.10.0
flax<=0.10.2

--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
Expand Down
44 changes: 44 additions & 0 deletions tests/layer_tests/jax_tests/test_square.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(34455)


class TestSquare(JaxLayerTest):
def _prepare_input(self):
if np.issubdtype(self.input_type, np.floating):
x = rng.uniform(-8.0, 8.0, self.input_shape).astype(self.input_type)
elif np.issubdtype(self.input_type, np.signedinteger):
x = rng.integers(-8, 8, self.input_shape).astype(self.input_type)
else:
x = rng.integers(0, 8, self.input_shape).astype(self.input_type)
x = jnp.array(x)
return [x]

def create_model(self, input_shape, input_type):
self.input_shape = input_shape
self.input_type = input_type

def jax_square(x):
return jax.numpy.square(x)

return jax_square, None, None

@pytest.mark.parametrize("input_shape", [[2], [3, 4]])
@pytest.mark.parametrize("input_type", [np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64,
np.float16, np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_square(self, ie_device, precision, ir_version, input_shape, input_type):
self._test(*self.create_model(input_shape, input_type),
ie_device, precision,
ir_version)

0 comments on commit be0ab30

Please sign in to comment.