Skip to content

Commit

Permalink
fix arm server
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes committed Sep 26, 2024
1 parent 1faac75 commit 8c7c540
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,37 @@ bool ACLMVNExecutor::supports(const MVNConfig &config) {
}

void ACLMVNExecutor::updateTensorsShapes(ACLShapes& aclMemoryShapes) {
arm_compute::TensorShape srcDims;
const auto src_num_dim = aclMemoryShapes[ACLArgs::ACL_SRC_0].num_dimensions();
for (size_t i = 0; i < src_num_dim; i++) {
srcDims.set(i, aclMemoryShapes[ACLArgs::ACL_SRC_0][src_num_dim - i - 1]);
}
const auto srcDims = aclMemoryShapes[ACLArgs::ACL_SRC_0];
const auto srcNumDim = aclMemoryShapes[ACLArgs::ACL_SRC_0].num_dimensions();

size_t X, Y;
if (aclMVNAtrrs.initAcrossChannels_) {
if (srcDims.num_dimensions() >= 2u) {
Y = srcDims[0];
X = srcDims[1];
Y = srcDims[srcNumDim - 1];
X = srcDims[srcNumDim - 2];
for (size_t i = 2; i < srcDims.num_dimensions(); i++) {
X *= srcDims[i];
X *= srcDims[srcNumDim - i - 1];
}
} else {
Y = 1;
X = srcDims[0];
X = srcDims[srcNumDim - 1];
}
} else {
if (srcDims.num_dimensions() > 2u) {
Y = srcDims[0] * srcDims[1];
X = srcDims[2];
Y = srcDims[srcNumDim - 1] * srcDims[srcNumDim - 2];
X = srcDims[srcNumDim - 3];
for (size_t i = 3; i < srcDims.num_dimensions(); i++) {
X *= srcDims[i];
X *= srcDims[srcNumDim - i - 1];
}
} else if (srcDims.num_dimensions() == 2u) {
Y = srcDims[0] * srcDims[1];
Y = srcDims[srcNumDim - 1] * srcDims[srcNumDim - 2];
X = 1;
} else {
Y = srcDims[0];
Y = srcDims[srcNumDim - 1];
X = 1;
}
}

aclMemoryShapes[ACLArgs::ACL_SRC_0].set(0, X);
aclMemoryShapes[ACLArgs::ACL_SRC_0].set(1, Y);
aclMemoryShapes[ACLArgs::ACL_DST].set(0, X);
aclMemoryShapes[ACLArgs::ACL_DST].set(1, Y);
aclMemoryShapes[ACLArgs::ACL_SRC_0] = aclMemoryShapes[ACLArgs::ACL_DST] = arm_compute::TensorShape(X, Y);
}

arm_compute::Status ACLMVNExecutor::validateTensorsInfo(const ACLInfos &aclMemoryInfos) {
Expand Down

0 comments on commit 8c7c540

Please sign in to comment.