Skip to content

Commit

Permalink
fix some part of x86
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes committed Sep 30, 2024
1 parent 2934d40 commit f4255d9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define UNSUPPORTED_SPARSE_WEIGHTS " sparse weights are not supported"
#define UNSUPPORTED_WEIGHTS_DECOMPRESSION " weights decompression is not supported"
#define UNSUPPORTED_POST_OPS " post ops are not supported"
#define UNSUPPORTED_LAYOUT " layout are not supported"
#define UNSUPPORTED_NUMBER_OF_POSTOPS " the number of post ops is not supported"
#define UNSUPPORTED_TYPE_OF_POSTOPS " the type of post ops is not supported"
#define UNSUPPORTED_SRC_PRECISIONS " unsupported src precisions"
Expand Down
85 changes: 75 additions & 10 deletions src/plugins/intel_cpu/src/nodes/executors/mvn_implementations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,29 @@ ov::optional<executor::Config<Attrs>> requiresFallbackCommon(const executor::Con
return ov::optional<executor::Config<Attrs>>(MVNConfig {optimalDescriptors, config.attrs, config.postOps});
}

OV_CPU_MAYBE_UNUSED_FUNCTION static inline bool noPostOps(const MVNConfig& config) {
return config.postOps.empty();
OV_CPU_MAYBE_UNUSED_FUNCTION static inline bool noLayout(const MVNConfig& config, const LayoutType& layoutType) {
return config.descs.at(ARG_SRC)->hasLayoutType(layoutType);
}

template <>
const std::vector<ExecutorImplementation<MVNAttrs>>& getImplementations() {
static const std::vector<ExecutorImplementation<MVNAttrs>> mvnImplementations {
OV_CPU_INSTANCE_X64(
"mvn_jit_x64_ncsp",
"mvn_jit_x64_nspc",
ExecutorType::jit_x64,
OperationType::MVN,
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
VERIFY(noLayout(config, LayoutType::nspc), UNSUPPORTED_LAYOUT);
VERIFY(one_of(srcRank(config), 4lu, 5lu), UNSUPPORTED_SRC_RANK);
return JITMVNExecutor::supports(config);
},
// requiresFallback
[](const MVNConfig& config) -> ov::optional<executor::Config<MVNAttrs>> {
return requiresFallbackCommon(config,
jitMVNTypeMapping,
{LayoutType::ncsp, LayoutType::ncsp},
{LayoutType::nspc, LayoutType::nspc},
mvnMappingNotation);
},
// acceptsShapes
Expand All @@ -160,19 +162,82 @@ const std::vector<ExecutorImplementation<MVNAttrs>>& getImplementations() {
return std::make_shared<JITMVNExecutor>(attrs, postOps, memory, context);
})
OV_CPU_INSTANCE_X64(
"mvn_jit_x64_nspc",
"mvn_jit_x64_nCsp16c",
ExecutorType::jit_x64,
OperationType::MVN,
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
VERIFY(noLayout(config, LayoutType::nCsp16c), UNSUPPORTED_LAYOUT);
VERIFY(one_of(srcRank(config), 4lu, 5lu), UNSUPPORTED_SRC_RANK);
VERIFY(mayiuse(cpu::x64::avx512_core), UNSUPPORTED_ISA);
return JITMVNExecutor::supports(config);
},
// requiresFallback
[](const MVNConfig& config) -> ov::optional<executor::Config<MVNAttrs>> {
return requiresFallbackCommon(config,
jitMVNTypeMapping,
{LayoutType::nCsp16c, LayoutType::nCsp16c},
mvnMappingNotation);
},
// acceptsShapes
[](const MemoryArgs& memory) -> bool {
// @todo create syntactic sugar (functor) for shape agnostic lambda
return true;
},
// create
[](const MVNAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
const ExecutorContext::CPtr context) {
return std::make_shared<JITMVNExecutor>(attrs, postOps, memory, context);
})
OV_CPU_INSTANCE_X64(
"mvn_jit_x64_nCsp8c",
ExecutorType::jit_x64,
OperationType::MVN,
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
VERIFY(noLayout(config, LayoutType::nCsp8c), UNSUPPORTED_LAYOUT);
VERIFY(one_of(srcRank(config), 4lu, 5lu), UNSUPPORTED_SRC_RANK);
VERIFY(mayiuse(cpu::x64::avx2) || mayiuse(cpu::x64::sse41), UNSUPPORTED_ISA);
return JITMVNExecutor::supports(config);
},
// requiresFallback
[](const MVNConfig& config) -> ov::optional<executor::Config<MVNAttrs>> {
return requiresFallbackCommon(config,
jitMVNTypeMapping,
{LayoutType::nspc, LayoutType::nspc},
{LayoutType::nCsp8c, LayoutType::nCsp8c},
mvnMappingNotation);
},
// acceptsShapes
[](const MemoryArgs& memory) -> bool {
// @todo create syntactic sugar (functor) for shape agnostic lambda
return true;
},
// create
[](const MVNAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
const ExecutorContext::CPtr context) {
return std::make_shared<JITMVNExecutor>(attrs, postOps, memory, context);
})
OV_CPU_INSTANCE_X64(
"mvn_jit_x64_ncsp",
ExecutorType::jit_x64,
OperationType::MVN,
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
VERIFY(noLayout(config, LayoutType::ncsp), UNSUPPORTED_LAYOUT);
return JITMVNExecutor::supports(config);
},
// requiresFallback
[](const MVNConfig& config) -> ov::optional<executor::Config<MVNAttrs>> {
return requiresFallbackCommon(config,
jitMVNTypeMapping,
{LayoutType::ncsp, LayoutType::ncsp},
mvnMappingNotation);
},
// acceptsShapes
Expand All @@ -194,7 +259,7 @@ const std::vector<ExecutorImplementation<MVNAttrs>>& getImplementations() {
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
if (!config.descs.at(ARG_SRC)->hasLayoutType(LayoutType::nspc)) return false;
VERIFY(noLayout(config, LayoutType::nspc), UNSUPPORTED_LAYOUT);
return ACLMVNExecutor::supports(config);
},
// requiresFallback
Expand Down Expand Up @@ -223,7 +288,7 @@ const std::vector<ExecutorImplementation<MVNAttrs>>& getImplementations() {
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
if (!config.descs.at(ARG_SRC)->hasLayoutType(LayoutType::ncsp)) return false;
VERIFY(noLayout(config, LayoutType::ncsp), UNSUPPORTED_LAYOUT);
return ACLMVNExecutor::supports(config);
},
// requiresFallback
Expand Down Expand Up @@ -252,7 +317,7 @@ const std::vector<ExecutorImplementation<MVNAttrs>>& getImplementations() {
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
if (!config.descs.at(ARG_SRC)->hasLayoutType(LayoutType::ncsp)) return false;
VERIFY(noLayout(config, LayoutType::ncsp), UNSUPPORTED_LAYOUT);
return CommonMVNExecutor::supports(config);
},
// requiresFallback
Expand Down Expand Up @@ -281,7 +346,7 @@ const std::vector<ExecutorImplementation<MVNAttrs>>& getImplementations() {
ShapeTolerance::Agnostic,
// supports
[](const MVNConfig& config) -> bool {
if (!config.descs.at(ARG_SRC)->hasLayoutType(LayoutType::nspc)) return false;
VERIFY(noLayout(config, LayoutType::nspc), UNSUPPORTED_LAYOUT);
return CommonMVNExecutor::supports(config);
},
// requiresFallback
Expand Down
66 changes: 7 additions & 59 deletions src/plugins/intel_cpu/src/nodes/mvn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ void MVN::initSupportedPrimitiveDescriptors() {
if (!fusedWith.empty())
dstTypes = fusedWith.back()->getOriginalOutputPrecisions();


VecMemoryDescs srcDescs;
const auto& creatorsMap = BlockedDescCreator::getCommonCreators();
for (size_t i = 0; i < srcTypes.size(); i++) {
Expand All @@ -239,6 +240,10 @@ void MVN::initSupportedPrimitiveDescriptors() {
{ARG_DST, dstDescs[0]},
};

if (one_of(descs.at(ARG_SRC)->getShape().getRank(), 1lu, 2lu) && mvnAttrs.initAcrossChannels_) {
mvnAttrs.execAcrossChannels_ = false;
}

mvnAttrs.srcIsNHWC = descs.at(ARG_SRC)->hasLayoutType(LayoutType::nspc);
mvnAttrs.src_prc = descs.at(ARG_SRC)->getPrecision();
mvnAttrs.dst_prc = descs.at(ARG_DST)->getPrecision();
Expand Down Expand Up @@ -270,53 +275,11 @@ void MVN::initSupportedPrimitiveDescriptors() {
}

// planar
if (canBeInplace)
if (descs.at(ARG_SRC)->hasLayoutType(LayoutType::nspc) && canBeInplace)
nodeConfig.inConfs[0].inPlace(0);

supportedPrimitiveDescriptors.emplace_back(nodeConfig, impl_desc_type::undef);
}
return;

auto pushDesc = [&](LayoutType format, impl_desc_type impl_type, bool useAclExecutor = false) {};

#if defined(OV_CPU_WITH_ACL)
pushDesc(LayoutType::nspc, acl, true);
pushDesc(LayoutType::ncsp, acl, true);
if (!supportedPrimitiveDescriptors.empty())
return;
else
// Reference MVN implementation does not support fp16, so set fp32 explicitly
inputPrecision = outputPrecision = ov::element::f32;
#endif // OV_CPU_WITH_ACL

impl_desc_type impl_type;
if (mayiuse(cpu::x64::avx512_core)) {
impl_type = impl_desc_type::jit_avx512;
} else if (mayiuse(cpu::x64::avx2)) {
impl_type = impl_desc_type::jit_avx2;
} else if (mayiuse(cpu::x64::sse41)) {
impl_type = impl_desc_type::jit_sse42;
} else {
impl_type = impl_desc_type::ref;
}

if (mayiuse(cpu::x64::sse41)) {
// nspc
if (getInputShapeAtPort(0).getRank() == 4 || getInputShapeAtPort(0).getRank() == 5) {
pushDesc(LayoutType::nspc, impl_type);
}
// blk
if (impl_desc_type::jit_avx512 == impl_type) {
if (getInputShapeAtPort(0).getRank() == 4 || getInputShapeAtPort(0).getRank() == 5) {
pushDesc(LayoutType::nCsp16c, impl_type);
}
} else if (impl_desc_type::jit_avx2 == impl_type || impl_desc_type::jit_sse42 == impl_type) {
if (getInputShapeAtPort(0).getRank() == 4 || getInputShapeAtPort(0).getRank() == 5) {
pushDesc(LayoutType::nCsp8c, impl_type);
}
}
}
pushDesc(LayoutType::ncsp, impl_type);
}

ExecutorPtr MVN::createExecutor() {
Expand All @@ -335,23 +298,8 @@ void MVN::prepareParams() {
if (getSelectedPrimitiveDescriptor() == nullptr)
OPENVINO_THROW("Preferable primitive descriptor is not set.");

const VectorDims in_dims = srcMemPtr->getStaticDims();
transformTo5DCase(in_dims);

#if defined(OPENVINO_ARCH_X86_64)
// New shape5D always need prepare via transformTo5DCase(), which is need in exec().
// MVN itself and unary post ops is totally shape agnostic, execPtr can be reused directly w/o recompilation and setPostOps when shape is changed.
// As key have not shape, if shape changes and new post ops attr is also the same, execPtr can still hit.
// If new shape(channel changes) impact post ops attr, such as entry.quantization.offset, entry.depthwise.offset, entry.quantization.per_channel,
// which is participate in compilation, even postOpsData is passed in runtime, still need recompilation.
if (executor != nullptr && (fusedWith.empty() || onlyUnaryPostOps)) {
return;
}
#endif
transformTo5DCase(srcMemPtr->getStaticDims());

auto selectedPD = getSelectedPrimitiveDescriptor();
mvnAttrs.src_prc = selectedPD->getConfig().inConfs[0].getMemDesc()->getPrecision();
mvnAttrs.dst_prc = selectedPD->getConfig().outConfs[0].getMemDesc()->getPrecision();
if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::ncsp)) {
mvnAttrs.layout = MVNLayoutType::mvn_planar;
} else if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc)) {
Expand Down

0 comments on commit f4255d9

Please sign in to comment.