Skip to content

Commit

Permalink
[PT FE] Support aten::log_sigmoid (openvinotoolkit#23200)
Browse files Browse the repository at this point in the history
### Details:
 - *Support `aten::log_sigmoid` in TS*
 - *Support `aten.expm1.default` and `aten.erfc.default` in FX*
 - *Unify unary ops testing*

### Tickets:
 - *CVS-134328*
  • Loading branch information
mvafin authored Mar 4, 2024
1 parent ca22b50 commit e269340
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 462 deletions.
50 changes: 27 additions & 23 deletions src/frontends/pytorch/src/op/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/reduce_sum.hpp"
Expand All @@ -21,44 +22,47 @@ namespace op {

using namespace ov::op;

OutputVector translate_log(const NodeContext& context) {
// torch.log returns a tensor with the natural logarithm of the elements of input.
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
auto log = context.mark_node(std::make_shared<v0::Log>(x));
return {log};
};

OutputVector translate_log_sigmoid(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
auto sigmoid = context.mark_node(std::make_shared<v0::Sigmoid>(x));
auto op_vector = op::translate_1to1_match_1_inputs_with_fp32_type_alignment<v0::Sigmoid>(context);
PYTORCH_OP_CONVERSION_CHECK(op_vector.size() == 1,
"Expected exactly one element in the vector. Got: ",
op_vector.size());
auto sigmoid = op_vector[0];
auto log = context.mark_node(std::make_shared<v0::Log>(sigmoid));
return {log};
};

OutputVector translate_log2(const NodeContext& context) {
// torch.log2 returns a tensor with the logarithm to the base 2 of the elements of input.
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
num_inputs_check(context, 1, 2);
auto op_vector = op::translate_1to1_match_1_inputs_with_fp32_type_alignment<v0::Log>(context);
PYTORCH_OP_CONVERSION_CHECK(op_vector.size() == 1,
"Expected exactly one element in the vector. Got: ",
op_vector.size());
auto log = op_vector[0];

auto two = context.mark_node(v0::Constant::create(element::f32, Shape{}, {2}));
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
two = context.mark_node(std::make_shared<v1::ConvertLike>(two, log));
auto log2 = context.mark_node(std::make_shared<v0::Log>(two));
auto log = context.mark_node(std::make_shared<v0::Log>(x));

auto res = context.mark_node(std::make_shared<v1::Divide>(log, log2));
return {res};
};

OutputVector translate_log10(const NodeContext& context) {
// torch.log10 returns a tensor with the logarithm to the base 10 of the elements of input.
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
num_inputs_check(context, 1, 2);
auto op_vector = op::translate_1to1_match_1_inputs_with_fp32_type_alignment<v0::Log>(context);
PYTORCH_OP_CONVERSION_CHECK(op_vector.size() == 1,
"Expected exactly one element in the vector. Got: ",
op_vector.size());
auto log = op_vector[0];

auto ten = context.mark_node(v0::Constant::create(element::f32, Shape{}, {10}));
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
ten = context.mark_node(std::make_shared<v1::ConvertLike>(ten, log));
auto log10 = context.mark_node(std::make_shared<v0::Log>(ten));
auto log = context.mark_node(std::make_shared<v0::Log>(x));

auto res = context.mark_node(std::make_shared<v1::Divide>(log, log10));
return {res};
};
Expand All @@ -80,10 +84,10 @@ OutputVector translate_logsumexp(const NodeContext& context) {

OutputVector translate_log1p(const NodeContext& context) {
// torch.log1p returns a tensor with the natural logarithm of the elements of input + 1.
num_inputs_check(context, 1, 1);
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}))->output(0);
align_eltwise_input_types(context, x, one);
auto x_plus_one = context.mark_node(std::make_shared<v1::Add>(x, one));
auto log = context.mark_node(std::make_shared<v0::Log>(x_plus_one));
return {log};
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/reciprocal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace op {
using namespace ov::op;

OutputVector translate_reciprocal(const NodeContext& context) {
num_inputs_check(context, 1, 1);
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
auto const_neg_1 = context.mark_node(v0::Constant::create(element::f32, Shape{}, {-1}))->output(0);
align_eltwise_input_types(context, x, const_neg_1, true);
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/rsqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace op {
using namespace ov::op;

OutputVector translate_rsqrt(const NodeContext& context) {
num_inputs_check(context, 1, 1);
num_inputs_check(context, 1, 2);
auto data = context.get_input(0);
auto one_const = context.mark_node(v0::Constant::create(element::f32, Shape({}), {1}));
Output<Node> fake_const_for_type = context.mark_node(v0::Constant::create(element::f32, Shape({}), {.5}));
Expand Down
62 changes: 34 additions & 28 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ OP_CONVERTER(translate_linear);
OP_CONVERTER(translate_linspace);
OP_CONVERTER(translate_list_construct);
OP_CONVERTER(translate_list_unpack);
OP_CONVERTER(translate_log);
OP_CONVERTER(translate_log1p);
OP_CONVERTER(translate_log_sigmoid);
OP_CONVERTER(translate_log_softmax);
Expand Down Expand Up @@ -305,11 +304,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa},
{"aten::_upsample_bilinear2d_aa", op::translate_upsample_bilinear2d_aa},
{"aten::_weight_norm", op::translate_weight_norm},
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"aten::abs", op::optional_out<op::translate_1to1_match_1_inputs<opset10::Abs>, 1>},
{"aten::abs_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Abs>>},
{"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
{"aten::acos", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>, 1>},
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
{"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
{"aten::acosh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>, 1>},
{"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acosh>>},
{"aten::adaptive_avg_pool1d", op::quantizable_op<op::translate_adaptive_avg_pool1d>},
{"aten::adaptive_avg_pool2d", op::quantizable_op<op::translate_adaptive_avg_pool2d>},
Expand All @@ -333,13 +333,15 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::argsort", op::translate_argsort},
{"aten::as_strided", op::translate_as_strided},
{"aten::as_tensor", op::translate_as_tensor},
{"aten::asin", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asin>},
{"aten::asin", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asin>, 1>},
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asin>>},
{"aten::asinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
{"aten::asinh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>, 1>},
{"aten::asinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asinh>>},
{"aten::atan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
{"aten::atan", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>, 1>},
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>},
{"aten::atanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
{"aten::atanh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
Expand All @@ -356,7 +358,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::broadcast_to", op::translate_expand},
{"aten::cat", op::translate_cat},
{"aten::cdist", op::translate_cdist},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil", op::optional_out<op::translate_1to1_match_1_inputs<opset10::Ceiling>, 1>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
{"aten::channel_shuffle", op::translate_channel_shuffle},
// aten::chunk - Supported in limited set of patterns
Expand All @@ -380,9 +382,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::convolution", op::translate_convolution},
{"aten::copy", op::skip_node},
{"aten::copy_", op::translate_copy_},
{"aten::cos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cos>},
{"aten::cos", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cos>, 1>},
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cos>>},
{"aten::cosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cosh>},
{"aten::cosh", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cosh>, 1>},
{"aten::cosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cosh>>},
{"aten::cross", op::translate_cross},
{"aten::cumsum", op::translate_cumsum},
Expand All @@ -404,7 +406,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::erf_", op::inplace_op<op::translate_erf>},
{"aten::erfc", op::translate_erfc},
{"aten::erfc_", op::inplace_op<op::translate_erfc>},
{"aten::exp", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
{"aten::exp", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>, 1>},
{"aten::exp_", op::inplace_op<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>>},
{"aten::expand", op::translate_expand},
{"aten::expand_as", op::translate_expand_as},
Expand All @@ -421,7 +423,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::fill_diagonal_", op::inplace_op<op::translate_fill_diagonal>},
{"aten::flatten", op::quantizable_op<op::translate_flatten>},
{"aten::flip", op::translate_flip},
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
{"aten::floor", op::optional_out<op::translate_1to1_match_1_inputs<opset10::Floor>, 1>},
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
{"aten::floor_divide", op::translate_floor_divide},
{"aten::floordiv", op::translate_floor_divide},
Expand Down Expand Up @@ -475,18 +477,19 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::linalg_vector_norm", op::translate_linalg_vector_norm},
{"aten::linear", op::translate_linear},
{"aten::linspace", op::translate_linspace},
{"aten::log", op::translate_log},
{"aten::log_", op::inplace_op<op::translate_log>},
{"aten::log", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Log>, 1>},
{"aten::log_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Log>>},
{"aten::logical_and", op::translate_and},
{"aten::logical_or", op::translate_or},
{"aten::logical_not", op::translate_not},
{"aten::logical_xor", op::translate_xor},
{"aten::log_sigmoid", op::translate_log_sigmoid},
{"aten::log_softmax", op::translate_log_softmax},
{"aten::log1p", op::translate_log1p},
{"aten::log1p", op::optional_out<op::translate_log1p, 1>},
{"aten::log1p_", op::inplace_op<op::translate_log1p>},
{"aten::log2", op::translate_log2},
{"aten::log2", op::optional_out<op::translate_log2, 1>},
{"aten::log2_", op::inplace_op<op::translate_log2>},
{"aten::log10", op::translate_log10},
{"aten::log10", op::optional_out<op::translate_log10, 1>},
{"aten::log10_", op::inplace_op<op::translate_log10>},
{"aten::lstm", op::translate_lstm},
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
Expand Down Expand Up @@ -548,10 +551,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::randn", op::translate_randn},
{"aten::randn_like", op::translate_randn_like},
// aten::real - Supported in limited set of patterns
{"aten::reciprocal", op::translate_reciprocal},
{"aten::reciprocal", op::optional_out<op::translate_reciprocal, 1>},
{"aten::reciprocal_", op::inplace_op<op::translate_reciprocal>},
// aten::reflection_pad2d - Supported in limited set of patterns
{"aten::relu", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten::relu", op::optional_out<op::translate_1to1_match_1_inputs<opset10::Relu>, 1>},
{"aten::relu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten::relu6", op::translate_relu6},
{"aten::relu6_", op::inplace_op<op::translate_relu6>},
Expand All @@ -569,7 +572,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::rnn_tanh", op::translate_rnn},
{"aten::roll", op::translate_roll},
{"aten::round", op::translate_round},
{"aten::rsqrt", op::translate_rsqrt},
{"aten::rsqrt", op::optional_out<op::translate_rsqrt, 1>},
{"aten::rsub", op::translate_rsub},
{"aten::ScalarImplicit", op::skip_node},
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
Expand All @@ -582,14 +585,15 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::select", op::quantizable_op<op::translate_select>},
{"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>},
{"aten::sigmoid", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
{"aten::sigmoid",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>, 1>},
{"aten::sigmoid_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sigmoid>>},
{"aten::sign", op::translate_sign},
{"aten::silu", op::translate_1to1_match_1_inputs<opset10::Swish>},
{"aten::silu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Swish>>},
{"aten::sin", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sin>},
{"aten::sin", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sin>, 1>},
{"aten::sin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sin>>},
{"aten::sinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>},
{"aten::sinh", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>, 1>},
{"aten::sinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sinh>>},
{"aten::size", op::translate_size},
{"aten::slice", op::quantizable_op<op::translate_slice>},
Expand All @@ -598,7 +602,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::sort", op::translate_sort},
// aten::split - Supported in limited set of patterns
// aten::split_with_sizes - Supported in limited set of patterns
{"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>},
{"aten::sqrt", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>, 1>},
{"aten::square", op::translate_square},
{"aten::squeeze", op::quantizable_op<op::translate_squeeze>},
// aten::stack - Supported in limited set of patterns
Expand All @@ -611,9 +615,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::t", op::translate_t},
{"aten::t_", op::inplace_op<op::translate_t>},
{"aten::take_along_dim", op::translate_take_along_dim},
{"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>},
{"aten::tan", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>, 1>},
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tan>>},
{"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>},
{"aten::tanh", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>, 1>},
{"aten::tanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tanh>>},
{"aten::tensor", op::translate_as_tensor},
// aten::tensor_split - Supported in limited set of patterns
Expand Down Expand Up @@ -755,7 +759,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.eq.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten.eq.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten.erf.default", op::translate_erf},
{"aten.erfc.default", op::translate_erfc},
{"aten.exp.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
{"aten.expm1.default", op::translate_expm1},
{"aten.expand.default", op::translate_expand_fx},
{"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx},
{"aten.fill.Scalar", op::translate_fill},
Expand Down Expand Up @@ -788,7 +794,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.leaky_relu_.default", op::inplace_op<op::translate_leaky_relu_fx>},
{"aten.lift_fresh_copy.default", op::skip_node},
{"aten.linalg_vector_norm.default", op::translate_linalg_vector_norm},
{"aten.log.default", op::translate_log},
{"aten.log.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Log>},
{"aten.log_sigmoid_forward.default", op::translate_log_sigmoid},
{"aten.log10.default", op::translate_log10},
{"aten.log1p.default", op::translate_log1p},
Expand Down
11 changes: 11 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ OutputVector inplace_op(const NodeContext& context) {
return translation_res;
}

template <OutputVector (*T)(const NodeContext&), size_t idx>
OutputVector optional_out(const NodeContext& context) {
auto translation_res = T(context);
if (!context.input_is_none(idx)) {
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() == 1,
"inplace_op function must be used on single output translators");
context.mutate_input(idx, translation_res[0]);
}
return translation_res;
}

template <typename T>
OutputVector translate_1to1_match_1_inputs(const NodeContext& context) {
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None.");
Expand Down
Loading

0 comments on commit e269340

Please sign in to comment.