-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX FE] Support square operation (#27978)
**Details:** It appears since JAX 0.4.36 **Ticket:** 158994 Signed-off-by: Kazantsev, Roman <[email protected]>
- Loading branch information
Showing
4 changed files
with
77 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/power.hpp" | ||
#include "openvino/op/squeeze.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_square(const NodeContext& context) { | ||
num_inputs_check(context, 1, 1); | ||
auto x = context.get_input(0); | ||
auto const_two = create_same_type_const_scalar<int64_t>(x, 2); | ||
return {std::make_shared<v1::Power>(x, const_two)}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(34455) | ||
|
||
|
||
class TestSquare(JaxLayerTest): | ||
def _prepare_input(self): | ||
if np.issubdtype(self.input_type, np.floating): | ||
x = rng.uniform(-8.0, 8.0, self.input_shape).astype(self.input_type) | ||
elif np.issubdtype(self.input_type, np.signedinteger): | ||
x = rng.integers(-8, 8, self.input_shape).astype(self.input_type) | ||
else: | ||
x = rng.integers(0, 8, self.input_shape).astype(self.input_type) | ||
x = jnp.array(x) | ||
return [x] | ||
|
||
def create_model(self, input_shape, input_type): | ||
self.input_shape = input_shape | ||
self.input_type = input_type | ||
|
||
def jax_square(x): | ||
return jax.numpy.square(x) | ||
|
||
return jax_square, None, None | ||
|
||
@pytest.mark.parametrize("input_shape", [[2], [3, 4]]) | ||
@pytest.mark.parametrize("input_type", [np.int8, np.uint8, np.int16, np.uint16, | ||
np.int32, np.uint32, np.int64, np.uint64, | ||
np.float16, np.float32, np.float64]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.precommit_jax_fe | ||
def test_square(self, ie_device, precision, ir_version, input_shape, input_type): | ||
self._test(*self.create_model(input_shape, input_type), | ||
ie_device, precision, | ||
ir_version) |