Skip to content

Commit

Permalink
[CPU] Transpose insertion before FC in ConvertMatMulToFC (#28401)
Browse files Browse the repository at this point in the history
### Details (Updated):
- *Previously, if `MatMul` has `transposed_b=false` and decompressed
convert on weights, the pass `ConvertMatMulToFC` inserted `Transpose`
before this `Convert`. It means that if `Convert` has another consumer
(`Result` or even `MatMul` with `transposed=true`), the inserted
`Transpose` could break the shapes of `Convert` consumers (please see
details in the mentioned ticket).
The current PR inserts `Transpose` after existing `Convert` and updates
CPUGraph-pass `FuseFCAndConvertOnWeights`.*

### Tickets:
 - *160215*
  • Loading branch information
a-sidorova authored Jan 15, 2025
1 parent 3ee2339 commit 3caf7b2
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 35 deletions.
41 changes: 32 additions & 9 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,27 +712,50 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {

// This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion
// handling based on internal logic (e.g. fuse conversion with weights reordering)

auto isSuitableTranspose = [](const NodePtr& node) {
return node->getType() == Type::Transpose && node->getChildEdges().size() == 1 && node->isConstant();
};
auto isSuitableConvert = [&](const NodePtr& node) {
return node->getType() == Type::Convert && node->isConstant() &&
one_of(node->getOriginalInputPrecisionAtPort(0), ov::element::f16, ov::element::bf16) &&
one_of(node->getOriginalOutputPrecisionAtPort(0), ov::element::f32, ov::element::bf16);
};

auto& graphNodes = graph.GetNodes();
for (const auto& fullyConnected : graphNodes) {
if (fullyConnected->getType() != Type::FullyConnected) {
continue;
}
const auto convert = fullyConnected->getParentEdgeAt(1)->getParent();
if (convert->getType() != Type::Convert ||
!one_of(convert->getOriginalInputPrecisionAtPort(0), ov::element::f16, ov::element::bf16) ||
!one_of(convert->getOriginalOutputPrecisionAtPort(0), ov::element::f32, ov::element::bf16) ||
!convert->isConstant()) {
continue;

NodePtr transpose = nullptr;
auto parent = fullyConnected->getParentEdgeAt(1)->getParent();
if (parent->getType() == Type::Transpose) {
if (!isSuitableTranspose(parent))
continue;

transpose = parent;
parent = transpose->getParentEdgeAt(0)->getParent();
}

const auto convert = parent;
if (!isSuitableConvert(convert))
continue;

const auto weights = convert->getParentEdgeAt(0)->getParent();
const auto weights_out_edge = weights->getChildEdges()[0].lock();
const auto fc_weights_path_edge = fullyConnected->getParentEdgeAt(1);
const auto fc_weights_path_edge =
transpose ? transpose->getParentEdgeAt(0) : fullyConnected->getParentEdgeAt(1);
const auto inNum = weights_out_edge->getInputNum();
const auto outNum = fc_weights_path_edge->getOutputNum();
fullyConnected->setOriginalInputPrecisionAtPort(1, convert->getOriginalInputPrecisionAtPort(0));
const auto originalPrecision = convert->getOriginalInputPrecisionAtPort(0);
fullyConnected->setOriginalInputPrecisionAtPort(1, originalPrecision);
if (transpose) {
transpose->setOriginalInputPrecisionAtPort(0, originalPrecision);
transpose->setOriginalOutputPrecisionAtPort(0, originalPrecision);
}
graph.RemoveEdge(fc_weights_path_edge);
graph.CreateEdge(weights, fullyConnected, inNum, outNum);
graph.CreateEdge(weights, transpose ? transpose : fullyConnected, inNum, outNum);
if (convert->getChildEdges().empty()) {
graph.DropNode(convert);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and fc_input_b.
auto fc_input_a = pattern_map.at(activations_m);
auto fc_input_b = pattern_map.at(weights_m);
bool is_convert = false;
if (auto convert_node = ov::as_type_ptr<ov::op::v0::Convert>(fc_input_b.get_node_shared_ptr())) {
if (is_decompression(convert_node)) {
is_convert = true;
fc_input_b = convert_node->get_input_node_shared_ptr(0);
} else {
if (!is_decompression(convert_node)) {
return false;
}
}
Expand Down Expand Up @@ -151,14 +147,6 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
}

// Connect Convert to new input if needed
if (is_convert) {
auto convert = pattern_map.at(weights_m).get_node_shared_ptr();
convert->input(0).replace_source_output(fc_input_b);
convert->validate_and_infer_types();
fc_input_b = convert;
}

auto bias = std::make_shared<ov::op::v0::Constant>(element::undefined, Shape{0});
new_ops.push_back(bias);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,18 @@ class MatMulDecompressConvertTest : public testing::WithParamInterface<MatMulDec
function = CPUTestsBase::makeNgraphFunction(netType, params, matMul, cpuNodeType);
}

void check_execution_graph() {
virtual void check_execution_graph() {
CheckPluginRelatedResults(compiledModel, "FullyConnected");
CheckNumberOfNodesWithType(compiledModel, "FullyConnected", fullyConnectedCount);
CheckNumberOfNodesWithType(compiledModel, "Transpose", transposeCount);
CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
CheckNumberOfNodesWithType(compiledModel, "Convert", convertCount);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
check_fc_weights_precision(expectedWeiConstElemType);
}

size_t fullyConnectedCount = 1;
size_t transposeCount = 0;
size_t convertCount = 0;
ElementType expectedWeiConstElemType = ElementType::f32;
};

Expand Down Expand Up @@ -410,11 +411,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_BF16,
|Output|
--------
*/
using MatMulDecompressConvertParams2 = std::tuple<std::vector<InputShape>, // input shapes
std::pair<bool, bool>, // transposeA, transposeB
ElementType, // weights precision
ov::AnyMap, // additional property
CPUSpecificParams>;

class MatMulDecompressConvertTest2 : public MatMulDecompressConvertTest {
protected:
Expand Down Expand Up @@ -519,5 +515,144 @@ INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16_2,

} // namespace


/* This test covers NNCF-case when decompression convert has not only MatMul consumer.
* Graph before:
------------ ---------------
|Input(f32)| |Constant(f16)|
------------ ---------------
| |
| ---------------------------------
| |Convert(decompression f16->f32)|
| ---------------------------------
| | |
---------------------------- -----------------------
|MatMul (transposed_b=true)| | Result |
---------------------------- -----------------------
|
-----------------------
| Result |
-----------------------
* Exec graph:
------------ -----------------------------
|Input(f32)| | Constant(f16) |
------------ -----------------------------
| | |
| ------------- ---------------------
| | Transpose | | Convert(f16->f32) |
| ------------- ---------------------
| | |
----------------------- -----------------------
| FullyConnected | | Result |
----------------------- -----------------------
|
-----------------------
| Result |
-----------------------
*/

class MatMulDecompressConvertTest3 : public MatMulDecompressConvertTest {
protected:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;

std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;
ElementType weiConstElemType;
ov::AnyMap additionalConfig;
CPUSpecificParams cpuParams;

std::tie(inputShapes, transpose, weiConstElemType, additionalConfig, cpuParams) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;

init_input_shapes(inputShapes);

bool transpA = transpose.first;
bool transpB = transpose.second;

if (transpA)
transposeCount++;
if (!transpB)
transposeCount++;

if (transpA) {
transpose_shape(inputDynamicShapes[0]);
for (auto& shapes : targetStaticShapes) {
transpose_shape(shapes[0]);
}
}
if (transpB) {
transpose_shape(inputDynamicShapes[1]);
for (auto& shapes : targetStaticShapes) {
transpose_shape(shapes[1]);
}
}

const auto& inShapeA = inputDynamicShapes[0];
const auto& inShapeB = inputDynamicShapes[1];

configuration.insert(additionalConfig.begin(), additionalConfig.end());

ElementType netType = ElementType::f32;
ElementType convertOutType = ElementType::f32;
inType = outType = netType;

std::string cpuNodeType = "FullyConnected";
selectedType = makeSelectedTypeStr(selectedType, outType);

ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inType, inShapeA)};
std::shared_ptr<ov::Node> inputB = ov::test::utils::make_constant(weiConstElemType, inShapeB.get_shape());
inputB = std::make_shared<ov::op::v0::Convert>(inputB, convertOutType);
mark_as_decompression(inputB);
expectedWeiConstElemType = weiConstElemType;
convertCount = 1;

auto matMul = std::make_shared<ov::op::v0::MatMul>(params[0], inputB, transpA, transpB);
auto result0 = std::make_shared<ov::op::v0::Result>(matMul);
auto result1 = std::make_shared<ov::op::v0::Result>(inputB);
result1->set_friendly_name("ConstantResult");

modifyGraph(netType, params, matMul);
function = std::make_shared<ov::Model>(ov::ResultVector{result0, result1}, params, "MatMulDecompressed3");
}

void check_execution_graph() override {
MatMulDecompressConvertTest::check_execution_graph();

// Check that Result has correct shape: the same as origin Constant
const auto results = compiledModel.outputs();
const auto result_it = std::find_if(results.cbegin(), results.cend(),
[](const ov::Output<const ov::Node>& out) {
return out.get_node()->get_friendly_name() == "ConstantResult";
});
ASSERT_NE(result_it, results.cend())
<< "Target Result has not been found!";
ASSERT_EQ(result_it->get_partial_shape(), inputDynamicShapes[1])
<< "Target Result has not origin shape. It has: " << result_it->get_partial_shape() << " but should have origin: " << inputDynamicShapes[1];
}
};

TEST_P(MatMulDecompressConvertTest3, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
run();
check_execution_graph();
}

namespace {
const auto testParams2D_FP16_3_smoke =
::testing::Combine(::testing::Values(static_shapes_to_test_representation({{1, 16, 32}, {32, 64}})),
::testing::Values(std::pair<bool, bool>{false, false}),
::testing::Values(ElementType::f16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));

INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16_3,
MatMulDecompressConvertTest3,
testParams2D_FP16_3_smoke,
MatMulDecompressConvertTest3::getTestCaseName);

} // namespace

} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,13 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_0) {
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 2, 2});

