Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] [EXPERIMENTAL] FullyConnected: enabled sparsity weights decompression #13775

Merged
merged 1 commit into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ void regmodule_properties(py::module m) {

// Submodule intel_cpu property
wrap_property_RW(m_intel_cpu, ov::intel_cpu::denormals_optimization, "denormals_optimization");
wrap_property_RW(m_intel_cpu,
ov::intel_cpu::sparse_weights_decompression_rate,
"sparse_weights_decompression_rate");

// Submodule device
py::module m_device =
Expand Down
2 changes: 2 additions & 0 deletions src/inference/include/ie/cpu/cpu_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ namespace CPUConfigParams {
*/
DECLARE_CPU_CONFIG_KEY(DENORMALS_OPTIMIZATION);

DECLARE_CPU_CONFIG_KEY(SPARSE_WEIGHTS_DECOMPRESSION_RATE);

} // namespace CPUConfigParams
} // namespace InferenceEngine
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,7 @@ namespace intel_cpu {
*/
static constexpr Property<bool> denormals_optimization{"CPU_DENORMALS_OPTIMIZATION"};

static constexpr Property<float> sparse_weights_decompression_rate{"SPARSE_WEIGHTS_DECOMPRESSION_RATE"};
antonvor marked this conversation as resolved.
Show resolved Hide resolved

} // namespace intel_cpu
} // namespace ov
14 changes: 14 additions & 0 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
// zero and any negative value will be treated
// as default batch size
batchLimit = std::max(val_i, 0);
} else if (key == CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE) {
float val_f = 0.0f;
try {
val_f = std::stof(val);
} catch (const std::exception&) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
<< ". Expected only float numbers";
}
if (val_f < 0.f || val_f > 1.f) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
<< ". Sparse rate must be in range [0.0f,1.0f]";
} else {
fcSparseWeiDecompressionRate = val_f;
}
} else if (key == PluginConfigParams::KEY_PERF_COUNT) {
if (val == PluginConfigParams::YES) collectPerfCounters = true;
else if (val == PluginConfigParams::NO) collectPerfCounters = false;
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct Config {
bool enableDynamicBatch = false;
std::string dumpToDot = "";
int batchLimit = 0;
float fcSparseWeiDecompressionRate = 1.0f;
size_t rtCacheCapacity = 5000ul;
InferenceEngine::IStreamsExecutor::Config streamExecutorConfig;
InferenceEngine::PerfHintsConfig perfHintsConfig;
Expand Down
12 changes: 12 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <nodes/reorder.h>
#include "nodes/convert.h"
#include "nodes/subgraph.h"
#include "nodes/fullyconnected.h"

#include <ie_algorithm.hpp>
#include <blob_factory.hpp>
Expand Down Expand Up @@ -341,6 +342,9 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
if (config.enforceBF16)
EnforceBF16();

if (config.fcSparseWeiDecompressionRate < 1.0f)
setMinSparseRate(config.fcSparseWeiDecompressionRate);

auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool {
const auto & childEdges = node->getChildEdges();
return std::any_of(childEdges.begin(), childEdges.end(),
Expand Down Expand Up @@ -1454,6 +1458,14 @@ void Graph::EnforceBF16() {
}
}

void Graph::setMinSparseRate(float minSparseRate) {
for (const auto &node : graphNodes) {
if (auto fcNodePtr = std::dynamic_pointer_cast<node::FullyConnected>(node)) {
fcNodePtr->setMinSparseRate(minSparseRate);
}
}
}

std::shared_ptr<ngraph::Function> Graph::dump() const {
return dump_graph_as_ie_ngraph_net(*this);
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class Graph {
DnnlScratchPadPtr rtScratchPad;

void EnforceBF16();
void setMinSparseRate(float minSparseRate);
};

} // namespace intel_cpu
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ std::string Node::getPrimitiveDescriptorType() {
SEARCH_TYPE(uni);

SEARCH_TYPE(winograd);
SEARCH_TYPE(sparse);
SEARCH_TYPE(_dw);
SEARCH_TYPE(_1x1);

Expand Down
82 changes: 78 additions & 4 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "fullyconnected.h"
#include "eltwise.h"
#include "input.h"
#include "fake_quantize.h"
#include "input.h"
#include "reorder.h"
Expand All @@ -22,6 +23,7 @@
#include <common/primitive_desc.hpp>
#include <common/primitive_desc_iface.hpp>
#include "onednn/dnnl.h"
#include "cpu/x64/cpu_isa_traits.hpp"

using namespace dnnl;
using namespace InferenceEngine;
Expand Down Expand Up @@ -172,6 +174,8 @@ void FullyConnected::getSupportedDescriptors() {
if (getChildEdges().empty())
IE_THROW()<< errorPrefix << " has incorrect number of output edges";

useSparseWeights = useSparseWeightsDecompression();

auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));

Expand Down Expand Up @@ -360,6 +364,10 @@ void FullyConnected::prepareParams() {
}
// changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType());
// WA: We update implType to know whether weights decompression was used inside the kernel
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) {
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
}
// maybe expected 1x1 conv is not created, update the flag depends on the real type
useConv1x1 = execPtr->getImplementationType() == brgconv_avx512_1x1;

Expand Down Expand Up @@ -503,6 +511,7 @@ bool FullyConnected::created() const {
const std::vector<impl_desc_type>& FullyConnected::getPrimitivesPriority() {
std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown,
impl_desc_type::brgemm_sparse_avx512_amx,
impl_desc_type::brgemm_avx512_amx,
impl_desc_type::brgemm_avx512,
impl_desc_type::gemm_blas,
Expand Down Expand Up @@ -578,9 +587,15 @@ void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDes
DnnlExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
}

dnnl::memory::desc wgh_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, dnnl::memory::format_tag::any);

// We need to explicitly specify the memory descriptor to use sparse weights decompression
dnnl::memory::desc wgh_candidate;
if (useSparseWeights) {
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, memory::desc::packed(nnzCount) };
} else {
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, dnnl::memory::format_tag::any };
}
if (withBiases) {
dnnl::memory::desc bias_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(BIAS_ID).getStaticDims()), bdt,
dnnl::memory::format_tag::any);
Expand Down Expand Up @@ -634,7 +649,7 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
portConfig.inPlace(-1);
portConfig.constant(false);
auto desc = getSrcMemDesc(itpd, i);
if (supportsUndefStridesAndOffset()) {
if (supportsUndefStridesAndOffset() && !(i == WEIGHTS_ID && useSparseWeights)) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
portConfig.setMemDesc(desc);
Expand Down Expand Up @@ -868,6 +883,65 @@ MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
return ptr;
}

