Skip to content

Commit

Permalink
Scaled dot product attention (openvinotoolkit#20492)
Browse files Browse the repository at this point in the history
* Added experimental ScaledDotProductAttention operation in opset12. Supported in PT FE for aten::scaled_dot_product_attention translation. Decomposed in the common optimizations as functional reference.

* Better ScaledDotProductAttention

- Moved decomposition to the decomposing transformation
- Implemented more ctors for the op
- Renamed is_causal to causal
- Shape/type inference native code instead of using decomposition
- Moved the op from opset12 to opset13
- Added Python wrapper for ScaledDotProductAttention

* Fix test that counts ops in the opsets

* Update src/core/src/op/scaled_dot_product_attention.cpp

Co-authored-by: Katarzyna Mitrus <[email protected]>

* Update src/core/src/op/scaled_dot_product_attention.cpp

Co-authored-by: Katarzyna Mitrus <[email protected]>

* Move ScaledDotProductAttentionDecomposition from fusions to decompositions.

* Remove not used legacy shape inference in ScaledDotProductAttention

* Better namespace usage

* Register all nodes in ScaledDotProductDecomposition for correct tracking of nodes and running next mather passes on all new nodes.

* Don't use register_new_node_

* ScaledDotProductAttention specification (with an extra scale argument)

* Code style fix

* Scale input implementation for ScaledDotProductAttention

* Handle attention_mask=0 case in the op spec

* Better description of scale input

* N->M in scale description

* Code style fix, remove debug print.

* Apply suggestions from code review

Co-authored-by: Katarzyna Mitrus <[email protected]>
Co-authored-by: Mateusz Mikolajczyk <[email protected]>

* Fix for case when is_causal is not passed

* Extended description of ScaledDotProduct op

* Better description in py op wrapper

* Basic shape propagation tests for ScaledDotProductAttention

* Added ScaledDotProductAttention to toc.

* Add op impl check

---------

Co-authored-by: Katarzyna Mitrus <[email protected]>
Co-authored-by: Mateusz Mikolajczyk <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2023
1 parent f627172 commit 8541586
Show file tree
Hide file tree
Showing 17 changed files with 688 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ Table of Contents
* :doc:`ROIPooling <openvino_docs_ops_detection_ROIPooling_1>`
* :doc:`Roll <openvino_docs_ops_movement_Roll_7>`
* :doc:`Round <openvino_docs_ops_arithmetic_Round_5>`
* :doc:`ScaledDotProductAttention <openvino_docs_ops_sequence_ScaledDotProductAttention_13>`
* :doc:`ScatterElementsUpdate <openvino_docs_ops_movement_ScatterElementsUpdate_12>`
* :doc:`ScatterNDUpdate <openvino_docs_ops_movement_ScatterNDUpdate_3>`
* :doc:`ScatterUpdate <openvino_docs_ops_movement_ScatterUpdate_3>`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
ROIPooling-1 <openvino_docs_ops_detection_ROIPooling_1>
Roll-7 <openvino_docs_ops_movement_Roll_7>
Round-5 <openvino_docs_ops_arithmetic_Round_5>
ScaledDotProductAttention-13 <openvino_docs_ops_sequence_ScaledDotProductAttention_13>
ScatterElementsUpdate-3 <openvino_docs_ops_movement_ScatterElementsUpdate_3>
ScatterElementsUpdate-12 <openvino_docs_ops_movement_ScatterElementsUpdate_12>
ScatterNDUpdate-3 <openvino_docs_ops_movement_ScatterNDUpdate_3>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# ScaledDotProductAttention {#openvino_docs_ops_sequence_ScaledDotProductAttention_13}

@sphinxdirective

.. meta::
:description: Learn about ScaledDotProductAttention-13 - a basic block for the transformer attention mechanism.

**Versioned name**: *ScaledDotProductAttention-13*

**Category**: *Sequence processing*

**Short description**: *ScaledDotProductAttention* partially implements
`torch.nn.functional.scaled_dot_product_attention <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`__,
omitting training-related parameter.

**Detailed description**:

*ScaledDotProductAttention* provides functionality according to the following pseudo-code using other operations from OpenVINO opset and ``numpy``:

.. code-block:: py

def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, causal):
L, S = Gather(ShapeOf(query), -2), Gather(ShapeOf(key), -2)
if scale is None:
scale = 1.0 / Sqrt(ConvertLike(Gather(ShapeOf(query), -1), query))
attn_bias = Broadcast(ConvertLike(0, query), [L, S])
if causal:
attn_bias = numpy.triu(Broadcast(ConvertLike(-inf, query), [L, S]), k=1)
elif attn_mask is not None:
if attn_mask.element_type == boolean:
attn_bias = Select(LogicalNot(attn_mask), ConvertLike(-inf, query), ConvertLike(0, query))
else:
attn_bias += attn_mask
attn_weight = MatMul(query, Transpose(key, [-2, -1])) * scale
attn_weight += attn_bias
attn_weight = Softmax(attn_weight, axis=-1)
return MatMul(attn_weight, value)


