Skip to content

Commit

Permalink
[GPU] Review fixes 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Dec 12, 2024
1 parent 944a29d commit f092cf2
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@ LoRAHorizontalFusion::LoRAHorizontalFusion() {
auto is_lora_pattern = [](const std::shared_ptr<Node>& node) {
#define check(node) if (!node) return false;

const auto& add = std::dynamic_pointer_cast<ov::op::v1::Add>(node); check(add)
const auto& matmul2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(0)) ?
std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(0)) :
std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(1)); check(matmul2)
const auto& multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(matmul2->get_input_node_shared_ptr(0)); check(multiply)
const auto& variable_b = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul2->get_input_node_shared_ptr(1)); check(variable_b)
const auto& matmul1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(multiply->get_input_node_shared_ptr(0)); check(matmul1)
const auto& variable_alpha = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(multiply->get_input_node_shared_ptr(1)); check(variable_alpha)
const auto& variable_a = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul1->get_input_node_shared_ptr(1)); check(variable_a)
const auto& add = std::dynamic_pointer_cast<ov::op::v1::Add>(node); check(add)

size_t matmul2_idx = ov::is_type<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(0)) ? 0 : 1;
const auto& matmul2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(matmul2_idx)); check(matmul2)

const auto& multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(matmul2->get_input_node_shared_ptr(0)); check(multiply)

const auto& variable_b = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul2->get_input_node_shared_ptr(1)); check(variable_b)

size_t matmul1_idx = ov::is_type<ov::op::v0::MatMul>(multiply->get_input_node_shared_ptr(0)) ? 0 : 1;
const auto& matmul1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(multiply->get_input_node_shared_ptr(matmul1_idx)); check(matmul1)

size_t alpha_idx = (matmul1_idx + 1) % 2;
const auto& variable_alpha =
std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(multiply->get_input_node_shared_ptr(alpha_idx)); check(variable_alpha)

const auto& variable_a = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul1->get_input_node_shared_ptr(1)); check(variable_a)

#undef check
return true;
Expand Down Expand Up @@ -68,17 +76,19 @@ LoRAHorizontalFusion::LoRAHorizontalFusion() {
for (const auto& add : split->get_users()) {
add_nodes.emplace_back(add);

bool first_input_matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(0)) != nullptr;
matmul2_nodes.emplace_back(first_input_matmul ? add->get_input_node_shared_ptr(0)
: add->get_input_node_shared_ptr(1));
size_t matmul2_idx = ov::is_type<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(0)) ? 0 : 1;
matmul2_nodes.emplace_back(add->get_input_node_shared_ptr(matmul2_idx));
}
for (const auto& matmul2 : matmul2_nodes) {
multiply_nodes.emplace_back(matmul2->get_input_node_shared_ptr(0));
variable_b_nodes.emplace_back(matmul2->get_input_node_shared_ptr(1));
}
for (const auto& multiply : multiply_nodes) {
matmul1_nodes.emplace_back(multiply->get_input_node_shared_ptr(0));
variable_alpha_nodes.emplace_back(multiply->get_input_node_shared_ptr(1));
size_t matmul1_idx = ov::is_type<ov::op::v0::MatMul>(multiply->get_input_node_shared_ptr(0)) ? 0 : 1;
matmul1_nodes.emplace_back(multiply->get_input_node_shared_ptr(matmul1_idx));

size_t alpha_idx = (matmul1_idx + 1) % 2;
variable_alpha_nodes.emplace_back(multiply->get_input_node_shared_ptr(alpha_idx));
}
for (const auto& matmul1 : matmul1_nodes) {
variable_a_nodes.emplace_back(matmul1->get_input_node_shared_ptr(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,75 @@
namespace ov {
namespace intel_gpu {

// Before:
// ┌─────────┐ ┌─────────┐
// │ReadValue│ │ReadValue│
// └────┬────┘ └────┬────┘
// │ ┌───────────┐ │
// │ ┌───────────────────────┼ LoraInput ┼───────────────────┐ │
// │ │ └─────┬─────┘ │ │
// │ ┌────▼───┐ │ ┌────▼───┐ │
// └────► Gemm │ │ │ Gemm ◄──────┘
// ┌─────────┐ └────┬───┘ │ └────┬───┘ ┌─────────┐
// │ReadValue│ │ │ │ │ReadValue│
// └────┬────┘ │ ┌───────────▼────────────┐ │ └────┬────┘
// │ ┌────▼───┐ │FullyConnectedCompressed│ ┌────▼───┐ │
// └─────────────►Multiply│ └───────────┬────────────┘ │Multiply◄────────────┘
// └────┬───┘ │ └────────┘
// ┌─────────┐ │ │ │ ┌─────────┐
// │ReadValue│ │ │ │ │ReadValue│
// └────┬────┘ │ │ │ └────┬────┘
// │ ┌────▼───┐ ┌──────▼──────┐ ┌────▼───┐ │
// └─────────────► Gemm │ ┌───────────┼VariadicSplit┼──────────┐ │ Gemm ◄────────────────┘
// └────┬───┘ │ └──────┬──────┘ │ └────┬───┘
// │ │ │ │ │
// │ │ │ │ │
// │ │ │ │ │
// │ ┌──▼──┐ ▼ ┌──▼──┐ │
// └───────► Add │ ... │ Add ◄────┘
// └─────┘ └─────┘
// After:
// ┌─────────┐
// ┌────┼ReadValue│
// ┌──────────┐ ┌──────┐ │ └─────────┘
// │LoRA_Input┼────────────────────────────┐ ┌─────────────┼Concat◄─────┤ ...
// └────┬─────┘ │ │ └──────┘ │ ┌─────────┐
// │ │ │ └────┼ReadValue│
// │ │ │ └─────────┘
// │ ┌────▼──▼───┐
// │ │MatMulFused│
// │ └───────────┘
// │ │ ┌─────────┐
// │ │ ┌────┼ReadValue│
// │ │ ┌──────┐ │ └─────────┘
// │ │ ┌────────┼Concat◄─────┤ ...
// │ │ │ └──────┘ │ ┌─────────┐
// │ │ │ └────┼ReadValue│
// ┌───────────▼────────────┐ ┌───▼──────▼──┐ └─────────┘
// │FullyConnectedCompressed│ │MultiplyFused│
// └───────────┬────────────┘ └──────┬──────┘
// │ │
// │ ┌─────────┐ │ ┌─────────┐
// │ │ReadValue│ ┌──▼──┐ │ReadValue│
// │ └────┬────┘ │Split│ └────┬────┘
// │ │ └──┬──┘ │
// │ │ │ │
// │ │ ┌────────┼────────┐ │
// │ │ │ │ │
// │ ┌──▼──▼──┐ ┌──▼──▼──┐
// │ │ MatMul │ ... │ MatMul │
// │ └────┬───┘ └────┬───┘
// │ └──────┐ ┌────────┘
// │ │ │
// │ ┌─────┐ ┌─▼────▼─┐
// └─────────────► Add ◄─────────────┼ Concat │
// └──┬──┘ └────────┘
//
//
// ┌──────▼──────┐
// │VariadicSplit│
// └─────────────┘

class LoRAHorizontalFusion: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("LoRAHorizontalFusion", "0");
Expand Down
Loading

0 comments on commit f092cf2

Please sign in to comment.