Skip to content

Commit

Permalink
Check dependencies of reorder node all the way to input_layout node f…
Browse files Browse the repository at this point in the history
…or quantized model

Signed-off-by: yuan.xiong <[email protected]>
  • Loading branch information
yuanxion committed Dec 12, 2024
1 parent 618a835 commit 19d048e
Showing 1 changed file with 88 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "broadcast_inst.h"
#include "shape_of_inst.h"
#include "read_value_inst.h"
#include "reshape_inst.h"
Expand All @@ -16,26 +17,102 @@
using namespace cldnn;

namespace {
bool conv_reorder_with_fake_quantize(program_node& node) {
if (!node.is_type<reorder>()) {
return false;
bool has_input_layout_dep(const std::vector<std::pair<cldnn::program_node*, int>>& shape_of_deps) {
for (auto& shape_of_dep : shape_of_deps) {
// input_layout node
if (shape_of_dep.first->is_type<input_layout>()) {
return true;
}
}
return false;
}

std::string dequantize_name = "DequantizationMultiply";
for (auto& dependency : node.get_dependencies()) {
if (dependency.first->is_type<convolution>()) {
auto& conv_deps = dependency.first->get_dependencies();
bool has_shape_of_dep(const std::vector<std::pair<cldnn::program_node*, int>>& broadcast_deps) {
for (auto& broadcast_dep : broadcast_deps) {
// shape_of node
if (broadcast_dep.first->is_type<shape_of>()) {
auto& shape_of_deps = broadcast_dep.first->get_dependencies();
return has_input_layout_dep(shape_of_deps);
}
}
return false;
}

bool has_broadcast_dep(const std::vector<std::pair<cldnn::program_node*, int>>& reorder_deps) {
for (auto& reorder_dep : reorder_deps) {
// broadcast node
if (reorder_dep.first->is_type<broadcast>()) {
auto& broadcast_deps = reorder_dep.first->get_dependencies();
return has_shape_of_dep(broadcast_deps);
}
}
return false;
}

for (auto& conv_dep : conv_deps) {
if (conv_dep.first->id().find(dequantize_name) != std::string::npos) {
return true;
bool has_reorder_reoder_dep(const std::vector<std::pair<cldnn::program_node*, int>>& eltwise_deps) {
for (auto& eltwise_dep : eltwise_deps) {
// reorder node (reorder -> eltwise)
if (eltwise_dep.first->is_type<reorder>()) {
auto& eltwise_dep_reorder_deps = eltwise_dep.first->get_dependencies();

for (auto& eltwise_dep_reorder_dep : eltwise_dep_reorder_deps) {
// reorder node (broadcast -> reorder)
if (eltwise_dep_reorder_dep.first->is_type<reorder>()) {
auto& reorder_dep_reorder_deps = eltwise_dep_reorder_dep.first->get_dependencies();
return has_broadcast_dep(reorder_dep_reorder_deps);
}
}
}
}
return false;
}

bool has_eltwise_dep(const std::vector<std::pair<cldnn::program_node*, int>>& reorder_deps) {
for (auto& reorder_dep : reorder_deps) {
// eltwise node
if (reorder_dep.first->is_type<eltwise>()) {
auto& eltwise_deps = reorder_dep.first->get_dependencies();
return has_reorder_reoder_dep(eltwise_deps);
}
}
return false;
}

bool has_reorder_dep(const std::vector<std::pair<cldnn::program_node*, int>>& conv_deps) {
for (auto& conv_dep : conv_deps) {
//if (conv_dep.first->id().find(dequantize_name) != std::string::npos) {

// reorder node ( reorder -> convolution)
if (conv_dep.first->is_type<reorder>()) {
auto& reorder_deps = conv_dep.first->get_dependencies();
return has_eltwise_dep(reorder_deps);
}
}
return false;
}

bool has_convolution_dep(const std::vector<std::pair<cldnn::program_node*, int>>& dependencies) {
for (auto& dependency : dependencies) {
// convolution node
if (dependency.first->is_type<convolution>()) {
auto& conv_deps = dependency.first->get_dependencies();
return has_reorder_dep(conv_deps);
}
}
return false;
}

// check dependencies for reorder node added for convolution in quantized model
bool skip_quantization_conv_reorder(const program_node& node) {
// reorder -> convolution -> reorder -> eltwise -> reorder -> reorder -> broadcast -> shape_of -> input_layout
if (!node.is_type<reorder>()) {
return false;
}

auto& dependencies = node.get_dependencies();
return has_convolution_dep(dependencies);
}

} // namespace

void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
Expand All @@ -45,7 +122,7 @@ void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
}

// skip mark_node for reorder node (after convolution node) for quantized model
if (conv_reorder_with_fake_quantize(node)) {
if (skip_quantization_conv_reorder(node)) {
return;
}

Expand Down

0 comments on commit 19d048e

Please sign in to comment.