**Attributes**

* *causal*

* **Description**: If true, assumes causal attention masking according to the pseudo-code. In this case ``attention_mask`` input described below is ignored.
* **Range of values**: a boolean value
* **Type**: ``bool``
* **Required**: *yes*


**Inputs**

* **1**: ``query`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., L, E]``. **Required.**

* **2**: ``key`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., S, E]``. **Required.**

* **3**: ``value`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., S, Ev]``. **Required.**

* **4**: ``attention_mask`` - two options:
** at least 3 dimensional tensor of type *T* or ``boolean`` and shape ``[M, ..., L, S]``, or
** a scalar of type *T* with value ``0``. Scalar zero value is used to indicate that `attention_mask` is really not required to be applied (``attention_mask=None`` in the pseudo-code above) but ``scale`` is required to be set.

``attention_mask`` is ignored if ``causal`` is set to ``True``. **Optional.**

* **5**: ``scale`` a scalar tensor of type *T*, an alternative scale factor instead of 1/sqrt(query.shape[-1]) used by default in the pseudo-code above. **Optional.**


**Outputs**

* **1**: - the result of scaled dot-product attention, a tensor of type *T* and shape ``[N, ..., L, Ev]``.

**Types**

* *T*: any supported floating-point type.


**Dimensions**

* ``N, ...`` - one or more batch dimensions

* ``S`` - source sequence length

* ``L`` - target sequence length

* ``E`` - embedding dimension of the query and key

* ``Ev`` - embedding dimension of the value

* ``M, ...`` - one of more batch dimensions of the mask, should be broadcastable to ``N, ...``

At least one batch dimension ``N`` is required and should match among ``query``, ``key`` and ``value`` inputs.
Other batch dimensions ``...`` are optional, if present should match among ``query``, ``key`` and ``value`` inputs as well.


**Example**

.. code-block:: xml
:force:

<layer id="285" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
</port>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
</port>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
</port>
<port id="3" precision="FP32">
<dim>1</dim>
<dim>1</dim>
<dim>-1</dim>
<dim>-1</dim>
</port>
</input>
<output>
<port id="4" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
</port>
</output>
</layer>

@endsphinxdirective
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset13.ops import scaled_dot_product_attention
from openvino.runtime.opset12.ops import scatter_elements_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
Expand Down
39 changes: 37 additions & 2 deletions src/bindings/python/src/openvino/runtime/opset13/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def multinomial(
inputs = as_nodes(probs, num_samples)

if global_seed < 0:
raise RuntimeError(f"global_seed should be positive or 0. Got: {global_seed}")
raise RuntimeError(
f"global_seed should be positive or 0. Got: {global_seed}")

if op_seed < 0:
raise RuntimeError(f"op_seed should be positive or 0. Got: {op_seed}")
Expand Down Expand Up @@ -178,7 +179,8 @@ def nms_rotated(
:param clockwise: Flag that specifies direction of the box rotation.
:return: The new node which performs NMSRotated
"""
inputs = as_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
inputs = as_nodes(boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold)

attributes = {
"sort_result_descending": sort_result_descending,
Expand All @@ -187,3 +189,36 @@ def nms_rotated(
}

return _get_node_factory_opset13().create("NMSRotated", inputs, attributes)


@nameable_op
def scaled_dot_product_attention(
query: NodeInput,
key: NodeInput,
value: NodeInput,
attention_mask: Optional[NodeInput] = None,
scale: Optional[NodeInput] = None,
causal: bool = False,
name: Optional[str] = None,
) -> Node:
"""Return a node which implements Scaled Dot Product Attention.
:param query: Query tensor of shape [N, ..., L, E] and floating-point datatype.
:param key: Key tensor of shape [N, ..., S, E] and floating-point datatype.
:param value: Value tensor of shape [N, ..., S, Ev] and floating-point datatype.
:param attention_mask: Optional attention mask tensor of shape [N, ..., L, S] or scalar float type zero value.
Refer to the operation specification for a complete description.
:param scale: Optional alternative scale, a floating-point type scalar.
:param causal: If true, then autogenerates causal attention mask instead of using attention_mask input.
In this case attention_mask input is ignored.
:param name: The optional new name for output node.
:return: The new node performing Scaled Dot Product Attention operation.
"""
inputs = as_nodes(query, key, value, attention_mask) if attention_mask is not None else as_nodes(
query, key, value, scale)

attributes = {
"causal": causal,
}
return _get_node_factory_opset13().create("ScaledDotProductAttention", inputs, attributes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ScaledDotProductAttentionDecomposition;

} // namespace pass
} // namespace ov

class ov::pass::ScaledDotProductAttentionDecomposition : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ScaledDotProductAttentionDecomposition", "0");
ScaledDotProductAttentionDecomposition();
std::shared_ptr<ov::Node> decompose(std::shared_ptr<ov::op::v13::ScaledDotProductAttention> node);
};
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
#include "transformations/op_conversions/normalize_l2_decomposition.hpp"
#include "transformations/op_conversions/reduce_l1_decomposition.hpp"
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp"
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
#include "transformations/op_conversions/softmax_decomposition.hpp"
#include "transformations/op_conversions/softsign_decomposition.hpp"
Expand Down Expand Up @@ -145,6 +146,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_DISABLED_PASS(manager, ConvertInterpolate1ToInterpolate4)

auto decomp = manager.register_pass<GraphRewrite>();
ADD_MATCHER(decomp, ScaledDotProductAttentionDecomposition)
ADD_MATCHER(decomp, Gelu7Downgrade)
ADD_MATCHER(decomp, BidirectionalSequenceDecomposition)
ADD_MATCHER(decomp, ReduceL1Decomposition)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp"

#include <memory>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::ScaledDotProductAttentionDecomposition::ScaledDotProductAttentionDecomposition() {
MATCHER_SCOPE(ScaledDotProductAttentionDecomposition);
auto pattern_node = ov::pass::pattern::wrap_type<ov::op::v13::ScaledDotProductAttention>();

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto node = std::dynamic_pointer_cast<ov::op::v13::ScaledDotProductAttention>(
pattern_to_output.at(pattern_node).get_node_shared_ptr());

if (node == nullptr || transformation_callback(node)) {
return false;
}

auto new_output_node = decompose(node);
ov::replace_node(node, new_output_node);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_node, matcher_name);
register_matcher(m, callback);
}

std::shared_ptr<ov::Node> ov::pass::ScaledDotProductAttentionDecomposition::decompose(
std::shared_ptr<ov::op::v13::ScaledDotProductAttention> node) {
using namespace ov::op;
auto query = node->input_value(0);
auto key = node->input_value(1);
auto value = node->input_value(2);
auto q_shape = register_new_node<v3::ShapeOf>(query, element::i32);
auto k_shape = register_new_node<v3::ShapeOf>(key, element::i32);
auto minus_one = register_new_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto minus_two = register_new_node(v0::Constant::create(element::i32, Shape{}, {-2}));
auto zero_i = register_new_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one_i = register_new_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto one_f = register_new_node<v1::ConvertLike>(one_i, query);
auto zero_f = register_new_node<v1::ConvertLike>(zero_i, query);

Output<Node> scale;
if (node->get_input_size() < 5) {
scale = register_new_node<v8::Gather>(q_shape, minus_one, zero_i)->output(0);
scale = register_new_node<v1::ConvertLike>(scale, query);
auto sqrt_scale = register_new_node<v0::Sqrt>(scale);
scale = register_new_node<v1::Divide>(one_f, sqrt_scale);
} else {
scale = node->input_value(4);
}

auto q_scaled = register_new_node<v1::Multiply>(query, scale);
auto k_rank = register_new_node<v3::ShapeOf>(k_shape, element::i32)->output(0);
auto k_last_dim = register_new_node<v1::Add>(k_rank, minus_one);
auto k_next_dim = register_new_node<v1::Add>(k_rank, minus_two)->output(0);
k_rank = register_new_node<v0::Squeeze>(k_rank, zero_i);
auto minus_inf =
register_new_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}))
->output(0);
auto keep_dim_last = register_new_node<v0::Squeeze>(k_next_dim, zero_i);
auto k_dims_before_transpose = register_new_node<v4::Range>(zero_i, keep_dim_last, one_i, element::i32);

auto transpose_dims =
register_new_node<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0);
auto k_transposed = register_new_node<v1::Transpose>(key, transpose_dims);
auto scaled_atten = register_new_node<v0::MatMul>(q_scaled, k_transposed)->output(0);
minus_inf = register_new_node<v1::ConvertLike>(minus_inf, scaled_atten);

if (node->get_causal() || node->get_input_size() > 3) {
Output<Node> mask;
Output<Node> atten_mask;
if (!node->get_causal()) {
mask = node->input_value(3);

// two types of masks are supported. A boolean mask where a value of True indicates that the element should
// take part in attention. A float mask of the same type as query, key, value that is added to the attention
// score.
if (mask.get_element_type() == element::boolean) {
atten_mask = register_new_node<v1::ConvertLike>(mask, scaled_atten);
auto inv_mask = register_new_node<v1::LogicalNot>(mask);
atten_mask = register_new_node<v1::Select>(inv_mask, atten_mask, minus_inf);
} else {
atten_mask = mask;
}
} else {
auto target_s_len = register_new_node<v8::Gather>(q_shape, minus_two, zero_i);
auto source_s_len = register_new_node<v8::Gather>(k_shape, minus_two, zero_i);
auto ssl = register_new_node<v0::Unsqueeze>(source_s_len, zero_i);
auto tsl = register_new_node<v0::Unsqueeze>(target_s_len, zero_i);
auto mask_shape = register_new_node<v0::Concat>(OutputVector{tsl, ssl}, 0);
mask = register_new_node<v1::Broadcast>(minus_inf, mask_shape);
auto horizontal_range = register_new_node<v4::Range>(zero_i, source_s_len, one_i, element::i32)->output(0);
horizontal_range = register_new_node<v0::Unsqueeze>(horizontal_range, zero_i);
auto stop = register_new_node<v1::Add>(target_s_len, one_i);
auto vertical_range = register_new_node<v4::Range>(one_i, stop, one_i, element::i32)->output(0);
vertical_range = register_new_node<v0::Unsqueeze>(vertical_range, one_i);
auto triu = register_new_node<v1::GreaterEqual>(horizontal_range, vertical_range);
atten_mask = register_new_node<v1::Select>(triu, mask, zero_f);
}
scaled_atten = register_new_node<v1::Add>(scaled_atten, atten_mask);
}

scaled_atten = register_new_node<v8::Softmax>(scaled_atten, -1);
auto result = register_new_node<v0::MatMul>(scaled_atten, value);
result->set_friendly_name(node->get_friendly_name());
copy_runtime_info(node, get_new_nodes());
return result;
}
Loading

0 comments on commit 8541586

Please sign in to comment.