forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Scaled dot product attention (openvinotoolkit#20492)
* Added experimental ScaledDotProductAttention operation in opset12. Supported in PT FE for aten::scaled_dot_product_attention translation. Decomposed in the common optimizations as functional reference. * Better ScaledDotProductAttention - Moved decomposition to the decomposing transformation - Implemented more ctors for the op - Renamed is_causal to causal - Shape/type inference native code instead of using decomposition - Moved the op from opset12 to opset13 - Added Python wrapper for ScaledDotProductAttention * Fix test that counts ops in the opsets * Update src/core/src/op/scaled_dot_product_attention.cpp Co-authored-by: Katarzyna Mitrus <[email protected]> * Update src/core/src/op/scaled_dot_product_attention.cpp Co-authored-by: Katarzyna Mitrus <[email protected]> * Move ScaledDotProductAttentionDecomposition from fusions to decompositions. * Remove not used legacy shape inference in ScaledDotProductAttention * Better namespace usage * Register all nodes in ScaledDotProductDecomposition for correct tracking of nodes and running next mather passes on all new nodes. * Don't use register_new_node_ * ScaledDotProductAttention specification (with an extra scale argument) * Code style fix * Scale input implementation for ScaledDotProductAttention * Handle attention_mask=0 case in the op spec * Better description of scale input * N->M in scale description * Code style fix, remove debug print. * Apply suggestions from code review Co-authored-by: Katarzyna Mitrus <[email protected]> Co-authored-by: Mateusz Mikolajczyk <[email protected]> * Fix for case when is_causal is not passed * Extended description of ScaledDotProduct op * Better description in py op wrapper * Basic shape propagation tests for ScaledDotProductAttention * Added ScaledDotProductAttention to toc. * Add op impl check --------- Co-authored-by: Katarzyna Mitrus <[email protected]> Co-authored-by: Mateusz Mikolajczyk <[email protected]>
- Loading branch information
1 parent
f627172
commit 8541586
Showing
17 changed files
with
688 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
.../operation_sets/operations_specifications/sequence/ScaledDotProductAttention.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# ScaledDotProductAttention {#openvino_docs_ops_sequence_ScaledDotProductAttention_13} | ||
|
||
@sphinxdirective | ||
|
||
.. meta:: | ||
:description: Learn about ScaledDotProductAttention-13 - a basic block for the transformer attention mechanism. | ||
|
||
**Versioned name**: *ScaledDotProductAttention-13* | ||
|
||
**Category**: *Sequence processing* | ||
|
||
**Short description**: *ScaledDotProductAttention* partially implements | ||
`torch.nn.functional.scaled_dot_product_attention <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`__, | ||
omitting training-related parameter. | ||
|
||
**Detailed description**: | ||
|
||
*ScaledDotProductAttention* provides functionality according to the following pseudo-code using other operations from OpenVINO opset and ``numpy``: | ||
|
||
.. code-block:: py | ||
|
||
def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, causal): | ||
L, S = Gather(ShapeOf(query), -2), Gather(ShapeOf(key), -2) | ||
if scale is None: | ||
scale = 1.0 / Sqrt(ConvertLike(Gather(ShapeOf(query), -1), query)) | ||
attn_bias = Broadcast(ConvertLike(0, query), [L, S]) | ||
if causal: | ||
attn_bias = numpy.triu(Broadcast(ConvertLike(-inf, query), [L, S]), k=1) | ||
elif attn_mask is not None: | ||
if attn_mask.element_type == boolean: | ||
attn_bias = Select(LogicalNot(attn_mask), ConvertLike(-inf, query), ConvertLike(0, query)) | ||
else: | ||
attn_bias += attn_mask | ||
attn_weight = MatMul(query, Transpose(key, [-2, -1])) * scale | ||
attn_weight += attn_bias | ||
attn_weight = Softmax(attn_weight, axis=-1) | ||
return MatMul(attn_weight, value) | ||
|
||
|
||
**Attributes** | ||
|
||
* *causal* | ||
|
||
* **Description**: If true, assumes causal attention masking according to the pseudo-code. In this case ``attention_mask`` input described below is ignored. | ||
* **Range of values**: a boolean value | ||
* **Type**: ``bool`` | ||
* **Required**: *yes* | ||
|
||
|
||
**Inputs** | ||
|
||
* **1**: ``query`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., L, E]``. **Required.** | ||
|
||
* **2**: ``key`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., S, E]``. **Required.** | ||
|
||
* **3**: ``value`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., S, Ev]``. **Required.** | ||
|
||
* **4**: ``attention_mask`` - two options: | ||
** at least 3 dimensional tensor of type *T* or ``boolean`` and shape ``[M, ..., L, S]``, or | ||
** a scalar of type *T* with value ``0``. Scalar zero value is used to indicate that `attention_mask` is really not required to be applied (``attention_mask=None`` in the pseudo-code above) but ``scale`` is required to be set. | ||
|
||
``attention_mask`` is ignored if ``causal`` is set to ``True``. **Optional.** | ||
|
||
* **5**: ``scale`` a scalar tensor of type *T*, an alternative scale factor instead of 1/sqrt(query.shape[-1]) used by default in the pseudo-code above. **Optional.** | ||
|
||
|
||
**Outputs** | ||
|
||
* **1**: - the result of scaled dot-product attention, a tensor of type *T* and shape ``[N, ..., L, Ev]``. | ||
|
||
**Types** | ||
|
||
* *T*: any supported floating-point type. | ||
|
||
|
||
**Dimensions** | ||
|
||
* ``N, ...`` - one or more batch dimensions | ||
|
||
* ``S`` - source sequence length | ||
|
||
* ``L`` - target sequence length | ||
|
||
* ``E`` - embedding dimension of the query and key | ||
|
||
* ``Ev`` - embedding dimension of the value | ||
|
||
* ``M, ...`` - one of more batch dimensions of the mask, should be broadcastable to ``N, ...`` | ||
|
||
At least one batch dimension ``N`` is required and should match among ``query``, ``key`` and ``value`` inputs. | ||
Other batch dimensions ``...`` are optional, if present should match among ``query``, ``key`` and ``value`` inputs as well. | ||
|
||
|
||
**Example** | ||
|
||
.. code-block:: xml | ||
:force: | ||
|
||
<layer id="285" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13"> | ||
<data causal="false" /> | ||
<input> | ||
<port id="0" precision="FP32"> | ||
<dim>1</dim> | ||
<dim>32</dim> | ||
<dim>-1</dim> | ||
<dim>80</dim> | ||
</port> | ||
<port id="1" precision="FP32"> | ||
<dim>1</dim> | ||
<dim>32</dim> | ||
<dim>-1</dim> | ||
<dim>80</dim> | ||
</port> | ||
<port id="2" precision="FP32"> | ||
<dim>1</dim> | ||
<dim>32</dim> | ||
<dim>-1</dim> | ||
<dim>80</dim> | ||
</port> | ||
<port id="3" precision="FP32"> | ||
<dim>1</dim> | ||
<dim>1</dim> | ||
<dim>-1</dim> | ||
<dim>-1</dim> | ||
</port> | ||
</input> | ||
<output> | ||
<port id="4" precision="FP32"> | ||
<dim>1</dim> | ||
<dim>32</dim> | ||
<dim>-1</dim> | ||
<dim>80</dim> | ||
</port> | ||
</output> | ||
</layer> | ||
|
||
@endsphinxdirective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
...ons/include/transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ScaledDotProductAttentionDecomposition; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::ScaledDotProductAttentionDecomposition : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ScaledDotProductAttentionDecomposition", "0"); | ||
ScaledDotProductAttentionDecomposition(); | ||
std::shared_ptr<ov::Node> decompose(std::shared_ptr<ov::op::v13::ScaledDotProductAttention> node); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
...mations/src/transformations/op_conversions/scaled_dot_product_attention_decomposition.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/broadcast.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert_like.hpp" | ||
#include "openvino/op/divide.hpp" | ||
#include "openvino/op/gather.hpp" | ||
#include "openvino/op/greater_eq.hpp" | ||
#include "openvino/op/logical_not.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/range.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/op/select.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "openvino/op/softmax.hpp" | ||
#include "openvino/op/sqrt.hpp" | ||
#include "openvino/op/squeeze.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/unsqueeze.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
|
||
ov::pass::ScaledDotProductAttentionDecomposition::ScaledDotProductAttentionDecomposition() { | ||
MATCHER_SCOPE(ScaledDotProductAttentionDecomposition); | ||
auto pattern_node = ov::pass::pattern::wrap_type<ov::op::v13::ScaledDotProductAttention>(); | ||
|
||
matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
auto& pattern_to_output = m.get_pattern_value_map(); | ||
auto node = std::dynamic_pointer_cast<ov::op::v13::ScaledDotProductAttention>( | ||
pattern_to_output.at(pattern_node).get_node_shared_ptr()); | ||
|
||
if (node == nullptr || transformation_callback(node)) { | ||
return false; | ||
} | ||
|
||
auto new_output_node = decompose(node); | ||
ov::replace_node(node, new_output_node); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_node, matcher_name); | ||
register_matcher(m, callback); | ||
} | ||
|
||
std::shared_ptr<ov::Node> ov::pass::ScaledDotProductAttentionDecomposition::decompose( | ||
std::shared_ptr<ov::op::v13::ScaledDotProductAttention> node) { | ||
using namespace ov::op; | ||
auto query = node->input_value(0); | ||
auto key = node->input_value(1); | ||
auto value = node->input_value(2); | ||
auto q_shape = register_new_node<v3::ShapeOf>(query, element::i32); | ||
auto k_shape = register_new_node<v3::ShapeOf>(key, element::i32); | ||
auto minus_one = register_new_node(v0::Constant::create(element::i32, Shape{}, {-1})); | ||
auto minus_two = register_new_node(v0::Constant::create(element::i32, Shape{}, {-2})); | ||
auto zero_i = register_new_node(v0::Constant::create(element::i32, Shape{}, {0})); | ||
auto one_i = register_new_node(v0::Constant::create(element::i32, Shape{}, {1})); | ||
auto one_f = register_new_node<v1::ConvertLike>(one_i, query); | ||
auto zero_f = register_new_node<v1::ConvertLike>(zero_i, query); | ||
|
||
Output<Node> scale; | ||
if (node->get_input_size() < 5) { | ||
scale = register_new_node<v8::Gather>(q_shape, minus_one, zero_i)->output(0); | ||
scale = register_new_node<v1::ConvertLike>(scale, query); | ||
auto sqrt_scale = register_new_node<v0::Sqrt>(scale); | ||
scale = register_new_node<v1::Divide>(one_f, sqrt_scale); | ||
} else { | ||
scale = node->input_value(4); | ||
} | ||
|
||
auto q_scaled = register_new_node<v1::Multiply>(query, scale); | ||
auto k_rank = register_new_node<v3::ShapeOf>(k_shape, element::i32)->output(0); | ||
auto k_last_dim = register_new_node<v1::Add>(k_rank, minus_one); | ||
auto k_next_dim = register_new_node<v1::Add>(k_rank, minus_two)->output(0); | ||
k_rank = register_new_node<v0::Squeeze>(k_rank, zero_i); | ||
auto minus_inf = | ||
register_new_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()})) | ||
->output(0); | ||
auto keep_dim_last = register_new_node<v0::Squeeze>(k_next_dim, zero_i); | ||
auto k_dims_before_transpose = register_new_node<v4::Range>(zero_i, keep_dim_last, one_i, element::i32); | ||
|
||
auto transpose_dims = | ||
register_new_node<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0); | ||
auto k_transposed = register_new_node<v1::Transpose>(key, transpose_dims); | ||
auto scaled_atten = register_new_node<v0::MatMul>(q_scaled, k_transposed)->output(0); | ||
minus_inf = register_new_node<v1::ConvertLike>(minus_inf, scaled_atten); | ||
|
||
if (node->get_causal() || node->get_input_size() > 3) { | ||
Output<Node> mask; | ||
Output<Node> atten_mask; | ||
if (!node->get_causal()) { | ||
mask = node->input_value(3); | ||
|
||
// two types of masks are supported. A boolean mask where a value of True indicates that the element should | ||
// take part in attention. A float mask of the same type as query, key, value that is added to the attention | ||
// score. | ||
if (mask.get_element_type() == element::boolean) { | ||
atten_mask = register_new_node<v1::ConvertLike>(mask, scaled_atten); | ||
auto inv_mask = register_new_node<v1::LogicalNot>(mask); | ||
atten_mask = register_new_node<v1::Select>(inv_mask, atten_mask, minus_inf); | ||
} else { | ||
atten_mask = mask; | ||
} | ||
} else { | ||
auto target_s_len = register_new_node<v8::Gather>(q_shape, minus_two, zero_i); | ||
auto source_s_len = register_new_node<v8::Gather>(k_shape, minus_two, zero_i); | ||
auto ssl = register_new_node<v0::Unsqueeze>(source_s_len, zero_i); | ||
auto tsl = register_new_node<v0::Unsqueeze>(target_s_len, zero_i); | ||
auto mask_shape = register_new_node<v0::Concat>(OutputVector{tsl, ssl}, 0); | ||
mask = register_new_node<v1::Broadcast>(minus_inf, mask_shape); | ||
auto horizontal_range = register_new_node<v4::Range>(zero_i, source_s_len, one_i, element::i32)->output(0); | ||
horizontal_range = register_new_node<v0::Unsqueeze>(horizontal_range, zero_i); | ||
auto stop = register_new_node<v1::Add>(target_s_len, one_i); | ||
auto vertical_range = register_new_node<v4::Range>(one_i, stop, one_i, element::i32)->output(0); | ||
vertical_range = register_new_node<v0::Unsqueeze>(vertical_range, one_i); | ||
auto triu = register_new_node<v1::GreaterEqual>(horizontal_range, vertical_range); | ||
atten_mask = register_new_node<v1::Select>(triu, mask, zero_f); | ||
} | ||
scaled_atten = register_new_node<v1::Add>(scaled_atten, atten_mask); | ||
} | ||
|
||
scaled_atten = register_new_node<v8::Softmax>(scaled_atten, -1); | ||
auto result = register_new_node<v0::MatMul>(scaled_atten, value); | ||
result->set_friendly_name(node->get_friendly_name()); | ||
copy_runtime_info(node, get_new_nodes()); | ||
return result; | ||
} |
Oops, something went wrong.