bool FullyConnected::useSparseWeightsDecompression() {
// minSparseRate == 1 means that sparse feature is switched off
if (minSparseRate == 1.f) {
return false;
}

if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx))
return false;

auto weiDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
if (weiDims.size() != 2 || weiDims[0] % 64 != 0 || weiDims[1] % 64 != 0) {
return false;
}

auto inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);
auto weightsPrecision = getOriginalInputPrecisionAtPort(WEIGHTS_ID);
if (!one_of(inputPrecision , Precision::U8, Precision::I8) || weightsPrecision != Precision::I8) {
return false;
}

// calculate sparse rate
const auto constNode = std::dynamic_pointer_cast<Input>(getParentEdgeAt(WEIGHTS_ID)->getParent());
if (!constNode) {
return false;
}
auto blb = constNode->getMemoryPtr();
if (blb == nullptr)
IE_THROW() << "Cannot get const blob for node " << getName() << ".";

auto weightsData = reinterpret_cast<const int8_t*>(blb->GetPtr());
auto elementsCount = blb->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
size_t zerosCounts = 0;
for (int i = 0; i < elementsCount; i++) {
if (weightsData[i] == 0) {
zerosCounts++;
}
}
nnzCount = elementsCount - zerosCounts;

DEBUG_LOG(getName(), ", weightsData.size() = ", elementsCount, ", zerosCounts = ",
zerosCounts, ", nnzCount = ", nnzCount);

weiSparseRate = static_cast<float>(zerosCounts) / static_cast<float>(elementsCount);

// [av] WA: there is no point in using sparse decompression when the sparse rate is low
// todo: add heuristic
if (minSparseRate < 0.5)
antonvor marked this conversation as resolved.
Show resolved Hide resolved
minSparseRate = 0.5;

DEBUG_LOG(getName(), " | sparse rate = ", weiSparseRate * 100, "%, min sparse rate = ",
minSparseRate * 100, "%, use sparse weights = ", weiSparseRate >= minSparseRate);

if (weiSparseRate < minSparseRate) {
return false;
}

