diff --git a/src/core/include/openvino/op/hswish.hpp b/src/core/include/openvino/op/hswish.hpp index 34cff2955b5ab6..fc2130b56d9655 100644 --- a/src/core/include/openvino/op/hswish.hpp +++ b/src/core/include/openvino/op/hswish.hpp @@ -25,12 +25,8 @@ class OPENVINO_API HSwish : public util::UnaryElementwiseArithmetic { /// \param data Input tensor HSwish(const Output& arg); - bool visit_attributes(AttributeVisitor& visitor) override; - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - OPENVINO_SUPPRESS_DEPRECATED_START - bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; - OPENVINO_SUPPRESS_DEPRECATED_END + bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override; bool has_evaluate() const override; }; } // namespace v4 diff --git a/src/core/src/op/hswish.cpp b/src/core/src/op/hswish.cpp index b509ecb95aabd1..fd2d89896c0460 100644 --- a/src/core/src/op/hswish.cpp +++ b/src/core/src/op/hswish.cpp @@ -2,78 +2,64 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/op/hswish.hpp" - -#include +#include "openvino/op/hswish.hpp" +#include "element_visitor.hpp" #include "itt.hpp" -#include "ngraph/attribute_visitor.hpp" -#include "ngraph/runtime/host_tensor.hpp" #include "openvino/reference/hswish.hpp" -using namespace std; -using namespace ngraph; - -op::v4::HSwish::HSwish(const Output& arg) : UnaryElementwiseArithmetic(arg) { +namespace ov { +namespace op { +namespace v4 { +HSwish::HSwish(const Output& arg) : UnaryElementwiseArithmetic(arg) { constructor_validate_and_infer_types(); } -bool op::v4::HSwish::visit_attributes(AttributeVisitor& visitor) { - OV_OP_SCOPE(v4_HSwish_visit_attributes); - return true; -} - -shared_ptr op::v4::HSwish::clone_with_new_inputs(const OutputVector& new_args) const { +std::shared_ptr HSwish::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v4_HSwish_clone_with_new_inputs); - return make_shared(new_args.at(0)); + return std::make_shared(new_args.at(0)); } -OPENVINO_SUPPRESS_DEPRECATED_START namespace hswish { namespace { -template -inline bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count) { - using T = typename element_type_traits::value_type; - - ov::reference::hswish(arg->get_data_ptr(), out->get_data_ptr(), count); - return true; -} +struct Evaluate : element::NoAction { + using element::NoAction::visit; -bool evaluate_hswish(const HostTensorPtr& arg, const HostTensorPtr& out) { - bool rc = true; - size_t count = shape_size(arg->get_shape()); - out->set_unary(arg); - - switch (arg->get_element_type()) { - OPENVINO_TYPE_CASE(evaluate_hswish, bf16, arg, out, count); - OPENVINO_TYPE_CASE(evaluate_hswish, f16, arg, out, count); - OPENVINO_TYPE_CASE(evaluate_hswish, f32, arg, out, count); - default: - rc = false; - break; + template > + static result_type visit(const Tensor& in, Tensor& out, const size_t count) { + ov::reference::hswish(in.data(), out.data(), count); + return true; } - return rc; -} +}; } // namespace } // namespace hswish -bool op::v4::HSwish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const { +bool HSwish::evaluate(TensorVector& outputs, const TensorVector& inputs) const { OV_OP_SCOPE(v4_HSwish_evaluate); - OPENVINO_SUPPRESS_DEPRECATED_START - OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 1)); - OPENVINO_SUPPRESS_DEPRECATED_END - return hswish::evaluate_hswish(inputs[0], outputs[0]); + OPENVINO_ASSERT(inputs.size() == 1); + OPENVINO_ASSERT(outputs.size() == 1); + + const auto& input_shape = inputs[0].get_shape(); + const auto count = shape_size(input_shape); + outputs[0].set_shape(input_shape); + using namespace ov::element; + return IfTypeOf::apply(inputs[0].get_element_type(), + inputs[0], + outputs[0], + count); } -bool op::v4::HSwish::has_evaluate() const { +bool HSwish::has_evaluate() const { OV_OP_SCOPE(v4_HSwish_has_evaluate); switch (get_input_element_type(0)) { - case ngraph::element::bf16: - case ngraph::element::f16: - case ngraph::element::f32: + case element::bf16: + case element::f16: + case element::f32: return true; default: - break; + return false; } - return false; } +} // namespace v4 +} // namespace op +} // namespace ov