-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
517 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 16 additions & 33 deletions
49
src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,34 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "acl_utils.hpp" | ||
#include "nodes/executors/mvn.hpp" | ||
#include "arm_compute/runtime/NEON/NEFunctions.h" | ||
#include "utils/debug_capabilities.h" | ||
#include "acl_common_executor.hpp" | ||
#include "nodes/executors/mvn_config.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
|
||
class AclMVNExecutor : public MVNExecutor { | ||
class ACLMVNExecutor : public ACLCommonExecutor { | ||
public: | ||
AclMVNExecutor(const ExecutorContext::CPtr context); | ||
ACLMVNExecutor(const MVNAttrs& attrs, | ||
const PostOps& postOps, | ||
const MemoryArgs& memory, | ||
const ExecutorContext::CPtr context) : aclMVNAtrrs(attrs) {} | ||
|
||
bool init(const MVNAttrs& mvnAttrs, | ||
const std::vector<MemoryDescPtr>& srcDescs, | ||
const std::vector<MemoryDescPtr>& dstDescs, | ||
const dnnl::primitive_attr &attr) override; | ||
void exec(const std::vector<MemoryCPtr>& src, | ||
const std::vector<MemoryPtr>& dst, | ||
const void *post_ops_data_) override; | ||
static bool supports(const MVNConfig& config); | ||
|
||
impl_desc_type getImplType() const override { | ||
return implType; | ||
} | ||
void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; | ||
|
||
private: | ||
impl_desc_type implType = impl_desc_type::acl; | ||
arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; | ||
|
||
arm_compute::Tensor srcTensor; | ||
arm_compute::Tensor dstTensor; | ||
std::unique_ptr<arm_compute::NEMeanStdDevNormalizationLayer> mvn = nullptr; | ||
}; | ||
ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; | ||
|
||
class AclMVNExecutorBuilder : public MVNExecutorBuilder { | ||
public: | ||
bool isSupported(const MVNAttrs& mvnAttrs, | ||
const std::vector<MemoryDescPtr>& srcDescs, | ||
const std::vector<MemoryDescPtr>& dstDescs) const override; | ||
|
||
MVNExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const override { | ||
return std::make_shared<AclMVNExecutor>(context); | ||
} | ||
private: | ||
MVNAttrs aclMVNAtrrs; | ||
}; | ||
|
||
using ACLMVNExecutorPtr = std::shared_ptr<ACLMVNExecutor>; | ||
} // namespace intel_cpu | ||
} // namespace ov | ||
} // namespace ov |
120 changes: 120 additions & 0 deletions
120
src/plugins/intel_cpu/src/nodes/executors/common/ref_mvn.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ref_mvn.hpp" | ||
#include "openvino/core/parallel.hpp" | ||
|
||
void ov::intel_cpu::CommonMVNExecutor::execute(const ov::intel_cpu::MemoryArgs &memory) { | ||
mvn_ref(reinterpret_cast<uint8_t *>(memory.at(ARG_SRC_0)->getData()), | ||
reinterpret_cast<uint8_t *>(memory.at(ARG_DST)->getData()), refMVNAttrs.shape5D); | ||
} | ||
|
||
bool ov::intel_cpu::CommonMVNExecutor::update(const ov::intel_cpu::MemoryArgs &memory) { | ||
return true; | ||
} | ||
|
||
bool ov::intel_cpu::CommonMVNExecutor::supports(const ov::intel_cpu::MVNConfig& config) { | ||
return true; | ||
} | ||
|
||
void ov::intel_cpu::CommonMVNExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data, const VectorDims& shape5d) { | ||
const float *src_data_ptr = reinterpret_cast<const float *>(src_data); | ||
float *dst_data_ptr = reinterpret_cast<float *>(dst_data); | ||
const size_t N = shape5d[0]; | ||
const size_t C = shape5d[1]; | ||
const size_t D = shape5d[2]; | ||
const size_t H = shape5d[3]; | ||
const size_t W = shape5d[4]; | ||
|
||
size_t C1 = H * W; | ||
size_t C2 = C1 * D; | ||
size_t C3 = C2 * C; | ||
|
||
parallel_for(N, [&](int b) { | ||
size_t cb = b * C3; | ||
if (refMVNAttrs.execAcrossChannels_) { | ||
// Parallel sum for each channel for mean | ||
float C3inv = 1.f / static_cast<float>(C3); | ||
float mean_temp = 0.0f; | ||
|
||
mean_temp = parallel_sum(C, mean_temp, [&](size_t c)->float { | ||
float mean_internal = 0.0f; | ||
size_t cc = cb + c * C2; | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
mean_internal += src_data_ptr[cc + sp]; | ||
} | ||
return mean_internal; | ||
}); | ||
|
||
float mean = mean_temp * C3inv; | ||
|
||
if (refMVNAttrs.normalizeVariance_) { | ||
// parallel sum for each channel for variance | ||
float variance_temp = 0.0f; | ||
variance_temp = parallel_sum(C, variance_temp, [&](size_t c)->float { | ||
float variance_internal = 0.0f; | ||
size_t cc = cb + c * C2; | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
variance_internal += (src_data_ptr[cc + sp] - mean) * (src_data_ptr[cc + sp] - mean); | ||
} | ||
return variance_internal; | ||
}); | ||
|
||
float variance = 1.f; | ||
if (refMVNAttrs.epsMode_ == INSIDE_SQRT) | ||
variance = 1.f / sqrtf(variance_temp * C3inv + refMVNAttrs.epsValue_); | ||
else if (refMVNAttrs.epsMode_ == OUTSIDE_SQRT) | ||
variance = 1.f / (sqrtf(variance_temp * C3inv) + refMVNAttrs.epsValue_); | ||
|
||
parallel_for(C, [&](int c) { | ||
size_t cc = cb + c * C2; | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
dst_data_ptr[cc + sp] = (src_data_ptr[cc + sp] - mean) * variance; | ||
} | ||
}); | ||
} else { | ||
parallel_for(C, [&](int c) { | ||
size_t cc = cb + c * C2; | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
dst_data_ptr[cc + sp] = src_data_ptr[cc + sp] - mean; | ||
} | ||
}); | ||
} | ||
} else { // per channel | ||
float C2inv = 1.f / static_cast<float>(C2); | ||
parallel_for(C, [&](size_t c) { | ||
// mean for this channel | ||
float mean = 0.f; | ||
size_t cc = cb + c * C2; | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
mean += src_data_ptr[cc + sp]; | ||
} | ||
mean *= C2inv; | ||
|
||
if (refMVNAttrs.normalizeVariance_) { | ||
// variance for this channel | ||
float variance = 0.f; | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
variance += (src_data_ptr[cc + sp] - mean) * (src_data_ptr[cc + sp] - mean); | ||
} | ||
|
||
if (refMVNAttrs.epsMode_ == INSIDE_SQRT) | ||
variance = 1.f / sqrtf(variance * C2inv + refMVNAttrs.epsValue_); | ||
else if (refMVNAttrs.epsMode_ == OUTSIDE_SQRT) | ||
variance = 1.f / (sqrtf(variance * C2inv) + refMVNAttrs.epsValue_); | ||
|
||
// mvn for this channel | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
dst_data_ptr[cc + sp] = (src_data_ptr[cc + sp] - mean) * variance; | ||
} | ||
} else { | ||
// mvn for this channel | ||
for (size_t sp = 0lu; sp < C2; sp++) { | ||
dst_data_ptr[cc + sp] = src_data_ptr[cc + sp] - mean; | ||
} | ||
} | ||
}); | ||
} | ||
}); | ||
} |
37 changes: 37 additions & 0 deletions
37
src/plugins/intel_cpu/src/nodes/executors/common/ref_mvn.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (C) 2018-2022 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
#pragma once | ||
|
||
#include <memory> | ||
#include "cpu_memory.h" | ||
#include "nodes/executors/mvn_config.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
|
||
class CommonMVNExecutor : public Executor { | ||
public: | ||
CommonMVNExecutor(const MVNAttrs& attrs, | ||
const PostOps& postOps, | ||
const MemoryArgs& memory, | ||
const ExecutorContext::CPtr context) : refMVNAttrs(attrs) {} | ||
|
||
void execute(const MemoryArgs& memory) override; | ||
|
||
impl_desc_type implType() const override { | ||
return impl_desc_type::ref; | ||
} | ||
|
||
// offloads execution data preparation from the exec call | ||
bool update(const MemoryArgs& memory) override; | ||
|
||
static bool supports(const MVNConfig& config); | ||
|
||
private: | ||
const MVNAttrs& refMVNAttrs; | ||
void mvn_ref(const uint8_t *in_ptr_, uint8_t *out_ptr_, const VectorDims& shape5d); | ||
}; | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.