auto input2 = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 2, 2}, {1});
auto convert = std::make_shared<ov::opset1::Convert>(input2, ov::element::f32);
auto transpose_constant = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1});
auto transpose = std::make_shared<ov::opset1::Transpose>(input2, transpose_constant);
auto convert = std::make_shared<ov::opset1::Convert>(transpose, ov::element::f32);
auto transpose = std::make_shared<ov::opset1::Transpose>(convert, transpose_constant);

auto matmul = std::make_shared<ov::op::internal::FullyConnected>(
input1,
convert,
transpose,
std::make_shared<ov::op::v0::Constant>(ov::element::undefined, ov::Shape{0}));

model_ref = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{input1});
Expand All @@ -491,13 +491,13 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_1) {
auto transpose1 = std::make_shared<ov::opset1::Transpose>(input1, transpose_constant1);

auto input2 = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 2, 2}, {1});
auto convert = std::make_shared<ov::opset1::Convert>(input2, ov::element::f32);
auto transpose_constant2 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1});
auto transpose2 = std::make_shared<ov::opset1::Transpose>(input2, transpose_constant2);
auto convert = std::make_shared<ov::opset1::Convert>(transpose2, ov::element::f32);
auto transpose2 = std::make_shared<ov::opset1::Transpose>(convert, transpose_constant2);

auto matmul = std::make_shared<ov::op::internal::FullyConnected>(
transpose1,
convert,
transpose2,
std::make_shared<ov::op::v0::Constant>(ov::element::undefined, ov::Shape{0}));

model_ref = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{input1});
Expand Down

0 comments on commit 3caf7b2

Please sign in to comment.