Skip to content

Commit

Permalink
[ONNX] Update Range to accept single value tensor (openvinotoolkit#…
Browse files Browse the repository at this point in the history
…21741)

* Update range.cpp
* Update test_backend.py

---------

Co-authored-by: Georgy Krivoruchko <[email protected]>
  • Loading branch information
siddhant-0707 and gkrivor authored Dec 31, 2023
1 parent ded4938 commit 387f453
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
27 changes: 24 additions & 3 deletions src/frontends/onnx/frontend/src/op/range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,37 @@
#include <memory>

#include "default_opset.hpp"
#include "exceptions.hpp"

OPENVINO_SUPPRESS_DEPRECATED_START
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector range(const Node& node) {
const Output<ngraph::Node> start{node.get_ng_inputs().at(0)};
const Output<ngraph::Node> stop{node.get_ng_inputs().at(1)};
const Output<ngraph::Node> step{node.get_ng_inputs().at(2)};
const auto inputs = node.get_ng_inputs();
CHECK_VALID_NODE(node, inputs.size() >= 3, "Minimum 3 inputs are required. Got: ", inputs.size());

Output<ngraph::Node> start{inputs[0]};
Output<ngraph::Node> stop{inputs[1]};
Output<ngraph::Node> step{inputs[2]};

auto axes =
std::make_shared<default_opset::Constant>(ngraph::element::i64, ngraph::Shape{}, std::vector<int64_t>{0});

// Check if step is a tensor with a single value
if (start.get_shape().size() == 1 && start.get_shape()[0] == 1) {
start = std::make_shared<default_opset::Squeeze>(start, axes);
}

if (stop.get_shape().size() == 1 && stop.get_shape()[0] == 1) {
stop = std::make_shared<default_opset::Squeeze>(stop, axes);
}

if (step.get_shape().size() == 1 && step.get_shape()[0] == 1) {
step = std::make_shared<default_opset::Squeeze>(step, axes);
}

return {std::make_shared<default_opset::Range>(start, stop, step, start.get_element_type())};
}
} // namespace set_1
Expand Down
4 changes: 0 additions & 4 deletions src/frontends/onnx/tests/tests_python/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,13 +698,9 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None
(
xfail_issue_125485,
"OnnxBackendNodeModelTest.test_affine_grid_2d_align_corners_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_2d_align_corners_expanded_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_2d_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_2d_expanded_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_3d_align_corners_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_3d_align_corners_expanded_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_3d_cpu",
"OnnxBackendNodeModelTest.test_affine_grid_3d_expanded_cpu",
),
(
xfail_issue_125486,
Expand Down

0 comments on commit 387f453

Please sign in to comment.