return true;
}

} // namespace node
} // namespace intel_cpu
} // namespace ov
10 changes: 10 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FullyConnected : public Node {

void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
// void createPrimitive() override;
antonvor marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;

Expand All @@ -58,6 +59,8 @@ class FullyConnected : public Node {

void setDynamicBatchLim(int lim) override;

void setMinSparseRate(float sparseRate) { minSparseRate = sparseRate; }

private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
const dnnl::memory::desc &outputDesc);
Expand Down Expand Up @@ -106,6 +109,13 @@ class FullyConnected : public Node {

bool canBeExecutedInConv1x1() const;
MemoryPtr prepareWeightMemory(const DnnlMemoryDescPtr weightDesc);

// sparse weights
bool useSparseWeights = false;
int nnzCount = -1;
float minSparseRate = 1.f;
float weiSparseRate = 0.f;
bool useSparseWeightsDecompression();
};

} // namespace node
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/onednn/iml_type_mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl_desc_type parse_impl_name(std::string impl_desc_name) {
SEARCH_WORD(_1x1);
SEARCH_WORD(_dw);
SEARCH_WORD(reorder);
SEARCH_WORD(sparse);
if ((res & impl_desc_type::avx2) != impl_desc_type::avx2 &&
(res & impl_desc_type::avx512) != impl_desc_type::avx512)
SEARCH_WORD(avx);
Expand Down Expand Up @@ -108,6 +109,7 @@ const char* impl_type_to_string(impl_desc_type type) {
CASE(brgemm_sse42);
CASE(brgemm_uni);
CASE(brgemm_avx512_amx);
CASE(brgemm_sparse_avx512_amx);

#undef CASE
return "unknown";
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/onednn/iml_type_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ enum impl_desc_type {
reorder = 1<<22,
// winograd
winograd = 1<<23,
// sparse
sparse = 1<<24,

// real types
ref_any = ref | any,
Expand Down Expand Up @@ -90,6 +92,7 @@ enum impl_desc_type {
brgemm_sse42 = brgemm | sse42,
brgemm_uni = brgemm | uni,
brgemm_avx512_amx = brgemm | avx512 | amx,
brgemm_sparse_avx512_amx = brgemm | sparse | avx512 | amx,
};

const char * impl_type_to_string(impl_desc_type type);
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/thirdparty/onednn
Submodule onednn updated 67 files
+95 −0 include/oneapi/dnnl/dnnl.h
+159 −3 include/oneapi/dnnl/dnnl.hpp
+1 −0 include/oneapi/dnnl/dnnl_debug.h
+44 −0 include/oneapi/dnnl/dnnl_types.h
+8 −1 scripts/generate_dnnl_debug.py
+13 −0 src/common/c_types_map.hpp
+13 −0 src/common/dnnl_debug_autogenerated.cpp
+205 −6 src/common/memory.cpp
+46 −6 src/common/memory.hpp
+111 −8 src/common/memory_desc_wrapper.hpp
+5 −1 src/common/primitive.hpp
+4 −3 src/common/primitive_exec_types.cpp
+3 −3 src/common/primitive_exec_types.hpp
+22 −0 src/common/primitive_hashing_utils.cpp
+9 −0 src/common/primitive_hashing_utils.hpp
+62 −8 src/common/type_helpers.hpp
+2 −2 src/common/utils.hpp
+6 −0 src/common/verbose.cpp
+3 −0 src/cpu/reorder/cpu_reorder.hpp
+4 −0 src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp
+4 −0 src/cpu/reorder/cpu_reorder_regular_s8.cpp
+248 −0 src/cpu/reorder/simple_sparse_reorder.hpp
+6 −0 src/cpu/x64/brgemm/brgemm_types.hpp
+4 −4 src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
+1 −1 src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
+1 −1 src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
+2 −2 src/cpu/x64/gemm_bf16_convolution.cpp
+615 −394 src/cpu/x64/injectors/jit_uni_binary_injector.cpp
+94 −115 src/cpu/x64/injectors/jit_uni_binary_injector.hpp
+1 −1 src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_avx2_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_common_conv_kernel.cpp
+4 −2 src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp
+1 −0 src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp
+5 −2 src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp
+1 −0 src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp
+1 −1 src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp
+2 −2 src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp
+106 −0 src/cpu/x64/jit_brgemm_decompress_kernel.cpp
+80 −0 src/cpu/x64/jit_brgemm_decompress_kernel.hpp
+33 −6 src/cpu/x64/jit_brgemm_inner_product.cpp
+8 −0 src/cpu/x64/jit_brgemm_inner_product.hpp
+16 −1 src/cpu/x64/jit_brgemm_inner_product_utils.cpp
+1 −1 src/cpu/x64/jit_brgemm_post_ops.hpp
+3 −0 src/cpu/x64/jit_brgemm_primitive_conf.hpp
+1 −1 src/cpu/x64/jit_gemm_convolution_utils.cpp
+1 −1 src/cpu/x64/jit_gemm_inner_product_utils.cpp
+1 −1 src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp
+2 −2 src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp
+2 −2 src/cpu/x64/jit_sse41_conv_kernel_f32.cpp
+4 −3 src/cpu/x64/jit_uni_binary_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_uni_i8i8_pooling.cpp
+1 −1 src/cpu/x64/jit_uni_pool_kernel.cpp
+3 −3 src/cpu/x64/jit_uni_reduction_kernel.cpp
+1 −0 src/cpu/x64/jit_uni_reduction_kernel.hpp
+1 −1 src/cpu/x64/jit_uni_resampling_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp