From 238a3f0e9b862600f9c6791cf9c6c5351009c593 Mon Sep 17 00:00:00 2001 From: ynimmaga Date: Wed, 6 Mar 2024 15:03:45 -0800 Subject: [PATCH 1/2] Added embedding_bag and fixed unbind int --- .../pytorch/torchdynamo/op_support.py | 1 + .../pytorch/src/op/embedding_bag.cpp | 20 +++++++++++++++---- src/frontends/pytorch/src/op/split.cpp | 14 +++++++++---- src/frontends/pytorch/src/op_table.cpp | 2 ++ 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 0b4e69624c4aaa..d2fda2a67267c7 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -33,6 +33,7 @@ def __init__(self, options): "torch.ops.aten._adaptive_avg_pool2d.default": None, "torch.ops.aten._adaptive_avg_pool3d.default": None, "torch.ops.aten._convolution.default": None, + "torch.ops.aten._embedding_bag.default": None, "torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None, "torch.ops.aten._local_scalar_dense.default": None, "torch.ops.aten._log_softmax.default": None, diff --git a/src/frontends/pytorch/src/op/embedding_bag.cpp b/src/frontends/pytorch/src/op/embedding_bag.cpp index 4560ea2a09db4f..999f597f196324 100644 --- a/src/frontends/pytorch/src/op/embedding_bag.cpp +++ b/src/frontends/pytorch/src/op/embedding_bag.cpp @@ -15,10 +15,9 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_embedding_bag(const NodeContext& context) { +OutputVector translate_embedding_bag_common(const NodeContext& context) { // aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False, // per_sample_weights=None, include_last_offset=False, padding_idx=None) - num_inputs_check(context, 9, 9); // we have only EmbeddingBagSum case support, check it before translation auto mode = context.const_input(4); PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation"); @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) { // with offsets case auto offsets = context.get_input(2); offsets = context.mark_node(std::make_shared(offsets, element::i32)); - auto include_last_offset = context.const_input(7); + bool include_last_offset = false; + if (!context.input_is_none(7)) + include_last_offset = context.const_input(7); PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported"); // no per_sample_wights if (context.input_is_none(6)) { @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) { return {result, zero, zero, zero}; }; +OutputVector translate_embedding_bag(const NodeContext& context) { + num_inputs_check(context, 9, 9); + return translate_embedding_bag_common(context); +} + +OutputVector translate_embedding_bag_fx(const NodeContext& context) { + num_inputs_check(context, 7, 9); + ov::OutputVector output = translate_embedding_bag_common(context); + return {context.mark_node(make_list_construct(output))}; +} + } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op/split.cpp b/src/frontends/pytorch/src/op/split.cpp index b8345a0b4a9700..45689ccb695e59 100644 --- a/src/frontends/pytorch/src/op/split.cpp +++ b/src/frontends/pytorch/src/op/split.cpp @@ -37,12 +37,18 @@ OutputVector translate_chunk_fx(const NodeContext& context) { } OutputVector translate_unbind_int_fx(const NodeContext& context) { - num_inputs_check(context, 2, 3); + num_inputs_check(context, 1, 3); auto input = context.get_input(0); - auto dim = context.get_input(1); - auto dim_val = context.const_input(1); + Output dim; + int64_t dim_val = 0; + if (context.input_is_none(1)) { + dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + } else { + dim = context.get_input(1); + dim_val = context.const_input(1); + } + auto shape = input.get_shape(); - if (dim_val < 0) { dim_val = static_cast(shape.size()) + dim_val; } diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 3f91f55ae42272..16e879ead9cfe4 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -253,6 +253,7 @@ OP_CONVERTER(translate_constant_pad_nd_fx); OP_CONVERTER(translate_cumsum_fx); OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_div_fx); +OP_CONVERTER(translate_embedding_bag_fx); OP_CONVERTER(translate_expand_fx); OP_CONVERTER(translate_fake_quantize_per_channel_affine_fx); OP_CONVERTER(translate_fake_quantize_per_tensor_affine_fx); @@ -691,6 +692,7 @@ const std::map get_supported_ops_fx() { {"aten._adaptive_avg_pool2d.default", op::translate_adaptive_avg_pool2d}, {"aten._adaptive_avg_pool3d.default", op::translate_adaptive_avg_pool3d}, {"aten._convolution.default", op::translate_convolution}, + {"aten._embedding_bag.default", op::translate_embedding_bag_fx}, {"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default", op::translate_fake_quantize_per_tensor_affine_fx}, {"aten._local_scalar_dense.default", op::skip_node}, From 255a78b4e0d1dd1c718615845a2c06a9b814f285 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 6 Mar 2024 15:19:44 -0800 Subject: [PATCH 2/2] Code style fix src/frontends/pytorch/src/op/split.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/frontends/pytorch/src/op/split.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/split.cpp b/src/frontends/pytorch/src/op/split.cpp index 45689ccb695e59..f43d8258859dd4 100644 --- a/src/frontends/pytorch/src/op/split.cpp +++ b/src/frontends/pytorch/src/op/split.cpp @@ -47,7 +47,6 @@ OutputVector translate_unbind_int_fx(const NodeContext& context) { dim = context.get_input(1); dim_val = context.const_input(1); } - auto shape = input.get_shape(); if (dim_val < 0) { dim_val = static_cast(shape.size()) + dim_val;