diff --git a/.github/workflows/code_style.yml b/.github/workflows/code_style.yml index a70d2641cb57f3..d4da2a16d38923 100644 --- a/.github/workflows/code_style.yml +++ b/.github/workflows/code_style.yml @@ -38,6 +38,36 @@ jobs: level: warning fail_on_error: true + clang-format-aarch64: + runs-on: ubuntu-22.04 + if: ${{ github.repository_owner == 'openvinotoolkit' }} + permissions: + pull-requests: write + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: 'true' + + - name: Install clang-format-15 and cross-compilation dependencies + run: | + sudo apt update + sudo apt --assume-yes install binutils-aarch64-linux-gnu gcc-aarch64-linux-gnu g++-aarch64-linux-gnu scons clang-format-15 + + # Run cmake with -DENABLE_PROFILING_ITT=ON -DSELECTIVE_BUILD=COLLECT in order to enable codestyle check for ITT collector + - name: CMake configure + run: cmake -DENABLE_CLANG_FORMAT=ON -DENABLE_TESTS=ON -DENABLE_PROFILING_ITT=ON -DSELECTIVE_BUILD=COLLECT -DCMAKE_TOOLCHAIN_FILE=cmake/arm64.toolchain.cmake -B build_arm64 + + - name: Create code style diff + run: cmake --build build_arm64 --target clang_format_fix_all -j8 + + - name: suggester / clang-format + if: startsWith(github.event_name, 'pull_request') + uses: reviewdog/action-suggester@db4abb16fbaabe386831e5addb7be1485d0d63d3 # v1.18.0 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + level: warning + fail_on_error: true + ShellCheck: runs-on: ubuntu-22.04 if: ${{ github.repository_owner == 'openvinotoolkit' }} diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index 7383b8b8f9ddab..43417942e8bc53 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_conversion_emitters.hpp" + #include "emitters/utils.hpp" using namespace dnnl::impl::cpu::aarch64; @@ -25,27 +26,27 @@ namespace aarch64 { // does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three // instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16). template -inline void jit_convert_emitter::cvt_f16_to_f32(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_f16_to_f32(const TReg& src, const TReg& dst) const { h->fcvtl(dst.s4, src.h4); } template -inline void jit_convert_emitter::cvt_f32_to_f16(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_f32_to_f16(const TReg& src, const TReg& dst) const { h->fcvtn(dst.h4, src.s4); } template -inline void jit_convert_emitter::cvt_f32_to_i32(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_f32_to_i32(const TReg& src, const TReg& dst) const { h->fcvtzs(dst.s, src.s); } template -inline void jit_convert_emitter::cvt_i32_to_f32(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_i32_to_f32(const TReg& src, const TReg& dst) const { h->scvtf(dst.s, src.s); } template -inline void jit_convert_emitter::cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const { +inline void jit_convert_emitter::cvt_i32_to_i16(const TReg& src, const TReg& dst, bool is_saturated) const { if (is_saturated) { h->sqxtn(dst.h4, src.s4); } else { @@ -54,22 +55,25 @@ inline void jit_convert_emitter::cvt_i32_to_i16(const TReg &src, const TReg &dst } template -inline void jit_convert_emitter::cvt_i16_to_i32(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_i16_to_i32(const TReg& src, const TReg& dst) const { h->sxtl(dst.s4, src.h4); } template -inline void jit_convert_emitter::cvt_f16_to_i16(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_f16_to_i16(const TReg& src, const TReg& dst) const { h->fcvtzs(dst.h4, src.h4); } template -inline void jit_convert_emitter::cvt_i16_to_f16(const TReg &src, const TReg &dst) const { +inline void jit_convert_emitter::cvt_i16_to_f16(const TReg& src, const TReg& dst) const { h->scvtf(dst.h4, src.h4); } template -inline void jit_convert_emitter::cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const { +inline void jit_convert_emitter::cvt_i16_to_byte(const TReg& src, + const TReg& dst, + bool is_signed, + bool is_saturated) const { if (is_saturated) { if (is_signed) { h->sqxtn(dst.b8, src.h8); @@ -82,7 +86,7 @@ inline void jit_convert_emitter::cvt_i16_to_byte(const TReg &src, const TReg &ds } template -inline void jit_convert_emitter::cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const { +inline void jit_convert_emitter::cvt_byte_to_i16(const TReg& src, const TReg& dst, bool is_signed) const { if (is_signed) { h->sxtl(dst.h8, src.b8); } else { @@ -91,10 +95,13 @@ inline void jit_convert_emitter::cvt_byte_to_i16(const TReg &src, const TReg &ds } template -void jit_convert_emitter::jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type, +void jit_convert_emitter::jit_convert_process(const TReg& src, + const TReg& dst, + ov::element::Type input_type, + ov::element::Type output_type, bool is_saturated) const { - if (input_type == output_type || (!is_saturated && - one_of(input_type, ov::element::i8, ov::element::u8) && one_of(output_type, ov::element::i8, ov::element::u8))) { + if (input_type == output_type || (!is_saturated && one_of(input_type, ov::element::i8, ov::element::u8) && + one_of(output_type, ov::element::i8, ov::element::u8))) { if (src.getIdx() != dst.getIdx()) { h->mov(dst.b16, src.b16); } @@ -102,119 +109,130 @@ void jit_convert_emitter::jit_convert_process(const TReg &src, const TReg &dst, } switch (output_type) { + case ov::element::f32: + switch (input_type) { + case ov::element::i32: + cvt_i32_to_f32(src, dst); + break; + case ov::element::f16: + cvt_f16_to_f32(src, dst); + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_i32(dst, dst); + cvt_i32_to_f32(dst, dst); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + case ov::element::i32: + switch (input_type) { + case ov::element::f32: + cvt_f32_to_i32(src, dst); + break; + case ov::element::f16: + cvt_f16_to_f32(src, dst); + cvt_f32_to_i32(dst, dst); + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_i32(dst, dst); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + case ov::element::f16: + switch (input_type) { + case ov::element::f32: + cvt_f32_to_f16(src, dst); + break; + case ov::element::i32: + cvt_i32_to_f32(src, dst); + cvt_f32_to_f16(dst, dst); + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_i32(dst, dst); + cvt_i32_to_f32(dst, dst); + cvt_f32_to_f16(dst, dst); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + case ov::element::i8: + case ov::element::u8: + switch (input_type) { case ov::element::f32: - switch (input_type) { - case ov::element::i32: - cvt_i32_to_f32(src, dst); - break; - case ov::element::f16: - cvt_f16_to_f32(src, dst); - break; - case ov::element::i8: - case ov::element::u8: - cvt_byte_to_i16(src, dst, input_type.is_signed()); - cvt_i16_to_i32(dst, dst); - cvt_i32_to_f32(dst, dst); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); - } + cvt_f32_to_i32(src, dst); + cvt_i32_to_i16(dst, dst, is_saturated); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::i32: - switch (input_type) { - case ov::element::f32: - cvt_f32_to_i32(src, dst); - break; - case ov::element::f16: - cvt_f16_to_f32(src, dst); - cvt_f32_to_i32(dst, dst); - break; - case ov::element::i8: - case ov::element::u8: - cvt_byte_to_i16(src, dst, input_type.is_signed()); - cvt_i16_to_i32(dst, dst); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); - } + cvt_i32_to_i16(src, dst, is_saturated); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::f16: - switch (input_type) { - case ov::element::f32: - cvt_f32_to_f16(src, dst); - break; - case ov::element::i32: - cvt_i32_to_f32(src, dst); - cvt_f32_to_f16(dst, dst); - break; - case ov::element::i8: - case ov::element::u8: - cvt_byte_to_i16(src, dst, input_type.is_signed()); - cvt_i16_to_i32(dst, dst); - cvt_i32_to_f32(dst, dst); - cvt_f32_to_f16(dst, dst); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); - } + cvt_f16_to_f32(src, dst); + cvt_f32_to_i32(dst, dst); + cvt_i32_to_i16(dst, dst, is_saturated); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::i8: case ov::element::u8: - switch (input_type) { - case ov::element::f32: - cvt_f32_to_i32(src, dst); - cvt_i32_to_i16(dst, dst, is_saturated); - cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); - break; - case ov::element::i32: - cvt_i32_to_i16(src, dst, is_saturated); - cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); - break; - case ov::element::f16: - cvt_f16_to_f32(src, dst); - cvt_f32_to_i32(dst, dst); - cvt_i32_to_i16(dst, dst, is_saturated); - cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); - break; - case ov::element::i8: - case ov::element::u8: - cvt_byte_to_i16(src, dst, input_type.is_signed()); - cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); - } + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", output_type.get_type_name()); + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", output_type.get_type_name()); } } -jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_convert_emitter::jit_convert_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { input_type = node->get_input_element_type(0); output_type = node->get_output_element_type(0); } void jit_convert_emitter::validate_types() const { - OV_CPU_JIT_EMITTER_ASSERT(one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported input type: ", input_type.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(one_of(output_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported output type: ", output_type.get_type_name()); + OV_CPU_JIT_EMITTER_ASSERT( + one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported input type: ", + input_type.get_type_name()); + OV_CPU_JIT_EMITTER_ASSERT( + one_of(output_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported output type: ", + output_type.get_type_name()); } -size_t jit_convert_emitter::get_inputs_count() const { return 1; } +size_t jit_convert_emitter::get_inputs_count() const { + return 1; +} void jit_convert_emitter::emit_data() const { jit_emitter::emit_data(); } -jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa, - const std::shared_ptr& node, ov::element::Type exec_prc) - : jit_convert_emitter(host, host_isa, node, exec_prc) { -} +jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, node, exec_prc) {} -void jit_convert_truncation_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_convert_truncation_emitter::emit_impl(const std::vector& in_idxs, + const std::vector& out_idxs) const { validate_types(); if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_idxs, out_idxs); @@ -224,19 +242,22 @@ void jit_convert_truncation_emitter::emit_impl(const std::vector &in_idx } template -void jit_convert_truncation_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_convert_truncation_emitter::emit_isa(const std::vector& in_idxs, + const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); jit_convert_process(src, dst, input_type, output_type, false); } -jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa, - const std::shared_ptr& node, ov::element::Type exec_prc) - : jit_convert_emitter(host, host_isa, node, exec_prc) { -} +jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, node, exec_prc) {} -void jit_convert_saturation_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_convert_saturation_emitter::emit_impl(const std::vector& in_idxs, + const std::vector& out_idxs) const { validate_types(); if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_idxs, out_idxs); @@ -246,13 +267,14 @@ void jit_convert_saturation_emitter::emit_impl(const std::vector &in_idx } template -void jit_convert_saturation_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_convert_saturation_emitter::emit_isa(const std::vector& in_idxs, + const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); jit_convert_process(src, dst, input_type, output_type, true); } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp index af0310f736e5c9..bc9bb1e5005672 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp @@ -12,8 +12,10 @@ namespace aarch64 { class jit_convert_emitter : public jit_emitter { public: - jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_count() const override; @@ -21,7 +23,10 @@ class jit_convert_emitter : public jit_emitter { void emit_data() const override; void validate_types() const; template - void jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type, + void jit_convert_process(const TReg& src, + const TReg& dst, + ov::element::Type input_type, + ov::element::Type output_type, bool is_saturated) const; ov::element::Type input_type; @@ -29,25 +34,25 @@ class jit_convert_emitter : public jit_emitter { private: template - inline void cvt_f16_to_f32(const TReg &src, const TReg &dst) const; + inline void cvt_f16_to_f32(const TReg& src, const TReg& dst) const; template - inline void cvt_f32_to_f16(const TReg &src, const TReg &dst) const; + inline void cvt_f32_to_f16(const TReg& src, const TReg& dst) const; template - inline void cvt_f32_to_i32(const TReg &src, const TReg &dst) const; + inline void cvt_f32_to_i32(const TReg& src, const TReg& dst) const; template - inline void cvt_i32_to_f32(const TReg &src, const TReg &dst) const; + inline void cvt_i32_to_f32(const TReg& src, const TReg& dst) const; template - inline void cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const; + inline void cvt_i32_to_i16(const TReg& src, const TReg& dst, bool is_saturated) const; template - inline void cvt_i16_to_i32(const TReg &src, const TReg &dst) const; + inline void cvt_i16_to_i32(const TReg& src, const TReg& dst) const; template - inline void cvt_f16_to_i16(const TReg &src, const TReg &dst) const; + inline void cvt_f16_to_i16(const TReg& src, const TReg& dst) const; template - inline void cvt_i16_to_f16(const TReg &src, const TReg &dst) const; + inline void cvt_i16_to_f16(const TReg& src, const TReg& dst) const; template - inline void cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const; + inline void cvt_i16_to_byte(const TReg& src, const TReg& dst, bool is_signed, bool is_saturated) const; template - inline void cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const; + inline void cvt_byte_to_i16(const TReg& src, const TReg& dst, bool is_signed) const; }; // This emitter is covered by specification of "Convert" operation. The implementation uses a "warp-around" conversion. @@ -56,13 +61,15 @@ class jit_convert_emitter : public jit_emitter { // 129 -> -127 class jit_convert_truncation_emitter : public jit_convert_emitter { public: - jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); private: void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; template - void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; + void emit_isa(const std::vector& in_idxs, const std::vector& out_idxs) const; }; // This emitter is covered by the common dnnl behavior. The implementation uses a "saturation" conversion. @@ -71,15 +78,17 @@ class jit_convert_truncation_emitter : public jit_convert_emitter { // 129 -> 127 class jit_convert_saturation_emitter : public jit_convert_emitter { public: - jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); private: void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; template - void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; + void emit_isa(const std::vector& in_idxs, const std::vector& out_idxs) const; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index 05a0e0a2cf6a0e..534470c746f2fe 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -3,11 +3,12 @@ // #include "jit_eltwise_emitters.hpp" -#include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include + #include "common/utils.hpp" #include "emitters/utils.hpp" +#include "transformations/cpu_opset/common/op/swish_cpu.hpp" namespace ov { namespace intel_cpu { @@ -21,34 +22,35 @@ namespace { ov::element::Type get_arithmetic_binary_exec_precision(const std::shared_ptr& n) { std::vector input_precisions; for (const auto& input : n->inputs()) { - input_precisions.push_back( - input.get_source_output().get_element_type()); + input_precisions.push_back(input.get_source_output().get_element_type()); } - assert(std::all_of( - input_precisions.begin(), - input_precisions.end(), - [&input_precisions](const ov::element::Type& precision) {return precision == input_precisions[0]; })); + assert(std::all_of(input_precisions.begin(), + input_precisions.end(), + [&input_precisions](const ov::element::Type& precision) { + return precision == input_precisions[0]; + })); return input_precisions[0]; } -} // namespace +} // namespace /// ABS /// jit_abs_emitter::jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_abs_emitter::jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_abs_emitter::get_inputs_count() const { return 1; } +size_t jit_abs_emitter::get_inputs_count() const { + return 1; +} -void jit_abs_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_abs_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -57,7 +59,7 @@ void jit_abs_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_abs_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_abs_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -75,17 +77,18 @@ std::set> jit_abs_emitter::get_supported_precisions(c jit_add_emitter::jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_add_emitter::jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_add_emitter::get_inputs_count() const { return 2; } +size_t jit_add_emitter::get_inputs_count() const { + return 2; +} -void jit_add_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_add_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -94,7 +97,7 @@ void jit_add_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_add_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_add_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -113,7 +116,7 @@ std::set> jit_add_emitter::get_supported_precisions(c jit_clamp_emitter::jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { const auto clamp = std::dynamic_pointer_cast(node); if (clamp == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v0::Clamp"); @@ -129,24 +132,31 @@ jit_clamp_emitter::jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator* ho const float min, const float max, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc), - min(min), - max(max) { + : jit_emitter(host, host_isa, exec_prc), + min(min), + max(max) { prepare_table(); } -size_t jit_clamp_emitter::get_inputs_count() const { return 1; } +size_t jit_clamp_emitter::get_inputs_count() const { + return 1; +} -size_t jit_clamp_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_clamp_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_clamp_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_clamp_emitter::get_aux_gprs_count() const { + return 1; +} void jit_clamp_emitter::register_table_entries() { push_arg_entry_of("min", dnnl::impl::float2int(min), true); push_arg_entry_of("max", dnnl::impl::float2int(max), true); } -void jit_clamp_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_clamp_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -155,7 +165,8 @@ void jit_clamp_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_clamp_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_clamp_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -169,24 +180,28 @@ void jit_clamp_emitter::emit_isa(const std::vector &in_vec_idxs, const s h->fmin(dst.s, dst.s, aux.s); } -std::set> jit_clamp_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_clamp_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } /// DIVIDE /// -jit_divide_emitter::jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} +jit_divide_emitter::jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -jit_divide_emitter::jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) {} +jit_divide_emitter::jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_divide_emitter::get_inputs_count() const { return 2; } +size_t jit_divide_emitter::get_inputs_count() const { + return 2; +} -void jit_divide_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_divide_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -195,7 +210,8 @@ void jit_divide_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_divide_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_divide_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -206,35 +222,44 @@ void jit_divide_emitter::emit_isa(const std::vector &in_vec_idxs, const h->uni_fdiv(dst.s, src0.s, src1.s); } -std::set> jit_divide_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_divide_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } /// EQUAL /// -jit_equal_emitter::jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_equal_emitter::jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } -jit_equal_emitter::jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_equal_emitter::jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_equal_emitter::get_inputs_count() const { return 2; } +size_t jit_equal_emitter::get_inputs_count() const { + return 2; +} -size_t jit_equal_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_equal_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_equal_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_equal_emitter::get_aux_gprs_count() const { + return 1; +} -std::set> jit_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -243,7 +268,8 @@ void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -266,7 +292,7 @@ void jit_equal_emitter::register_table_entries() { jit_elu_emitter::jit_elu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { const auto elu = std::dynamic_pointer_cast(node); if (elu == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v0::Clamp"); @@ -280,12 +306,16 @@ jit_elu_emitter::jit_elu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, jit_elu_emitter::jit_elu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const float alpha, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc), alpha(alpha) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc), + alpha(alpha) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_elu_emitter::get_inputs_count() const { return 1; } +size_t jit_elu_emitter::get_inputs_count() const { + return 1; +} size_t jit_elu_emitter::get_aux_vecs_count() const { return std::max(exp_emitter->get_aux_vecs_count() + 1ull, 2ull); @@ -295,7 +325,7 @@ size_t jit_elu_emitter::get_aux_gprs_count() const { return exp_emitter->get_aux_gprs_count() + 1; } -void jit_elu_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_elu_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -304,7 +334,7 @@ void jit_elu_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_elu_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_elu_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -315,11 +345,7 @@ void jit_elu_emitter::emit_isa(const std::vector &in_vec_idxs, const std h->mov(vmm_aux1.b16, vmm_src.b16); // compute exponent - exp_emitter->emit_code( - { vmm_src.getIdx() }, - out_vec_idxs, - aux_vec_idxs, - aux_gpr_idxs); + exp_emitter->emit_code({vmm_src.getIdx()}, out_vec_idxs, aux_vec_idxs, aux_gpr_idxs); // alpha * (exp(x) - 1) const TReg vmm_aux0(aux_vec_idxs[0]); @@ -351,23 +377,30 @@ std::set> jit_elu_emitter::get_supported_precisions(c jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_exp_emitter::get_inputs_count() const { return 1; } +size_t jit_exp_emitter::get_inputs_count() const { + return 1; +} -size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; } +size_t jit_exp_emitter::get_aux_vecs_count() const { + return 4; +} -size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_exp_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_exp_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_exp_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -376,7 +409,7 @@ void jit_exp_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_exp_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (exec_prc_ != ov::element::f32) { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); } @@ -484,17 +517,19 @@ std::set> jit_exp_emitter::get_supported_precisions(c jit_floor_emitter::jit_floor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_floor_emitter::jit_floor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_floor_emitter::get_inputs_count() const { return 1; } +size_t jit_floor_emitter::get_inputs_count() const { + return 1; +} -void jit_floor_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_floor_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -503,7 +538,8 @@ void jit_floor_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_floor_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_floor_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -512,27 +548,32 @@ void jit_floor_emitter::emit_isa(const std::vector &in_vec_idxs, const s h->frintm(dst.s, src.s); } -std::set> jit_floor_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_floor_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } /// FLOOR_MOD /// -jit_floor_mod_emitter::jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_floor_mod_emitter::jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -jit_floor_mod_emitter::jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_floor_mod_emitter::jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc): jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_floor_mod_emitter::get_inputs_count() const { return 2; } +size_t jit_floor_mod_emitter::get_inputs_count() const { + return 2; +} -size_t jit_floor_mod_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_floor_mod_emitter::get_aux_vecs_count() const { + return 1; +} -void jit_floor_mod_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_floor_mod_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -541,7 +582,8 @@ void jit_floor_mod_emitter::emit_impl(const std::vector &in_vec_idxs, co } template -void jit_floor_mod_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_floor_mod_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -557,29 +599,32 @@ void jit_floor_mod_emitter::emit_isa(const std::vector &in_vec_idxs, con h->fsub(r.s, dividend.s, aux.s); } -std::set> jit_floor_mod_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_floor_mod_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } /// CEILING /// -//Initialization of the emitter, taking node as input +// Initialization of the emitter, taking node as input jit_ceiling_emitter::jit_ceiling_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -//Initialization of emitter, without taking node as input +// Initialization of emitter, without taking node as input jit_ceiling_emitter::jit_ceiling_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -//This will tell the JIT compiler that how many inputs the ceiling operation requires (here 1) -size_t jit_ceiling_emitter::get_inputs_count() const { return 1; } +// This will tell the JIT compiler that how many inputs the ceiling operation requires (here 1) +size_t jit_ceiling_emitter::get_inputs_count() const { + return 1; +} -//Main implementation method that emits the JIT code -void jit_ceiling_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +// Main implementation method that emits the JIT code +void jit_ceiling_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -590,7 +635,8 @@ void jit_ceiling_emitter::emit_impl(const std::vector &in_vec_idxs, cons // Template method that generates actual instruction sequence for ceiling operation // The h->frintp() method rounds up the floating value to the nearest integer. template -void jit_ceiling_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_ceiling_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -601,7 +647,8 @@ void jit_ceiling_emitter::emit_isa(const std::vector &in_vec_idxs, const // Template method that generates actual instruction sequence for ceiling operation // Currently only supports 32-bit floating point (f32) -std::set> jit_ceiling_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_ceiling_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } @@ -609,19 +656,22 @@ std::set> jit_ceiling_emitter::get_supported_precisio jit_gelu_erf_emitter::jit_gelu_erf_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, node); } jit_gelu_erf_emitter::jit_gelu_erf_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_gelu_erf_emitter::get_inputs_count() const { return 1; } +size_t jit_gelu_erf_emitter::get_inputs_count() const { + return 1; +} size_t jit_gelu_erf_emitter::get_aux_vecs_count() const { return std::max(exp_emitter->get_aux_vecs_count() + 3, 7); @@ -631,7 +681,8 @@ size_t jit_gelu_erf_emitter::get_aux_gprs_count() const { return exp_emitter->get_aux_gprs_count() + 1; } -void jit_gelu_erf_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_gelu_erf_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -640,7 +691,8 @@ void jit_gelu_erf_emitter::emit_impl(const std::vector &in_vec_idxs, con } template -void jit_gelu_erf_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_gelu_erf_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -673,11 +725,7 @@ void jit_gelu_erf_emitter::emit_isa(const std::vector &in_vec_idxs, cons h->fmul(vmm_aux.s, vmm_aux0.s, vmm_aux0.s); h->ld1r(vmm_aux2.s, table_val2("sign_mask")); h->orr(vmm_aux.b16, vmm_aux.b16, vmm_aux2.b16); - exp_emitter->emit_code( - { vmm_aux.getIdx() }, - { vmm_aux_dst.getIdx() }, - aux_vec_idxs, - aux_gpr_idxs); + exp_emitter->emit_code({vmm_aux.getIdx()}, {vmm_aux_dst.getIdx()}, aux_vec_idxs, aux_gpr_idxs); h->ld1r(vmm_aux2.s, table_val2("sign_mask")); // vmm_aux_dst = -exp(-x*x) h->orr(vmm_aux_dst.b16, vmm_aux_dst.b16, vmm_aux2.b16); @@ -723,11 +771,11 @@ void jit_gelu_erf_emitter::register_table_entries() { push_arg_entry_of("gelu_erf_one_over_sqrt_two", 0x3f3504f3, true); push_arg_entry_of("gelu_erf_one_over_sqrt_pi", 0x3f106eba, true); - push_arg_entry_of("erf_pol1", 0x3e827906, true); // p1 = 0.254829592f - push_arg_entry_of("erf_pol2", 0xbe91a98e, true); // p2 = -0.284496736f - push_arg_entry_of("erf_pol3", 0x3fb5f0e3, true); // p3 = 1.421413741f - push_arg_entry_of("erf_pol4", 0xbfba00e3, true); // p4 = -1.453152027f - push_arg_entry_of("erf_pol5", 0x3f87dc22, true); // p5 = 1.061405429f + push_arg_entry_of("erf_pol1", 0x3e827906, true); // p1 = 0.254829592f + push_arg_entry_of("erf_pol2", 0xbe91a98e, true); // p2 = -0.284496736f + push_arg_entry_of("erf_pol3", 0x3fb5f0e3, true); // p3 = 1.421413741f + push_arg_entry_of("erf_pol4", 0xbfba00e3, true); // p4 = -1.453152027f + push_arg_entry_of("erf_pol5", 0x3f87dc22, true); // p5 = 1.061405429f } void jit_gelu_erf_emitter::emit_data() const { @@ -735,7 +783,8 @@ void jit_gelu_erf_emitter::emit_data() const { exp_emitter->emit_data(); } -std::set> jit_gelu_erf_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_gelu_erf_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } @@ -743,19 +792,22 @@ std::set> jit_gelu_erf_emitter::get_supported_precisi jit_gelu_tanh_emitter::jit_gelu_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); tanh_emitter = std::make_unique(h, host_isa, node); } jit_gelu_tanh_emitter::jit_gelu_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); tanh_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_gelu_tanh_emitter::get_inputs_count() const { return 1; } +size_t jit_gelu_tanh_emitter::get_inputs_count() const { + return 1; +} size_t jit_gelu_tanh_emitter::get_aux_vecs_count() const { return std::max(tanh_emitter->get_aux_vecs_count() + 2, 3); @@ -765,7 +817,8 @@ size_t jit_gelu_tanh_emitter::get_aux_gprs_count() const { return tanh_emitter->get_aux_gprs_count() + 1; } -void jit_gelu_tanh_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_gelu_tanh_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -774,7 +827,8 @@ void jit_gelu_tanh_emitter::emit_impl(const std::vector &in_vec_idxs, co } template -void jit_gelu_tanh_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_gelu_tanh_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -794,11 +848,7 @@ void jit_gelu_tanh_emitter::emit_isa(const std::vector &in_vec_idxs, con h->ld1r(vmm_aux1.s, table_val2("gelu_tanh_sqrt_two_over_pi")); h->fmul(vmm_aux0.s, vmm_aux1.s, vmm_aux2.s); - tanh_emitter->emit_code( - { vmm_aux0.getIdx() }, - { vmm_aux2.getIdx() }, - aux_vec_idxs, - aux_gpr_idxs); + tanh_emitter->emit_code({vmm_aux0.getIdx()}, {vmm_aux2.getIdx()}, aux_vec_idxs, aux_gpr_idxs); // compute 0.5 * x * (1 + tanh(G(x))) h->ld1r(vmm_aux1.s, table_val2("one")); @@ -821,7 +871,8 @@ void jit_gelu_tanh_emitter::emit_data() const { tanh_emitter->emit_data(); } -std::set> jit_gelu_tanh_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_gelu_tanh_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } @@ -829,7 +880,7 @@ std::set> jit_gelu_tanh_emitter::get_supported_precis jit_greater_emitter::jit_greater_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } @@ -840,13 +891,20 @@ jit_greater_emitter::jit_greater_emitter(dnnl::impl::cpu::aarch64::jit_generator prepare_table(); } -size_t jit_greater_emitter::get_inputs_count() const { return 2; } +size_t jit_greater_emitter::get_inputs_count() const { + return 2; +} -size_t jit_greater_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_greater_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_greater_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_greater_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_greater_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_greater_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -855,7 +913,8 @@ void jit_greater_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_greater_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_greater_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -873,7 +932,8 @@ void jit_greater_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_greater_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_greater_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } @@ -881,7 +941,7 @@ std::set> jit_greater_emitter::get_supported_precisio jit_greater_equal_emitter::jit_greater_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } @@ -892,13 +952,20 @@ jit_greater_equal_emitter::jit_greater_equal_emitter(dnnl::impl::cpu::aarch64::j prepare_table(); } -size_t jit_greater_equal_emitter::get_inputs_count() const { return 2; } +size_t jit_greater_equal_emitter::get_inputs_count() const { + return 2; +} -size_t jit_greater_equal_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_greater_equal_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_greater_equal_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_greater_equal_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_greater_equal_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_greater_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -907,7 +974,8 @@ void jit_greater_equal_emitter::emit_impl(const std::vector &in_vec_idxs } template -void jit_greater_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_greater_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -925,31 +993,40 @@ void jit_greater_equal_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_greater_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_greater_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } /// HARD_SWISH /// jit_hswish_emitter::jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } jit_hswish_emitter::jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_hswish_emitter::get_inputs_count() const { return 1; } +size_t jit_hswish_emitter::get_inputs_count() const { + return 1; +} -size_t jit_hswish_emitter::get_aux_vecs_count() const { return 2; } +size_t jit_hswish_emitter::get_aux_vecs_count() const { + return 2; +} -size_t jit_hswish_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_hswish_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_hswish_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_hswish_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -958,7 +1035,8 @@ void jit_hswish_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_hswish_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_hswish_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -983,18 +1061,19 @@ void jit_hswish_emitter::register_table_entries() { push_arg_entry_of("zero", 0x00000000, true); push_arg_entry_of("three", 0x40400000, true); push_arg_entry_of("six", 0x40c00000, true); - push_arg_entry_of("one_sixth", dnnl::impl::float2int(1.f/6.f), true); + push_arg_entry_of("one_sixth", dnnl::impl::float2int(1.f / 6.f), true); } -std::set> jit_hswish_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_hswish_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } /// IS_FINITE /// jit_is_finite_emitter::jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { auto isNaN = ov::as_type_ptr(node); if (isNaN == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v10::IsNaN"); @@ -1004,23 +1083,31 @@ jit_is_finite_emitter::jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_gener } jit_is_finite_emitter::jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_is_finite_emitter::get_inputs_count() const { return 1; } +size_t jit_is_finite_emitter::get_inputs_count() const { + return 1; +} -size_t jit_is_finite_emitter::get_aux_vecs_count() const { return 2; } +size_t jit_is_finite_emitter::get_aux_vecs_count() const { + return 2; +} -size_t jit_is_finite_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_is_finite_emitter::get_aux_gprs_count() const { + return 1; +} -std::set> jit_is_finite_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_is_finite_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_is_finite_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_is_finite_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1029,7 +1116,8 @@ void jit_is_finite_emitter::emit_impl(const std::vector& in_vec_idxs, co } template -void jit_is_finite_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_finite_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1039,7 +1127,8 @@ void jit_is_finite_emitter::emit_isa(const std::vector &in_vec_idxs, con TReg aux0 = TReg(aux_vec_idxs[0]); TReg aux1 = TReg(aux_vec_idxs[1]); - // According to the IEEE standard, NaN values have the odd property that comparisons involving them are always false. + // According to the IEEE standard, NaN values have the odd property that comparisons involving them are always + // false. h->fcmeq(aux0.s, src.s, src.s); h->not_(aux0.b16, aux0.b16); @@ -1068,7 +1157,6 @@ jit_is_inf_emitter::jit_is_inf_emitter(dnnl::impl::cpu::aarch64::jit_generator* dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { - auto isInf = ov::as_type_ptr(node); if (isInf == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v10::IsInf"); @@ -1163,9 +1251,9 @@ void jit_is_inf_emitter::register_table_entries() { /// IS_NAN /// jit_is_nan_emitter::jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { auto isNaN = ov::as_type_ptr(node); if (isNaN == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v10::IsNaN"); @@ -1175,23 +1263,31 @@ jit_is_nan_emitter::jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* } jit_is_nan_emitter::jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_is_nan_emitter::get_inputs_count() const { return 1; } +size_t jit_is_nan_emitter::get_inputs_count() const { + return 1; +} -size_t jit_is_nan_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_is_nan_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_is_nan_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_is_nan_emitter::get_aux_gprs_count() const { + return 1; +} -std::set> jit_is_nan_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_is_nan_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_is_nan_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_is_nan_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1200,7 +1296,8 @@ void jit_is_nan_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_is_nan_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_nan_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1209,7 +1306,8 @@ void jit_is_nan_emitter::emit_isa(const std::vector &in_vec_idxs, const TReg dst = TReg(out_vec_idxs[0]); TReg aux = TReg(aux_vec_idxs[0]); - // According to the IEEE standard, NaN values have the odd property that comparisons involving them are always false. + // According to the IEEE standard, NaN values have the odd property that comparisons involving them are always + // false. h->fcmeq(dst.s, src.s, src.s); h->ld1r(aux.s, table_val2("zero")); h->fcmeq(dst.s, dst.s, aux.s); @@ -1228,7 +1326,7 @@ void jit_is_nan_emitter::register_table_entries() { jit_less_equal_emitter::jit_less_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } @@ -1239,13 +1337,20 @@ jit_less_equal_emitter::jit_less_equal_emitter(dnnl::impl::cpu::aarch64::jit_gen prepare_table(); } -size_t jit_less_equal_emitter::get_inputs_count() const { return 2; } +size_t jit_less_equal_emitter::get_inputs_count() const { + return 2; +} -size_t jit_less_equal_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_less_equal_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_less_equal_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_less_equal_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_less_equal_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_less_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1254,7 +1359,8 @@ void jit_less_equal_emitter::emit_impl(const std::vector &in_vec_idxs, c } template -void jit_less_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_less_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1273,7 +1379,8 @@ void jit_less_equal_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_less_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_less_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } @@ -1281,7 +1388,7 @@ std::set> jit_less_equal_emitter::get_supported_preci jit_logical_and_emitter::jit_logical_and_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } @@ -1292,13 +1399,20 @@ jit_logical_and_emitter::jit_logical_and_emitter(dnnl::impl::cpu::aarch64::jit_g prepare_table(); } -size_t jit_logical_and_emitter::get_inputs_count() const { return 2; } +size_t jit_logical_and_emitter::get_inputs_count() const { + return 2; +} -size_t jit_logical_and_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_logical_and_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_logical_and_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_logical_and_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_logical_and_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_and_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1307,7 +1421,8 @@ void jit_logical_and_emitter::emit_impl(const std::vector &in_vec_idxs, } template -void jit_logical_and_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_and_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1325,32 +1440,40 @@ void jit_logical_and_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_logical_and_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_and_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } /// LOGICAL_OR /// jit_logical_or_emitter::jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } jit_logical_or_emitter::jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_logical_or_emitter::get_inputs_count() const { return 2; } +size_t jit_logical_or_emitter::get_inputs_count() const { + return 2; +} -size_t jit_logical_or_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_logical_or_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_logical_or_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_logical_or_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_logical_or_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_or_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1359,7 +1482,8 @@ void jit_logical_or_emitter::emit_impl(const std::vector &in_vec_idxs, c } template -void jit_logical_or_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_or_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1377,24 +1501,25 @@ void jit_logical_or_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_logical_or_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_or_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } /// LOGICAL_NOT /// jit_logical_not_emitter::jit_logical_not_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { - prepare_table(); - } + prepare_table(); +} jit_logical_not_emitter::jit_logical_not_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { - prepare_table(); - } + prepare_table(); +} size_t jit_logical_not_emitter::get_inputs_count() const { return 1; @@ -1409,7 +1534,7 @@ size_t jit_logical_not_emitter::get_aux_gprs_count() const { } void jit_logical_not_emitter::emit_impl(const std::vector& in_vec_idxs, - const std::vector& out_vec_idxs) const { + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1419,7 +1544,7 @@ void jit_logical_not_emitter::emit_impl(const std::vector& in_vec_idxs, template void jit_logical_not_emitter::emit_isa(const std::vector& in_vec_idxs, - const std::vector& out_vec_idxs) const { + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1447,7 +1572,7 @@ std::set> jit_logical_not_emitter::get_supported_prec jit_logical_xor_emitter::jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } @@ -1458,13 +1583,20 @@ jit_logical_xor_emitter::jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_g prepare_table(); } -size_t jit_logical_xor_emitter::get_inputs_count() const { return 2; } +size_t jit_logical_xor_emitter::get_inputs_count() const { + return 2; +} -size_t jit_logical_xor_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_logical_xor_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_logical_xor_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_logical_xor_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_logical_xor_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_xor_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1473,7 +1605,8 @@ void jit_logical_xor_emitter::emit_impl(const std::vector &in_vec_idxs, } template -void jit_logical_xor_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_xor_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1491,7 +1624,8 @@ void jit_logical_xor_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_logical_xor_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_xor_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } @@ -1499,17 +1633,19 @@ std::set> jit_logical_xor_emitter::get_supported_prec jit_maximum_emitter::jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_maximum_emitter::jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_maximum_emitter::get_inputs_count() const { return 2; } +size_t jit_maximum_emitter::get_inputs_count() const { + return 2; +} -void jit_maximum_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_maximum_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1518,7 +1654,8 @@ void jit_maximum_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_maximum_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_maximum_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1529,7 +1666,8 @@ void jit_maximum_emitter::emit_isa(const std::vector &in_vec_idxs, const h->fmaxnm(dst.s, src1.s, src2.s); } -std::set> jit_maximum_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_maximum_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } @@ -1537,17 +1675,19 @@ std::set> jit_maximum_emitter::get_supported_precisio jit_minimum_emitter::jit_minimum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_minimum_emitter::jit_minimum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_minimum_emitter::get_inputs_count() const { return 2; } +size_t jit_minimum_emitter::get_inputs_count() const { + return 2; +} -void jit_minimum_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_minimum_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1556,7 +1696,8 @@ void jit_minimum_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_minimum_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_minimum_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1567,7 +1708,8 @@ void jit_minimum_emitter::emit_isa(const std::vector &in_vec_idxs, const h->fminnm(dst.s, src1.s, src2.s); } -std::set> jit_minimum_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_minimum_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } @@ -1575,19 +1717,22 @@ std::set> jit_minimum_emitter::get_supported_precisio jit_mish_emitter::jit_mish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, node); } jit_mish_emitter::jit_mish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_mish_emitter::get_inputs_count() const { return 1; } +size_t jit_mish_emitter::get_inputs_count() const { + return 1; +} size_t jit_mish_emitter::get_aux_vecs_count() const { return std::max(exp_emitter->get_aux_vecs_count() + 1, 2); @@ -1597,7 +1742,8 @@ size_t jit_mish_emitter::get_aux_gprs_count() const { return exp_emitter->get_aux_gprs_count() + 1; } -void jit_mish_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mish_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1606,7 +1752,7 @@ void jit_mish_emitter::emit_impl(const std::vector &in_vec_idxs, const s } template -void jit_mish_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mish_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); // An equation other than mish(x) = x*tanh(srelu(x)) was used @@ -1628,11 +1774,7 @@ void jit_mish_emitter::emit_isa(const std::vector &in_vec_idxs, const st h->ld1r(vmm_aux0.s, table_val2("fwd_mish_max_x_for_equation_f")); h->fminnm(vmm_aux2.s, vmm_src.s, vmm_aux0.s); - exp_emitter->emit_code( - { vmm_aux2.getIdx() }, - { vmm_aux2.getIdx() }, - aux_vec_idxs, - aux_gpr_idxs); + exp_emitter->emit_code({vmm_aux2.getIdx()}, {vmm_aux2.getIdx()}, aux_vec_idxs, aux_gpr_idxs); // (e^x+1)^2 h->fmov(vmm_aux0.s, 1.f); @@ -1665,22 +1807,25 @@ std::set> jit_mish_emitter::get_supported_precisions( } /// MOD /// -jit_mod_emitter::jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_mod_emitter::jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -jit_mod_emitter::jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc): jit_emitter(host, host_isa, exec_prc) { -} +jit_mod_emitter::jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_mod_emitter::get_inputs_count() const { return 2; } +size_t jit_mod_emitter::get_inputs_count() const { + return 2; +} -size_t jit_mod_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_mod_emitter::get_aux_vecs_count() const { + return 1; +} -void jit_mod_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mod_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1689,7 +1834,7 @@ void jit_mod_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_mod_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mod_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1713,20 +1858,23 @@ std::set> jit_mod_emitter::get_supported_precisions(c jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { -} + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_mul_add_emitter::get_inputs_count() const { return 3; } +size_t jit_mul_add_emitter::get_inputs_count() const { + return 3; +} -size_t jit_mul_add_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_mul_add_emitter::get_aux_vecs_count() const { + return 1; +} -void jit_mul_add_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mul_add_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1735,7 +1883,8 @@ void jit_mul_add_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_mul_add_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mul_add_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1765,24 +1914,28 @@ void jit_mul_add_emitter::emit_isa(const std::vector &in_vec_idxs, const h->fmla(dst.s, mul0.s, mul1.s); } -std::set> jit_mul_add_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_mul_add_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32, element::f32}}; } /// MULTIPLY /// -jit_multiply_emitter::jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_multiply_emitter::jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -jit_multiply_emitter::jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_multiply_emitter::jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) {} + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_multiply_emitter::get_inputs_count() const { return 2; } +size_t jit_multiply_emitter::get_inputs_count() const { + return 2; +} -void jit_multiply_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_multiply_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1791,7 +1944,8 @@ void jit_multiply_emitter::emit_impl(const std::vector &in_vec_idxs, con } template -void jit_multiply_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_multiply_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1802,16 +1956,17 @@ void jit_multiply_emitter::emit_isa(const std::vector &in_vec_idxs, cons h->uni_fmul(dst.s, src0.s, src1.s); } -std::set> jit_multiply_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_multiply_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } /// POWER /// -jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, node, exec_prc) { + : jit_emitter(host, host_isa, node, exec_prc) { auto powerStaticNode = ov::as_type_ptr(node); if (powerStaticNode == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to snippets::op::PowerStatic"); @@ -1824,24 +1979,30 @@ jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit prepare_table(); } -jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const float power, const float scale, const float shift, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc), - power(power), - scale(scale), - shift(shift) { + : jit_emitter(host, host_isa, exec_prc), + power(power), + scale(scale), + shift(shift) { prepare_table(); } -size_t jit_power_static_emitter::get_inputs_count() const { return 1; } +size_t jit_power_static_emitter::get_inputs_count() const { + return 1; +} -size_t jit_power_static_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_power_static_emitter::get_aux_vecs_count() const { + return 1; +} -size_t jit_power_static_emitter::get_aux_gprs_count() const { return 2; } +size_t jit_power_static_emitter::get_aux_gprs_count() const { + return 2; +} void jit_power_static_emitter::register_table_entries() { push_arg_entry_of("power", dnnl::impl::float2int(power), true); @@ -1849,11 +2010,13 @@ void jit_power_static_emitter::register_table_entries() { push_arg_entry_of("shift", dnnl::impl::float2int(shift), true); } -std::set> jit_power_static_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_power_static_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1862,7 +2025,8 @@ void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_power_static_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_power_static_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1945,26 +2109,30 @@ void jit_power_static_emitter::emit_isa(const std::vector &in_vec_idxs, /// PRELU /// jit_prelu_emitter::jit_prelu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_prelu_emitter::jit_prelu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { -} + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_prelu_emitter::get_inputs_count() const { return 2; } +size_t jit_prelu_emitter::get_inputs_count() const { + return 2; +} -size_t jit_prelu_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_prelu_emitter::get_aux_vecs_count() const { + return 1; +} -std::set> jit_prelu_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_prelu_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -1973,7 +2141,8 @@ void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_prelu_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_prelu_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -1992,24 +2161,27 @@ void jit_prelu_emitter::emit_isa(const std::vector &in_vec_idxs, const s jit_relu_emitter::jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_relu_emitter::jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { -} + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_relu_emitter::get_inputs_count() const { return 1; } +size_t jit_relu_emitter::get_inputs_count() const { + return 1; +} -size_t jit_relu_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_relu_emitter::get_aux_vecs_count() const { + return 1; +} std::set> jit_relu_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32}}; } -void jit_relu_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_relu_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2018,7 +2190,7 @@ void jit_relu_emitter::emit_impl(const std::vector& in_vec_idxs, const s } template -void jit_relu_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_relu_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2032,27 +2204,29 @@ void jit_relu_emitter::emit_isa(const std::vector &in_vec_idxs, const st } /// ROUND_HALF_AWAY_FROM_ZERO /// -jit_round_half_away_from_zero_emitter::jit_round_half_away_from_zero_emitter - (dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} - -jit_round_half_away_from_zero_emitter::jit_round_half_away_from_zero_emitter - (dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { +jit_round_half_away_from_zero_emitter::jit_round_half_away_from_zero_emitter( + dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} + +jit_round_half_away_from_zero_emitter::jit_round_half_away_from_zero_emitter( + dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} + +size_t jit_round_half_away_from_zero_emitter::get_inputs_count() const { + return 1; } -size_t jit_round_half_away_from_zero_emitter::get_inputs_count() const { return 1; } - -std::set> jit_round_half_away_from_zero_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_round_half_away_from_zero_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_round_half_away_from_zero_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_round_half_away_from_zero_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2061,7 +2235,8 @@ void jit_round_half_away_from_zero_emitter::emit_impl(const std::vector& } template -void jit_round_half_away_from_zero_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_round_half_away_from_zero_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2073,27 +2248,27 @@ void jit_round_half_away_from_zero_emitter::emit_isa(const std::vector & } /// ROUND_HALF_TO_EVEN /// -jit_round_half_to_even_emitter::jit_round_half_to_even_emitter - (dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} +jit_round_half_to_even_emitter::jit_round_half_to_even_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} -jit_round_half_to_even_emitter::jit_round_half_to_even_emitter - (dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { -} +jit_round_half_to_even_emitter::jit_round_half_to_even_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_round_half_to_even_emitter::get_inputs_count() const { return 1; } +size_t jit_round_half_to_even_emitter::get_inputs_count() const { + return 1; +} -std::set> jit_round_half_to_even_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_round_half_to_even_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_round_half_to_even_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_round_half_to_even_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2102,7 +2277,8 @@ void jit_round_half_to_even_emitter::emit_impl(const std::vector& in_vec } template -void jit_round_half_to_even_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_round_half_to_even_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2114,26 +2290,30 @@ void jit_round_half_to_even_emitter::emit_isa(const std::vector &in_vec_ } /// SELECT /// -jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { -} -jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { -} + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_select_emitter::get_inputs_count() const { return 3; } +size_t jit_select_emitter::get_inputs_count() const { + return 3; +} -size_t jit_select_emitter::get_aux_vecs_count() const { return 1; } +size_t jit_select_emitter::get_aux_vecs_count() const { + return 1; +} -std::set> jit_select_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_select_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32, element::f32}}; } -void jit_select_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_select_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2142,7 +2322,8 @@ void jit_select_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_select_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_select_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2163,19 +2344,22 @@ void jit_select_emitter::emit_isa(const std::vector &in_vec_idxs, const jit_sigmoid_emitter::jit_sigmoid_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, node); } jit_sigmoid_emitter::jit_sigmoid_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); exp_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_sigmoid_emitter::get_inputs_count() const { return 1; } +size_t jit_sigmoid_emitter::get_inputs_count() const { + return 1; +} size_t jit_sigmoid_emitter::get_aux_vecs_count() const { return exp_emitter->get_aux_vecs_count() + 2; @@ -2185,7 +2369,8 @@ size_t jit_sigmoid_emitter::get_aux_gprs_count() const { return exp_emitter->get_aux_gprs_count() + 1; } -void jit_sigmoid_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_sigmoid_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2194,7 +2379,8 @@ void jit_sigmoid_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_sigmoid_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_sigmoid_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (exec_prc_ != ov::element::f32) { OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string()); } @@ -2217,11 +2403,7 @@ void jit_sigmoid_emitter::emit_isa(const std::vector &in_vec_idxs, const h->ld1r(vmm_aux0.s, table_val2("sign_mask")); h->orr(vmm_aux0.b16, vmm_src.b16, vmm_aux0.b16); - exp_emitter->emit_code( - { vmm_aux0.getIdx() }, - out_vec_idxs, - aux_vec_idxs, - aux_gpr_idxs); + exp_emitter->emit_code({vmm_aux0.getIdx()}, out_vec_idxs, aux_vec_idxs, aux_gpr_idxs); const TReg vmm_aux1(aux_vec_idxs[0]); const TReg vmm_aux2(aux_vec_idxs[1]); @@ -2251,7 +2433,8 @@ void jit_sigmoid_emitter::emit_data() const { exp_emitter->emit_data(); } -std::set> jit_sigmoid_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_sigmoid_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } @@ -2259,23 +2442,31 @@ std::set> jit_sigmoid_emitter::get_supported_precisio jit_soft_sign_emitter::jit_soft_sign_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); } jit_soft_sign_emitter::jit_soft_sign_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_soft_sign_emitter::get_inputs_count() const { return 1; } +size_t jit_soft_sign_emitter::get_inputs_count() const { + return 1; +} -size_t jit_soft_sign_emitter::get_aux_vecs_count() const { return 2; } +size_t jit_soft_sign_emitter::get_aux_vecs_count() const { + return 2; +} -size_t jit_soft_sign_emitter::get_aux_gprs_count() const { return 1; } +size_t jit_soft_sign_emitter::get_aux_gprs_count() const { + return 1; +} -void jit_soft_sign_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_soft_sign_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2284,7 +2475,8 @@ void jit_soft_sign_emitter::emit_impl(const std::vector &in_vec_idxs, co } template -void jit_soft_sign_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_soft_sign_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (exec_prc_ != ov::element::f32) { OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string()); } @@ -2305,7 +2497,8 @@ void jit_soft_sign_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } -std::set> jit_soft_sign_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_soft_sign_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } @@ -2314,15 +2507,15 @@ jit_sqrt_emitter::jit_sqrt_emitter(dnnl::impl::cpu::aarch64::jit_generator* host dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { - prepare_table(); - } + prepare_table(); +} jit_sqrt_emitter::jit_sqrt_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { - prepare_table(); - } + prepare_table(); +} size_t jit_sqrt_emitter::get_inputs_count() const { return 1; @@ -2338,8 +2531,7 @@ void jit_sqrt_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_sqrt_emitter::emit_isa(const std::vector& in_vec_idxs, - const std::vector& out_vec_idxs) const { +void jit_sqrt_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2349,8 +2541,7 @@ void jit_sqrt_emitter::emit_isa(const std::vector& in_vec_idxs, h->fsqrt(dst.s, src.s); } -std::set> jit_sqrt_emitter::get_supported_precisions( - const std::shared_ptr& node) { +std::set> jit_sqrt_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32}}; } @@ -2358,17 +2549,19 @@ std::set> jit_sqrt_emitter::get_supported_precisions( jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { -} + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { -} + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_subtract_emitter::get_inputs_count() const { return 2; } +size_t jit_subtract_emitter::get_inputs_count() const { + return 2; +} -void jit_subtract_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_subtract_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2377,7 +2570,8 @@ void jit_subtract_emitter::emit_impl(const std::vector &in_vec_idxs, con } template -void jit_subtract_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_subtract_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2388,7 +2582,8 @@ void jit_subtract_emitter::emit_isa(const std::vector &in_vec_idxs, cons h->uni_fsub(dst.s, src0.s, src1.s); } -std::set> jit_subtract_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_subtract_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } @@ -2396,7 +2591,7 @@ std::set> jit_subtract_emitter::get_supported_precisi jit_swish_emitter::jit_swish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { const auto swish = std::dynamic_pointer_cast(node); if (swish == nullptr) { OV_CPU_JIT_EMITTER_THROW("Can't cast to SwishNode"); @@ -2411,12 +2606,15 @@ jit_swish_emitter::jit_swish_emitter(dnnl::impl::cpu::aarch64::jit_generator* ho dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const float beta, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc), beta(beta) { + : jit_emitter(host, host_isa, exec_prc), + beta(beta) { prepare_table(); sigmoid_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_swish_emitter::get_inputs_count() const {return 1; } +size_t jit_swish_emitter::get_inputs_count() const { + return 1; +} size_t jit_swish_emitter::get_aux_vecs_count() const { return sigmoid_emitter->get_aux_vecs_count() + 2; @@ -2426,7 +2624,8 @@ size_t jit_swish_emitter::get_aux_gprs_count() const { return sigmoid_emitter->get_aux_gprs_count() + 1; } -void jit_swish_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_swish_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2435,7 +2634,8 @@ void jit_swish_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_swish_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_swish_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2451,11 +2651,7 @@ void jit_swish_emitter::emit_isa(const std::vector &in_vec_idxs, const s h->fmul(vmm_aux.s, vmm_aux.s, vmm_src.s); // sigmoid(x*beta) - sigmoid_emitter->emit_code( - { vmm_aux.getIdx() }, - out_vec_idxs, - aux_vec_idxs, - aux_gpr_idxs); + sigmoid_emitter->emit_code({vmm_aux.getIdx()}, out_vec_idxs, aux_vec_idxs, aux_gpr_idxs); // x*sigmoid(x*beta) h->fmul(vmm_dst.s, vmm_dst.s, vmm_orig_src.s); @@ -2470,28 +2666,31 @@ void jit_swish_emitter::emit_data() const { sigmoid_emitter->emit_data(); } -std::set> jit_swish_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_swish_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } /// TANH /// -jit_tanh_emitter::jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_tanh_emitter::jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) - : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { prepare_table(); sigmoid_emitter = std::make_unique(h, host_isa, node); } -jit_tanh_emitter::jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, +jit_tanh_emitter::jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc) { + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); sigmoid_emitter = std::make_unique(h, host_isa, exec_prc); } -size_t jit_tanh_emitter::get_inputs_count() const { return 1; } +size_t jit_tanh_emitter::get_inputs_count() const { + return 1; +} size_t jit_tanh_emitter::get_aux_vecs_count() const { return sigmoid_emitter->get_aux_vecs_count() + 1; @@ -2501,7 +2700,8 @@ size_t jit_tanh_emitter::get_aux_gprs_count() const { return sigmoid_emitter->get_aux_gprs_count() + 1; } -void jit_tanh_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_tanh_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_vec_idxs, out_vec_idxs); } else { @@ -2510,7 +2710,7 @@ void jit_tanh_emitter::emit_impl(const std::vector &in_vec_idxs, const s } template -void jit_tanh_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_tanh_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -2522,11 +2722,7 @@ void jit_tanh_emitter::emit_isa(const std::vector &in_vec_idxs, const st h->ld1r(aux.s, table_val2("two")); h->uni_fmul(aux.s, src.s, aux.s); - sigmoid_emitter->emit_code( - { aux.getIdx() }, - out_vec_idxs, - aux_vec_idxs, - aux_gpr_idxs); + sigmoid_emitter->emit_code({aux.getIdx()}, out_vec_idxs, aux_vec_idxs, aux_gpr_idxs); h->ld1r(aux.s, table_val2("two")); h->uni_fmul(dst.s, aux.s, dst.s); @@ -2548,6 +2744,6 @@ std::set> jit_tanh_emitter::get_supported_precisions( return {{element::f32}}; } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp index be4e51cd0b759d..13567b6fbf7d64 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp @@ -12,55 +12,57 @@ namespace aarch64 { class jit_abs_emitter : public jit_emitter { public: - jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_add_emitter : public jit_emitter { public: - jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_clamp_emitter : public jit_emitter { public: - jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const float min, const float max, const ov::element::Type exec_prc = ov::element::f32); - jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -72,46 +74,48 @@ class jit_clamp_emitter : public jit_emitter { void register_table_entries() override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: float min; float max; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_divide_emitter : public jit_emitter { public: - jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc = ov::element::f32); + jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); - jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node); + jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_equal_emitter : public jit_emitter { public: - jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -122,13 +126,13 @@ class jit_equal_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; @@ -151,13 +155,14 @@ class jit_exp_emitter : public jit_emitter { void register_table_entries() override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_elu_emitter : public jit_emitter { @@ -181,45 +186,47 @@ class jit_elu_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr exp_emitter; float alpha; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_floor_emitter : public jit_emitter { public: - jit_floor_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_floor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_floor_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_floor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_floor_mod_emitter : public jit_emitter { public: - jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_floor_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -227,23 +234,24 @@ class jit_floor_mod_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_ceiling_emitter : public jit_emitter { public: // Constructor with explicit precision - jit_ceiling_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_ceiling_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); // Constructor from node - jit_ceiling_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_ceiling_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -256,13 +264,11 @@ class jit_ceiling_emitter : public jit_emitter { private: // Implementation of JIT code emission - void emit_impl(const std::vector &in_vec_idxs, - const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; // ISA-specific implementation template - void emit_isa(const std::vector &in_vec_idxs, - const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_gelu_erf_emitter : public jit_emitter { @@ -285,15 +291,16 @@ class jit_gelu_erf_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr exp_emitter; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_tanh_emitter; @@ -318,24 +325,25 @@ class jit_gelu_tanh_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr tanh_emitter; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_greater_emitter : public jit_emitter { public: - jit_greater_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_greater_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_greater_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_greater_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -346,24 +354,24 @@ class jit_greater_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_greater_equal_emitter : public jit_emitter { public: - jit_greater_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_greater_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_greater_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_greater_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -374,24 +382,24 @@ class jit_greater_equal_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_hswish_emitter : public jit_emitter { public: - jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -403,24 +411,25 @@ class jit_hswish_emitter : public jit_emitter { void register_table_entries() override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_is_finite_emitter : public jit_emitter { public: - jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc = ov::element::f32); + jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); - jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node); + jit_is_finite_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node); size_t get_inputs_count() const override; @@ -428,24 +437,25 @@ class jit_is_finite_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_is_nan_emitter : public jit_emitter { public: - jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -455,57 +465,60 @@ class jit_is_nan_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_maximum_emitter : public jit_emitter { public: - jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_minimum_emitter : public jit_emitter { public: - jit_minimum_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_minimum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_minimum_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_minimum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_mish_emitter : public jit_emitter { @@ -528,15 +541,16 @@ class jit_mish_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr exp_emitter; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_is_inf_emitter : public jit_emitter { @@ -574,11 +588,11 @@ class jit_is_inf_emitter : public jit_emitter { class jit_less_equal_emitter : public jit_emitter { public: - jit_less_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_less_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_less_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_less_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -589,24 +603,24 @@ class jit_less_equal_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_logical_and_emitter : public jit_emitter { public: - jit_logical_and_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_logical_and_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_logical_and_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_logical_and_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -617,26 +631,26 @@ class jit_logical_and_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_logical_or_emitter : public jit_emitter { public: - jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc = ov::element::f32); + jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); - jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& n); + jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_count() const override; @@ -645,13 +659,13 @@ class jit_logical_or_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; @@ -659,12 +673,12 @@ class jit_logical_or_emitter : public jit_emitter { class jit_logical_not_emitter : public jit_emitter { public: jit_logical_not_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc = ov::element::f32); + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); jit_logical_not_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node); + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node); size_t get_inputs_count() const override; @@ -686,11 +700,11 @@ class jit_logical_not_emitter : public jit_emitter { class jit_logical_xor_emitter : public jit_emitter { public: - jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -701,24 +715,24 @@ class jit_logical_xor_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_mod_emitter : public jit_emitter { public: - jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -726,13 +740,14 @@ class jit_mod_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_mul_add_emitter : public jit_emitter { @@ -749,49 +764,50 @@ class jit_mul_add_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_multiply_emitter : public jit_emitter { public: - jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_power_static_emitter : public jit_emitter { public: jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const float power, - const float scale, - const float shift, - const ov::element::Type exec_prc = ov::element::f32); + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const float power, + const float scale, + const float shift, + const ov::element::Type exec_prc = ov::element::f32); jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node, - const ov::element::Type exec_prc = ov::element::f32); + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node, + const ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_count() const override; @@ -801,16 +817,17 @@ class jit_power_static_emitter : public jit_emitter { void register_table_entries() override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: float power; float scale; float shift; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_prelu_emitter : public jit_emitter { @@ -827,13 +844,14 @@ class jit_prelu_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_relu_emitter : public jit_emitter { @@ -850,13 +868,14 @@ class jit_relu_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_round_half_away_from_zero_emitter : public jit_emitter { @@ -871,13 +890,14 @@ class jit_round_half_away_from_zero_emitter : public jit_emitter { size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_round_half_to_even_emitter : public jit_emitter { @@ -892,22 +912,23 @@ class jit_round_half_to_even_emitter : public jit_emitter { size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_select_emitter : public jit_emitter { public: - jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc = ov::element::f32); - jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n); @@ -916,13 +937,13 @@ class jit_select_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; static std::set> get_supported_precisions( - const std::shared_ptr& node = nullptr); + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_sigmoid_emitter : public jit_emitter { @@ -945,15 +966,16 @@ class jit_sigmoid_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr exp_emitter; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_soft_sign_emitter : public jit_emitter { @@ -974,13 +996,14 @@ class jit_soft_sign_emitter : public jit_emitter { void register_table_entries() override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_sqrt_emitter : public jit_emitter { @@ -1007,23 +1030,24 @@ class jit_sqrt_emitter : public jit_emitter { class jit_subtract_emitter : public jit_emitter { public: - jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const ov::element::Type exec_prc = ov::element::f32); + jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); - jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, - dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - const std::shared_ptr& node); + jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node); size_t get_inputs_count() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_swish_emitter : public jit_emitter { @@ -1047,25 +1071,26 @@ class jit_swish_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr sigmoid_emitter; float beta; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_tanh_emitter : public jit_emitter { public: - jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node); @@ -1079,17 +1104,18 @@ class jit_tanh_emitter : public jit_emitter { void emit_data() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: std::unique_ptr sigmoid_emitter; - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp index bb7783b60a1b53..4c0b0f95f783c2 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp @@ -3,9 +3,11 @@ // #include "jit_emitter.hpp" + #include -#include "utils/general_utils.h" + #include "emitters/utils.hpp" +#include "utils/general_utils.h" using namespace dnnl::impl::cpu; using namespace dnnl::impl; @@ -16,26 +18,37 @@ namespace aarch64 { const std::vector jit_emitter::store_gpr_regs = { // Parameter/result registers - 0, 1, 2, 3, 4, 5, 6, 7, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, // r8: Indirect result location register // r9...r15: Temporary registers - 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, // r19...r28: Callee-saved registers - 29, 30 -}; - -static const std::vector vec_regs = { - 0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, - 24, 25, 26, 27, 28, 29, 30, 31 -}; - -void jit_emitter::emit_code(const std::vector &in_idxs, - const std::vector &out_idxs, - const std::vector &pool_vec_idxs, - const std::vector &pool_gpr_idxs) const { + 29, + 30}; + +static const std::vector vec_regs = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + +void jit_emitter::emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs); emit_impl(in_idxs, out_idxs); @@ -52,7 +65,7 @@ void jit_emitter::emit_data() const { // Run through the map and insert values stored there for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) { - const auto &te = (*it).second; // get map entry for a given key + const auto& te = (*it).second; // get map entry for a given key const auto len = te.bcast ? get_vec_length() : sizeof(table_entry_val_t); for (size_t d = 0; d < len; d += sizeof(table_entry_val_t)) h->dd(te.val); @@ -88,7 +101,7 @@ void jit_emitter::prepare_table() { // prepare_table. size_t off = 0; for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) { - auto &te = (*it).second; + auto& te = (*it).second; te.off = off; off += te.bcast ? get_vec_length() : sizeof(table_entry_val_t); } @@ -99,10 +112,10 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, const std::vector& pool_aux_vec_idxs, const std::vector& pool_aux_gpr_idxs) const { using namespace Xbyak_aarch64::util; - const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) || - (in_out_type_ == emitter_in_out_map::vec_to_gpr); - const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) || - (in_out_type_ == emitter_in_out_map::gpr_to_vec); + const bool is_vec_input = + (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::vec_to_gpr); + const bool is_vec_output = + (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::gpr_to_vec); // vector registers for (auto idx : pool_aux_vec_idxs) { @@ -110,20 +123,27 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, } for (size_t idx = 0; idx < get_max_vecs_count(); idx++) { - if (aux_vec_idxs.size() >= get_aux_vecs_count()) break; + if (aux_vec_idxs.size() >= get_aux_vecs_count()) + break; if (is_vec_input) { - if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) + continue; } if (is_vec_output) { - if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; + if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) + continue; } - if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; - if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) + continue; + if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) + continue; - if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) continue; - if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) continue; + if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) + continue; + if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) + continue; aux_vec_idxs.push_back(idx); preserved_vec_idxs.push_back(idx); @@ -138,23 +158,27 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30; for (size_t gpr_idx = 0; gpr_idx <= end_gpr_idx; ++gpr_idx) { - size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end + size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end - if (aux_gpr_idxs.size() >= get_aux_gprs_count()) break; - if ((_idx == Xbyak_aarch64::Operand::X18) || - (_idx == Xbyak_aarch64::Operand::X23) || - (_idx == Xbyak_aarch64::Operand::X24) || - (_idx == Xbyak_aarch64::Operand::X28)) continue; + if (aux_gpr_idxs.size() >= get_aux_gprs_count()) + break; + if ((_idx == Xbyak_aarch64::Operand::X18) || (_idx == Xbyak_aarch64::Operand::X23) || + (_idx == Xbyak_aarch64::Operand::X24) || (_idx == Xbyak_aarch64::Operand::X28)) + continue; if (!is_vec_input) { - if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue; + if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) + continue; } if (!is_vec_output) { - if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue; + if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) + continue; } - if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue; - if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue; + if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) + continue; + if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) + continue; aux_gpr_idxs.push_back(_idx); preserved_gpr_idxs.push_back(_idx); @@ -189,23 +213,21 @@ void jit_emitter::store_context(const std::unordered_set& ignore_registe store_context(store_gpr_regs, vec_regs, ignore_registers); } -void jit_emitter::store_context( - const std::vector& gpr_regs, - const std::vector& vec_regs, - const std::unordered_set& ignore_vec_regs) const { +void jit_emitter::store_context(const std::vector& gpr_regs, + const std::vector& vec_regs, + const std::unordered_set& ignore_vec_regs) const { // 1. General-purpose Registers // 1.1. store pair registers const auto store_gpr_regs_size = gpr_regs.size(); const auto last = store_gpr_regs_size % 2; for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) { - h->stp(Xbyak_aarch64::XReg(gpr_regs[i]), - Xbyak_aarch64::XReg(gpr_regs[i + 1]), - pre_ptr(h->sp, -get_gpr_length() * 2)); + h->stp(Xbyak_aarch64::XReg(gpr_regs[i]), + Xbyak_aarch64::XReg(gpr_regs[i + 1]), + pre_ptr(h->sp, -get_gpr_length() * 2)); } // 1.2. store the remaining register if (last != 0) { - h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]), - pre_ptr(h->sp, -get_gpr_length())); + h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]), pre_ptr(h->sp, -get_gpr_length())); } // 2. SIMD and Floating-Point registers @@ -221,17 +243,14 @@ void jit_emitter::store_context( prev_reg_idx = static_cast(reg_idx); continue; } - h->stp(Xbyak_aarch64::QReg(prev_reg_idx), - Xbyak_aarch64::QReg(reg_idx), - pre_ptr(h->sp, -get_vec_length() * 2)); + h->stp(Xbyak_aarch64::QReg(prev_reg_idx), Xbyak_aarch64::QReg(reg_idx), pre_ptr(h->sp, -get_vec_length() * 2)); prev_reg_idx = -1; } // 2.1. store the remaining register if (prev_reg_idx != -1) { if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) { - h->str(Xbyak_aarch64::QReg(prev_reg_idx), - pre_ptr(h->sp, -get_vec_length())); + h->str(Xbyak_aarch64::QReg(prev_reg_idx), pre_ptr(h->sp, -get_vec_length())); } else { ignore_registers_count++; } @@ -245,10 +264,9 @@ void jit_emitter::restore_context(const std::unordered_set& ignore_vec_r restore_context(store_gpr_regs, vec_regs, ignore_vec_regs); } -void jit_emitter::restore_context( - const std::vector& gpr_regs, - const std::vector& vec_regs, - const std::unordered_set& ignore_vec_regs) const { +void jit_emitter::restore_context(const std::vector& gpr_regs, + const std::vector& vec_regs, + const std::unordered_set& ignore_vec_regs) const { // 1. SIMD and Floating-Point registers // 1.1. restore the remaining register auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2; @@ -260,8 +278,7 @@ void jit_emitter::restore_context( continue; } - h->ldr(Xbyak_aarch64::QReg(reg_idx), - post_ptr(h->sp, get_vec_length())); + h->ldr(Xbyak_aarch64::QReg(reg_idx), post_ptr(h->sp, get_vec_length())); break; } } @@ -278,9 +295,7 @@ void jit_emitter::restore_context( prev_reg_idx = static_cast(reg_idx); continue; } - h->ldp(Xbyak_aarch64::QReg(reg_idx), - Xbyak_aarch64::QReg(prev_reg_idx), - post_ptr(h->sp, get_vec_length() * 2)); + h->ldp(Xbyak_aarch64::QReg(reg_idx), Xbyak_aarch64::QReg(prev_reg_idx), post_ptr(h->sp, get_vec_length() * 2)); prev_reg_idx = -1; } @@ -292,8 +307,7 @@ void jit_emitter::restore_context( const auto save_gpr_regs_size = gpr_regs.size(); const auto last = save_gpr_regs_size % 2; if (last != 0) { - h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]), - post_ptr(h->sp, get_gpr_length())); + h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]), post_ptr(h->sp, get_gpr_length())); } // 2.2. restore pair registers @@ -304,6 +318,6 @@ void jit_emitter::restore_context( } } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp index ba9b8c2d9cbdf1..9ce8203afe7783 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp @@ -4,14 +4,13 @@ #pragma once +#include #include #include -#include -#include "snippets/snippets_isa.hpp" -#include "snippets/generator.hpp" #include "node.h" - +#include "snippets/generator.hpp" +#include "snippets/snippets_isa.hpp" namespace ov { namespace intel_cpu { @@ -29,25 +28,32 @@ class jit_emitter : public ov::snippets::Emitter { jit_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32, - emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) : - Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc), - in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) { - } + emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) + : Emitter(), + h(host), + host_isa_(host_isa), + exec_prc_(exec_prc), + in_out_type_(in_out_type), + p_table(0), + l_table(new Xbyak_aarch64::Label()) {} jit_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32, - emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) : - Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc), - in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) { - } - - void emit_code( - const std::vector &in_idxs, - const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, - const std::vector &pool_gpr_idxs = {}) const override; + emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) + : Emitter(), + h(host), + host_isa_(host_isa), + exec_prc_(exec_prc), + in_out_type_(in_out_type), + p_table(0), + l_table(new Xbyak_aarch64::Label()) {} + + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; void emit_data() const override; @@ -60,7 +66,8 @@ class jit_emitter : public ov::snippets::Emitter { * Precisions are ordered, the first bigger bitness precision with the same type will be selected. * Empty collection means the emitter supports any input precisions. */ - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); protected: size_t get_max_vecs_count() const; @@ -78,12 +85,14 @@ class jit_emitter : public ov::snippets::Emitter { virtual void prepare_table(); virtual void register_table_entries() {} - void load_table_addr() const { h->adr(p_table, *l_table.get()); } + void load_table_addr() const { + h->adr(p_table, *l_table.get()); + } // we accept only 32bit hexadecimal table values to avoid any rounding using table_entry_val_t = uint32_t; - using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table - using table_entry_bcast_t = bool; // true => bcast value + using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table + using table_entry_bcast_t = bool; // true => bcast value struct table_entry_t { table_entry_val_t val; @@ -98,7 +107,7 @@ class jit_emitter : public ov::snippets::Emitter { mutable Xbyak_aarch64::XReg p_table; mutable std::shared_ptr l_table; - virtual void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const = 0; + virtual void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const = 0; virtual void emitter_preamble(const std::vector& in_idxs, const std::vector& out_idxs, @@ -128,14 +137,14 @@ class jit_emitter : public ov::snippets::Emitter { } void push_arg_entry_of(const std::string key, const table_entry_val_t val, const bool broadcast) { - mapped_table_entry_t te {0, val, broadcast}; + mapped_table_entry_t te{0, val, broadcast}; entry_map_.insert(std::make_pair(key, te)); } - void push_entries_of(const table_t &t) { + void push_entries_of(const table_t& t) { for (auto it = t.begin(); it != t.end(); it++) { auto key = (*it).first; - auto te = (*it).second; // copy values from table + auto te = (*it).second; // copy values from table push_arg_entry_of(key, te.val, te.bcast); } } @@ -150,9 +159,9 @@ class jit_emitter : public ov::snippets::Emitter { size_t table_off(const std::string& key, const size_t key_off_val_shift = 0) const { // assumption: all table entries sharing the same key also // share their broadcast property - const auto it = entry_map_.find(key); // search an entry for a key + const auto it = entry_map_.find(key); // search an entry for a key assert(it != entry_map_.end()); - const auto &te = (*it).second; + const auto& te = (*it).second; const auto scale = te.bcast ? get_vec_length() : sizeof(table_entry_val_t); return te.off + key_off_val_shift * scale; } @@ -176,6 +185,6 @@ class jit_emitter : public ov::snippets::Emitter { const std::unordered_set& ignore_vec_regs = {}) const; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index ca18bc5d4b575d..3ca77bdac53baf 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_load_store_emitters.hpp" + #include "cpu/aarch64/cpu_isa_traits.hpp" #include "emitters/utils.hpp" @@ -15,14 +16,23 @@ namespace aarch64 { using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; -jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, int byte_offset, - ov::element::Type exec_prc, emitter_in_out_map in_out_type) -: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset), prc_(src_prc) { +jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int load_num, + int byte_offset, + ov::element::Type exec_prc, + emitter_in_out_map in_out_type) + : jit_emitter(host, host_isa, exec_prc, in_out_type), + name_("unknown"), + load_num_(load_num), + byte_offset_(byte_offset), + prc_(src_prc) { OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Unsupported precision pair."); } -void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_load_emitter::emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_idxs, out_idxs); } else { @@ -31,7 +41,7 @@ void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std:: } template -void jit_load_emitter::load_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_load_emitter::load_qbyte(const std::vector& in_idxs, const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); @@ -39,31 +49,31 @@ void jit_load_emitter::load_qbyte(const std::vector &in_idxs, const std: DReg dst_d = DReg(out_idxs[0]); switch (load_num_) { - case 0: - break; - case 1: - h->ldr(dst_s, ptr(src, byte_offset_)); - break; - case 2: - h->ldr(dst_d, ptr(src, byte_offset_)); - break; - case 3: { - XReg prc = XReg(aux_gpr_idxs[0]); - h->ldr(dst_d, ptr(src, byte_offset_)); - h->add_imm(prc, src, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); - h->ld1(dst.s[2], ptr(prc)); - break; - } - case 4: - h->uni_ldr(dst, src, byte_offset_); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); + case 0: + break; + case 1: + h->ldr(dst_s, ptr(src, byte_offset_)); + break; + case 2: + h->ldr(dst_d, ptr(src, byte_offset_)); + break; + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); + h->ldr(dst_d, ptr(src, byte_offset_)); + h->add_imm(prc, src, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); + h->ld1(dst.s[2], ptr(prc)); + break; + } + case 4: + h->uni_ldr(dst, src, byte_offset_); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); } } template -void jit_load_emitter::load_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_load_emitter::load_dbyte(const std::vector& in_idxs, const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); @@ -72,31 +82,31 @@ void jit_load_emitter::load_dbyte(const std::vector &in_idxs, const std: DReg dst_d = DReg(out_idxs[0]); switch (load_num_) { - case 0: - break; - case 1: - h->ldr(dst_h, ptr(src, byte_offset_)); - break; - case 2: - h->ldr(dst_s, ptr(src, byte_offset_)); - break; - case 3: { - XReg prc = XReg(aux_gpr_idxs[0]); - h->ldr(dst_s, ptr(src, byte_offset_)); - h->add_imm(prc, src, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); - h->ld1(dst.h[2], ptr(prc)); - break; - } - case 4: - h->ldr(dst_d, ptr(src, byte_offset_)); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); + case 0: + break; + case 1: + h->ldr(dst_h, ptr(src, byte_offset_)); + break; + case 2: + h->ldr(dst_s, ptr(src, byte_offset_)); + break; + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); + h->ldr(dst_s, ptr(src, byte_offset_)); + h->add_imm(prc, src, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); + h->ld1(dst.h[2], ptr(prc)); + break; + } + case 4: + h->ldr(dst_d, ptr(src, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); } } template -void jit_load_emitter::load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_load_emitter::load_byte(const std::vector& in_idxs, const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); @@ -105,49 +115,50 @@ void jit_load_emitter::load_byte(const std::vector &in_idxs, const std:: SReg dst_s = SReg(out_idxs[0]); switch (load_num_) { - case 0: - break; - case 1: - h->ldr(dst_b, ptr(src, byte_offset_)); - break; - case 2: - h->ldr(dst_h, ptr(src, byte_offset_)); - break; - case 3: { - XReg prc = XReg(aux_gpr_idxs[0]); - h->ldr(dst_h, ptr(src, byte_offset_)); - h->add_imm(prc, src, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); - h->ld1(dst.b[2], ptr(prc)); - break; - } - case 4: - h->ldr(dst_s, ptr(src, byte_offset_)); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); + case 0: + break; + case 1: + h->ldr(dst_b, ptr(src, byte_offset_)); + break; + case 2: + h->ldr(dst_h, ptr(src, byte_offset_)); + break; + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); + h->ldr(dst_h, ptr(src, byte_offset_)); + h->add_imm(prc, src, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); + h->ld1(dst.b[2], ptr(prc)); + break; + } + case 4: + h->ldr(dst_s, ptr(src, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); } } template -void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { - OV_CPU_JIT_EMITTER_ASSERT(one_of(prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported precision."); +void jit_load_emitter::emit_isa(const std::vector& in_idxs, const std::vector& out_idxs) const { + OV_CPU_JIT_EMITTER_ASSERT( + one_of(prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported precision."); OV_CPU_JIT_EMITTER_ASSERT(load_num_ <= 4, "Unexpected number of elements to load."); switch (prc_) { - case ov::element::f32: - case ov::element::i32: - load_qbyte(in_idxs, out_idxs); - break; - case ov::element::f16: - load_dbyte(in_idxs, out_idxs); - break; - case ov::element::i8: - case ov::element::u8: - load_byte(in_idxs, out_idxs); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported precision: ", prc_.get_type_name()); + case ov::element::f32: + case ov::element::i32: + load_qbyte(in_idxs, out_idxs); + break; + case ov::element::f16: + load_dbyte(in_idxs, out_idxs); + break; + case ov::element::i8: + case ov::element::u8: + load_byte(in_idxs, out_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision: ", prc_.get_type_name()); } } @@ -158,14 +169,24 @@ size_t jit_load_emitter::get_aux_gprs_count() const { return 0; } -jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, - arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type) - : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), prc_(dst_prc) { +jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int store_num, + int byte_offset, + arithmetic_mode mode, + ov::element::Type exec_prc, + emitter_in_out_map in_out_type) + : jit_emitter(host, host_isa, exec_prc, in_out_type), + name_("unknown"), + store_num_(store_num), + byte_offset_(byte_offset), + prc_(dst_prc) { OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Unsupported precision pair."); } -void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_store_emitter::emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in_idxs, out_idxs); } else { @@ -174,7 +195,7 @@ void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std: } template -void jit_store_emitter::store_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_store_emitter::store_qbyte(const std::vector& in_idxs, const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); SReg src_s = SReg(in_idxs[0]); @@ -183,31 +204,31 @@ void jit_store_emitter::store_qbyte(const std::vector &in_idxs, const st XReg dst = XReg(out_idxs[0]); switch (store_num_) { - case 0: - break; - case 1: - h->str(src_s, ptr(dst, byte_offset_)); - break; - case 2: - h->str(src_d, ptr(dst, byte_offset_)); - break; - case 3: { - XReg prc = XReg(aux_gpr_idxs[0]); - h->str(src_d, ptr(dst, byte_offset_)); - h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); - h->st1(src.s[2], ptr(prc)); - break; - } - case 4: - h->str(src_q, ptr(dst, byte_offset_)); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); + case 0: + break; + case 1: + h->str(src_s, ptr(dst, byte_offset_)); + break; + case 2: + h->str(src_d, ptr(dst, byte_offset_)); + break; + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); + h->str(src_d, ptr(dst, byte_offset_)); + h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); + h->st1(src.s[2], ptr(prc)); + break; + } + case 4: + h->str(src_q, ptr(dst, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); } } template -void jit_store_emitter::store_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_store_emitter::store_dbyte(const std::vector& in_idxs, const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); HReg src_h = HReg(in_idxs[0]); @@ -216,31 +237,31 @@ void jit_store_emitter::store_dbyte(const std::vector &in_idxs, const st XReg dst = XReg(out_idxs[0]); switch (store_num_) { - case 0: - break; - case 1: - h->str(src_h, ptr(dst, byte_offset_)); - break; - case 2: - h->str(src_s, ptr(dst, byte_offset_)); - break; - case 3: { - XReg prc = XReg(aux_gpr_idxs[0]); - h->str(src_s, ptr(dst, byte_offset_)); - h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); - h->st1(src.h[2], ptr(prc)); - break; - } - case 4: - h->str(src_d, ptr(dst, byte_offset_)); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); + case 0: + break; + case 1: + h->str(src_h, ptr(dst, byte_offset_)); + break; + case 2: + h->str(src_s, ptr(dst, byte_offset_)); + break; + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); + h->str(src_s, ptr(dst, byte_offset_)); + h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); + h->st1(src.h[2], ptr(prc)); + break; + } + case 4: + h->str(src_d, ptr(dst, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); } } template -void jit_store_emitter::store_byte(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_store_emitter::store_byte(const std::vector& in_idxs, const std::vector& out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); BReg src_b = BReg(in_idxs[0]); @@ -249,49 +270,50 @@ void jit_store_emitter::store_byte(const std::vector &in_idxs, const std XReg dst = XReg(out_idxs[0]); switch (store_num_) { - case 0: - break; - case 1: - h->str(src_b, ptr(dst, byte_offset_)); - break; - case 2: - h->str(src_h, ptr(dst, byte_offset_)); - break; - case 3: { - XReg prc = XReg(aux_gpr_idxs[0]); - h->str(src_h, ptr(dst, byte_offset_)); - h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); - h->st1(src.b[2], ptr(prc)); - break; - } - case 4: - h->str(src_s, ptr(dst, byte_offset_)); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); + case 0: + break; + case 1: + h->str(src_b, ptr(dst, byte_offset_)); + break; + case 2: + h->str(src_h, ptr(dst, byte_offset_)); + break; + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); + h->str(src_h, ptr(dst, byte_offset_)); + h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); + h->st1(src.b[2], ptr(prc)); + break; + } + case 4: + h->str(src_s, ptr(dst, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); } } template -void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { - OV_CPU_JIT_EMITTER_ASSERT(one_of(prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported precision."); +void jit_store_emitter::emit_isa(const std::vector& in_idxs, const std::vector& out_idxs) const { + OV_CPU_JIT_EMITTER_ASSERT( + one_of(prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported precision."); OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= 4, "Unexpected number of elements to store."); switch (prc_) { - case ov::element::f32: - case ov::element::i32: - store_qbyte(in_idxs, out_idxs); - break; - case ov::element::f16: - store_dbyte(in_idxs, out_idxs); - break; - case ov::element::i8: - case ov::element::u8: - store_byte(in_idxs, out_idxs); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported precision: ", prc_.get_type_name()); + case ov::element::f32: + case ov::element::i32: + store_qbyte(in_idxs, out_idxs); + break; + case ov::element::f16: + store_dbyte(in_idxs, out_idxs); + break; + case ov::element::i8: + case ov::element::u8: + store_byte(in_idxs, out_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision: ", prc_.get_type_name()); } } @@ -302,6 +324,6 @@ size_t jit_store_emitter::get_aux_gprs_count() const { return 0; } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 887522ed1055e1..8c0983189f3083 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -4,38 +4,41 @@ #pragma once -#include "jit_emitter.hpp" #include "cpu/aarch64/jit_generator.hpp" +#include "jit_emitter.hpp" namespace ov { namespace intel_cpu { namespace aarch64 { // Arithmetic modes for data type conversion in store_emitter -enum class arithmetic_mode { - saturation, - truncation -}; +enum class arithmetic_mode { saturation, truncation }; class jit_load_emitter : public jit_emitter { public: - jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, int byte_offset, + jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int load_num, + int byte_offset, ov::element::Type exec_prc = ov::element::f32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec); - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; - size_t get_inputs_count() const override { return 1; }; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; + size_t get_inputs_count() const override { + return 1; + }; private: template - void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; + void emit_isa(const std::vector& in_idxs, const std::vector& out_idxs) const; template - void load_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + void load_qbyte(const std::vector& in_idxs, const std::vector& out_idxs) const; template - void load_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + void load_dbyte(const std::vector& in_idxs, const std::vector& out_idxs) const; template - void load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; + void load_byte(const std::vector& in_idxs, const std::vector& out_idxs) const; size_t get_aux_gprs_count() const override; std::string name_; @@ -46,23 +49,30 @@ class jit_load_emitter : public jit_emitter { class jit_store_emitter : public jit_emitter { public: - jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset_, - arithmetic_mode mode = arithmetic_mode::saturation, ov::element::Type exec_prc = ov::element::f32, + jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int store_num, + int byte_offset_, + arithmetic_mode mode = arithmetic_mode::saturation, + ov::element::Type exec_prc = ov::element::f32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; - size_t get_inputs_count() const override { return 1; } + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; + size_t get_inputs_count() const override { + return 1; + } private: template - void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; + void emit_isa(const std::vector& in_idxs, const std::vector& out_idxs) const; template - void store_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + void store_qbyte(const std::vector& in_idxs, const std::vector& out_idxs) const; template - void store_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + void store_dbyte(const std::vector& in_idxs, const std::vector& out_idxs) const; template - void store_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; + void store_byte(const std::vector& in_idxs, const std::vector& out_idxs) const; size_t get_aux_gprs_count() const override; std::string name_; @@ -71,6 +81,6 @@ class jit_store_emitter : public jit_emitter { ov::element::Type prc_; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp index a56c2316183643..95698f8ac78bb0 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp @@ -2,74 +2,76 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "snippets/snippets_isa.hpp" #include "cpu_generator.hpp" -#include "jit_snippets_emitters.hpp" -#include "emitters/utils.hpp" -#include "emitters/snippets/cpu_runtime_configurator.hpp" -#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp" + #include "emitters/plugin/aarch64/jit_conversion_emitters.hpp" +#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp" +#include "emitters/snippets/aarch64/jit_fill_emitter.hpp" #include "emitters/snippets/aarch64/jit_kernel_emitter.hpp" #include "emitters/snippets/aarch64/jit_loop_emitters.hpp" #include "emitters/snippets/aarch64/jit_memory_emitters.hpp" -#include "emitters/snippets/aarch64/jit_fill_emitter.hpp" - -#include "transformations/snippets/common/op/fused_mul_add.hpp" -#include "transformations/cpu_opset/common/op/swish_cpu.hpp" - +#include "emitters/snippets/cpu_runtime_configurator.hpp" +#include "emitters/utils.hpp" +#include "jit_snippets_emitters.hpp" #include "openvino/opsets/opset13.hpp" +#include "snippets/snippets_isa.hpp" +#include "transformations/cpu_opset/common/op/swish_cpu.hpp" +#include "transformations/snippets/common/op/fused_mul_add.hpp" namespace ov { -#define CREATE_SNIPPETS_EMITTER(e_type) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return std::make_shared(h.get(), isa, expr); \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return e_type::get_supported_precisions(n); \ - } \ -} +#define CREATE_SNIPPETS_EMITTER(e_type) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + return std::make_shared(h.get(), isa, expr); \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ + } -#define CREATE_CPU_EMITTER(e_type) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return std::make_shared(h.get(), isa, expr->get_node()); \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return e_type::get_supported_precisions(n); \ - } \ -} +#define CREATE_CPU_EMITTER(e_type) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + return std::make_shared(h.get(), isa, expr->get_node()); \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ + } -#define CREATE_GELU_V7_EMITTER(e_type_erf, e_type_tanh) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - const auto& n = expr->get_node(); \ - const auto& gelu = std::dynamic_pointer_cast(n); \ - if (gelu == nullptr) { \ - OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \ - } \ - const auto approximationMode = gelu->get_approximation_mode(); \ - if (approximationMode == ov::op::GeluApproximationMode::ERF) { \ - return std::make_shared(h.get(), isa, n); \ - } else if (approximationMode == ov::op::GeluApproximationMode::TANH) { \ - return std::make_shared(h.get(), isa, n); \ - } else { \ - OPENVINO_THROW("Unsupported Gelu approximation mode"); \ - } \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - const auto& gelu = std::dynamic_pointer_cast(n); \ - if (gelu == nullptr) { \ - OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \ - } \ - const auto approximationMode = gelu->get_approximation_mode(); \ - if (approximationMode == ov::op::GeluApproximationMode::ERF) { \ - return e_type_erf::get_supported_precisions(n); \ - } else if (approximationMode == ov::op::GeluApproximationMode::TANH) { \ - return e_type_tanh::get_supported_precisions(n); \ - } else { \ - OPENVINO_THROW("Unsupported Gelu approximation mode"); \ - } \ - } \ -} +#define CREATE_GELU_V7_EMITTER(e_type_erf, e_type_tanh) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + const auto& n = expr->get_node(); \ + const auto& gelu = std::dynamic_pointer_cast(n); \ + if (gelu == nullptr) { \ + OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \ + } \ + const auto approximationMode = gelu->get_approximation_mode(); \ + if (approximationMode == ov::op::GeluApproximationMode::ERF) { \ + return std::make_shared(h.get(), isa, n); \ + } else if (approximationMode == ov::op::GeluApproximationMode::TANH) { \ + return std::make_shared(h.get(), isa, n); \ + } else { \ + OPENVINO_THROW("Unsupported Gelu approximation mode"); \ + } \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + const auto& gelu = std::dynamic_pointer_cast(n); \ + if (gelu == nullptr) { \ + OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \ + } \ + const auto approximationMode = gelu->get_approximation_mode(); \ + if (approximationMode == ov::op::GeluApproximationMode::ERF) { \ + return e_type_erf::get_supported_precisions(n); \ + } else if (approximationMode == ov::op::GeluApproximationMode::TANH) { \ + return e_type_tanh::get_supported_precisions(n); \ + } else { \ + OPENVINO_THROW("Unsupported Gelu approximation mode"); \ + } \ + } \ + } class jit_snippet : public dnnl::impl::cpu::aarch64::jit_generator { public: @@ -85,7 +87,8 @@ class jit_snippet : public dnnl::impl::cpu::aarch64::jit_generator { namespace intel_cpu { namespace aarch64 { -CompiledSnippetCPU::CompiledSnippetCPU(std::unique_ptr h) : h_compiled(std::move(h)) { +CompiledSnippetCPU::CompiledSnippetCPU(std::unique_ptr h) + : h_compiled(std::move(h)) { OPENVINO_ASSERT(h_compiled && h_compiled->jit_ker(), "Got invalid jit generator or kernel was nopt compiled"); } @@ -102,15 +105,19 @@ bool CompiledSnippetCPU::empty() const { } CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa) - : TargetMachine(std::make_shared()), h(new jit_snippet()), isa(host_isa) { + : TargetMachine(std::make_shared()), + h(new jit_snippet()), + isa(host_isa) { // data movement jitters[op::v0::Parameter::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_nop_emitter); jitters[op::v0::Result::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_nop_emitter); jitters[snippets::op::VectorBuffer::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_nop_emitter); jitters[snippets::op::RankNormalization::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_nop_emitter); jitters[snippets::op::BroadcastMove::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_broadcast_move_emitter); - jitters[snippets::op::ConvertTruncation::get_type_info_static()] = CREATE_CPU_EMITTER(jit_convert_truncation_emitter); - jitters[snippets::op::ConvertSaturation::get_type_info_static()] = CREATE_CPU_EMITTER(jit_convert_saturation_emitter); + jitters[snippets::op::ConvertTruncation::get_type_info_static()] = + CREATE_CPU_EMITTER(jit_convert_truncation_emitter); + jitters[snippets::op::ConvertSaturation::get_type_info_static()] = + CREATE_CPU_EMITTER(jit_convert_saturation_emitter); // memory access jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter); @@ -136,7 +143,8 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa) jitters[ov::op::v0::Exp::get_type_info_static()] = CREATE_CPU_EMITTER(jit_exp_emitter); jitters[ov::op::v0::Floor::get_type_info_static()] = CREATE_CPU_EMITTER(jit_floor_emitter); jitters[ov::op::v0::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_gelu_erf_emitter); - jitters[ov::op::v7::Gelu::get_type_info_static()] = CREATE_GELU_V7_EMITTER(jit_gelu_erf_emitter, jit_gelu_tanh_emitter); + jitters[ov::op::v7::Gelu::get_type_info_static()] = + CREATE_GELU_V7_EMITTER(jit_gelu_erf_emitter, jit_gelu_tanh_emitter); jitters[ov::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_hswish_emitter); jitters[ov::op::v4::Mish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mish_emitter); jitters[ov::op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_relu_emitter); @@ -168,7 +176,8 @@ bool CPUTargetMachine::is_supported() const { snippets::CompiledSnippetPtr CPUTargetMachine::get_snippet() { OPENVINO_ASSERT(h->create_kernel() == dnnl::impl::status::success, "Failed to create jit_kernel in get_snippet()"); - const auto& result = std::make_shared(std::unique_ptr(h.release())); + const auto& result = + std::make_shared(std::unique_ptr(h.release())); // Note that we reset all the generated code, since it was copied into CompiledSnippetCPU h.reset(new jit_snippet()); return result; @@ -176,8 +185,10 @@ snippets::CompiledSnippetPtr CPUTargetMachine::get_snippet() { size_t CPUTargetMachine::get_lanes() const { switch (isa) { - case dnnl::impl::cpu::aarch64::asimd : return dnnl::impl::cpu::aarch64::cpu_isa_traits::vlen / sizeof(float); - default : OPENVINO_THROW("unknown isa ", isa); + case dnnl::impl::cpu::aarch64::asimd: + return dnnl::impl::cpu::aarch64::cpu_isa_traits::vlen / sizeof(float); + default: + OPENVINO_THROW("unknown isa ", isa); } } @@ -190,18 +201,19 @@ dnnl::impl::cpu::aarch64::cpu_isa_t CPUTargetMachine::get_isa() const { return isa; } -CPUGenerator::CPUGenerator(dnnl::impl::cpu::aarch64::cpu_isa_t isa_) : Generator(std::make_shared(isa_)) {} +CPUGenerator::CPUGenerator(dnnl::impl::cpu::aarch64::cpu_isa_t isa_) + : Generator(std::make_shared(isa_)) {} std::shared_ptr CPUGenerator::clone() const { const auto& cpu_target_machine = std::dynamic_pointer_cast(target); - OPENVINO_ASSERT(cpu_target_machine, "Failed to clone CPUGenerator: the instance contains incompatible TargetMachine type"); + OPENVINO_ASSERT(cpu_target_machine, + "Failed to clone CPUGenerator: the instance contains incompatible TargetMachine type"); return std::make_shared(cpu_target_machine->get_isa()); } ov::snippets::RegType CPUGenerator::get_specific_op_out_reg_type(const ov::Output& out) const { const auto op = out.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) + if (std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op)) return ov::snippets::RegType::vec; else return ov::snippets::RegType::undefined; @@ -211,6 +223,6 @@ bool CPUGenerator::uses_precompiled_kernel(const std::shared_ptr& out) const override; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.cpp index 2b6056e92644a3..053cebe747e529 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.cpp @@ -3,6 +3,7 @@ // #include "jit_fill_emitter.hpp" + #include "cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_adr.h" #include "emitters/utils.hpp" @@ -21,7 +22,8 @@ jit_fill_emitter::jit_fill_emitter(jit_generator* h, cpu_isa_t isa, const Expres const auto fill = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(fill != nullptr, "Expects Fill expression"); OV_CPU_JIT_EMITTER_ASSERT(fill->get_element_type().size() == 4, - "Supports only 4 Byte element types but gets ", fill->get_element_type()); + "Supports only 4 Byte element types but gets ", + fill->get_element_type()); offset = fill->get_offset(); fill_value = fill->get_fill_value(); @@ -38,8 +40,7 @@ size_t jit_fill_emitter::get_aux_gprs_count() const { return 1; } -void jit_fill_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_fill_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in, out); } else { @@ -48,7 +49,7 @@ void jit_fill_emitter::emit_impl(const std::vector& in, } template -void jit_fill_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_fill_emitter::emit_isa(const std::vector& in, const std::vector& out) const { if (is_full_reg()) fill_full(out); else @@ -56,7 +57,7 @@ void jit_fill_emitter::emit_isa(const std::vector &in, const std::vector } template -void jit_fill_emitter::fill_full(const std::vector &out) const { +void jit_fill_emitter::fill_full(const std::vector& out) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg dst = TReg(out[0]); @@ -71,28 +72,28 @@ void jit_fill_emitter::fill_full(const std::vector &out) const { } template -void jit_fill_emitter::fill_tail(const std::vector &in, const std::vector &out) const { +void jit_fill_emitter::fill_tail(const std::vector& in, const std::vector& out) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg dst = TReg(out[0]); switch (offset) { - case 1: - h->ld1(dst.s[1], table_val2("value", sizeof(float))); - h->ld1(dst.d[1], table_val2("value", 2 * sizeof(float))); - break; - case 2: - h->ld1(dst.d[1], table_val2("value", 2 * sizeof(float))); - break; - case 3: - h->ld1(dst.s[3], table_val2("value", 3 * sizeof(float))); - break; - case 4: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Fill emitter has unexpected offset ", offset); + case 1: + h->ld1(dst.s[1], table_val2("value", sizeof(float))); + h->ld1(dst.d[1], table_val2("value", 2 * sizeof(float))); + break; + case 2: + h->ld1(dst.d[1], table_val2("value", 2 * sizeof(float))); + break; + case 3: + h->ld1(dst.s[3], table_val2("value", 3 * sizeof(float))); + break; + case 4: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Fill emitter has unexpected offset ", offset); } } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.hpp index 0ce0ac62d03979..7c827e4920d5eb 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_fill_emitter.hpp @@ -16,29 +16,34 @@ class jit_fill_emitter : public jit_emitter { dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 1;} + size_t get_inputs_count() const override { + return 1; + } protected: size_t get_aux_gprs_count() const override; private: - void emit_impl(const std::vector& in, - const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; template - void fill_full(const std::vector &out) const; + void fill_full(const std::vector& out) const; template - void fill_tail(const std::vector &in, const std::vector &out) const; + void fill_tail(const std::vector& in, const std::vector& out) const; - bool is_full_reg() const { return offset == 0; } - bool is_optimized() const { return is_full_reg() && fill_value == uint32_t(0x0); } + bool is_full_reg() const { + return offset == 0; + } + bool is_optimized() const { + return is_full_reg() && fill_value == uint32_t(0x0); + } size_t offset = 0; uint32_t fill_value = 0x0; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp index 8f7a54dc9ebdb3..32ed1a844b6724 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp @@ -19,18 +19,25 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; inline static std::vector transform_idxs_to_regs(const std::vector& idxs) { std::vector regs(idxs.size(), XReg(0)); - std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx){return XReg(idx);}); + std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx) { + return XReg(idx); + }); return regs; } inline static std::vector transform_snippets_regs_to_idxs(const std::vector& regs) { std::vector idxs(regs.size()); - std::transform(regs.cbegin(), regs.cend(), idxs.begin(), [](const snippets::Reg& reg) { return reg.idx; }); + std::transform(regs.cbegin(), regs.cend(), idxs.begin(), [](const snippets::Reg& reg) { + return reg.idx; + }); return idxs; } -jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), reg_runtime_params_idx(Operand::X0) { +jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, + cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr) + : jit_emitter(h, isa), + reg_runtime_params_idx(Operand::X0) { const auto kernel = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "Invoked with invalid op argument"); OV_CPU_JIT_EMITTER_ASSERT(!kernel->region->empty(), "Invoked with empty body"); @@ -113,35 +120,50 @@ void jit_kernel_emitter::init_reg_pools(const std::set& gpr_blacklist, c gp_regs_pool[i] = vec_regs_pool[i] = 31 - i; auto remove_regs_from_pool = [](std::vector& pool, const std::set& to_remove) { // It's important to keep the order of other elements - pool.erase(std::remove_if(pool.begin(), pool.end(), - [&](size_t x) {return to_remove.count(x) != 0;}), pool.end()); + pool.erase(std::remove_if(pool.begin(), + pool.end(), + [&](size_t x) { + return to_remove.count(x) != 0; + }), + pool.end()); }; - std::set gprs_blacklist_extended{Operand::X18, Operand::X23, Operand::X24, Operand::X28, Operand::X29, Operand::SP}; + std::set gprs_blacklist_extended{Operand::X18, + Operand::X23, + Operand::X24, + Operand::X28, + Operand::X29, + Operand::SP}; gprs_blacklist_extended.insert(gpr_blacklist.begin(), gpr_blacklist.end()); // Reserve reg_indexes_idx and reg_runtime_params_idx, since they'll be used to pass runtime call args to kernel remove_regs_from_pool(gp_regs_pool, gprs_blacklist_extended); remove_regs_from_pool(vec_regs_pool, vec_blacklist); } -void jit_kernel_emitter::emit_code(const std::vector &in, const std::vector &out, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_kernel_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); emit_impl(in, out); } -void jit_kernel_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_kernel_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(in.empty() && out.empty(), ": Expects 0 registers on input and output"); const auto num_params = num_inputs + num_outputs + num_unique_buffers; // The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount OV_CPU_JIT_EMITTER_ASSERT(data_ptr_regs_idx.size() == num_params, - "Number of inputs and outputs is inconsistent with the number of allocated registers ", num_params, - " data_ptr_regs_idx.size() = ", data_ptr_regs_idx.size()); + "Number of inputs and outputs is inconsistent with the number of allocated registers ", + num_params, + " data_ptr_regs_idx.size() = ", + data_ptr_regs_idx.size()); } void jit_kernel_emitter::init_body_regs(const std::set& kernel_regs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) { + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) { // Initialize pools of gp and vec registers - // Reserve kernel regs (reg_indexes_idx and, if there is, reg_runtime_params_idx), since they'll be used to pass runtime call args to kernel + // Reserve kernel regs (reg_indexes_idx and, if there is, reg_runtime_params_idx), since they'll be used to pass + // runtime call args to kernel init_reg_pools(kernel_regs, {}); mapping_info gpr_map_pool({}, gp_regs_pool); @@ -175,9 +197,11 @@ void jit_kernel_emitter::emit_impl(const std::vector& in, const std::vec h->postamble(); } -jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, +jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_kernel_emitter(h, isa, expr), reg_indexes_idx(Operand::X1) { + : jit_kernel_emitter(h, isa, expr), + reg_indexes_idx(Operand::X1) { const auto kernel = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "Expectes KernelStatic expression"); jcp = *reinterpret_cast(kernel->compile_params); @@ -219,24 +243,29 @@ void jit_kernel_static_emitter::init_data_pointers(const std::vector& data // NOTE: Snippets Buffer Scratchpad has the common data pointer for all Buffers (even with different ID). // The accessing memory is covered by correct offsets in each Buffer and the corresponding MemoryAccess ops for (size_t i = 0; i < num_unique_buffers; i++) { - h->ldr(data_ptr_regs[num_params + i], ptr(reg_runtime_params, static_cast(GET_OFF(buffer_scratchpad_ptr)))); + h->ldr(data_ptr_regs[num_params + i], + ptr(reg_runtime_params, static_cast(GET_OFF(buffer_scratchpad_ptr)))); } for (size_t i = 0; i < num_params; i++) { if (i < num_inputs) - h->ldr(data_ptr_regs[i], ptr(reg_runtime_params, static_cast(GET_OFF(src_ptrs) + i * sizeof(void*)))); + h->ldr(data_ptr_regs[i], + ptr(reg_runtime_params, static_cast(GET_OFF(src_ptrs) + i * sizeof(void*)))); else - h->ldr(data_ptr_regs[i], ptr(reg_runtime_params, static_cast(GET_OFF(dst_ptrs) + (i - num_inputs) * sizeof(void*)))); + h->ldr(data_ptr_regs[i], + ptr(reg_runtime_params, static_cast(GET_OFF(dst_ptrs) + (i - num_inputs) * sizeof(void*)))); init_ptr_with_offset(data_ptr_regs[i], data_offsets[i]); } } -jit_kernel_dynamic_emitter::jit_kernel_dynamic_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, +jit_kernel_dynamic_emitter::jit_kernel_dynamic_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) : jit_kernel_emitter(h, isa, expr) { const auto kernel = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel, "Expectes KernelDynamic expression"); - // - Reserve reg_runtime_params_idx, since it wll be used to pass runtime call args to all dynamic emitters that needs runtime args + // - Reserve reg_runtime_params_idx, since it wll be used to pass runtime call args to all dynamic emitters that + // needs runtime args // - We cannot assign this register to the body emitters since runtime params MUST be valid during whole execution // for all dynamic emitters init_body_regs({reg_runtime_params_idx}); @@ -247,16 +276,19 @@ void jit_kernel_dynamic_emitter::init_data_pointers(const std::vector& dat const auto num_params = num_inputs + num_outputs; for (size_t i = 0; i < num_unique_buffers; ++i) { - h->ldr(data_ptr_regs[num_params + i], ptr(reg_runtime_params, static_cast(GET_OFF(buffer_scratchpad_ptr)))); + h->ldr(data_ptr_regs[num_params + i], + ptr(reg_runtime_params, static_cast(GET_OFF(buffer_scratchpad_ptr)))); } for (size_t i = 0; i < num_params; i++) { if (i < num_inputs) - h->ldr(data_ptr_regs[i], ptr(reg_runtime_params, static_cast(GET_OFF(src_ptrs) + i * sizeof(void*)))); + h->ldr(data_ptr_regs[i], + ptr(reg_runtime_params, static_cast(GET_OFF(src_ptrs) + i * sizeof(void*)))); else - h->ldr(data_ptr_regs[i], ptr(reg_runtime_params, static_cast(GET_OFF(dst_ptrs) + (i - num_inputs) * sizeof(void*)))); + h->ldr(data_ptr_regs[i], + ptr(reg_runtime_params, static_cast(GET_OFF(dst_ptrs) + (i - num_inputs) * sizeof(void*)))); } } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp index 63bac54e5c1f26..0ede91f100f110 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp @@ -5,8 +5,8 @@ #pragma once #include "emitters/plugin/aarch64/jit_emitter.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" #include "emitters/snippets/jit_container_emitter.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" namespace ov { namespace intel_cpu { @@ -15,8 +15,9 @@ namespace aarch64 { /// /// \brief Kernel is the only entry point to Codogen Jit compilation. Kernel perform abstract-to-physical register /// mapping and creates a pools of available gpr and vec registers. Kernel usually contains (at least one) -/// jit_loop_begin_emitter and jit_loop_end_emitter pair. In general the enclosed emitters should be organized in the following way: -/// jit_kernel_emitter { /* entry point, maps registers, creates pools of available registers */ +/// jit_loop_begin_emitter and jit_loop_end_emitter pair. In general the enclosed emitters should be organized in the +/// following way: jit_kernel_emitter { /* entry point, maps registers, creates pools of available +/// registers */ /// 1.S jit_loop_begin_emitter /* Scalar Loop over the outer dimension [START] */ /// 2.S jit_loop_begin_emitter /* inner vector loop [START] */ /// ... /* All the necessary Load/Strore/elementwise emitters */ @@ -31,21 +32,29 @@ namespace aarch64 { class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { public: - jit_kernel_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - - size_t get_inputs_count() const override {return 0;} - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + jit_kernel_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); + + size_t get_inputs_count() const override { + return 0; + } + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; protected: void validate_arguments(const std::vector& in, const std::vector& out) const override; - void init_body_regs(const std::set& kernel_regs, const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}); + void init_body_regs(const std::set& kernel_regs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}); /** - * @brief populates physical registers pools for x86 (both vec and gp). + * @brief populates physical registers pools for x86 (both vec and gp). * Skips stack-related gprs and extra gprs passed as arguments. * @arg gpr_blacklist - set of gp registers that should not be added to register pool * @arg vec_blacklist - set of vec registers should not be added to register pool - */ + */ void init_reg_pools(const std::set& gpr_blacklist, const std::set& vec_blacklist); virtual void init_data_pointers(const std::vector& data_ptr_regs) const = 0; @@ -69,14 +78,15 @@ class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { std::shared_ptr body; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_kernel_emitter(const jit_kernel_emitter *emitter); + friend std::string init_info_jit_kernel_emitter(const jit_kernel_emitter* emitter); #endif }; class jit_kernel_static_emitter : public jit_kernel_emitter { public: jit_kernel_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); private: void init_data_pointers(const std::vector& data_ptr_regs) const override; @@ -86,23 +96,24 @@ class jit_kernel_static_emitter : public jit_kernel_emitter { std::vector> data_offsets; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter *emitter); + friend std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter* emitter); #endif }; class jit_kernel_dynamic_emitter : public jit_kernel_emitter { public: jit_kernel_dynamic_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); private: void init_data_pointers(const std::vector& data_ptr_regs) const override; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter *emitter); + friend std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter* emitter); #endif }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp index 2b5b41fb912606..0666505a6d31ab 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp @@ -3,8 +3,9 @@ // #include "jit_loop_emitters.hpp" -#include "jit_kernel_emitter.hpp" + #include "emitters/utils.hpp" +#include "jit_kernel_emitter.hpp" using namespace Xbyak_aarch64; @@ -18,9 +19,11 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; /* ================== jit_loop_begin_emitter ====================== */ -jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, +jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), loop_begin_label{new Xbyak_aarch64::Label()} { + : jit_emitter(h, isa), + loop_begin_label{new Xbyak_aarch64::Label()} { const auto loop_begin = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(loop_begin, "expects LoopBegin expression"); const auto loop_end = loop_begin->get_loop_end(); @@ -31,15 +34,17 @@ jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::aarch64::jit_gen in_out_type_ = emitter_in_out_map::gpr_to_gpr; } -void jit_loop_begin_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_loop_begin_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(in.empty(), "Invalid inputs size: expected 0 got " + std::to_string(in.size())); // Note: the only expected output is work amount register (communicated to jit_loop_end_emitter) OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Invalid outputs size: expected 1 got " + std::to_string(out.size())); OV_CPU_JIT_EMITTER_ASSERT(loop_begin_label != nullptr, "has not inited label!"); } -void jit_loop_begin_emitter::emit_code(const std::vector &in, const std::vector &out, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_loop_begin_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); emit_impl(in, out); } @@ -56,9 +61,11 @@ void jit_loop_begin_emitter::emit_impl(const std::vector& in, const std: /* ================== jit_loop_end_emitter ====================== */ -jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, +jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), loop_begin_label{nullptr} { + : jit_emitter(h, isa), + loop_begin_label{nullptr} { in_out_type_ = emitter_in_out_map::gpr_to_gpr; const auto loop_end = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(loop_end != nullptr, "expected LoopEnd expr"); @@ -79,27 +86,49 @@ jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generat loop_begin_label = loop_begin_emitter->get_begin_label(); } -ov::snippets::lowered::ExpressionPtr jit_loop_end_emitter::get_loop_begin_expr(const ov::snippets::lowered::ExpressionPtr& expr) { +ov::snippets::lowered::ExpressionPtr jit_loop_end_emitter::get_loop_begin_expr( + const ov::snippets::lowered::ExpressionPtr& expr) { const auto begin_expr = expr->get_input_port_connectors().back()->get_source().get_expr(); OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(begin_expr->get_node()), "LoopEnd expression must have th last port connector to LoopBegin"); return begin_expr; } -void jit_loop_end_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { -const auto io_size = num_inputs + num_outputs; +void jit_loop_end_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { + const auto io_size = num_inputs + num_outputs; OV_CPU_JIT_EMITTER_ASSERT(out.size() == 0, "Invalid number of out arguments: expected ", 0, " got ", out.size()); - OV_CPU_JIT_EMITTER_ASSERT(in.size() == io_size + 1, "Invalid number of in arguments: expected ", io_size + 1, " got ", in.size()); - OV_CPU_JIT_EMITTER_ASSERT(is_incremented.size() == io_size, "Invalid is_incremented size: expected ", io_size, " got ", is_incremented.size()); - OV_CPU_JIT_EMITTER_ASSERT(ptr_increments.size() == io_size, "Invalid ptr_increments size: expected ", io_size, " got ", ptr_increments.size()); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == io_size + 1, + "Invalid number of in arguments: expected ", + io_size + 1, + " got ", + in.size()); + OV_CPU_JIT_EMITTER_ASSERT(is_incremented.size() == io_size, + "Invalid is_incremented size: expected ", + io_size, + " got ", + is_incremented.size()); + OV_CPU_JIT_EMITTER_ASSERT(ptr_increments.size() == io_size, + "Invalid ptr_increments size: expected ", + io_size, + " got ", + ptr_increments.size()); OV_CPU_JIT_EMITTER_ASSERT(finalization_offsets.size() == io_size, - "Invalid finalization_offsets size: expected: ", io_size, " got ", finalization_offsets.size()); - OV_CPU_JIT_EMITTER_ASSERT(data_sizes.size() == io_size, "Invalid data_sizes size: expected: ", io_size, " got ", data_sizes.size()); + "Invalid finalization_offsets size: expected: ", + io_size, + " got ", + finalization_offsets.size()); + OV_CPU_JIT_EMITTER_ASSERT(data_sizes.size() == io_size, + "Invalid data_sizes size: expected: ", + io_size, + " got ", + data_sizes.size()); OV_CPU_JIT_EMITTER_ASSERT(loop_begin_label != nullptr, "has not inited begin label!"); } -void jit_loop_end_emitter::emit_code(const std::vector &in, const std::vector &out, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_loop_end_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); emit_impl(in, out); } @@ -118,7 +147,7 @@ void jit_loop_end_emitter::emit_impl(const std::vector& in, const std::v if (ptr_increments[idx] > 0) { h->add_imm(data_reg, data_reg, ptr_increments[idx] * wa_increment * data_sizes[idx], h->X_TMP_0); } else if (ptr_increments[idx] < 0) { - h->sub_imm(data_reg, data_reg, - ptr_increments[idx] * wa_increment * data_sizes[idx], h->X_TMP_0); + h->sub_imm(data_reg, data_reg, -ptr_increments[idx] * wa_increment * data_sizes[idx], h->X_TMP_0); } } h->sub_imm(reg_work_amount, reg_work_amount, wa_increment, h->X_TMP_0); @@ -133,13 +162,13 @@ void jit_loop_end_emitter::emit_impl(const std::vector& in, const std::v if (finalization_offsets[idx] > 0) { h->add_imm(data_reg, data_reg, finalization_offsets[idx] * data_sizes[idx], h->X_TMP_0); } else if (finalization_offsets[idx] < 0) { - h->sub_imm(data_reg, data_reg, - finalization_offsets[idx] * data_sizes[idx], h->X_TMP_0); + h->sub_imm(data_reg, data_reg, -finalization_offsets[idx] * data_sizes[idx], h->X_TMP_0); } } } /* ============================================================== */ -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.hpp index 6ec87835821df2..c89928353646cd 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.hpp @@ -5,7 +5,6 @@ #pragma once #include "emitters/plugin/aarch64/jit_emitter.hpp" - #include "snippets/op/loop.hpp" namespace ov { @@ -14,20 +13,27 @@ namespace aarch64 { /* ================== jit_loop_begin_emitter ====================== */ -class jit_loop_begin_emitter: public jit_emitter { +class jit_loop_begin_emitter : public jit_emitter { public: - jit_loop_begin_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, + jit_loop_begin_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override { return 0; } + size_t get_inputs_count() const override { + return 0; + } - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; - std::shared_ptr get_begin_label() { return loop_begin_label; } + std::shared_ptr get_begin_label() { + return loop_begin_label; + } protected: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; std::shared_ptr loop_begin_label; @@ -40,18 +46,23 @@ class jit_loop_begin_emitter: public jit_emitter { /* ================== jit_loop_end_emitter ====================== */ -class jit_loop_end_emitter: public jit_emitter { +class jit_loop_end_emitter : public jit_emitter { public: - jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override { return 0; } + size_t get_inputs_count() const override { + return 0; + } - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; protected: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; static ov::snippets::lowered::ExpressionPtr get_loop_begin_expr(const ov::snippets::lowered::ExpressionPtr& expr); @@ -70,6 +81,6 @@ class jit_loop_end_emitter: public jit_emitter { /* ============================================================== */ -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index d19843f395d2f3..9989f3431fb2a8 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_memory_emitters.hpp" + #include "emitters/utils.hpp" using namespace Xbyak_aarch64; @@ -15,15 +16,18 @@ using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; -jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_emitter(h, isa) { +jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_emitter(h, isa) { const auto n = expr->get_node(); src_prc = n->get_input_element_type(0); dst_prc = n->get_output_element_type(0); } -jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { - bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && - src_prc == dst_prc; +jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_memory_emitter(h, isa, expr) { + bool is_supported_precision = + one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + src_prc == dst_prc; OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); const auto load = std::dynamic_pointer_cast(expr->get_node()); @@ -34,8 +38,7 @@ jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa load_emitter.reset(new jit_load_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); } -void jit_load_memory_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_load_memory_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in, out); } else { @@ -44,7 +47,7 @@ void jit_load_memory_emitter::emit_impl(const std::vector& in, } template -void jit_load_memory_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_load_memory_emitter::emit_isa(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(load_emitter != nullptr, "Load CPU emitter isn't initialized!"); load_emitter->emit_code(in, out, aux_vec_idxs, aux_gpr_idxs); @@ -56,7 +59,8 @@ void jit_load_memory_emitter::emit_data() const { jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { - OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Only support equal input and output types but gets ", + OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, + "Only support equal input and output types but gets ", src_prc.get_type_name(), " and ", dst_prc.get_type_name()); @@ -68,8 +72,7 @@ jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa in_out_type_ = emitter_in_out_map::gpr_to_vec; } -void jit_load_broadcast_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_load_broadcast_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in, out); } else { @@ -78,7 +81,7 @@ void jit_load_broadcast_emitter::emit_impl(const std::vector& in, } template -void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_load_broadcast_emitter::emit_isa(const std::vector& in, const std::vector& out) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in[0]); TReg dst = TReg(out[0]); @@ -86,9 +89,11 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const s h->uni_ld1rw(dst.s, src, byte_offset); } -jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { - bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && - src_prc == dst_prc; +jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_memory_emitter(h, isa, expr) { + bool is_supported_precision = + one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + src_prc == dst_prc; OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); const auto store = ov::as_type_ptr(expr->get_node()); @@ -99,8 +104,7 @@ jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t i store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); } -void jit_store_memory_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_store_memory_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in, out); } else { @@ -109,7 +113,7 @@ void jit_store_memory_emitter::emit_impl(const std::vector& in, } template -void jit_store_memory_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_store_memory_emitter::emit_isa(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(store_emitter != nullptr, "Store CPU emitter isn't initialized!"); store_emitter->emit_code(in, out, aux_vec_idxs, aux_gpr_idxs); @@ -119,6 +123,6 @@ void jit_store_memory_emitter::emit_data() const { store_emitter->emit_data(); } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp index ba0b4e4acfedb4..edb85751f9086d 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp @@ -11,11 +11,11 @@ namespace ov { namespace intel_cpu { namespace aarch64 { -class jit_memory_emitter : public jit_emitter { +class jit_memory_emitter : public jit_emitter { public: jit_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); protected: ov::element::Type src_prc; @@ -28,17 +28,18 @@ class jit_memory_emitter : public jit_emitter { class jit_load_memory_emitter : public jit_memory_emitter { public: jit_load_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 1;} + size_t get_inputs_count() const override { + return 1; + } private: - void emit_impl(const std::vector& in, - const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; void emit_data() const override; private: @@ -48,39 +49,41 @@ class jit_load_memory_emitter : public jit_memory_emitter { class jit_load_broadcast_emitter : public jit_memory_emitter { public: jit_load_broadcast_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 1;} + size_t get_inputs_count() const override { + return 1; + } private: - void emit_impl(const std::vector& in, - const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; }; -class jit_store_memory_emitter : public jit_memory_emitter { +class jit_store_memory_emitter : public jit_memory_emitter { public: jit_store_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 1;} + size_t get_inputs_count() const override { + return 1; + } private: - void emit_impl(const std::vector& in, - const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; void emit_data() const override; private: std::unique_ptr store_emitter = nullptr; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.cpp index 6529312ae1095a..69fcc7a92fd259 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_snippets_emitters.hpp" + #include "cpu/aarch64/jit_generator.hpp" #include "cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_adr.h" #include "emitters/utils.hpp" @@ -17,7 +18,8 @@ using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; -jit_nop_emitter::jit_nop_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : aarch64::jit_emitter(h, isa) { +jit_nop_emitter::jit_nop_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : aarch64::jit_emitter(h, isa) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; } @@ -29,14 +31,12 @@ jit_broadcast_move_emitter::jit_broadcast_move_emitter(jit_generator* h, cpu_isa n->get_input_element_type(0), " and ", n->get_output_element_type(0)); - OV_CPU_JIT_EMITTER_ASSERT(n->get_input_element_type(0) == ov::element::f32, - "Only supports FP32 precision."); + OV_CPU_JIT_EMITTER_ASSERT(n->get_input_element_type(0) == ov::element::f32, "Only supports FP32 precision."); byte_size = n->get_input_element_type(0).size(); } -void jit_broadcast_move_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_broadcast_move_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in, out); } else { @@ -45,42 +45,42 @@ void jit_broadcast_move_emitter::emit_impl(const std::vector& in, } template -void jit_broadcast_move_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_broadcast_move_emitter::emit_isa(const std::vector& in, const std::vector& out) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in[0]); TReg dst = TReg(out[0]); switch (byte_size) { - case 4: - h->dup(dst.s, src.s[0]); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported data size ", byte_size); + case 4: + h->dup(dst.s, src.s[0]); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported data size ", byte_size); } } -jit_scalar_emitter::jit_scalar_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_emitter(h, isa) { +jit_scalar_emitter::jit_scalar_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_emitter(h, isa) { const auto n = expr->get_node(); const auto& precision = n->get_output_element_type(0); switch (precision) { - case element::i32: { - value = ov::as_type_ptr(n)->cast_vector()[0]; - break; - } - case element::f32: { - value = dnnl::impl::float2int(ov::as_type_ptr(n)->cast_vector()[0]); - break; - } - default: { - OV_CPU_JIT_EMITTER_THROW("Doesn't support precision ", precision); - } + case element::i32: { + value = ov::as_type_ptr(n)->cast_vector()[0]; + break; + } + case element::f32: { + value = dnnl::impl::float2int(ov::as_type_ptr(n)->cast_vector()[0]); + break; + } + default: { + OV_CPU_JIT_EMITTER_THROW("Doesn't support precision ", precision); + } } push_arg_entry_of("scalar", value, true); prepare_table(); } -void jit_scalar_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_scalar_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { emit_isa(in, out); } else { @@ -89,7 +89,7 @@ void jit_scalar_emitter::emit_impl(const std::vector& in, } template -void jit_scalar_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_scalar_emitter::emit_isa(const std::vector& in, const std::vector& out) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg dst = TReg(out[0]); AdrImm src = table_val("scalar"); @@ -97,6 +97,6 @@ void jit_scalar_emitter::emit_isa(const std::vector &in, const std::vect h->uni_ld1rw(dst.s, src.getXn(), src.getImm()); } -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.hpp index 0f05024ed12168..13f9aa70fb2c8e 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_snippets_emitters.hpp @@ -16,11 +16,12 @@ class jit_nop_emitter : public jit_emitter { dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 0;} + size_t get_inputs_count() const override { + return 0; + } private: - void emit_impl(const std::vector& in, - const std::vector& out) const override {} + void emit_impl(const std::vector& in, const std::vector& out) const override {} }; class jit_broadcast_move_emitter : public jit_emitter { @@ -29,14 +30,15 @@ class jit_broadcast_move_emitter : public jit_emitter { dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 1;} + size_t get_inputs_count() const override { + return 1; + } private: - void emit_impl(const std::vector& in, - const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; private: size_t byte_size = 0lu; @@ -48,22 +50,25 @@ class jit_scalar_emitter : public jit_emitter { dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_count() const override {return 0;} + size_t get_inputs_count() const override { + return 0; + } protected: - size_t get_aux_gprs_count() const override {return 1;} + size_t get_aux_gprs_count() const override { + return 1; + } private: - void emit_impl(const std::vector& in, - const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; private: int32_t value; }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp index 912fe23fcd1fcf..86d090a858fd7b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp @@ -3,6 +3,7 @@ // #include "jit_eltwise.hpp" + #include namespace ov { @@ -10,13 +11,12 @@ namespace intel_cpu { namespace executors { namespace aarch64 { -bool JitEltwiseExecutor::isSupported( - const Algorithm& algorithm, - const std::vector& input_precisions, - const std::vector& output_precisions, - const float alpha, - const float beta, - const float gamma) { +bool JitEltwiseExecutor::isSupported(const Algorithm& algorithm, + const std::vector& input_precisions, + const std::vector& output_precisions, + const float alpha, + const float beta, + const float gamma) { const auto is_supported = one_of(algorithm, Algorithm::EltwiseAbs, Algorithm::EltwiseAdd, @@ -67,10 +67,9 @@ bool JitEltwiseExecutor::isSupported( return false; } - const auto check_precisions = []( - const std::vector& input_precisions, - const std::vector& output_precisions, - const std::set& supported_precisions) { + const auto check_precisions = [](const std::vector& input_precisions, + const std::vector& output_precisions, + const std::set& supported_precisions) { if (std::any_of(input_precisions.begin(), input_precisions.end(), [&supported_precisions](const ov::element::Type& precision) { @@ -92,15 +91,13 @@ bool JitEltwiseExecutor::isSupported( const std::set supported_precisions = // Divide and Floor (issue #138629) operations are supported for fp32 and fp16 only. - ((algorithm == Algorithm::EltwiseDivide) || (algorithm == Algorithm::EltwiseFloor)) ? - std::set { ov::element::f16, ov::element::f32 } : - std::set { - ov::element::f16, - ov::element::f32, - ov::element::i32, - ov::element::i8, - ov::element::u8 - }; + ((algorithm == Algorithm::EltwiseDivide) || (algorithm == Algorithm::EltwiseFloor)) + ? std::set{ov::element::f16, ov::element::f32} + : std::set{ov::element::f16, + ov::element::f32, + ov::element::i32, + ov::element::i8, + ov::element::u8}; if (!check_precisions(input_precisions, output_precisions, supported_precisions)) { return false; @@ -111,20 +108,20 @@ bool JitEltwiseExecutor::isSupported( JitEltwiseExecutor::JitEltwiseExecutor(const ExecutorContext::CPtr context) : EltwiseExecutor(context) {} -bool JitEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, - const std::vector &srcDescs, - const std::vector &dstDescs, - const std::vector &postOps) { +bool JitEltwiseExecutor::init(const EltwiseAttrs& eltwiseAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const std::vector& postOps) { return true; } -void JitEltwiseExecutor::exec(const std::vector &src, - const std::vector &dst, - const void *post_ops_data_) { +void JitEltwiseExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) { exec_func(); } -} // namespace aarch64 -} // namespace executors -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace executors +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.hpp index 5244a80b542fe5..adaaea6a738c7a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.hpp @@ -5,8 +5,8 @@ #pragma once #include "cpu_types.h" -#include "nodes/executors/eltwise.hpp" #include "node.h" +#include "nodes/executors/eltwise.hpp" namespace ov { namespace intel_cpu { @@ -17,13 +17,12 @@ class JitEltwiseExecutor : public EltwiseExecutor { public: explicit JitEltwiseExecutor(const ExecutorContext::CPtr context); - static bool isSupported( - const Algorithm& algorithm, - const std::vector& input_precisions, - const std::vector& output_precisions, - const float alpha, - const float beta, - const float gamma); + static bool isSupported(const Algorithm& algorithm, + const std::vector& input_precisions, + const std::vector& output_precisions, + const float alpha, + const float beta, + const float gamma); bool init(const EltwiseAttrs& eltwiseAttrs, const std::vector& srcDescs, @@ -32,7 +31,7 @@ class JitEltwiseExecutor : public EltwiseExecutor { void exec(const std::vector& src, const std::vector& dst, - const void *post_ops_data_) override; + const void* post_ops_data_) override; impl_desc_type getImplType() const override { return impl_desc_type::asimd; @@ -42,7 +41,7 @@ class JitEltwiseExecutor : public EltwiseExecutor { std::function exec_func; }; -} // namespace aarch64 -} // namespace executors -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace executors +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp index 646cf47c1bcf6c..4804c7b4efe252 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp @@ -3,6 +3,7 @@ // #include "acl_common_executor.hpp" + #include "acl_utils.hpp" #include "nodes/executors/memory_arguments.hpp" #include "utils/debug_capabilities.h" @@ -14,12 +15,12 @@ static const std::unordered_map argConvert = { {ARG_SRC_0, ACL_SRC_0}, {ARG_SRC_1, ACL_SRC_1}, {ARG_SRC_2, ACL_SRC_2}, - {ARG_BIAS, ACL_BIAS}, - {ARG_WEI, ACL_WEI}, - {ARG_DST, ACL_DST}, + {ARG_BIAS, ACL_BIAS}, + {ARG_WEI, ACL_WEI}, + {ARG_DST, ACL_DST}, }; -using ACLTypes = std::array; +using ACLTypes = std::array; using ACLLayouts = std::array; static void initACLTensorParams(const MemoryPtr& memoryPtr, @@ -39,14 +40,11 @@ static void initACLTensorParams(const MemoryPtr& memoryPtr, } static std::shared_ptr initTensorInfo(const arm_compute::TensorShape& tensorShape, - const arm_compute::DataType& dataType, - const arm_compute::DataLayout& dataLayout) { + const arm_compute::DataType& dataType, + const arm_compute::DataLayout& dataLayout) { std::shared_ptr aclMemoryInfo = nullptr; if (dataType != arm_compute::DataType::UNKNOWN) { - aclMemoryInfo = std::make_shared( - tensorShape, 1, - dataType, - dataLayout); + aclMemoryInfo = std::make_shared(tensorShape, 1, dataType, dataLayout); } return aclMemoryInfo; } @@ -66,14 +64,15 @@ ACLCommonExecutor::ACLCommonExecutor() { } } -bool ACLCommonExecutor::update(const MemoryArgs &memory) { +bool ACLCommonExecutor::update(const MemoryArgs& memory) { // Initialize ACL tensors params - ACLShapes aclMemoryShapes; - ACLTypes aclDataType{}; + ACLShapes aclMemoryShapes; + ACLTypes aclDataType{}; ACLLayouts aclDataLayout{}; for (auto& cpu_mem_ptr : memory) { const ACLArgs index = argConvert.at(cpu_mem_ptr.first); - initACLTensorParams(cpu_mem_ptr.second, aclTensorAttrs, + initACLTensorParams(cpu_mem_ptr.second, + aclTensorAttrs, aclMemoryShapes[index], aclDataType[index], aclDataLayout[index]); @@ -110,7 +109,7 @@ bool ACLCommonExecutor::update(const MemoryArgs &memory) { return true; } -void ACLCommonExecutor::execute(const MemoryArgs &memory) { +void ACLCommonExecutor::execute(const MemoryArgs& memory) { // TODO: Move import_memory() to update() function - CVS-145871 for (auto& cpu_mem_ptr : memory) { const ACLArgs index = argConvert.at(cpu_mem_ptr.first); @@ -129,5 +128,5 @@ ACLCommonExecutor::~ACLCommonExecutor() { } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp index 1a5a00c7a85f7a..94c5dbe219aae8 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp @@ -4,27 +4,19 @@ #pragma once +#include "arm_compute/runtime/NEON/NEFunctions.h" #include "cpu_memory.h" #include "nodes/executors/executor.hpp" -#include "arm_compute/runtime/NEON/NEFunctions.h" namespace ov { namespace intel_cpu { -enum ACLArgs { - ACL_SRC_0, - ACL_SRC_1, - ACL_SRC_2, - ACL_BIAS, - ACL_WEI, - ACL_DST, - COUNT_OF_ARGS -}; +enum ACLArgs { ACL_SRC_0, ACL_SRC_1, ACL_SRC_2, ACL_BIAS, ACL_WEI, ACL_DST, COUNT_OF_ARGS }; using ACLFunction = std::unique_ptr; -using ACLShapes = std::array; -using ACLInfos = std::array, ACLArgs::COUNT_OF_ARGS>; -using ACLTensors = std::array, ACLArgs::COUNT_OF_ARGS>; +using ACLShapes = std::array; +using ACLInfos = std::array, ACLArgs::COUNT_OF_ARGS>; +using ACLTensors = std::array, ACLArgs::COUNT_OF_ARGS>; struct ACLTensorAttrs { bool hasLayoutTypeNHWC = false; @@ -50,6 +42,7 @@ class ACLCommonExecutor : public Executor { protected: ACLTensorAttrs aclTensorAttrs; + private: ACLTensors aclMemoryTensors; ACLInfos aclMemoryInfos; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.cpp index 1bc0585930387f..ed12b0b76a2c1e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.cpp @@ -3,6 +3,7 @@ // #include "acl_convert.hpp" + #include "acl_utils.hpp" namespace ov { @@ -10,7 +11,6 @@ namespace intel_cpu { using namespace arm_compute; - bool ACLConvertExecutor::init(const ConvertParams& convertParams, const MemoryDescPtr& srcDesc, const MemoryDescPtr& dstDesc, @@ -51,10 +51,14 @@ bool ACLConvertExecutor::init(const ConvertParams& convertParams, if (isCopyOp) { acl_copy = std::make_unique(); - configureThreadSafe([&] { acl_copy->configure(&srcTensor, &dstTensor); }); + configureThreadSafe([&] { + acl_copy->configure(&srcTensor, &dstTensor); + }); } else { acl_cast = std::make_unique(); - configureThreadSafe([&] { acl_cast->configure(&srcTensor, &dstTensor, ConvertPolicy::SATURATE); }); + configureThreadSafe([&] { + acl_cast->configure(&srcTensor, &dstTensor, ConvertPolicy::SATURATE); + }); } return true; } @@ -91,45 +95,34 @@ bool ACLConvertExecutorBuilder::isSupported(const ConvertParams& convertParams, DEBUG_LOG("NECopy does not support source precision: ", convertParams.srcPrc.to_string()); return false; } - if ((convertParams.srcPrc == ov::element::i8 && !one_of(convertParams.dstPrc, - ov::element::i16, - ov::element::i32, - ov::element::f16, - ov::element::f32)) || + if ((convertParams.srcPrc == ov::element::i8 && + !one_of(convertParams.dstPrc, ov::element::i16, ov::element::i32, ov::element::f16, ov::element::f32)) || (convertParams.srcPrc == ov::element::u8 && !one_of(convertParams.dstPrc, - ov::element::u16, - ov::element::i16, - ov::element::i32, - ov::element::f16, - ov::element::f32)) || - (convertParams.srcPrc == ov::element::u16 && !one_of(convertParams.dstPrc, - ov::element::u8, - ov::element::u32)) || - (convertParams.srcPrc == ov::element::i16 && !one_of(convertParams.dstPrc, - ov::element::i8, - ov::element::u8, - ov::element::i32)) || - (convertParams.srcPrc == ov::element::f16 && !one_of(convertParams.dstPrc, - ov::element::i8, - ov::element::f32, + ov::element::u16, + ov::element::i16, ov::element::i32, - ov::element::u8)) || - (convertParams.srcPrc == ov::element::i32 && !one_of(convertParams.dstPrc, - ov::element::i8, - ov::element::f16, - ov::element::f32, - ov::element::u8)) || - (convertParams.srcPrc == ov::element::f32 && !one_of(convertParams.dstPrc, - ov::element::bf16, ov::element::f16, - ov::element::i32))) { + ov::element::f32)) || + (convertParams.srcPrc == ov::element::u16 && + !one_of(convertParams.dstPrc, ov::element::u8, ov::element::u32)) || + (convertParams.srcPrc == ov::element::i16 && + !one_of(convertParams.dstPrc, ov::element::i8, ov::element::u8, ov::element::i32)) || + (convertParams.srcPrc == ov::element::f16 && + !one_of(convertParams.dstPrc, ov::element::i8, ov::element::f32, ov::element::i32, ov::element::u8)) || + (convertParams.srcPrc == ov::element::i32 && + !one_of(convertParams.dstPrc, ov::element::i8, ov::element::f16, ov::element::f32, ov::element::u8)) || + (convertParams.srcPrc == ov::element::f32 && + !one_of(convertParams.dstPrc, ov::element::bf16, ov::element::f16, ov::element::i32))) { DEBUG_LOG("NECopy does not support passed combination of source and destination precisions. ", - "source precision: ", convertParams.srcPrc.to_string(), " destination precsion: ", convertParams.dstPrc.to_string()); + "source precision: ", + convertParams.srcPrc.to_string(), + " destination precsion: ", + convertParams.dstPrc.to_string()); return false; } } return true; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.hpp index b81e34004f9f31..431f8ce6887cbe 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.hpp @@ -4,9 +4,9 @@ #pragma once +#include "arm_compute/runtime/NEON/NEFunctions.h" #include "nodes/executors/convert.hpp" #include "utils/debug_capabilities.h" -#include "arm_compute/runtime/NEON/NEFunctions.h" namespace ov { namespace intel_cpu { @@ -17,9 +17,12 @@ class ACLConvertExecutor : public ConvertExecutor { bool init(const ConvertParams& convertParams, const MemoryDescPtr& srcDesc, const MemoryDescPtr& dstDesc, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst) override; - impl_desc_type implType() const override { return impl_desc_type::acl; }; + impl_desc_type implType() const override { + return impl_desc_type::acl; + }; + protected: ConvertParams aclConvertParams; bool isCopyOp; @@ -38,5 +41,5 @@ class ACLConvertExecutorBuilder : public ConvertExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.cpp index fa40b3a27322c2..cd5e935b41a2d5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.cpp @@ -3,6 +3,7 @@ // #include "acl_deconv.hpp" + #include "openvino/core/parallel.hpp" namespace ov { @@ -13,7 +14,9 @@ using namespace arm_compute; ACLDeconvTensorInfo getACLDeconvTensorInfo(const DeconvAttrs& deconvAttrs, const std::vector& srcDescs, const std::vector& dstDescs) { - auto func_mod = [](long a) -> unsigned int { return a < 0 ? 0 : a; }; + auto func_mod = [](long a) -> unsigned int { + return a < 0 ? 0 : a; + }; auto pad_l = deconvAttrs.paddingL.size() > 1 ? deconvAttrs.paddingL.at(1) : deconvAttrs.paddingL.at(0); auto pad_r = deconvAttrs.paddingR.size() > 1 ? deconvAttrs.paddingR.at(1) : deconvAttrs.paddingR.at(0); auto pad_t = deconvAttrs.paddingL.at(0); @@ -21,18 +24,30 @@ ACLDeconvTensorInfo getACLDeconvTensorInfo(const DeconvAttrs& deconvAttrs, unsigned int stride_x = (deconvAttrs.stride.size() > 1) ? deconvAttrs.stride.at(1) : deconvAttrs.stride.at(0); unsigned int stride_y = deconvAttrs.stride.at(0); - auto deconv_info = PadStrideInfo(stride_x, stride_y, func_mod(pad_l), func_mod(pad_r), func_mod(pad_t), func_mod(pad_b), DimensionRoundingType::FLOOR); - - auto srcDims = srcDescs[0]->getShape().getDims(); - auto weiDims = srcDescs[1]->getShape().getDims(); - auto dstDims = dstDescs[0]->getShape().getDims(); + auto deconv_info = PadStrideInfo(stride_x, + stride_y, + func_mod(pad_l), + func_mod(pad_r), + func_mod(pad_t), + func_mod(pad_b), + DimensionRoundingType::FLOOR); + + auto srcDims = srcDescs[0]->getShape().getDims(); + auto weiDims = srcDescs[1]->getShape().getDims(); + auto dstDims = dstDescs[0]->getShape().getDims(); // ACL can't work with custom output shape, this we make WA for that problem if (pad_l < 0 || pad_r < 0 || pad_t < 0 || pad_b < 0) { auto out_dims = deconvolution_output_dimensions(srcDims[3], srcDims[2], weiDims[3], weiDims[2], deconv_info); stride_x += (out_dims.first - dstDims[3] - 2 * (pad_l + pad_r)) / (srcDims[3] - 1); stride_y += (out_dims.second - dstDims[2] - 2 * (pad_t + pad_b)) / (srcDims[2] - 1); - deconv_info = PadStrideInfo(stride_x, stride_y, func_mod(pad_l), func_mod(pad_r), func_mod(pad_t), func_mod(pad_b), DimensionRoundingType::FLOOR); + deconv_info = PadStrideInfo(stride_x, + stride_y, + func_mod(pad_l), + func_mod(pad_r), + func_mod(pad_t), + func_mod(pad_b), + DimensionRoundingType::FLOOR); } std::swap(weiDims[0], weiDims[1]); @@ -59,16 +74,18 @@ ACLDeconvTensorInfo getACLDeconvTensorInfo(const DeconvAttrs& deconvAttrs, weiLayout = arm_compute::DataLayout::NHWC; } - TensorInfo srcTensorInfo = TensorInfo(srcVecDims, 1, - precisionToAclDataType(srcDescs[0]->getPrecision()), srcLayout); - TensorInfo weiTensorInfo = TensorInfo(weiVecDims, 1, - precisionToAclDataType(srcDescs[1]->getPrecision()), weiLayout); - TensorInfo dstTensorInfo = TensorInfo(dstVecDims, 1, - precisionToAclDataType(dstDescs[0]->getPrecision()), dstLayout); + TensorInfo srcTensorInfo = + TensorInfo(srcVecDims, 1, precisionToAclDataType(srcDescs[0]->getPrecision()), srcLayout); + TensorInfo weiTensorInfo = + TensorInfo(weiVecDims, 1, precisionToAclDataType(srcDescs[1]->getPrecision()), weiLayout); + TensorInfo dstTensorInfo = + TensorInfo(dstVecDims, 1, precisionToAclDataType(dstDescs[0]->getPrecision()), dstLayout); TensorInfo biasTensorInfo; if (deconvAttrs.withBiasesParam) { - biasTensorInfo = TensorInfo(biasVecDims, 1, - precisionToAclDataType(srcDescs[2]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[2])); + biasTensorInfo = TensorInfo(biasVecDims, + 1, + precisionToAclDataType(srcDescs[2]->getPrecision()), + getAclDataLayoutByMemoryDesc(srcDescs[2])); } return ACLDeconvTensorInfo{srcTensorInfo, weiTensorInfo, biasTensorInfo, dstTensorInfo, deconv_info}; @@ -77,9 +94,9 @@ ACLDeconvTensorInfo getACLDeconvTensorInfo(const DeconvAttrs& deconvAttrs, AclDeconvExecutor::AclDeconvExecutor(const ExecutorContext::CPtr context) : DeconvExecutor(context) {} bool AclDeconvExecutor::init(const DeconvAttrs& deconvAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { this->deconvAttrs = deconvAttrs; ACLDeconvTensorInfo aclDeconvTensorInfo = getACLDeconvTensorInfo(deconvAttrs, srcDescs, dstDescs); TensorInfo srcTensorInfo = aclDeconvTensorInfo.srcTensorInfo; @@ -99,12 +116,17 @@ bool AclDeconvExecutor::init(const DeconvAttrs& deconvAttrs, deconv = std::make_unique(); configureThreadSafe([&] { - deconv->configure(&srcTensor, &weiTensor, deconvAttrs.withBiasesParam ? &biasTensor : nullptr, &dstTensor, deconv_info, deconvAttrs.aclFastMath); + deconv->configure(&srcTensor, + &weiTensor, + deconvAttrs.withBiasesParam ? &biasTensor : nullptr, + &dstTensor, + deconv_info, + deconvAttrs.aclFastMath); }); return true; } -template +template static void transpose_weights(const MemoryCPtr& srcMemPtr, MemoryPtr& newSrcMemPtr, bool isNCHW) { const auto src_data = srcMemPtr->getDataAs(); const auto new_src_data = newSrcMemPtr->getDataAs(); @@ -118,29 +140,17 @@ static void transpose_weights(const MemoryCPtr& srcMemPtr, MemoryPtr& newSrcMemP if (isNCHW) { parallel_for3d(DIM0, DIM1, DIM2, [&](const int dim0, const int dim1, const int dim2) { for (int dim3 = 0; dim3 < DIM3; ++dim3) { - const int src_off = dim0 * DIM1 * DIM2 * DIM3 + - dim1 * DIM2 * DIM3 + - dim2 * DIM3 + - dim3; - const int dst_off = dim1 * DIM0 * DIM2 * DIM3 + - dim0 * DIM2 * DIM3 + - dim2 * DIM3 + - dim3; + const int src_off = dim0 * DIM1 * DIM2 * DIM3 + dim1 * DIM2 * DIM3 + dim2 * DIM3 + dim3; + const int dst_off = dim1 * DIM0 * DIM2 * DIM3 + dim0 * DIM2 * DIM3 + dim2 * DIM3 + dim3; new_src_data[dst_off] = src_data[src_off]; } }); - // 0231 -> 1230 + // 0231 -> 1230 } else { parallel_for3d(DIM0, DIM1, DIM2, [&](const int dim0, const int dim1, const int dim2) { for (int dim3 = 0; dim3 < DIM3; ++dim3) { - const int src_off = dim0 * DIM1 * DIM2 * DIM3 + - dim1 * DIM2 * DIM3 + - dim2 * DIM3 + - dim3; - const int dst_off = dim1 * DIM2 * DIM3 * DIM0 + - dim2 * DIM3 * DIM0 + - dim3 * DIM0 + - dim0; + const int src_off = dim0 * DIM1 * DIM2 * DIM3 + dim1 * DIM2 * DIM3 + dim2 * DIM3 + dim3; + const int dst_off = dim1 * DIM2 * DIM3 * DIM0 + dim2 * DIM3 * DIM0 + dim3 * DIM0 + dim0; new_src_data[dst_off] = src_data[src_off]; } }); @@ -176,7 +186,9 @@ static MemoryPtr prepareWeightMemory(const std::vector& src, const E return create(); } -void AclDeconvExecutor::exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) { +void AclDeconvExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) { // TODO: Remove transpose from exec auto newWei = prepareWeightMemory(src, context); @@ -194,16 +206,19 @@ void AclDeconvExecutor::exec(const std::vector& src, const std::vect biasTensor.allocator()->free(); } -bool AclDeconvExecutorBuilder::customIsSupported(const DeconvAttrs &deconvAttrs, - const std::vector &srcDescs, - const std::vector &dstDescs) { +bool AclDeconvExecutorBuilder::customIsSupported(const DeconvAttrs& deconvAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) { if ((srcDescs[0]->getShape().getDims().size() != 3 && srcDescs[0]->getShape().getDims().size() != 4) || dstDescs[0]->getShape().getDims().size() != srcDescs[0]->getShape().getDims().size() || srcDescs[1]->getShape().getDims().size() != 4) { DEBUG_LOG("AclDeconvExecutor does not support dimension:", - " src[0]=", srcDescs[0]->getShape().getDims().size(), - " src[1]=", srcDescs[1]->getShape().getDims().size(), - " dst[0]=", dstDescs[0]->getShape().getDims().size()); + " src[0]=", + srcDescs[0]->getShape().getDims().size(), + " src[1]=", + srcDescs[1]->getShape().getDims().size(), + " dst[0]=", + dstDescs[0]->getShape().getDims().size()); return false; } @@ -211,62 +226,71 @@ bool AclDeconvExecutorBuilder::customIsSupported(const DeconvAttrs &deconvAttrs, srcDescs[0]->getPrecision() == srcDescs[1]->getPrecision() && srcDescs[1]->getPrecision() == dstDescs[0]->getPrecision())) { DEBUG_LOG("AclDeconvExecutor does not support precisions:", - " src[0]=", srcDescs[0]->getPrecision(), - " src[1]=", srcDescs[1]->getPrecision(), - " dst[0]=", dstDescs[0]->getPrecision()); + " src[0]=", + srcDescs[0]->getPrecision(), + " src[1]=", + srcDescs[1]->getPrecision(), + " dst[0]=", + dstDescs[0]->getPrecision()); return false; } if (deconvAttrs.withBiasesParam && srcDescs[2]->getPrecision() != srcDescs[0]->getPrecision()) { - DEBUG_LOG("AclDeconvExecutor does not support precisions:", - " src[2]=", srcDescs[2]->getPrecision()); + DEBUG_LOG("AclDeconvExecutor does not support precisions:", " src[2]=", srcDescs[2]->getPrecision()); return false; } - if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && - srcDescs[1]->hasLayoutType(LayoutType::ncsp) && + if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && srcDescs[1]->hasLayoutType(LayoutType::ncsp) && dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && // Check weights as ncsp because we remove reorder and will transform ncsp -> nspc in exec() function - srcDescs[1]->hasLayoutType(LayoutType::ncsp) && - dstDescs[0]->hasLayoutType(LayoutType::nspc))) { + srcDescs[1]->hasLayoutType(LayoutType::ncsp) && dstDescs[0]->hasLayoutType(LayoutType::nspc))) { DEBUG_LOG("AclDeconvExecutor does not support layouts:", - " src[0]=", srcDescs[0]->serializeFormat(), - " src[1]=", srcDescs[1]->serializeFormat(), - " dst=", dstDescs[0]->serializeFormat()); + " src[0]=", + srcDescs[0]->serializeFormat(), + " src[1]=", + srcDescs[1]->serializeFormat(), + " dst=", + dstDescs[0]->serializeFormat()); return false; } - if (deconvAttrs.withBiasesParam && - !(srcDescs[2]->hasLayoutType(LayoutType::ncsp)) && + if (deconvAttrs.withBiasesParam && !(srcDescs[2]->hasLayoutType(LayoutType::ncsp)) && !(srcDescs[2]->hasLayoutType(LayoutType::nspc))) { DEBUG_LOG("AclDeconvExecutor does not support layouts:", - " src[0]=", srcDescs[0]->serializeFormat(), - " src[1]=", srcDescs[1]->serializeFormat(), - " src[2]=", srcDescs[2]->serializeFormat(), - " dst=", dstDescs[0]->serializeFormat()); + " src[0]=", + srcDescs[0]->serializeFormat(), + " src[1]=", + srcDescs[1]->serializeFormat(), + " src[2]=", + srcDescs[2]->serializeFormat(), + " dst=", + dstDescs[0]->serializeFormat()); return false; } ACLDeconvTensorInfo aclDeconvTensorInfo = getACLDeconvTensorInfo(deconvAttrs, srcDescs, dstDescs); - auto srcTensorInfo = aclDeconvTensorInfo.srcTensorInfo; - auto weiTensorInfo = aclDeconvTensorInfo.weiTensorInfo; + auto srcTensorInfo = aclDeconvTensorInfo.srcTensorInfo; + auto weiTensorInfo = aclDeconvTensorInfo.weiTensorInfo; auto biasTensorInfo = aclDeconvTensorInfo.biasTensorInfo; - auto dstTensorInfo = aclDeconvTensorInfo.dstTensorInfo; - auto deconv_info = aclDeconvTensorInfo.deconv_info; + auto dstTensorInfo = aclDeconvTensorInfo.dstTensorInfo; + auto deconv_info = aclDeconvTensorInfo.deconv_info; - unsigned int dilation_x = (deconvAttrs.dilation.size() > 1) ? deconvAttrs.dilation.at(1) : deconvAttrs.dilation.at(0); + unsigned int dilation_x = + (deconvAttrs.dilation.size() > 1) ? deconvAttrs.dilation.at(1) : deconvAttrs.dilation.at(0); unsigned int dilation_y = deconvAttrs.dilation.at(0); - if (!one_of(dilation_x, static_cast(0), static_cast(1)) || - !one_of(dilation_y, static_cast(0), static_cast(1))) return false; + if (!one_of(dilation_x, static_cast(0), static_cast(1)) || + !one_of(dilation_y, static_cast(0), static_cast(1))) + return false; try { - arm_compute::Status status = arm_compute::NEDeconvolutionLayer::validate(&srcTensorInfo, - &weiTensorInfo, - deconvAttrs.withBiasesParam ? &biasTensorInfo : nullptr, - &dstTensorInfo, - deconv_info, - deconvAttrs.aclFastMath); + arm_compute::Status status = + arm_compute::NEDeconvolutionLayer::validate(&srcTensorInfo, + &weiTensorInfo, + deconvAttrs.withBiasesParam ? &biasTensorInfo : nullptr, + &dstTensorInfo, + deconv_info, + deconvAttrs.aclFastMath); if (!status) { DEBUG_LOG("NEDeconvolutionLayer validation failed: ", status.error_description()); return false; @@ -280,5 +304,5 @@ bool AclDeconvExecutorBuilder::customIsSupported(const DeconvAttrs &deconvAttrs, return true; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.hpp index ad743690f45c52..e27551ad48d44d 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_deconv.hpp @@ -4,11 +4,11 @@ #pragma once -#include "nodes/executors/deconv.hpp" -#include "arm_compute/runtime/NEON/NEFunctions.h" -#include "utils/debug_capabilities.h" #include "acl_utils.hpp" +#include "arm_compute/runtime/NEON/NEFunctions.h" +#include "nodes/executors/deconv.hpp" #include "src/cpu/CpuTypes.h" +#include "utils/debug_capabilities.h" namespace ov { namespace intel_cpu { @@ -22,8 +22,8 @@ struct ACLDeconvTensorInfo { }; ACLDeconvTensorInfo getACLDeconvTensorInfo(const DeconvAttrs& deconvAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs); + const std::vector& srcDescs, + const std::vector& dstDescs); class AclDeconvExecutor : public DeconvExecutor { public: @@ -31,10 +31,10 @@ class AclDeconvExecutor : public DeconvExecutor { bool init(const DeconvAttrs& deconvAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst, - const void *post_ops_data_) override; + const void* post_ops_data_) override; impl_desc_type getImplType() const override { return implType; @@ -68,5 +68,5 @@ class AclDeconvExecutorBuilder : public DeconvExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp index 942bacd91349ff..26d387c7659dc5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp @@ -3,6 +3,7 @@ // #include "acl_eltwise.hpp" + #include "acl_utils.hpp" #include "utils/debug_capabilities.h" @@ -34,40 +35,45 @@ inline void log_unsupported_prec(const std::vector& srcDescs, for (size_t i = 0; i < srcDescs.size(); i++) { srcPrec += srcDescs[i]->getPrecision().to_string() + " "; } - DEBUG_LOG(algToString(eltwiseAlgorithm), ": provided combination of src precisions: [", srcPrec, - "] and dst precision: ", dstDescs[0]->getPrecision().to_string(), " is not supported"); + DEBUG_LOG(algToString(eltwiseAlgorithm), + ": provided combination of src precisions: [", + srcPrec, + "] and dst precision: ", + dstDescs[0]->getPrecision().to_string(), + " is not supported"); } bool AclEltwiseExecutor::isEltwiseAlgorithmSupported(Algorithm algorithm) { - if (one_of(algorithm, Algorithm::EltwiseSqrt, - Algorithm::EltwiseDivide, - Algorithm::EltwiseRelu, + if (one_of(algorithm, + Algorithm::EltwiseSqrt, + Algorithm::EltwiseDivide, + Algorithm::EltwiseRelu, #ifdef OPENVINO_ARCH_ARM64 - Algorithm::EltwiseGeluErf, + Algorithm::EltwiseGeluErf, #endif - Algorithm::EltwiseElu, - Algorithm::EltwiseTanh, - Algorithm::EltwiseSigmoid, - Algorithm::EltwiseSoftRelu, - Algorithm::EltwiseClamp, - Algorithm::EltwiseSwish, - Algorithm::EltwisePrelu, - Algorithm::EltwiseHswish, - Algorithm::EltwiseAbs, - Algorithm::EltwiseExp, - Algorithm::EltwiseLog, - Algorithm::EltwiseMaximum, - Algorithm::EltwiseMinimum, - Algorithm::EltwiseSquaredDifference, - Algorithm::EltwiseAdd, - Algorithm::EltwiseSubtract, - Algorithm::EltwiseMultiply, - Algorithm::EltwiseEqual, - Algorithm::EltwiseNotEqual, - Algorithm::EltwiseGreater, - Algorithm::EltwiseGreaterEqual, - Algorithm::EltwiseLess, - Algorithm::EltwiseLessEqual)) { + Algorithm::EltwiseElu, + Algorithm::EltwiseTanh, + Algorithm::EltwiseSigmoid, + Algorithm::EltwiseSoftRelu, + Algorithm::EltwiseClamp, + Algorithm::EltwiseSwish, + Algorithm::EltwisePrelu, + Algorithm::EltwiseHswish, + Algorithm::EltwiseAbs, + Algorithm::EltwiseExp, + Algorithm::EltwiseLog, + Algorithm::EltwiseMaximum, + Algorithm::EltwiseMinimum, + Algorithm::EltwiseSquaredDifference, + Algorithm::EltwiseAdd, + Algorithm::EltwiseSubtract, + Algorithm::EltwiseMultiply, + Algorithm::EltwiseEqual, + Algorithm::EltwiseNotEqual, + Algorithm::EltwiseGreater, + Algorithm::EltwiseGreaterEqual, + Algorithm::EltwiseLess, + Algorithm::EltwiseLessEqual)) { return true; } return false; @@ -76,107 +82,111 @@ bool AclEltwiseExecutor::isEltwiseAlgorithmSupported(Algorithm algorithm) { bool AclEltwiseExecutorBuilder::isSupported(const EltwiseAttrs& eltwiseAttrs, const std::vector& srcDescs, const std::vector& dstDescs) const { - auto checkPrecision = [&srcDescs, &dstDescs](std::vector srcVecPrc, ov::element::Type dstPrc) -> bool { + auto checkPrecision = [&srcDescs, &dstDescs](std::vector srcVecPrc, + ov::element::Type dstPrc) -> bool { for (size_t i = 0; i < srcDescs.size(); i++) { - if (srcDescs[i]->getPrecision() != srcVecPrc[i]) return false; + if (srcDescs[i]->getPrecision() != srcVecPrc[i]) + return false; + } + if (dstDescs[0]->getPrecision() != dstPrc) { + return false; } - if (dstDescs[0]->getPrecision() != dstPrc) { return false; } return true; }; switch (eltwiseAttrs.algorithm) { - case Algorithm::EltwiseSqrt: - case Algorithm::EltwiseDivide: - case Algorithm::EltwiseRelu: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseDivide: + case Algorithm::EltwiseRelu: #ifdef OPENVINO_ARCH_ARM64 - case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseGeluErf: #endif - case Algorithm::EltwiseElu: - case Algorithm::EltwiseTanh: - case Algorithm::EltwiseSigmoid: - case Algorithm::EltwiseSoftRelu: - case Algorithm::EltwiseClamp: - case Algorithm::EltwiseSwish: - case Algorithm::EltwisePrelu: - case Algorithm::EltwiseHswish: - if (!(checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || - checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { - log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); - return false; - } - break; - case Algorithm::EltwiseAbs: - case Algorithm::EltwiseExp: - case Algorithm::EltwiseLog: - if (!(checkPrecision({ov::element::i32, ov::element::i32}, ov::element::i32) || - checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || - checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { - log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); - return false; - } - break; - case Algorithm::EltwiseMaximum: - case Algorithm::EltwiseMinimum: - case Algorithm::EltwiseSquaredDifference: - if (!(checkPrecision({ov::element::i16, ov::element::i16}, ov::element::i16) || - checkPrecision({ov::element::i32, ov::element::i32}, ov::element::i32) || - checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || - checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { - log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); - return false; - } - break; - case Algorithm::EltwiseAdd: - case Algorithm::EltwiseSubtract: - if (!(checkPrecision({ov::element::u8, ov::element::u8}, ov::element::u8) || - checkPrecision({ov::element::i16, ov::element::i16}, ov::element::i16) || - checkPrecision({ov::element::i32, ov::element::i32}, ov::element::i32) || - checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || - checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { - log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); - return false; - } - break; - case Algorithm::EltwiseMultiply: - if (!(checkPrecision({ov::element::u8, ov::element::u8}, ov::element::u8) || - checkPrecision({ov::element::u8, ov::element::u8}, ov::element::i16) || - checkPrecision({ov::element::u8, ov::element::i16}, ov::element::i16) || - checkPrecision({ov::element::i16, ov::element::u8}, ov::element::i16) || - checkPrecision({ov::element::i16, ov::element::i16}, ov::element::i16) || - checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || - checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { - log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); - return false; - } - break; - // ACL supports only U8 precision on output for comparison operations - case Algorithm::EltwiseEqual: - case Algorithm::EltwiseNotEqual: - case Algorithm::EltwiseGreater: - case Algorithm::EltwiseGreaterEqual: - case Algorithm::EltwiseLess: - case Algorithm::EltwiseLessEqual: - if (!(checkPrecision({ov::element::u8, ov::element::u8}, ov::element::u8) || - checkPrecision({ov::element::i16, ov::element::i16}, ov::element::u8) || - checkPrecision({ov::element::i32, ov::element::i32}, ov::element::u8) || - checkPrecision({ov::element::f16, ov::element::f16}, ov::element::u8) || - checkPrecision({ov::element::f32, ov::element::f32}, ov::element::u8))) { - log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); - return false; - } - break; - default: - DEBUG_LOG("Eltwise algorithm ", algToString(eltwiseAttrs.algorithm), " is not supported"); + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: + case Algorithm::EltwisePrelu: + case Algorithm::EltwiseHswish: + if (!(checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || + checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { + log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); + return false; + } + break; + case Algorithm::EltwiseAbs: + case Algorithm::EltwiseExp: + case Algorithm::EltwiseLog: + if (!(checkPrecision({ov::element::i32, ov::element::i32}, ov::element::i32) || + checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || + checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { + log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); + return false; + } + break; + case Algorithm::EltwiseMaximum: + case Algorithm::EltwiseMinimum: + case Algorithm::EltwiseSquaredDifference: + if (!(checkPrecision({ov::element::i16, ov::element::i16}, ov::element::i16) || + checkPrecision({ov::element::i32, ov::element::i32}, ov::element::i32) || + checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || + checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { + log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); + return false; + } + break; + case Algorithm::EltwiseAdd: + case Algorithm::EltwiseSubtract: + if (!(checkPrecision({ov::element::u8, ov::element::u8}, ov::element::u8) || + checkPrecision({ov::element::i16, ov::element::i16}, ov::element::i16) || + checkPrecision({ov::element::i32, ov::element::i32}, ov::element::i32) || + checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || + checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { + log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); return false; + } + break; + case Algorithm::EltwiseMultiply: + if (!(checkPrecision({ov::element::u8, ov::element::u8}, ov::element::u8) || + checkPrecision({ov::element::u8, ov::element::u8}, ov::element::i16) || + checkPrecision({ov::element::u8, ov::element::i16}, ov::element::i16) || + checkPrecision({ov::element::i16, ov::element::u8}, ov::element::i16) || + checkPrecision({ov::element::i16, ov::element::i16}, ov::element::i16) || + checkPrecision({ov::element::f16, ov::element::f16}, ov::element::f16) || + checkPrecision({ov::element::f32, ov::element::f32}, ov::element::f32))) { + log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); + return false; + } + break; + // ACL supports only U8 precision on output for comparison operations + case Algorithm::EltwiseEqual: + case Algorithm::EltwiseNotEqual: + case Algorithm::EltwiseGreater: + case Algorithm::EltwiseGreaterEqual: + case Algorithm::EltwiseLess: + case Algorithm::EltwiseLessEqual: + if (!(checkPrecision({ov::element::u8, ov::element::u8}, ov::element::u8) || + checkPrecision({ov::element::i16, ov::element::i16}, ov::element::u8) || + checkPrecision({ov::element::i32, ov::element::i32}, ov::element::u8) || + checkPrecision({ov::element::f16, ov::element::f16}, ov::element::u8) || + checkPrecision({ov::element::f32, ov::element::f32}, ov::element::u8))) { + log_unsupported_prec(srcDescs, dstDescs, eltwiseAttrs.algorithm); + return false; + } + break; + default: + DEBUG_LOG("Eltwise algorithm ", algToString(eltwiseAttrs.algorithm), " is not supported"); + return false; } - for (const auto & srcDesc : srcDescs) { + for (const auto& srcDesc : srcDescs) { if (getAclDataLayoutByMemoryDesc(srcDesc) == arm_compute::DataLayout::UNKNOWN) { DEBUG_LOG("src descriptor layout is unsupported by ACL: ", srcDesc->serializeFormat()); return false; } } - for (const auto & dstDesc : dstDescs) { + for (const auto& dstDesc : dstDescs) { if (getAclDataLayoutByMemoryDesc(dstDesc) == arm_compute::DataLayout::UNKNOWN) { DEBUG_LOG("dst descriptor layout is unsupported by ACL: ", dstDesc->serializeFormat()); return false; @@ -188,10 +198,13 @@ bool AclEltwiseExecutorBuilder::isSupported(const EltwiseAttrs& eltwiseAttrs, AclEltwiseExecutor::AclEltwiseExecutor(const ExecutorContext::CPtr context) : EltwiseExecutor(context) {} -bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vector &srcDescs, - const std::vector &dstDescs, - const std::vector &postOps) { - if (!postOps.empty()) { return false; } +bool AclEltwiseExecutor::init(const EltwiseAttrs& eltwiseAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const std::vector& postOps) { + if (!postOps.empty()) { + return false; + } aclEltwiseAttrs = eltwiseAttrs; std::vector srcVecDims(srcDescs.size()), dstVecDims(dstDescs.size()); @@ -209,15 +222,19 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto for (size_t i = 0; i < srcDescs.size(); i++) { srcDataLayout[i] = getAclDataLayoutByMemoryDesc(srcDescs[i]); - if (srcDataLayout[i] == arm_compute::DataLayout::UNKNOWN) { return false; } + if (srcDataLayout[i] == arm_compute::DataLayout::UNKNOWN) { + return false; + } } for (size_t i = 0; i < dstDescs.size(); i++) { dstDataLayout[i] = getAclDataLayoutByMemoryDesc(dstDescs[i]); - if (dstDataLayout[i] == arm_compute::DataLayout::UNKNOWN) { return false; } + if (dstDataLayout[i] == arm_compute::DataLayout::UNKNOWN) { + return false; + } } - if (srcDescs.size() == 2 && - srcDescs[0]->hasLayoutType(LayoutType::nspc) && srcDescs[1]->hasLayoutType(LayoutType::nspc) && + if (srcDescs.size() == 2 && srcDescs[0]->hasLayoutType(LayoutType::nspc) && + srcDescs[1]->hasLayoutType(LayoutType::nspc) && srcDescs[0]->getShape().getDims() != srcDescs[1]->getShape().getDims()) { if (srcVecDims[0].num_dimensions() < 5) { srcDataLayout[0] = srcDataLayout[1] = dstDataLayout[0] = DataLayout::NCHW; @@ -228,210 +245,248 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto } for (size_t i = 0; i < srcVecDims.size(); i++) { - srcTensorsInfo[i] = TensorInfo(srcVecDims[i], 1, - precisionToAclDataType(srcDescs[i]->getPrecision()), - srcDataLayout[i]); + srcTensorsInfo[i] = + TensorInfo(srcVecDims[i], 1, precisionToAclDataType(srcDescs[i]->getPrecision()), srcDataLayout[i]); srcTensors[i].allocator()->init(srcTensorsInfo[i]); } for (size_t i = 0; i < dstVecDims.size(); i++) { - dstTensorsInfo[i] = TensorInfo(dstVecDims[i], 1, - precisionToAclDataType(dstDescs[i]->getPrecision()), - dstDataLayout[i]); + dstTensorsInfo[i] = + TensorInfo(dstVecDims[i], 1, precisionToAclDataType(dstDescs[i]->getPrecision()), dstDataLayout[i]); dstTensors[i].allocator()->init(dstTensorsInfo[i]); } std::function(void)> exec_func; switch (aclEltwiseAttrs.algorithm) { - case Algorithm::EltwiseAdd: - if (!NEArithmeticAddition::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ConvertPolicy::SATURATE)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ConvertPolicy::SATURATE); - return acl_op; - }; - break; - case Algorithm::EltwiseMultiply: - if (!NEPixelWiseMultiplication::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], - 1.0f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], 1.0f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); - return acl_op; - }; - break; - case Algorithm::EltwiseSubtract: - if (!NEArithmeticSubtraction::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ConvertPolicy::SATURATE)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ConvertPolicy::SATURATE); - return acl_op; - }; - break; - case Algorithm::EltwiseDivide: - if (!NEElementwiseDivision::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwiseMaximum: - if (!NEElementwiseMax::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwiseMinimum: - if (!NEElementwiseMin::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwiseSquaredDifference: - if (!NEElementwiseSquaredDiff::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwiseEqual: - if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ComparisonOperation::Equal)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::Equal); - return acl_op; - }; - break; - case Algorithm::EltwiseNotEqual: - if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ComparisonOperation::NotEqual)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::NotEqual); - return acl_op; - }; - break; - case Algorithm::EltwiseGreater: - if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ComparisonOperation::Greater)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::Greater); - return acl_op; - }; - break; - case Algorithm::EltwiseGreaterEqual: - if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ComparisonOperation::GreaterEqual)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::GreaterEqual); - return acl_op; - }; - break; - case Algorithm::EltwiseLess: - if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ComparisonOperation::Less)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::Less); - return acl_op; - }; - break; - case Algorithm::EltwiseLessEqual: - if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0], ComparisonOperation::LessEqual)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::LessEqual); - return acl_op; - }; - break; - case Algorithm::EltwiseAbs: - if (!NEAbsLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwiseExp: - if (!NEExpLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwisePrelu: - if (!NEPReluLayer::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); - return acl_op; - }; - break; - case Algorithm::EltwiseRelu: - case Algorithm::EltwiseGeluErf: - case Algorithm::EltwiseElu: - case Algorithm::EltwiseTanh: - case Algorithm::EltwiseSigmoid: - case Algorithm::EltwiseSqrt: - case Algorithm::EltwiseSoftRelu: - case Algorithm::EltwiseClamp: - case Algorithm::EltwiseSwish: - case Algorithm::EltwiseHswish: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], getActivationLayerInfo(aclEltwiseAttrs.algorithm, - aclEltwiseAttrs.alpha, - aclEltwiseAttrs.beta, - aclEltwiseAttrs.gamma))) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], getActivationLayerInfo(aclEltwiseAttrs.algorithm, - aclEltwiseAttrs.alpha, - aclEltwiseAttrs.beta, - aclEltwiseAttrs.gamma)); - return acl_op; - }; - break; - case Algorithm::EltwiseLog: - if (!NELogLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0]); - return acl_op; - }; - break; - default: - OPENVINO_THROW("Unsupported operation type for ACL Eltwise executor: ", - static_cast(aclEltwiseAttrs.algorithm)); + case Algorithm::EltwiseAdd: + if (!NEArithmeticAddition::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ConvertPolicy::SATURATE)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ConvertPolicy::SATURATE); + return acl_op; + }; + break; + case Algorithm::EltwiseMultiply: + if (!NEPixelWiseMultiplication::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + 1.0f, + ConvertPolicy::SATURATE, + RoundingPolicy::TO_ZERO)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], + &srcTensors[1], + &dstTensors[0], + 1.0f, + ConvertPolicy::SATURATE, + RoundingPolicy::TO_ZERO); + return acl_op; + }; + break; + case Algorithm::EltwiseSubtract: + if (!NEArithmeticSubtraction::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ConvertPolicy::SATURATE)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ConvertPolicy::SATURATE); + return acl_op; + }; + break; + case Algorithm::EltwiseDivide: + if (!NEElementwiseDivision::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwiseMaximum: + if (!NEElementwiseMax::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwiseMinimum: + if (!NEElementwiseMin::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwiseSquaredDifference: + if (!NEElementwiseSquaredDiff::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwiseEqual: + if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ComparisonOperation::Equal)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::Equal); + return acl_op; + }; + break; + case Algorithm::EltwiseNotEqual: + if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ComparisonOperation::NotEqual)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::NotEqual); + return acl_op; + }; + break; + case Algorithm::EltwiseGreater: + if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ComparisonOperation::Greater)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::Greater); + return acl_op; + }; + break; + case Algorithm::EltwiseGreaterEqual: + if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ComparisonOperation::GreaterEqual)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::GreaterEqual); + return acl_op; + }; + break; + case Algorithm::EltwiseLess: + if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ComparisonOperation::Less)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::Less); + return acl_op; + }; + break; + case Algorithm::EltwiseLessEqual: + if (!NEElementwiseComparison::validate(&srcTensorsInfo[0], + &srcTensorsInfo[1], + &dstTensorsInfo[0], + ComparisonOperation::LessEqual)) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0], ComparisonOperation::LessEqual); + return acl_op; + }; + break; + case Algorithm::EltwiseAbs: + if (!NEAbsLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwiseExp: + if (!NEExpLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwisePrelu: + if (!NEPReluLayer::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &srcTensors[1], &dstTensors[0]); + return acl_op; + }; + break; + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: + case Algorithm::EltwiseHswish: + if (!NEActivationLayer::validate(&srcTensorsInfo[0], + &dstTensorsInfo[0], + getActivationLayerInfo(aclEltwiseAttrs.algorithm, + aclEltwiseAttrs.alpha, + aclEltwiseAttrs.beta, + aclEltwiseAttrs.gamma))) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], + &dstTensors[0], + getActivationLayerInfo(aclEltwiseAttrs.algorithm, + aclEltwiseAttrs.alpha, + aclEltwiseAttrs.beta, + aclEltwiseAttrs.gamma)); + return acl_op; + }; + break; + case Algorithm::EltwiseLog: + if (!NELogLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) + return false; + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensors[0], &dstTensors[0]); + return acl_op; + }; + break; + default: + OPENVINO_THROW("Unsupported operation type for ACL Eltwise executor: ", + static_cast(aclEltwiseAttrs.algorithm)); } - configureThreadSafe([&] { ifunc = exec_func(); }); + configureThreadSafe([&] { + ifunc = exec_func(); + }); return true; } -void AclEltwiseExecutor::exec(const std::vector &src, const std::vector &dst, - const void *post_ops_data_) { +void AclEltwiseExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) { for (size_t i = 0; i < src.size(); i++) { srcTensors[i].allocator()->import_memory(src[i]->getData()); } @@ -448,5 +503,5 @@ void AclEltwiseExecutor::exec(const std::vector &src, const std::vec dstTensors[i].allocator()->free(); } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp index 6daf9e606c461b..1aae396f25a0fe 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp @@ -5,8 +5,8 @@ #pragma once #include "../eltwise.hpp" -#include "arm_compute/runtime/NEON/NEFunctions.h" #include "acl_utils.hpp" +#include "arm_compute/runtime/NEON/NEFunctions.h" namespace ov { namespace intel_cpu { @@ -23,11 +23,12 @@ class AclEltwiseExecutor : public EltwiseExecutor { void exec(const std::vector& src, const std::vector& dst, - const void *post_ops_data_) override; + const void* post_ops_data_) override; impl_desc_type getImplType() const override { return implType; } + private: EltwiseAttrs aclEltwiseAttrs{}; impl_desc_type implType = impl_desc_type::acl; @@ -46,5 +47,5 @@ class AclEltwiseExecutorBuilder : public EltwiseExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp index 9660178e1af4a4..74bdb97cdf2a8c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp @@ -2,23 +2,24 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include "acl_fullyconnected.hpp" + #include +#include -#include "ov_optional.hpp" -#include "acl_fullyconnected.hpp" #include "acl_utils.hpp" -#include "nodes/executors/executor.hpp" -#include "nodes/executors/memory_arguments.hpp" -#include "utils/debug_capabilities.h" -#include "utils/cpu_utils.hpp" -#include "nodes/executors/debug_messages.hpp" -#include "nodes/executors/implementation_utils.hpp" -#include "nodes/convert.h" +#include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/common/cpu_convert.h" #include "nodes/common/cpu_memcpy.h" #include "nodes/common/reorder_prim.h" -#include "memory_desc/cpu_memory_desc_utils.h" +#include "nodes/convert.h" +#include "nodes/executors/debug_messages.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/implementation_utils.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "ov_optional.hpp" +#include "utils/cpu_utils.hpp" +#include "utils/debug_capabilities.h" namespace ov { namespace intel_cpu { @@ -57,13 +58,16 @@ static VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDim static DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc, const DnnlMemoryDescPtr dstDesc) { const auto& weiDesc = srcDesc->getDnnlDesc(); - const auto reorderedWeiDesc = dnnl::memory::desc{weiDesc.get_dims(), weiDesc.get_data_type(), dnnl::memory::format_tag::ba}; + const auto reorderedWeiDesc = + dnnl::memory::desc{weiDesc.get_dims(), weiDesc.get_data_type(), dnnl::memory::format_tag::ba}; const auto transposedWeiDesc = reorderedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims()); return DnnlExtensionUtils::makeDescriptor(transposedWeiDesc); } -static ov::optional convertWeightPrecision(MemoryPtr input, MemoryPtr output, ov::element::Type weightPrecision) { +static ov::optional convertWeightPrecision(MemoryPtr input, + MemoryPtr output, + ov::element::Type weightPrecision) { MemoryArgs memoryArgs; memoryArgs[ARG_SRC] = input; memoryArgs[ARG_DST] = output; @@ -74,16 +78,18 @@ static ov::optional convertWeightPrecision(MemoryPtr input, MemoryPtr return ov::optional(memoryArgs.at(ARG_DST)); } - if (!node::Convert::isSupportedDesc(input->getDesc()) || - !node::Convert::isSupportedDesc(output->getDesc())) { + if (!node::Convert::isSupportedDesc(input->getDesc()) || !node::Convert::isSupportedDesc(output->getDesc())) { return {}; } - auto data = static_cast(input->getData()); + auto data = static_cast(input->getData()); std::vector tmpBuff; tmpBuff.resize(output->getSize()); - cpu_convert(data, tmpBuff.data(), DnnlExtensionUtils::DataTypeToElementType(input->getDataType()), - weightPrecision, input->getSize() / input->getDesc().getPrecision().size()); + cpu_convert(data, + tmpBuff.data(), + DnnlExtensionUtils::DataTypeToElementType(input->getDataType()), + weightPrecision, + input->getSize() / input->getDesc().getPrecision().size()); return ov::optional(std::make_shared(output->getPrimitive().get_engine(), output->getDesc().cloneWithNewPrecision(weightPrecision), @@ -96,12 +102,14 @@ static ov::optional reorderDataFallback(MemoryPtr input, MemoryPtr ou } const auto inPrc = DnnlExtensionUtils::DataTypeToElementType(input->getDataType()); auto convertedDstMemoryDesc = output->getDesc().cloneWithNewPrecision(inPrc); - dnnl::reorder reorderWithoutConvert = getReorderPrim(context->getRuntimeCache(), - output->getPrimitive().get_engine(), - input->getPrimitive().get_desc(), - MemoryDescUtils::convertToDnnlMemoryDesc(convertedDstMemoryDesc)->getDnnlDesc()); - - if (reorderWithoutConvert && parse_impl_name(reorderWithoutConvert.get_primitive_desc()->impl()->name()) != ref_any) { + dnnl::reorder reorderWithoutConvert = + getReorderPrim(context->getRuntimeCache(), + output->getPrimitive().get_engine(), + input->getPrimitive().get_desc(), + MemoryDescUtils::convertToDnnlMemoryDesc(convertedDstMemoryDesc)->getDnnlDesc()); + + if (reorderWithoutConvert && + parse_impl_name(reorderWithoutConvert.get_primitive_desc()->impl()->name()) != ref_any) { auto convertOutput = convertWeightPrecision(input, output, inPrc); if (!convertOutput) { return {}; @@ -110,7 +118,9 @@ static ov::optional reorderDataFallback(MemoryPtr input, MemoryPtr ou if (reorderWithoutConvert) { dnnl::stream loc_stream(output->getPrimitive().get_engine(), dnnl::stream::flags::in_order); - reorderWithoutConvert.execute(loc_stream, {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); + reorderWithoutConvert.execute( + loc_stream, + {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); return ov::optional(output); } } @@ -155,14 +165,15 @@ static MemoryPtr reorderData(DnnlMemoryDescPtr srcWeightDesc, // if precision conversion does not work then do direct reference reorder if (directReorder) { dnnl::stream loc_stream(engine, dnnl::stream::flags::in_order); - directReorder.execute(loc_stream, {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); + directReorder.execute(loc_stream, + {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); } else { OPENVINO_THROW("Could not make onednn reorder."); } return output; } -static MemoryPtr reorderWeights(const MemoryArgs &memory, +static MemoryPtr reorderWeights(const MemoryArgs& memory, const ExecutorContext::CPtr context, ACLFCAttrs& aclfcAttrs, DnnlMemoryDescPtr dnnlSrcDesc, @@ -192,16 +203,16 @@ static MemoryPtr reorderWeights(const MemoryArgs &memory, return create(); } -static MemoryPtr prepareWeightMemory(const MemoryArgs &memory, +static MemoryPtr prepareWeightMemory(const MemoryArgs& memory, const ExecutorContext::CPtr context, - const FCAttrs &attrs, + const FCAttrs& attrs, ACLFCAttrs& aclfcAttrs, - const PostOps &postOps, + const PostOps& postOps, arm_compute::WeightFormat& expectedWeightFormat, arm_compute::TensorInfo& weiTensorInfo) { MemoryArgs memoryArgs; - memoryArgs[ARG_BIAS] = memory.at(ARG_BIAS); - memoryArgs[ARG_WEI] = memory.at(ARG_WEI); + memoryArgs[ARG_BIAS] = memory.at(ARG_BIAS); + memoryArgs[ARG_WEI] = memory.at(ARG_WEI); auto originalWeightsDesc = memory.at(ARG_WEI)->getDescPtr(); @@ -219,21 +230,25 @@ static MemoryPtr prepareWeightMemory(const MemoryArgs &memory, const auto& inShape = memory.at(ARG_SRC_0)->getShape(); const auto& wShape = originalWeightsDesc->getShape(); const auto& inDymmyDims = makeDummyInputDims(inShape, wShape); - const auto& outDymmyDims = makeDummyOutputDims(inDymmyDims, wShape.getStaticDims(), memory.at(ARG_DST)->getShape().getRank()); - memoryArgs[ARG_SRC_0] = std::make_shared(context->getEngine(), - memory.at(ARG_SRC_0)->getDescPtr()->cloneWithNewDims(inDymmyDims)); - memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), - memory.at(ARG_DST)->getDescPtr()->cloneWithNewDims(outDymmyDims)); + const auto& outDymmyDims = + makeDummyOutputDims(inDymmyDims, wShape.getStaticDims(), memory.at(ARG_DST)->getShape().getRank()); + memoryArgs[ARG_SRC_0] = + std::make_shared(context->getEngine(), + memory.at(ARG_SRC_0)->getDescPtr()->cloneWithNewDims(inDymmyDims)); + memoryArgs[ARG_DST] = + std::make_shared(context->getEngine(), + memory.at(ARG_DST)->getDescPtr()->cloneWithNewDims(outDymmyDims)); } else { memoryArgs[ARG_SRC_0] = memory.at(ARG_SRC_0); - memoryArgs[ARG_DST] = memory.at(ARG_DST); + memoryArgs[ARG_DST] = memory.at(ARG_DST); } // TODO: ACLWeightFormatGenerator should be replaced with Reorder executor // that calls ACL NEReorder + NETranspose or dnnl::reorder depending on backend availability auto aclWeightsRepack = std::make_shared(attrs, postOps, memoryArgs); bool isNeededReorder = aclWeightsRepack->update(memoryArgs); - expectedWeightFormat = isNeededReorder ? aclWeightsRepack->getOptImplWeightFormat() : arm_compute::WeightFormat::UNSPECIFIED; + expectedWeightFormat = + isNeededReorder ? aclWeightsRepack->getOptImplWeightFormat() : arm_compute::WeightFormat::UNSPECIFIED; weiTensorInfo = aclWeightsRepack->getTensorInfo(ACLArgs::ACL_WEI); if (isNeededReorder) { @@ -241,8 +256,13 @@ static MemoryPtr prepareWeightMemory(const MemoryArgs &memory, dnnl::impl::dim_t inner_dim = 1; std::vector remaining_dims = {}; auto weights_md_ = dnnlDstDesc->getDnnlDesc().get(); - dnnl::impl::cpu::acl::acl_utils::reorder_to_weight_format(weiTensorInfo, *weights_md_, expectedWeightFormat, - inner_dim, o_dim, remaining_dims, {}); + dnnl::impl::cpu::acl::acl_utils::reorder_to_weight_format(weiTensorInfo, + *weights_md_, + expectedWeightFormat, + inner_dim, + o_dim, + remaining_dims, + {}); if (aclfcAttrs.weightsNonTransposed) { dnnlSrcDesc = makeTransposedWeightDescriptor(dnnlSrcDesc, dnnlDstDesc); } @@ -256,7 +276,7 @@ static MemoryPtr prepareWeightMemory(const MemoryArgs &memory, return reorderWeights(memory, context, aclfcAttrs, dnnlSrcDesc, dnnlDstDesc); } -static bool checkPostOps(const PostOps &postOps) { +static bool checkPostOps(const PostOps& postOps) { if (postOps.empty()) { return true; } @@ -271,12 +291,12 @@ static bool checkPostOps(const PostOps &postOps) { return false; } -static void initFCAttrs(const FCAttrs &attrs, +static void initFCAttrs(const FCAttrs& attrs, ACLTensorAttrs& aclTensorAttrs, ACLFCAttrs& aclfcAttrs, - const MemoryArgs &memory, + const MemoryArgs& memory, arm_compute::FullyConnectedLayerInfo& fullyConnectedLayerInfo, - const PostOps &postOps) { + const PostOps& postOps) { aclTensorAttrs.hasLayoutTypeNHWC = memory.at(ARG_SRC)->getDescPtr()->hasLayoutType(LayoutType::nspc); fullyConnectedLayerInfo.weights_trained_layout = getAclDataLayoutByMemoryDesc(memory.at(ARG_WEI)->getDescPtr()); aclfcAttrs.inputPrecision = memory.at(ARG_SRC)->getDescPtr()->getPrecision(); @@ -285,9 +305,10 @@ static void initFCAttrs(const FCAttrs &attrs, if (!postOps.empty() && checkPostOps(postOps)) { auto activation = std::dynamic_pointer_cast(postOps[0]); - fullyConnectedLayerInfo.activation_info = getActivationLayerInfo( - convertToEltwiseAlgorithm(activation->type()), - activation->alpha(), activation->beta(), activation->gamma()); + fullyConnectedLayerInfo.activation_info = getActivationLayerInfo(convertToEltwiseAlgorithm(activation->type()), + activation->alpha(), + activation->beta(), + activation->gamma()); } if (memory.at(ARG_SRC)->getPrecision() != memory.at(ARG_WEI)->getPrecision()) { @@ -295,21 +316,22 @@ static void initFCAttrs(const FCAttrs &attrs, } } -ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, - const PostOps &postOps, - const MemoryArgs &memory, +ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, const ExecutorContext::CPtr context) { initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); - packedWeights = prepareWeightMemory(memory, context, attrs, aclfcAttrs, postOps, expectedWeightFormat, weiTensorInfo); + packedWeights = + prepareWeightMemory(memory, context, attrs, aclfcAttrs, postOps, expectedWeightFormat, weiTensorInfo); } -bool ACLFullyConnectedExecutor::supports(const FCConfig &config) { +bool ACLFullyConnectedExecutor::supports(const FCConfig& config) { VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_SRC_PRECISIONS); VERIFY(one_of(weiType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_WEI_PRECISIONS); - VERIFY(postOpsNumbers(config) < 2, UNSUPPORTED_NUMBER_OF_POSTOPS); - VERIFY(checkPostOps(config.postOps), UNSUPPORTED_TYPE_OF_POSTOPS); - VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK); - VERIFY(one_of(weiRank(config), 2U, 3U), UNSUPPORTED_WEI_RANK); + VERIFY(postOpsNumbers(config) < 2, UNSUPPORTED_NUMBER_OF_POSTOPS); + VERIFY(checkPostOps(config.postOps), UNSUPPORTED_TYPE_OF_POSTOPS); + VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK); + VERIFY(one_of(weiRank(config), 2U, 3U), UNSUPPORTED_WEI_RANK); return true; } @@ -329,35 +351,34 @@ void ACLFullyConnectedExecutor::updateTensorsShapes(ACLShapes& aclMemoryShapes) updateFCTensorsShapes(aclMemoryShapes); } -arm_compute::Status ACLFullyConnectedExecutor::validateTensorsInfo(const ACLInfos & aclMemoryInfos) { +arm_compute::Status ACLFullyConnectedExecutor::validateTensorsInfo(const ACLInfos& aclMemoryInfos) { if (aclfcAttrs.isConvertedWeights) { aclMemoryInfos[ACLArgs::ACL_WEI]->set_data_type(aclMemoryInfos[ACLArgs::ACL_SRC_0]->data_type()); } int ic_total = aclMemoryInfos[ACLArgs::ACL_SRC_0]->dimension(0); return arm_compute::NEFullyConnectedLayer::validate( - aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), - &weiTensorInfo, - aclMemoryInfos[ACLArgs::ACL_BIAS].get(), - aclMemoryInfos[ACLArgs::ACL_DST].get(), - fullyConnectedLayerInfo, - expectedWeightFormat == arm_compute::WeightFormat::UNSPECIFIED ? - arm_compute::WeightsInfo() : - arm_compute::WeightsInfo(false, 1, 1, ic_total, false, expectedWeightFormat)); + aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + &weiTensorInfo, + aclMemoryInfos[ACLArgs::ACL_BIAS].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + expectedWeightFormat == arm_compute::WeightFormat::UNSPECIFIED + ? arm_compute::WeightsInfo() + : arm_compute::WeightsInfo(false, 1, 1, ic_total, false, expectedWeightFormat)); } -ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLTensors & aclMemoryTensors) { +ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLTensors& aclMemoryTensors) { auto neFC = std::make_unique(); aclMemoryTensors[ACLArgs::ACL_WEI]->allocator()->init(weiTensorInfo); int icTotal = aclMemoryTensors[ACLArgs::ACL_WEI]->info()->dimension(0); - neFC->configure( - aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), - aclMemoryTensors[ACLArgs::ACL_WEI].get(), - aclMemoryTensors[ACLArgs::ACL_BIAS].get(), - aclMemoryTensors[ACLArgs::ACL_DST].get(), - fullyConnectedLayerInfo, - expectedWeightFormat == arm_compute::WeightFormat::UNSPECIFIED ? - arm_compute::WeightsInfo() : - arm_compute::WeightsInfo(false, 1, 1, icTotal, false, expectedWeightFormat)); + neFC->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_WEI].get(), + aclMemoryTensors[ACLArgs::ACL_BIAS].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + expectedWeightFormat == arm_compute::WeightFormat::UNSPECIFIED + ? arm_compute::WeightsInfo() + : arm_compute::WeightsInfo(false, 1, 1, icTotal, false, expectedWeightFormat)); // TODO: get rid of those flags and decide whether to import memory or not just based on input type if (aclfcAttrs.isWeightsRepacked || aclfcAttrs.isConvertedWeights) { aclTensorAttrs.memoryUsageIndicator[ACLArgs::ACL_WEI] = false; @@ -366,13 +387,13 @@ ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLTensors & aclM return neFC; } -arm_compute::Status acl_fc_executor::ACLWeightsConverter::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { +arm_compute::Status acl_fc_executor::ACLWeightsConverter::validateTensorsInfo(const ACLInfos& aclMemoryInfos) { return arm_compute::NECast::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), aclMemoryInfos[ACLArgs::ACL_DST].get(), arm_compute::ConvertPolicy::SATURATE); } -ACLFunction acl_fc_executor::ACLWeightsConverter::configureFunction(const ACLTensors &aclMemoryTensors) { +ACLFunction acl_fc_executor::ACLWeightsConverter::configureFunction(const ACLTensors& aclMemoryTensors) { auto neCast = std::make_unique(); neCast->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), aclMemoryTensors[ACLArgs::ACL_DST].get(), @@ -380,34 +401,34 @@ ACLFunction acl_fc_executor::ACLWeightsConverter::configureFunction(const ACLTen return neCast; } -acl_fc_executor::ACLWeightFormatGenerator::ACLWeightFormatGenerator(const FCAttrs &attrs, - const PostOps &postOps, - const MemoryArgs &memory) { +acl_fc_executor::ACLWeightFormatGenerator::ACLWeightFormatGenerator(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory) { initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); } -void acl_fc_executor::ACLWeightFormatGenerator::updateTensorsShapes(ACLShapes &aclMemoryShapes) { +void acl_fc_executor::ACLWeightFormatGenerator::updateTensorsShapes(ACLShapes& aclMemoryShapes) { updateFCTensorsShapes(aclMemoryShapes); } -arm_compute::Status acl_fc_executor::ACLWeightFormatGenerator::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { +arm_compute::Status acl_fc_executor::ACLWeightFormatGenerator::validateTensorsInfo(const ACLInfos& aclMemoryInfos) { if (aclfcAttrs.isConvertedWeights) { aclMemoryInfos[ACLArgs::ACL_WEI]->set_data_type(aclMemoryInfos[ACLArgs::ACL_SRC_0]->data_type()); } int icTotal = aclMemoryInfos[ACLArgs::ACL_SRC_0]->dimension(0); return arm_compute::NEFullyConnectedLayer::has_opt_impl( - expectedWeightFormat, - aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), - aclMemoryInfos[ACLArgs::ACL_WEI].get(), - aclMemoryInfos[ACLArgs::ACL_BIAS].get(), - aclMemoryInfos[ACLArgs::ACL_DST].get(), - fullyConnectedLayerInfo, - arm_compute::WeightsInfo(false, 1, 1, icTotal, false, arm_compute::WeightFormat::ANY)); + expectedWeightFormat, + aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_WEI].get(), + aclMemoryInfos[ACLArgs::ACL_BIAS].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + arm_compute::WeightsInfo(false, 1, 1, icTotal, false, arm_compute::WeightFormat::ANY)); } -ACLFunction acl_fc_executor::ACLWeightFormatGenerator::configureFunction(const ACLTensors &aclMemoryTensors) { +ACLFunction acl_fc_executor::ACLWeightFormatGenerator::configureFunction(const ACLTensors& aclMemoryTensors) { return std::make_unique(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp index fcbcb1475efa15..afeb4a5ce45c95 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp @@ -23,21 +23,20 @@ class ACLWeightsConverter : public ACLCommonExecutor { public: ACLWeightsConverter() = default; void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} - arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; - ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + arm_compute::Status validateTensorsInfo(const ACLInfos& aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors& aclMemoryTensors) override; }; class ACLWeightFormatGenerator : public ACLCommonExecutor { public: - ACLWeightFormatGenerator(const FCAttrs& attrs, - const PostOps& postOps, - const MemoryArgs& memory); + ACLWeightFormatGenerator(const FCAttrs& attrs, const PostOps& postOps, const MemoryArgs& memory); void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; - arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; - ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + arm_compute::Status validateTensorsInfo(const ACLInfos& aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors& aclMemoryTensors) override; arm_compute::WeightFormat getOptImplWeightFormat() { return expectedWeightFormat; } + private: arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; ACLFCAttrs aclfcAttrs; @@ -49,17 +48,17 @@ class ACLWeightFormatGenerator : public ACLCommonExecutor { class ACLFullyConnectedExecutor : public ACLCommonExecutor { public: ACLFullyConnectedExecutor(const FCAttrs& attrs, - const PostOps& postOps, - const MemoryArgs& memory, - const ExecutorContext::CPtr context); + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr context); static bool supports(const FCConfig& config); void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; - arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + arm_compute::Status validateTensorsInfo(const ACLInfos& aclMemoryInfos) override; - ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + ACLFunction configureFunction(const ACLTensors& aclMemoryTensors) override; private: arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.cpp index b6fa129974107f..2af982024b5637 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.cpp @@ -22,23 +22,23 @@ unsigned int ACLScheduler::num_threads() const { void ACLScheduler::set_num_threads(unsigned int num_threads) {} -void ACLScheduler::schedule_custom(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors) { - const Window & max_window = window; +void ACLScheduler::schedule_custom(ICPPKernel* kernel, const Hints& hints, const Window& window, ITensorPack& tensors) { + const Window& max_window = window; const unsigned int num_iterations = max_window.num_iterations(hints.split_dimension()); #if OV_THREAD == OV_THREAD_OMP - //In OpenMP case parallel_get_num_threads() method returns 1 here because it's called outside parallel section - //This is the reason why this method isn't used to initialize _num_threads + // In OpenMP case parallel_get_num_threads() method returns 1 here because it's called outside parallel section + // This is the reason why this method isn't used to initialize _num_threads const auto _num_threads = num_iterations; #else const auto _num_threads = std::min(num_iterations, static_cast(parallel_get_num_threads())); #endif - std::function main_run; + std::function main_run; if (tensors.empty()) { - main_run = [&](const Window &window, const ThreadInfo &info) { + main_run = [&](const Window& window, const ThreadInfo& info) { kernel->run(window, info); }; } else { - main_run = [&](const Window &window, const ThreadInfo &info) { + main_run = [&](const Window& window, const ThreadInfo& info) { kernel->run_op(tensors, window, info); }; } @@ -59,20 +59,20 @@ void ACLScheduler::schedule_custom(ICPPKernel *kernel, const Hints &hints, const } } -void ACLScheduler::schedule(ICPPKernel *kernel, const Hints &hints) { +void ACLScheduler::schedule(ICPPKernel* kernel, const Hints& hints) { ITensorPack tensors; schedule_custom(kernel, hints, kernel->window(), tensors); } -void ACLScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors) { +void ACLScheduler::schedule_op(ICPPKernel* kernel, const Hints& hints, const Window& window, ITensorPack& tensors) { schedule_custom(kernel, hints, window, tensors); } -void ACLScheduler::run_workloads(std::vector &workloads) { +void ACLScheduler::run_workloads(std::vector& workloads) { ov::parallel_for(workloads.size(), [&](int wid) { workloads[wid]({wid, static_cast(parallel_get_num_threads()), &cpu_info()}); }); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.hpp index 1148f4ad5edd69..c94f0aa3abce3a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_ie_scheduler.hpp @@ -4,9 +4,10 @@ #pragma once -#include #include #include +#include + #include "support/Mutex.h" namespace ov { @@ -20,12 +21,14 @@ class ACLScheduler final : public IScheduler { ~ACLScheduler() override = default; std::uint32_t num_threads() const override; void set_num_threads(unsigned int num_threads) override; - void schedule(ICPPKernel *kernel, const Hints &hints) override; - void schedule_op(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors) override; + void schedule(ICPPKernel* kernel, const Hints& hints) override; + void schedule_op(ICPPKernel* kernel, const Hints& hints, const Window& window, ITensorPack& tensors) override; + protected: - void run_workloads(std::vector &workloads) override; + void run_workloads(std::vector& workloads) override; + private: - void schedule_custom(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors); + void schedule_custom(ICPPKernel* kernel, const Hints& hints, const Window& window, ITensorPack& tensors); }; } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.cpp index 33bd49e2f04d9b..077759193d1c30 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.cpp @@ -3,13 +3,14 @@ // #include "acl_interpolate.hpp" + #include "acl_utils.hpp" #include "utils/debug_capabilities.h" -bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpolateAttrs, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) { +bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs& interpolateAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { aclInterpolateAttrs = interpolateAttrs; InterpolateExecutor::init(aclInterpolateAttrs, srcDescs, dstDescs, attr); acl_coord = arm_compute::SamplingPolicy::TOP_LEFT; @@ -17,22 +18,23 @@ bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpo static const size_t index_h = 2; static const size_t index_w = 3; - if ((aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::pytorch_half_pixel && out_shape[index_h] > 1 && out_shape[index_w] > 1) || + if ((aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::pytorch_half_pixel && + out_shape[index_h] > 1 && out_shape[index_w] > 1) || aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::half_pixel) { acl_coord = arm_compute::SamplingPolicy::CENTER; } switch (aclInterpolateAttrs.mode) { - case InterpolateMode::linear: - case InterpolateMode::linear_onnx: - acl_policy = arm_compute::InterpolationPolicy::BILINEAR; - break; - case InterpolateMode::nearest: - acl_policy = arm_compute::InterpolationPolicy::NEAREST_NEIGHBOR; - break; - default: - DEBUG_LOG("Unsupported interpolate mode: ", static_cast(aclInterpolateAttrs.mode)); - return false; + case InterpolateMode::linear: + case InterpolateMode::linear_onnx: + acl_policy = arm_compute::InterpolationPolicy::BILINEAR; + break; + case InterpolateMode::nearest: + acl_policy = arm_compute::InterpolationPolicy::NEAREST_NEIGHBOR; + break; + default: + DEBUG_LOG("Unsupported interpolate mode: ", static_cast(aclInterpolateAttrs.mode)); + return false; } auto srcDims = shapeCast(srcDescs[0]->getShape().getDims()); @@ -42,22 +44,25 @@ bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpo changeLayoutToNH_C({&srcDims, &dstDims}); } - auto srcTensorInfo = arm_compute::TensorInfo(srcDims, 1, + auto srcTensorInfo = arm_compute::TensorInfo(srcDims, + 1, precisionToAclDataType(srcDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[0])); - auto dstTensorInfo = arm_compute::TensorInfo(dstDims, 1, + auto dstTensorInfo = arm_compute::TensorInfo(dstDims, + 1, precisionToAclDataType(dstDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[0])); - arm_compute::Status status = arm_compute::NEScale::validate(&srcTensorInfo, - &dstTensorInfo, - arm_compute::ScaleKernelInfo(acl_policy, - arm_compute::BorderMode::REPLICATE, - arm_compute::PixelValue(), - acl_coord, - false, - aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::align_corners, - getAclDataLayoutByMemoryDesc(srcDescs[0]))); + arm_compute::Status status = arm_compute::NEScale::validate( + &srcTensorInfo, + &dstTensorInfo, + arm_compute::ScaleKernelInfo(acl_policy, + arm_compute::BorderMode::REPLICATE, + arm_compute::PixelValue(), + acl_coord, + false, + aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::align_corners, + getAclDataLayoutByMemoryDesc(srcDescs[0]))); if (!status) { DEBUG_LOG("NEScale validation failed: ", status.error_description()); return false; @@ -68,21 +73,25 @@ bool ov::intel_cpu::ACLInterpolateExecutor::init(const InterpolateAttrs &interpo acl_scale = std::make_unique(); configureThreadSafe([&] { - acl_scale->configure(&srcTensor, &dstTensor, arm_compute::ScaleKernelInfo(acl_policy, - arm_compute::BorderMode::REPLICATE, - arm_compute::PixelValue(), - acl_coord, - false, - aclInterpolateAttrs.coordTransMode == - InterpolateCoordTransMode::align_corners, - getAclDataLayoutByMemoryDesc(srcDescs[0]))); + acl_scale->configure( + &srcTensor, + &dstTensor, + arm_compute::ScaleKernelInfo(acl_policy, + arm_compute::BorderMode::REPLICATE, + arm_compute::PixelValue(), + acl_coord, + false, + aclInterpolateAttrs.coordTransMode == InterpolateCoordTransMode::align_corners, + getAclDataLayoutByMemoryDesc(srcDescs[0]))); }); return true; } -void ov::intel_cpu::ACLInterpolateExecutor::exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) { +void ov::intel_cpu::ACLInterpolateExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) { auto in_ptr_ = padPreprocess(src, dst); - srcTensor.allocator()->import_memory(const_cast(reinterpret_cast(in_ptr_))); + srcTensor.allocator()->import_memory(const_cast(reinterpret_cast(in_ptr_))); dstTensor.allocator()->import_memory(dst[0]->getData()); acl_scale->run(); @@ -92,8 +101,9 @@ void ov::intel_cpu::ACLInterpolateExecutor::exec(const std::vector& } bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration( - const ov::intel_cpu::InterpolateAttrs &interpolateAttrs, const std::vector &srcDescs, - const std::vector &dstDescs) { + const ov::intel_cpu::InterpolateAttrs& interpolateAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) { OPENVINO_ASSERT(srcDescs[0]->getShape().getDims().size() == 4); auto& inp_shape = srcDescs[0]->getShape().getDims(); @@ -116,7 +126,8 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration( if (coord_mode == InterpolateCoordTransMode::half_pixel && (nearest_mode == InterpolateNearestMode::simple || nearest_mode == InterpolateNearestMode::round_prefer_ceil)) { - DEBUG_LOG("InterpolateCoordTransMode half_pixel is not supported for InterpolateNearestMode simple and round_prefer_ceil"); + DEBUG_LOG("InterpolateCoordTransMode half_pixel is not supported for InterpolateNearestMode simple and " + "round_prefer_ceil"); return false; } @@ -129,15 +140,17 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration( if (is_upsample) { bool int_factor = scale_h == static_cast(scale_h) && scale_w == static_cast(scale_w); if (int_factor && coord_mode != InterpolateCoordTransMode::asymmetric && - (nearest_mode == InterpolateNearestMode::round_prefer_ceil - || nearest_mode == InterpolateNearestMode::round_prefer_floor)) { - DEBUG_LOG("upsample && int_factor && !asymmetric && (round_prefer_ceil || round_prefer_floor) case is supported"); + (nearest_mode == InterpolateNearestMode::round_prefer_ceil || + nearest_mode == InterpolateNearestMode::round_prefer_floor)) { + DEBUG_LOG( + "upsample && int_factor && !asymmetric && (round_prefer_ceil || round_prefer_floor) case is supported"); return true; } } else if (scale_h < 1 && scale_w < 1) { float down_scale_h = static_cast(inp_shape[index_h]) / out_shape[index_h]; float down_scale_w = static_cast(inp_shape[index_w]) / out_shape[index_w]; - bool int_factor = down_scale_h == static_cast(down_scale_h) && down_scale_w == static_cast(down_scale_w); + bool int_factor = + down_scale_h == static_cast(down_scale_h) && down_scale_w == static_cast(down_scale_w); if (int_factor && coord_mode != InterpolateCoordTransMode::align_corners && nearest_mode == InterpolateNearestMode::simple) { @@ -146,29 +159,45 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupportedConfiguration( } if (int_factor && nearest_mode == InterpolateNearestMode::round_prefer_ceil && - ((out_shape[index_h] > 1 && out_shape[index_w] > 1) || coord_mode != InterpolateCoordTransMode::half_pixel)) { - DEBUG_LOG("!upsample && int_factor && round_prefer_ceil && (out_shape > 1 || half_pixel) case is supported"); + ((out_shape[index_h] > 1 && out_shape[index_w] > 1) || + coord_mode != InterpolateCoordTransMode::half_pixel)) { + DEBUG_LOG( + "!upsample && int_factor && round_prefer_ceil && (out_shape > 1 || half_pixel) case is supported"); return true; } } - DEBUG_LOG("ACL Interpolate executor does not support such configuration: coord_mode=", static_cast(coord_mode), - " nearest_mode=", static_cast(nearest_mode), " upsample=", is_upsample, " scale_h=", scale_h, " scale_w=", scale_w); + DEBUG_LOG("ACL Interpolate executor does not support such configuration: coord_mode=", + static_cast(coord_mode), + " nearest_mode=", + static_cast(nearest_mode), + " upsample=", + is_upsample, + " scale_h=", + scale_h, + " scale_w=", + scale_w); return false; } -bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupported(const ov::intel_cpu::InterpolateAttrs &interpolateAttrs, - const std::vector &srcDescs, - const std::vector &dstDescs) const { +bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupported(const ov::intel_cpu::InterpolateAttrs& interpolateAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const { if (srcDescs[0]->getShape().getDims().size() != 4u) { DEBUG_LOG("ACL Interpolate does not support src shape rank: ", srcDescs[0]->getShape().getDims().size()); return false; } auto& pads_begin = interpolateAttrs.padBegin; - auto& pads_end = interpolateAttrs.padEnd; - - if (!std::all_of(pads_begin.begin(), pads_begin.end(), [](int i){return i == 0;}) || - !std::all_of(pads_end.begin(), pads_end.end(), [](int i){return i == 0;})) { + auto& pads_end = interpolateAttrs.padEnd; + + if (!std::all_of(pads_begin.begin(), + pads_begin.end(), + [](int i) { + return i == 0; + }) || + !std::all_of(pads_end.begin(), pads_end.end(), [](int i) { + return i == 0; + })) { DEBUG_LOG("ACL Interpolate does not support padding"); return false; } @@ -180,15 +209,16 @@ bool ov::intel_cpu::ACLInterpolateExecutorBuilder::isSupported(const ov::intel_c return false; } - if (interpolateAttrs.mode == InterpolateMode::cubic || - interpolateAttrs.mode == InterpolateMode::bilinear_pillow || + if (interpolateAttrs.mode == InterpolateMode::cubic || interpolateAttrs.mode == InterpolateMode::bilinear_pillow || interpolateAttrs.mode == InterpolateMode::bicubic_pillow) { DEBUG_LOG("ACL Interpolate does not support cubic, bilinear_pillow, bicubic_pillow modes"); return false; } if (interpolateAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales && - one_of(interpolateAttrs.coordTransMode, InterpolateCoordTransMode::half_pixel, InterpolateCoordTransMode::asymmetric) && + one_of(interpolateAttrs.coordTransMode, + InterpolateCoordTransMode::half_pixel, + InterpolateCoordTransMode::asymmetric) && one_of(interpolateAttrs.mode, InterpolateMode::linear, InterpolateMode::linear_onnx)) { DEBUG_LOG("ACL Interpolate does not support scales mode with linear/linear_onnx and half_pixel/asymmetric"); return false; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.hpp index 17cdfec5928544..c141fa132a31ff 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_interpolate.hpp @@ -4,9 +4,9 @@ #pragma once -#include "nodes/executors/interpolate.hpp" #include "arm_compute/runtime/NEON/functions/NEScale.h" #include "arm_compute/runtime/Tensor.h" +#include "nodes/executors/interpolate.hpp" namespace ov { namespace intel_cpu { @@ -18,9 +18,11 @@ class ACLInterpolateExecutor : public InterpolateExecutor { bool init(const InterpolateAttrs& interpolateAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; - void exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) override; + void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) override; impl_desc_type getImplType() const override { return implType; @@ -44,10 +46,11 @@ class ACLInterpolateExecutorBuilder : public InterpolateExecutorBuilder { InterpolateExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const override { return std::make_shared(context); } + private: static bool isSupportedConfiguration(const InterpolateAttrs& interpolateAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs); + const std::vector& srcDescs, + const std::vector& dstDescs); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.cpp index 6fde4bb0db5604..290cd3c9dbcce9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.cpp @@ -14,7 +14,7 @@ AclMVNExecutor::AclMVNExecutor(const ExecutorContext::CPtr context) : MVNExecuto bool AclMVNExecutor::init(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const dnnl::primitive_attr& attr) { auto srcDims = srcDescs[0]->getShape().getStaticDims(); auto dstDims = dstDescs[0]->getShape().getStaticDims(); @@ -46,9 +46,14 @@ bool AclMVNExecutor::init(const MVNAttrs& mvnAttrs, } } - TensorInfo srcTensorInfo = TensorInfo(TensorShape(X, Y), 1, precisionToAclDataType(srcDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[0])); - TensorInfo dstTensorInfo = TensorInfo(TensorShape(X, Y), 1, precisionToAclDataType(dstDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[0])); - + TensorInfo srcTensorInfo = TensorInfo(TensorShape(X, Y), + 1, + precisionToAclDataType(srcDescs[0]->getPrecision()), + getAclDataLayoutByMemoryDesc(srcDescs[0])); + TensorInfo dstTensorInfo = TensorInfo(TensorShape(X, Y), + 1, + precisionToAclDataType(dstDescs[0]->getPrecision()), + getAclDataLayoutByMemoryDesc(dstDescs[0])); if (!arm_compute::NEMeanStdDevNormalizationLayer::validate(&srcTensorInfo, &dstTensorInfo, mvnAttrs.epsValue_)) return false; @@ -57,12 +62,16 @@ bool AclMVNExecutor::init(const MVNAttrs& mvnAttrs, dstTensor.allocator()->init(dstTensorInfo); mvn = std::make_unique(); - configureThreadSafe([&] { mvn->configure(&srcTensor, &dstTensor, mvnAttrs.epsValue_); }); + configureThreadSafe([&] { + mvn->configure(&srcTensor, &dstTensor, mvnAttrs.epsValue_); + }); return true; } -void AclMVNExecutor::exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) { +void AclMVNExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) { srcTensor.allocator()->import_memory(src[0]->getData()); dstTensor.allocator()->import_memory(dst[0]->getData()); @@ -75,41 +84,41 @@ void AclMVNExecutor::exec(const std::vector& src, const std::vector< bool AclMVNExecutorBuilder::isSupported(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs) const { - if ((srcDescs[0]->getPrecision() != ov::element::f32 && - srcDescs[0]->getPrecision() != ov::element::f16) || - srcDescs[0]->getPrecision() != dstDescs[0]->getPrecision()) { - DEBUG_LOG("NEMeanStdDevNormalizationLayer does not support precisions:", - " src[0]=", srcDescs[0]->getPrecision(), - " dst[0]=", dstDescs[0]->getPrecision()); - return false; - } - - if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && - dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && - !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && - dstDescs[0]->hasLayoutType(LayoutType::nspc))) { - DEBUG_LOG("NEMeanStdDevNormalizationLayer does not support layout:", - " src: ", srcDescs[0]->serializeFormat(), - " dst: ", dstDescs[0]->serializeFormat()); - return false; - } + if ((srcDescs[0]->getPrecision() != ov::element::f32 && srcDescs[0]->getPrecision() != ov::element::f16) || + srcDescs[0]->getPrecision() != dstDescs[0]->getPrecision()) { + DEBUG_LOG("NEMeanStdDevNormalizationLayer does not support precisions:", + " src[0]=", + srcDescs[0]->getPrecision(), + " dst[0]=", + dstDescs[0]->getPrecision()); + return false; + } - if (mvnAttrs.epsMode_ == MVNEpsMode::OUTSIDE_SQRT) { - DEBUG_LOG("NEMeanStdDevNormalizationLayer does not support OUTSIDE_SQRT mode"); - return false; - } - if (!mvnAttrs.normalizeVariance_) { - DEBUG_LOG("NEMeanStdDevNormalizationLayer supports normalize_variance=true only"); - return false; - } - if (!mvnAttrs.initAcrossChannels_ && - srcDescs[0]->hasLayoutType(LayoutType::nspc)) { - DEBUG_LOG("initAcrossChannels = false is not supported by ACL for NHWC layout"); - return false; - } + if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && + !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && dstDescs[0]->hasLayoutType(LayoutType::nspc))) { + DEBUG_LOG("NEMeanStdDevNormalizationLayer does not support layout:", + " src: ", + srcDescs[0]->serializeFormat(), + " dst: ", + dstDescs[0]->serializeFormat()); + return false; + } - return true; + if (mvnAttrs.epsMode_ == MVNEpsMode::OUTSIDE_SQRT) { + DEBUG_LOG("NEMeanStdDevNormalizationLayer does not support OUTSIDE_SQRT mode"); + return false; } + if (!mvnAttrs.normalizeVariance_) { + DEBUG_LOG("NEMeanStdDevNormalizationLayer supports normalize_variance=true only"); + return false; + } + if (!mvnAttrs.initAcrossChannels_ && srcDescs[0]->hasLayoutType(LayoutType::nspc)) { + DEBUG_LOG("initAcrossChannels = false is not supported by ACL for NHWC layout"); + return false; + } + + return true; +} -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.hpp index 7ba445253f8d02..02521551509366 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_mvn.hpp @@ -5,8 +5,8 @@ #pragma once #include "acl_utils.hpp" -#include "nodes/executors/mvn.hpp" #include "arm_compute/runtime/NEON/NEFunctions.h" +#include "nodes/executors/mvn.hpp" #include "utils/debug_capabilities.h" namespace ov { @@ -19,10 +19,10 @@ class AclMVNExecutor : public MVNExecutor { bool init(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst, - const void *post_ops_data_) override; + const void* post_ops_data_) override; impl_desc_type getImplType() const override { return implType; @@ -47,5 +47,5 @@ class AclMVNExecutorBuilder : public MVNExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.cpp index a49a5cea4ef26e..2e4aed30d7b33e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.cpp @@ -3,6 +3,7 @@ // #include "acl_pooling.hpp" + #include "acl_utils.hpp" namespace ov { @@ -22,10 +23,12 @@ bool AclPoolingExecutor::isSupported(const TensorInfo& srcTensorInfo, PoolingLayerInfo* pool_info, Pooling3dLayerInfo* pool3d_info, bool ignoreOutShapeErrors) { - unsigned int pad_left = (poolingAttrs.data_pad_begin.size() >= 2u) ? poolingAttrs.data_pad_begin[1] : poolingAttrs.data_pad_begin[0]; - unsigned int pad_right = (poolingAttrs.data_pad_end.size() >= 2u) ? poolingAttrs.data_pad_end[1] : poolingAttrs.data_pad_end[0]; - unsigned int pad_top = (poolingAttrs.data_pad_begin.size() >= 2u) ? poolingAttrs.data_pad_begin[0] : 0; - unsigned int pad_bottom = (poolingAttrs.data_pad_end.size() >= 2u) ? poolingAttrs.data_pad_end[0] : 0; + unsigned int pad_left = + (poolingAttrs.data_pad_begin.size() >= 2u) ? poolingAttrs.data_pad_begin[1] : poolingAttrs.data_pad_begin[0]; + unsigned int pad_right = + (poolingAttrs.data_pad_end.size() >= 2u) ? poolingAttrs.data_pad_end[1] : poolingAttrs.data_pad_end[0]; + unsigned int pad_top = (poolingAttrs.data_pad_begin.size() >= 2u) ? poolingAttrs.data_pad_begin[0] : 0; + unsigned int pad_bottom = (poolingAttrs.data_pad_end.size() >= 2u) ? poolingAttrs.data_pad_end[0] : 0; unsigned int kernel_w = (poolingAttrs.kernel.size() >= 2u) ? poolingAttrs.kernel[1] : poolingAttrs.kernel[0]; unsigned int kernel_h = (poolingAttrs.kernel.size() >= 2u) ? poolingAttrs.kernel[0] : 1; unsigned int stride_x = (poolingAttrs.stride.size() >= 2u) ? poolingAttrs.stride[1] : poolingAttrs.stride[0]; @@ -47,45 +50,48 @@ bool AclPoolingExecutor::isSupported(const TensorInfo& srcTensorInfo, // The combination of parameters: NCHW + CEIL gives an accuracy problem in AvgPool. // One workaround is to disable the ACL executor for these parameters. // Then OneDNN will run this case in ACL backend as reorder -> NHWC -> reorder - if (pool_type == PoolingType::AVG && - dataLayout == arm_compute::DataLayout::NCHW && + if (pool_type == PoolingType::AVG && dataLayout == arm_compute::DataLayout::NCHW && poolingAttrs.rounding == op::RoundingType::CEIL) { DEBUG_LOG("NCHW + CEIL gives an accuracy problem in ACL AvgPool. ACL executor will not be created."); return false; } - DimensionRoundingType round = (poolingAttrs.rounding == op::RoundingType::CEIL) ? - DimensionRoundingType::CEIL : DimensionRoundingType::FLOOR; + DimensionRoundingType round = + (poolingAttrs.rounding == op::RoundingType::CEIL) ? DimensionRoundingType::CEIL : DimensionRoundingType::FLOOR; if (srcDimsSize == 5) { if (dstDescsSize > 1) { DEBUG_LOG("NEPooling3dLayer does not support indices"); return false; } else { - unsigned int kernel_d = poolingAttrs.kernel[2]; - unsigned int stride_z = poolingAttrs.stride[2]; + unsigned int kernel_d = poolingAttrs.kernel[2]; + unsigned int stride_z = poolingAttrs.stride[2]; unsigned int pad_front = poolingAttrs.data_pad_begin[2]; - unsigned int pad_back = poolingAttrs.data_pad_end[2]; - pool3d_info->pool_type = pool_type; + unsigned int pad_back = poolingAttrs.data_pad_end[2]; + pool3d_info->pool_type = pool_type; pool3d_info->exclude_padding = exclude_padding; - pool3d_info->pool_size = arm_compute::Size3D(kernel_w, kernel_h, kernel_d); - pool3d_info->stride = arm_compute::Size3D(stride_x, stride_y, stride_z); - pool3d_info->padding = arm_compute::Padding3D(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back); - pool3d_info->round_type = round; - arm_compute::Status s = arm_compute::NEPooling3dLayer::validate(&srcTensorInfo, &dstTensorInfo, *pool3d_info); + pool3d_info->pool_size = arm_compute::Size3D(kernel_w, kernel_h, kernel_d); + pool3d_info->stride = arm_compute::Size3D(stride_x, stride_y, stride_z); + pool3d_info->padding = + arm_compute::Padding3D(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back); + pool3d_info->round_type = round; + arm_compute::Status s = + arm_compute::NEPooling3dLayer::validate(&srcTensorInfo, &dstTensorInfo, *pool3d_info); if (!s) { DEBUG_LOG("NEPooling3dLayer validation failed: ", s.error_description()); return false; } } } else { - pool_info->data_layout = dataLayout; - pool_info->pool_size = arm_compute::Size2D(kernel_w, kernel_h); - pool_info->pad_stride_info = arm_compute::PadStrideInfo(stride_x, stride_y, pad_left, pad_right, pad_top, pad_bottom, round); - pool_info->pool_type = pool_type; - pool_info->exclude_padding = exclude_padding; + pool_info->data_layout = dataLayout; + pool_info->pool_size = arm_compute::Size2D(kernel_w, kernel_h); + pool_info->pad_stride_info = + arm_compute::PadStrideInfo(stride_x, stride_y, pad_left, pad_right, pad_top, pad_bottom, round); + pool_info->pool_type = pool_type; + pool_info->exclude_padding = exclude_padding; if (dstDescsSize > 1) { TensorInfo indTensorInfo = TensorInfo(shapeCast(*indDims), 1, arm_compute::DataType::U32, dataLayout); - arm_compute::Status s = arm_compute::NEPoolingLayer::validate(&srcTensorInfo, &dstTensorInfo, *pool_info, &indTensorInfo); + arm_compute::Status s = + arm_compute::NEPoolingLayer::validate(&srcTensorInfo, &dstTensorInfo, *pool_info, &indTensorInfo); if (!s) { DEBUG_LOG("NEPoolingLayer validation with indices failed: ", s.error_description()); if (ignoreOutShapeErrors && @@ -112,9 +118,9 @@ bool AclPoolingExecutor::isSupported(const TensorInfo& srcTensorInfo, } bool AclPoolingExecutor::init(const PoolingAttrs& poolingAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { auto srcDims = srcDescs[0]->getShape().getStaticDims(); auto dstDims = dstDescs[0]->getShape().getStaticDims(); @@ -124,10 +130,14 @@ bool AclPoolingExecutor::init(const PoolingAttrs& poolingAttrs, changeLayoutToNH_C({&srcShape, &dstShape}); } - TensorInfo srcTensorInfo = TensorInfo(srcShape, 1, - precisionToAclDataType(srcDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[0])); - TensorInfo dstTensorInfo = TensorInfo(dstShape, 1, - precisionToAclDataType(dstDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[0])); + TensorInfo srcTensorInfo = TensorInfo(srcShape, + 1, + precisionToAclDataType(srcDescs[0]->getPrecision()), + getAclDataLayoutByMemoryDesc(srcDescs[0])); + TensorInfo dstTensorInfo = TensorInfo(dstShape, + 1, + precisionToAclDataType(dstDescs[0]->getPrecision()), + getAclDataLayoutByMemoryDesc(dstDescs[0])); srcTensor.allocator()->init(srcTensorInfo); dstTensor.allocator()->init(dstTensorInfo); @@ -166,7 +176,9 @@ bool AclPoolingExecutor::init(const PoolingAttrs& poolingAttrs, nullptr)) return false; auto indDims = dstDescs[1]->getShape().getStaticDims(); - TensorInfo indTensorInfo = TensorInfo(shapeCast(indDims), 1, precisionToAclDataType(dstDescs[1]->getPrecision()), + TensorInfo indTensorInfo = TensorInfo(shapeCast(indDims), + 1, + precisionToAclDataType(dstDescs[1]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[1])); indTensor.allocator()->init(indTensorInfo); exec_func = [this, pool_info]() -> std::unique_ptr { @@ -192,21 +204,27 @@ bool AclPoolingExecutor::init(const PoolingAttrs& poolingAttrs, }; } } - configureThreadSafe([&] { ifunc = exec_func(); }); + configureThreadSafe([&] { + ifunc = exec_func(); + }); return true; } -void AclPoolingExecutor::exec(const std::vector& src, const std::vector& dst, std::unordered_map postOpsArgs) { +void AclPoolingExecutor::exec(const std::vector& src, + const std::vector& dst, + std::unordered_map postOpsArgs) { srcTensor.allocator()->import_memory(src[0]->getData()); dstTensor.allocator()->import_memory(dst[0]->getData()); - if (dst.size() > 1u) indTensor.allocator()->import_memory(dst[1]->getData()); + if (dst.size() > 1u) + indTensor.allocator()->import_memory(dst[1]->getData()); ifunc->run(); srcTensor.allocator()->free(); dstTensor.allocator()->free(); - if (dst.size() > 1u) indTensor.allocator()->free(); + if (dst.size() > 1u) + indTensor.allocator()->free(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.hpp index 9f6b1bb0fcc668..75b3d28eecf4aa 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.hpp @@ -4,8 +4,8 @@ #pragma once -#include "nodes/executors/pooling.hpp" #include "arm_compute/runtime/NEON/NEFunctions.h" +#include "nodes/executors/pooling.hpp" #include "utils/debug_capabilities.h" namespace ov { @@ -18,7 +18,7 @@ class AclPoolingExecutor : public PoolingExecutor { bool init(const PoolingAttrs& poolingAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst, std::unordered_map postOpsArgs) override; @@ -54,70 +54,72 @@ class AclPoolingExecutorBuilder : public PoolingExecutorBuilder { bool isSupported(const PoolingAttrs& poolingAttrs, const std::vector& srcDescs, const std::vector& dstDescs) const override { - if ((srcDescs[0]->getPrecision() != ov::element::f32 && - dstDescs[0]->getPrecision() != ov::element::f32) && - (srcDescs[0]->getPrecision() != ov::element::f16 && - dstDescs[0]->getPrecision() != ov::element::f16)) { + if ((srcDescs[0]->getPrecision() != ov::element::f32 && dstDescs[0]->getPrecision() != ov::element::f32) && + (srcDescs[0]->getPrecision() != ov::element::f16 && dstDescs[0]->getPrecision() != ov::element::f16)) { DEBUG_LOG("AclPoolingExecutor does not support precisions:", - " src[0]=", srcDescs[0]->getPrecision(), - " dst[0]=", dstDescs[0]->getPrecision()); + " src[0]=", + srcDescs[0]->getPrecision(), + " dst[0]=", + dstDescs[0]->getPrecision()); return false; } if (srcDescs.size() == 2u && - (srcDescs[1]->getPrecision() != ov::element::f32 && - srcDescs[0]->getPrecision() != ov::element::f32 && + (srcDescs[1]->getPrecision() != ov::element::f32 && srcDescs[0]->getPrecision() != ov::element::f32 && dstDescs[0]->getPrecision() != ov::element::f32) && - (srcDescs[1]->getPrecision() != ov::element::f16 && - srcDescs[0]->getPrecision() != ov::element::f16 && + (srcDescs[1]->getPrecision() != ov::element::f16 && srcDescs[0]->getPrecision() != ov::element::f16 && dstDescs[0]->getPrecision() != ov::element::f16)) { DEBUG_LOG("AclPoolingExecutor does not support precisions:", - " src[0]=", srcDescs[0]->getPrecision(), - " src[1]=", srcDescs[1]->getPrecision(), - " dst[0]=", dstDescs[0]->getPrecision()); + " src[0]=", + srcDescs[0]->getPrecision(), + " src[1]=", + srcDescs[1]->getPrecision(), + " dst[0]=", + dstDescs[0]->getPrecision()); return false; } - if (dstDescs.size() == 2u && - dstDescs[1]->getPrecision() != ov::element::u32) { + if (dstDescs.size() == 2u && dstDescs[1]->getPrecision() != ov::element::u32) { DEBUG_LOG("AclPoolingExecutor supports U32 as indices precisions only. ", - "Passed indices precision: ", dstDescs[1]->getPrecision()); - return false; - } + "Passed indices precision: ", + dstDescs[1]->getPrecision()); + return false; + } if (srcDescs[0]->getShape().getRank() < 5) { - if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && - dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && - !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && - dstDescs[0]->hasLayoutType(LayoutType::nspc))) { - DEBUG_LOG("NEPoolingLayer does not support layouts:", - " src=", srcDescs[0]->serializeFormat(), - " dst=", dstDescs[0]->serializeFormat()); - return false; - } + if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && + !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && dstDescs[0]->hasLayoutType(LayoutType::nspc))) { + DEBUG_LOG("NEPoolingLayer does not support layouts:", + " src=", + srcDescs[0]->serializeFormat(), + " dst=", + dstDescs[0]->serializeFormat()); + return false; + } if (srcDescs.size() == 2u && - !(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && - srcDescs[1]->hasLayoutType(LayoutType::ncsp) && - dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && - !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && - srcDescs[1]->hasLayoutType(LayoutType::nspc) && - dstDescs[0]->hasLayoutType(LayoutType::nspc))) { - DEBUG_LOG("NEPoolingLayer does not support layouts:", - " src[0]=", srcDescs[0]->serializeFormat(), - " src[1]=", srcDescs[1]->serializeFormat(), - " dst=", dstDescs[0]->serializeFormat()); - return false; - } + !(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && srcDescs[1]->hasLayoutType(LayoutType::ncsp) && + dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && + !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && srcDescs[1]->hasLayoutType(LayoutType::nspc) && + dstDescs[0]->hasLayoutType(LayoutType::nspc))) { + DEBUG_LOG("NEPoolingLayer does not support layouts:", + " src[0]=", + srcDescs[0]->serializeFormat(), + " src[1]=", + srcDescs[1]->serializeFormat(), + " dst=", + dstDescs[0]->serializeFormat()); + return false; + } } else { - if (!(srcDescs[0]->hasLayoutType(LayoutType::nspc) && - dstDescs[0]->hasLayoutType(LayoutType::nspc)) && - !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && - dstDescs[0]->hasLayoutType(LayoutType::nspc))) { - DEBUG_LOG("Pooling3dLayer does not support layouts:", - " src=", srcDescs[0]->serializeFormat(), - " dst=", dstDescs[0]->serializeFormat()); - return false; - } + if (!(srcDescs[0]->hasLayoutType(LayoutType::nspc) && dstDescs[0]->hasLayoutType(LayoutType::nspc)) && + !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && dstDescs[0]->hasLayoutType(LayoutType::nspc))) { + DEBUG_LOG("Pooling3dLayer does not support layouts:", + " src=", + srcDescs[0]->serializeFormat(), + " dst=", + dstDescs[0]->serializeFormat()); + return false; + } } return true; @@ -128,5 +130,5 @@ class AclPoolingExecutorBuilder : public PoolingExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp index e99747121cb623..5973027a0376cb 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp @@ -11,28 +11,31 @@ using namespace arm_compute; static arm_compute::ReductionOperation getAclReductionOperationByAlgorithm(Algorithm algorithm) { switch (algorithm) { - case Algorithm::ReduceMax: return arm_compute::ReductionOperation::MAX; - case Algorithm::ReduceMin: return arm_compute::ReductionOperation::MIN; - case Algorithm::ReduceSum: return arm_compute::ReductionOperation::SUM; - case Algorithm::ReduceProd: return arm_compute::ReductionOperation::PROD; - default: OPENVINO_THROW("Unsupported reduction operation: ", static_cast(algorithm)); + case Algorithm::ReduceMax: + return arm_compute::ReductionOperation::MAX; + case Algorithm::ReduceMin: + return arm_compute::ReductionOperation::MIN; + case Algorithm::ReduceSum: + return arm_compute::ReductionOperation::SUM; + case Algorithm::ReduceProd: + return arm_compute::ReductionOperation::PROD; + default: + OPENVINO_THROW("Unsupported reduction operation: ", static_cast(algorithm)); } } AclReduceExecutor::AclReduceExecutor(const ExecutorContext::CPtr context) : ReduceExecutor(context) {} bool AclReduceExecutor::init(const ReduceAttrs& reduceAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { - if (reduceAttrs.operation != Algorithm::ReduceMax && - reduceAttrs.operation != Algorithm::ReduceMin && - reduceAttrs.operation != Algorithm::ReduceSum && - reduceAttrs.operation != Algorithm::ReduceProd && + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { + if (reduceAttrs.operation != Algorithm::ReduceMax && reduceAttrs.operation != Algorithm::ReduceMin && + reduceAttrs.operation != Algorithm::ReduceSum && reduceAttrs.operation != Algorithm::ReduceProd && reduceAttrs.operation != Algorithm::ReduceMean) { - DEBUG_LOG("Unknown reduce algorithm passed into AclReduceExecutor: ", static_cast(reduceAttrs.operation)); - return false; - } + DEBUG_LOG("Unknown reduce algorithm passed into AclReduceExecutor: ", static_cast(reduceAttrs.operation)); + return false; + } this->reduceAttrs = reduceAttrs; @@ -46,10 +49,14 @@ bool AclReduceExecutor::init(const ReduceAttrs& reduceAttrs, changeLayoutToNH_C({&srcShape, &dstShape}); } - TensorInfo srcTensorInfo = TensorInfo(srcShape, 1, - precisionToAclDataType(srcDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[0])); - TensorInfo dstTensorInfo = TensorInfo(dstShape, 1, - precisionToAclDataType(dstDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[0])); + TensorInfo srcTensorInfo = TensorInfo(srcShape, + 1, + precisionToAclDataType(srcDescs[0]->getPrecision()), + getAclDataLayoutByMemoryDesc(srcDescs[0])); + TensorInfo dstTensorInfo = TensorInfo(dstShape, + 1, + precisionToAclDataType(dstDescs[0]->getPrecision()), + getAclDataLayoutByMemoryDesc(dstDescs[0])); srcTensor.allocator()->init(srcTensorInfo); dstTensor.allocator()->init(dstTensorInfo); @@ -57,54 +64,69 @@ bool AclReduceExecutor::init(const ReduceAttrs& reduceAttrs, std::function(void)> exec_func; std::vector castedAxes; for (size_t i = 0; i < reduceAttrs.axes.size(); ++i) { - int axis = axisCast(reduceAttrs.axes[i], srcDims.size(), hasSrcNspcLayout ? NHWC_TO_NCHW : NO_LAYOUT_CONVERSION); - if (hasSrcNspcLayout && axis == -1) return false; + int axis = + axisCast(reduceAttrs.axes[i], srcDims.size(), hasSrcNspcLayout ? NHWC_TO_NCHW : NO_LAYOUT_CONVERSION); + if (hasSrcNspcLayout && axis == -1) + return false; castedAxes.push_back(axis); } switch (reduceAttrs.operation) { - case Algorithm::ReduceMean: { - for (size_t i = 0; i < reduceAttrs.axes.size(); ++i) { - auto pos = axisCast(i, reduceAttrs.axes.size()); - axesMean.set(pos, castedAxes[i]); - } - Status reduceMeanStatus = NEReduceMean::validate(&srcTensorInfo, axesMean, reduceAttrs.keepDims, &dstTensorInfo); - if (!reduceMeanStatus) { - DEBUG_LOG("NEReduceMean validation failed: ", reduceMeanStatus.error_description()); - return false; - } - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensor, axesMean, this->reduceAttrs.keepDims, &dstTensor); - return acl_op; - }; - break; + case Algorithm::ReduceMean: { + for (size_t i = 0; i < reduceAttrs.axes.size(); ++i) { + auto pos = axisCast(i, reduceAttrs.axes.size()); + axesMean.set(pos, castedAxes[i]); } - case Algorithm::ReduceMax: - case Algorithm::ReduceMin: - case Algorithm::ReduceSum: - case Algorithm::ReduceProd: { - Status reductionOperationStatus = NEReductionOperation::validate(&srcTensorInfo, &dstTensorInfo, castedAxes[0], - getAclReductionOperationByAlgorithm(reduceAttrs.operation), reduceAttrs.keepDims); - if (!reductionOperationStatus) { - DEBUG_LOG("NEReductionOperation validation with indices failed: ", reductionOperationStatus.error_description()); - return false; - } - exec_func = [this, castedAxes]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensor, &dstTensor, castedAxes[0], - getAclReductionOperationByAlgorithm(this->reduceAttrs.operation), this->reduceAttrs.keepDims); - return acl_op; - }; - break; + Status reduceMeanStatus = + NEReduceMean::validate(&srcTensorInfo, axesMean, reduceAttrs.keepDims, &dstTensorInfo); + if (!reduceMeanStatus) { + DEBUG_LOG("NEReduceMean validation failed: ", reduceMeanStatus.error_description()); + return false; } - default: - OPENVINO_THROW("Unsupported operation type for ACL Reduce executor: ", static_cast(reduceAttrs.operation)); + exec_func = [this]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensor, axesMean, this->reduceAttrs.keepDims, &dstTensor); + return acl_op; + }; + break; + } + case Algorithm::ReduceMax: + case Algorithm::ReduceMin: + case Algorithm::ReduceSum: + case Algorithm::ReduceProd: { + Status reductionOperationStatus = + NEReductionOperation::validate(&srcTensorInfo, + &dstTensorInfo, + castedAxes[0], + getAclReductionOperationByAlgorithm(reduceAttrs.operation), + reduceAttrs.keepDims); + if (!reductionOperationStatus) { + DEBUG_LOG("NEReductionOperation validation with indices failed: ", + reductionOperationStatus.error_description()); + return false; + } + exec_func = [this, castedAxes]() -> std::unique_ptr { + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensor, + &dstTensor, + castedAxes[0], + getAclReductionOperationByAlgorithm(this->reduceAttrs.operation), + this->reduceAttrs.keepDims); + return acl_op; + }; + break; + } + default: + OPENVINO_THROW("Unsupported operation type for ACL Reduce executor: ", static_cast(reduceAttrs.operation)); } - configureThreadSafe([&] { ifunc = exec_func(); }); + configureThreadSafe([&] { + ifunc = exec_func(); + }); return true; } -void AclReduceExecutor::exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) { +void AclReduceExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) { srcTensor.allocator()->import_memory(src[0]->getData()); dstTensor.allocator()->import_memory(dst[0]->getData()); @@ -114,5 +136,5 @@ void AclReduceExecutor::exec(const std::vector& src, const std::vect dstTensor.allocator()->free(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp index 69bf6062918963..a121868bf80ba3 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp @@ -20,10 +20,10 @@ class AclReduceExecutor : public ReduceExecutor { bool init(const ReduceAttrs& reduceAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst, - const void *post_ops_data_) override; + const void* post_ops_data_) override; impl_desc_type getImplType() const override { return implType; @@ -46,33 +46,38 @@ class AclReduceExecutorBuilder : public ReduceExecutorBuilder { const std::vector& dstDescs) const override { if (reduceAttrs.operation == Algorithm::ReduceMean) { if (srcDescs[0]->getPrecision() != dstDescs[0]->getPrecision() || - (srcDescs[0]->getPrecision() != ov::element::f32 && - srcDescs[0]->getPrecision() != ov::element::f16)) { + (srcDescs[0]->getPrecision() != ov::element::f32 && srcDescs[0]->getPrecision() != ov::element::f16)) { DEBUG_LOG("NEReduceMean does not support precisions:", - " src[0]=", srcDescs[0]->getPrecision(), - " dst[0]=", dstDescs[0]->getPrecision()); + " src[0]=", + srcDescs[0]->getPrecision(), + " dst[0]=", + dstDescs[0]->getPrecision()); return false; } } else { if (srcDescs[0]->getPrecision() != dstDescs[0]->getPrecision() || - (srcDescs[0]->getPrecision() != ov::element::f32 && - srcDescs[0]->getPrecision() != ov::element::f16 && - srcDescs[0]->getPrecision() != ov::element::i32)) { + (srcDescs[0]->getPrecision() != ov::element::f32 && srcDescs[0]->getPrecision() != ov::element::f16 && + srcDescs[0]->getPrecision() != ov::element::i32)) { DEBUG_LOG("NEReductionOperation does not support precisions:", - " src[0]=", srcDescs[0]->getPrecision(), - " dst[0]=", dstDescs[0]->getPrecision()); + " src[0]=", + srcDescs[0]->getPrecision(), + " dst[0]=", + dstDescs[0]->getPrecision()); return false; } } if (srcDescs[0]->getShape().getRank() >= arm_compute::MAX_DIMS) { - DEBUG_LOG("ACL supports ", arm_compute::MAX_DIMS, - " dimensions maximum. src[0] shape rank is ", srcDescs[0]->getShape().getRank()); + DEBUG_LOG("ACL supports ", + arm_compute::MAX_DIMS, + " dimensions maximum. src[0] shape rank is ", + srcDescs[0]->getShape().getRank()); return false; } auto srcShapeRank = srcDescs[0]->getShape().getRank(); bool hasSrcNspcLayout = srcDescs[0]->hasLayoutType(LayoutType::nspc); for (size_t i = 0; i < reduceAttrs.axes.size(); ++i) { - int axis = axisCast(reduceAttrs.axes[i], srcShapeRank, hasSrcNspcLayout ? NHWC_TO_NCHW : NO_LAYOUT_CONVERSION); + int axis = + axisCast(reduceAttrs.axes[i], srcShapeRank, hasSrcNspcLayout ? NHWC_TO_NCHW : NO_LAYOUT_CONVERSION); if (axis == -1) { DEBUG_LOG("Layout conversion to NHWC has failed"); return false; @@ -82,14 +87,12 @@ class AclReduceExecutorBuilder : public ReduceExecutorBuilder { return false; } } - if ((reduceAttrs.operation == Algorithm::ReduceSum || - reduceAttrs.operation == Algorithm::ReduceMax || - reduceAttrs.operation == Algorithm::ReduceMin || - reduceAttrs.operation == Algorithm::ReduceProd) && - reduceAttrs.axes.size() != 1) { - DEBUG_LOG("ACL supports single axes reduce only. Number of axes: ", reduceAttrs.axes.size()); - return false; - } + if ((reduceAttrs.operation == Algorithm::ReduceSum || reduceAttrs.operation == Algorithm::ReduceMax || + reduceAttrs.operation == Algorithm::ReduceMin || reduceAttrs.operation == Algorithm::ReduceProd) && + reduceAttrs.axes.size() != 1) { + DEBUG_LOG("ACL supports single axes reduce only. Number of axes: ", reduceAttrs.axes.size()); + return false; + } return true; } @@ -99,5 +102,5 @@ class AclReduceExecutorBuilder : public ReduceExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.cpp index 801e50831b3bb1..dd16b333cb6b32 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.cpp @@ -3,12 +3,13 @@ // #include "acl_transpose.hpp" + #include "acl_utils.hpp" -bool ov::intel_cpu::ACLTransposeExecutor::init(const ov::intel_cpu::TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) { +bool ov::intel_cpu::ACLTransposeExecutor::init(const ov::intel_cpu::TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { auto inputOrder = transposeParams.permuteParams.order; if (inputOrder.empty()) { inputOrder.resize(srcDescs[0]->getShape().getRank()); @@ -24,7 +25,7 @@ bool ov::intel_cpu::ACLTransposeExecutor::init(const ov::intel_cpu::TransposePar }; auto srcDims = changeLayoutToNhwc(srcDescs[0]->getShape().getStaticDims()); auto dstDims = changeLayoutToNhwc(dstDescs[0]->getShape().getStaticDims()); - for (int i = inputOrder.size() - 1; i >= 0 ; --i) { + for (int i = inputOrder.size() - 1; i >= 0; --i) { auto it = find(srcDims.rbegin(), srcDims.rend(), dstDims[i]); int index = it - srcDims.rbegin(); vec.push_back(index); @@ -46,10 +47,12 @@ bool ov::intel_cpu::ACLTransposeExecutor::init(const ov::intel_cpu::TransposePar if (srcDescs[0]->hasLayoutType(LayoutType::nspc) && dstDescs[0]->hasLayoutType(LayoutType::nspc)) { changeLayoutToNH_C({&srcDims, &dstDims}); } - auto srcTensorInfo = arm_compute::TensorInfo(srcDims, 1, + auto srcTensorInfo = arm_compute::TensorInfo(srcDims, + 1, precisionToAclDataType(srcDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[0])); - auto dstTensorInfo = arm_compute::TensorInfo(dstDims, 1, + auto dstTensorInfo = arm_compute::TensorInfo(dstDims, + 1, precisionToAclDataType(dstDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[0])); arm_compute::Status status = arm_compute::NEPermute::validate(&srcTensorInfo, &dstTensorInfo, order); @@ -61,11 +64,13 @@ bool ov::intel_cpu::ACLTransposeExecutor::init(const ov::intel_cpu::TransposePar dstTensor.allocator()->init(dstTensorInfo); acl_permute = std::make_unique(); - configureThreadSafe([&] { acl_permute->configure(&srcTensor, &dstTensor, order); }); + configureThreadSafe([&] { + acl_permute->configure(&srcTensor, &dstTensor, order); + }); return true; } -void ov::intel_cpu::ACLTransposeExecutor::exec(const std::vector &src, const std::vector &dst) { +void ov::intel_cpu::ACLTransposeExecutor::exec(const std::vector& src, const std::vector& dst) { srcTensor.allocator()->import_memory(src[0]->getData()); dstTensor.allocator()->import_memory(dst[0]->getData()); diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.hpp index 02a190597531ea..c6765aa1ff25f0 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_transpose.hpp @@ -4,10 +4,9 @@ #pragma once -#include "nodes/executors/transpose.hpp" - -#include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/NEON/functions/NEPermute.h" +#include "arm_compute/runtime/Tensor.h" +#include "nodes/executors/transpose.hpp" #include "utils/debug_capabilities.h" namespace ov { @@ -20,9 +19,12 @@ class ACLTransposeExecutor : public TransposeExecutor { bool init(const TransposeParams& transposeParams, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst) override; - impl_desc_type implType() const override { return impl_desc_type::acl; } + impl_desc_type implType() const override { + return impl_desc_type::acl; + } + private: arm_compute::Tensor srcTensor, dstTensor; std::unique_ptr acl_permute; @@ -33,13 +35,13 @@ class ACLTransposeExecutorBuilder : public TransposeExecutorBuilder { bool isSupported(const TransposeParams& transposeParams, const std::vector& srcDescs, const std::vector& dstDescs) const override { - if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && - dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && - !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && - dstDescs[0]->hasLayoutType(LayoutType::nspc))) { + if (!(srcDescs[0]->hasLayoutType(LayoutType::ncsp) && dstDescs[0]->hasLayoutType(LayoutType::ncsp)) && + !(srcDescs[0]->hasLayoutType(LayoutType::nspc) && dstDescs[0]->hasLayoutType(LayoutType::nspc))) { DEBUG_LOG("NEPermute does not support layout:", - " src: ", srcDescs[0]->serializeFormat(), - " dst: ", dstDescs[0]->serializeFormat()); + " src: ", + srcDescs[0]->serializeFormat(), + " dst: ", + dstDescs[0]->serializeFormat()); return false; } if (srcDescs[0]->getShape().getRank() > 4) { @@ -59,5 +61,5 @@ class ACLTransposeExecutorBuilder : public TransposeExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp index df57d29f4a44ec..6c3799da70bfda 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp @@ -3,6 +3,7 @@ // #include "acl_utils.hpp" + #include "support/Mutex.h" #include "utils/debug_capabilities.h" @@ -18,55 +19,55 @@ void configureThreadSafe(const std::function& config) { arm_compute::ActivationLayerInfo getActivationLayerInfo(Algorithm algorithm, float alpha = 0.0, - float beta = 0.0, + float beta = 0.0, float gamma = 0.0) { switch (algorithm) { - case Algorithm::EltwiseRelu: - if (alpha == 0) { - return arm_compute::ActivationLayerInfo::ActivationFunction::RELU; - } else { - return {arm_compute::ActivationLayerInfo::ActivationFunction::LEAKY_RELU, alpha}; - } - case Algorithm::EltwiseGeluErf: - return arm_compute::ActivationLayerInfo::ActivationFunction::GELU; - case Algorithm::EltwiseElu: - return {arm_compute::ActivationLayerInfo::ActivationFunction::ELU, alpha}; - case Algorithm::EltwiseTanh: - return {arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f}; - case Algorithm::EltwiseSigmoid: - return arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC; - case Algorithm::EltwiseSqrt: - return arm_compute::ActivationLayerInfo::ActivationFunction::SQRT; - case Algorithm::EltwiseSoftRelu: - return arm_compute::ActivationLayerInfo::ActivationFunction::SOFT_RELU; - case Algorithm::EltwiseClamp: - return {arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, beta, alpha}; - case Algorithm::EltwiseSwish: - return {arm_compute::ActivationLayerInfo::ActivationFunction::SWISH, alpha}; - case Algorithm::EltwiseHswish: - return arm_compute::ActivationLayerInfo::ActivationFunction::HARD_SWISH; - default: - OPENVINO_THROW("Unsupported operation type for ACL Eltwise executor: ", static_cast(algorithm)); + case Algorithm::EltwiseRelu: + if (alpha == 0) { + return arm_compute::ActivationLayerInfo::ActivationFunction::RELU; + } else { + return {arm_compute::ActivationLayerInfo::ActivationFunction::LEAKY_RELU, alpha}; + } + case Algorithm::EltwiseGeluErf: + return arm_compute::ActivationLayerInfo::ActivationFunction::GELU; + case Algorithm::EltwiseElu: + return {arm_compute::ActivationLayerInfo::ActivationFunction::ELU, alpha}; + case Algorithm::EltwiseTanh: + return {arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f}; + case Algorithm::EltwiseSigmoid: + return arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC; + case Algorithm::EltwiseSqrt: + return arm_compute::ActivationLayerInfo::ActivationFunction::SQRT; + case Algorithm::EltwiseSoftRelu: + return arm_compute::ActivationLayerInfo::ActivationFunction::SOFT_RELU; + case Algorithm::EltwiseClamp: + return {arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, beta, alpha}; + case Algorithm::EltwiseSwish: + return {arm_compute::ActivationLayerInfo::ActivationFunction::SWISH, alpha}; + case Algorithm::EltwiseHswish: + return arm_compute::ActivationLayerInfo::ActivationFunction::HARD_SWISH; + default: + OPENVINO_THROW("Unsupported operation type for ACL Eltwise executor: ", static_cast(algorithm)); } } bool checkActivationLayerInfo(Algorithm algorithm) { switch (algorithm) { - case Algorithm::EltwiseRelu: - case Algorithm::EltwiseGeluErf: - case Algorithm::EltwiseElu: - case Algorithm::EltwiseTanh: - case Algorithm::EltwiseSigmoid: - case Algorithm::EltwiseSqrt: - case Algorithm::EltwiseSoftRelu: - case Algorithm::EltwiseClamp: - case Algorithm::EltwiseSwish: - case Algorithm::EltwiseHswish: - return true; - default: - return false; + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: + case Algorithm::EltwiseHswish: + return true; + default: + return false; } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp index a3d151192e601b..1d30736353b878 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp @@ -3,19 +3,19 @@ // #pragma once -#include "memory_desc/cpu_memory_desc.h" #include "arm_compute/core/Types.h" #include "cpu_types.h" +#include "memory_desc/cpu_memory_desc.h" namespace ov { namespace intel_cpu { /** -* @brief ACL supports arm_compute::MAX_DIMS maximum. The method squashes the last -* dimensions in order to comply with this limitation -* @param dims vector of dimensions to squash -* @return vector of dimensions that complies to ACL -*/ + * @brief ACL supports arm_compute::MAX_DIMS maximum. The method squashes the last + * dimensions in order to comply with this limitation + * @param dims vector of dimensions to squash + * @return vector of dimensions that complies to ACL + */ inline VectorDims collapse_dims_to_max_rank(VectorDims dims, size_t max_num_shape = arm_compute::MAX_DIMS) { VectorDims result_dims(max_num_shape - 1); if (dims.size() >= max_num_shape) { @@ -32,17 +32,23 @@ inline VectorDims collapse_dims_to_max_rank(VectorDims dims, size_t max_num_shap } /** -* @brief ACL handles NH_C specifically, it thinks it is NC_W, so we need to change layout manually: -* e.g. NCHW (0, 1, 2, 3) -> NHWC (0, 2, 3, 1) -* @param _listDims list of dimensions to convert -* @return none -*/ - -inline void changeLayoutToNH_C(const std::vector &_listDims) { - auto mover = [](arm_compute::TensorShape &_shape) { - if (_shape.num_dimensions() > 4) { std::swap(_shape[2], _shape[3]); } - if (_shape.num_dimensions() > 3) { std::swap(_shape[1], _shape[2]); } - if (_shape.num_dimensions() > 2) { std::swap(_shape[0], _shape[1]); } + * @brief ACL handles NH_C specifically, it thinks it is NC_W, so we need to change layout manually: + * e.g. NCHW (0, 1, 2, 3) -> NHWC (0, 2, 3, 1) + * @param _listDims list of dimensions to convert + * @return none + */ + +inline void changeLayoutToNH_C(const std::vector& _listDims) { + auto mover = [](arm_compute::TensorShape& _shape) { + if (_shape.num_dimensions() > 4) { + std::swap(_shape[2], _shape[3]); + } + if (_shape.num_dimensions() > 3) { + std::swap(_shape[1], _shape[2]); + } + if (_shape.num_dimensions() > 2) { + std::swap(_shape[0], _shape[1]); + } }; for (auto& dims : _listDims) { @@ -51,10 +57,10 @@ inline void changeLayoutToNH_C(const std::vector &_li } /** -* @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL -* @param dims vector of dimensions to convert -* @return ComputeLibrary TensorShape object -*/ + * @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL + * @param dims vector of dimensions to convert + * @return ComputeLibrary TensorShape object + */ inline arm_compute::TensorShape shapeCast(const VectorDims& dims) { arm_compute::TensorShape tensorShape; for (std::size_t i = 0; i < dims.size(); ++i) { @@ -67,20 +73,18 @@ inline arm_compute::TensorShape shapeCast(const VectorDims& dims) { return tensorShape; } -enum ACLAxisCastMode { - NO_LAYOUT_CONVERSION, - NHWC_TO_NCHW, - NCHW_TO_NHWC -}; +enum ACLAxisCastMode { NO_LAYOUT_CONVERSION, NHWC_TO_NCHW, NCHW_TO_NHWC }; /** -* @brief Return reverted axis used in ACL. If axis cast mode is -* @param axis axis that needs to be converted -* @param shapeSize size of the shape, which axis needs to be converted -* @param axisCastMode specifies whether layout conversion is required or not -* @return reverted axis -*/ -inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxisCastMode axisCastMode = NO_LAYOUT_CONVERSION) { + * @brief Return reverted axis used in ACL. If axis cast mode is + * @param axis axis that needs to be converted + * @param shapeSize size of the shape, which axis needs to be converted + * @param axisCastMode specifies whether layout conversion is required or not + * @return reverted axis + */ +inline int axisCast(const std::size_t axis, + const std::size_t shapeSize, + ACLAxisCastMode axisCastMode = NO_LAYOUT_CONVERSION) { // CWHN (reverted NHWC) (0, 1, 2, 3) into WHCN (reverted NCHW) (1, 2, 0, 3) static const std::array nhwcToNchw = {1, 2, 0, 3}; // WHCN (reverted NCHW) (0, 1, 2, 3) into CWHN (reverted NHWC) (2, 0, 1, 3) @@ -92,80 +96,100 @@ inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxis size_t revertedAxis = shapeSize - axis - 1; switch (axisCastMode) { - case NO_LAYOUT_CONVERSION: - return revertedAxis; - case NHWC_TO_NCHW: - if (shapeSize == 4) return nhwcToNchw[revertedAxis]; - if (shapeSize == 5) return ndhwcToNcdhw[revertedAxis]; - case NCHW_TO_NHWC: - if (shapeSize == 4) return nchwToNhwc[revertedAxis]; - if (shapeSize == 5) return ncdhwToNdhwc[revertedAxis]; - default: - return -1; + case NO_LAYOUT_CONVERSION: + return revertedAxis; + case NHWC_TO_NCHW: + if (shapeSize == 4) + return nhwcToNchw[revertedAxis]; + if (shapeSize == 5) + return ndhwcToNcdhw[revertedAxis]; + case NCHW_TO_NHWC: + if (shapeSize == 4) + return nchwToNhwc[revertedAxis]; + if (shapeSize == 5) + return ncdhwToNdhwc[revertedAxis]; + default: + return -1; } } /** -* @brief Return ComputeLibrary DataType that corresponds to the given precision -* @param precision precision to be converted -* @return ComputeLibrary DataType or UNKNOWN if precision is not mapped to DataType -*/ + * @brief Return ComputeLibrary DataType that corresponds to the given precision + * @param precision precision to be converted + * @return ComputeLibrary DataType or UNKNOWN if precision is not mapped to DataType + */ inline arm_compute::DataType precisionToAclDataType(ov::element::Type precision) { switch (precision) { - case ov::element::i8: return arm_compute::DataType::S8; - case ov::element::u8: return arm_compute::DataType::U8; - case ov::element::i16: return arm_compute::DataType::S16; - case ov::element::u16: return arm_compute::DataType::U16; - case ov::element::i32: return arm_compute::DataType::S32; - case ov::element::u32: return arm_compute::DataType::U32; - case ov::element::f16: return arm_compute::DataType::F16; - case ov::element::f32: return arm_compute::DataType::F32; - case ov::element::f64: return arm_compute::DataType::F64; - case ov::element::i64: return arm_compute::DataType::S64; - case ov::element::bf16: return arm_compute::DataType::BFLOAT16; - default: return arm_compute::DataType::UNKNOWN; + case ov::element::i8: + return arm_compute::DataType::S8; + case ov::element::u8: + return arm_compute::DataType::U8; + case ov::element::i16: + return arm_compute::DataType::S16; + case ov::element::u16: + return arm_compute::DataType::U16; + case ov::element::i32: + return arm_compute::DataType::S32; + case ov::element::u32: + return arm_compute::DataType::U32; + case ov::element::f16: + return arm_compute::DataType::F16; + case ov::element::f32: + return arm_compute::DataType::F32; + case ov::element::f64: + return arm_compute::DataType::F64; + case ov::element::i64: + return arm_compute::DataType::S64; + case ov::element::bf16: + return arm_compute::DataType::BFLOAT16; + default: + return arm_compute::DataType::UNKNOWN; } } /** -* @brief Return ComputeLibrary DataLayout that corresponds to MemoryDecs layout -* @param desc MemoryDecs from which layout is retrieved -* @param treatAs4D the flag that treats MemoryDecs as 4D shape -* @return ComputeLibrary DataLayout or UNKNOWN if MemoryDecs layout is not mapped to DataLayout -*/ + * @brief Return ComputeLibrary DataLayout that corresponds to MemoryDecs layout + * @param desc MemoryDecs from which layout is retrieved + * @param treatAs4D the flag that treats MemoryDecs as 4D shape + * @return ComputeLibrary DataLayout or UNKNOWN if MemoryDecs layout is not mapped to DataLayout + */ inline arm_compute::DataLayout getAclDataLayoutByMemoryDesc(MemoryDescCPtr desc) { if (desc->hasLayoutType(LayoutType::ncsp)) { - if (desc->getShape().getRank() <= 4) return arm_compute::DataLayout::NCHW; - if (desc->getShape().getRank() == 5) return arm_compute::DataLayout::NCDHW; + if (desc->getShape().getRank() <= 4) + return arm_compute::DataLayout::NCHW; + if (desc->getShape().getRank() == 5) + return arm_compute::DataLayout::NCDHW; } else if (desc->hasLayoutType(LayoutType::nspc)) { - if (desc->getShape().getRank() <= 4) return arm_compute::DataLayout::NHWC; - if (desc->getShape().getRank() == 5) return arm_compute::DataLayout::NDHWC; + if (desc->getShape().getRank() <= 4) + return arm_compute::DataLayout::NHWC; + if (desc->getShape().getRank() == 5) + return arm_compute::DataLayout::NDHWC; } return arm_compute::DataLayout::UNKNOWN; } /** -* @brief run thread-safe configure for ComputeLibrary configuration function. -* Arm Compute Library 23.08 does not officially support thread-safe configure() calls. -* For example, calling configure for Eltwise operations from multiple streams leads to a data race and seg fault. -* @param config ComputeLibrary configuration function -*/ + * @brief run thread-safe configure for ComputeLibrary configuration function. + * Arm Compute Library 23.08 does not officially support thread-safe configure() calls. + * For example, calling configure for Eltwise operations from multiple streams leads to a data race and seg fault. + * @param config ComputeLibrary configuration function + */ void configureThreadSafe(const std::function& config); /** -* @brief get ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. -* @param algorithm activation function of openvino representation -* @param alpha alpha coefficient for algorithm -* @param beta beta coefficient for algorithm -* @param gamma gamma coefficient for algorithm -*/ + * @brief get ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. + * @param algorithm activation function of openvino representation + * @param alpha alpha coefficient for algorithm + * @param beta beta coefficient for algorithm + * @param gamma gamma coefficient for algorithm + */ arm_compute::ActivationLayerInfo getActivationLayerInfo(Algorithm algorithm, float alpha, float beta, float gamma); /** -* @brief check ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. -* @param algorithm activation function of openvino representation -*/ + * @brief check ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. + * @param algorithm activation function of openvino representation + */ bool checkActivationLayerInfo(Algorithm algorithm); -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp index b3fe7018d23677..4b4b07df572b4a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp @@ -12,9 +12,8 @@ using namespace Xbyak_aarch64; using namespace dnnl::impl::cpu; using namespace dnnl::impl::cpu::aarch64; -void jit_uni_eltwise_kernel::operator()( - const node::jit_eltwise_call_args_ptrs* const_args, - const jit_eltwise_call_args_indexes* indexes) { +void jit_uni_eltwise_kernel::operator()(const node::jit_eltwise_call_args_ptrs* const_args, + const jit_eltwise_call_args_indexes* indexes) { assert(ker_); ker_(const_args, indexes); } @@ -23,12 +22,12 @@ template jit_uni_eltwise_generic::jit_uni_eltwise_generic(const jit_eltwise_params& jep, const std::vector& eltwise_data, const std::vector& ops_list, - const dnnl::post_ops& post_ops) : - jit_uni_eltwise_kernel(jep), - jit_generator(), - eltwise_data_(eltwise_data), - ops_list_(ops_list), - post_ops_(post_ops) {} + const dnnl::post_ops& post_ops) + : jit_uni_eltwise_kernel(jep), + jit_generator(), + eltwise_data_(eltwise_data), + ops_list_(ops_list), + post_ops_(post_ops) {} template void jit_uni_eltwise_generic::generate() { @@ -41,7 +40,7 @@ void jit_uni_eltwise_generic::generate() { post_op_emitters.push_back(create_eltwise_emitter(eltwise_data_[i], exec_prc)); } - const auto &jep = jep_; + const auto& jep = jep_; XReg param2 = abi_param2; const int offset_count = jep.input_size - 1; @@ -49,10 +48,15 @@ void jit_uni_eltwise_generic::generate() { // ptrs initializing if (jep.use_runtime_ptrs) { for (size_t i = 0; i < jep.inputs_number; i++) { - ldr(start_to_offsets, ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, src_offsets) + i * sizeof(size_t)))); - ldr(get_src_reg(i), ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t)))); - XReg offset_reg = get_aux_gpr(0); // X_TMP_0; - XReg index_reg = get_aux_gpr(1); // X_TMP_1; + ldr(start_to_offsets, + ptr(reg_const_params, + static_cast(offsetof(node::jit_eltwise_call_args_ptrs, src_offsets) + + i * sizeof(size_t)))); + ldr(get_src_reg(i), + ptr(reg_const_params, + static_cast(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t)))); + XReg offset_reg = get_aux_gpr(0); // X_TMP_0; + XReg index_reg = get_aux_gpr(1); // X_TMP_1; for (int j = 0; j < offset_count; j++) { ldr(offset_reg, ptr(start_to_offsets, static_cast(j * sizeof(size_t)))); ldr(index_reg, ptr(reg_indexes, static_cast(j * sizeof(size_t)))); @@ -60,10 +64,11 @@ void jit_uni_eltwise_generic::generate() { } } - ldr(start_to_offsets, ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, dst_offsets)))); + ldr(start_to_offsets, + ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, dst_offsets)))); ldr(reg_dst, ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr)))); - XReg offset_reg = get_aux_gpr(0); // X_TMP_0; - XReg index_reg = get_aux_gpr(1); // X_TMP_1; + XReg offset_reg = get_aux_gpr(0); // X_TMP_0; + XReg index_reg = get_aux_gpr(1); // X_TMP_1; for (int j = 0; j < offset_count; j++) { ldr(offset_reg, ptr(start_to_offsets, static_cast(j * sizeof(size_t)))); ldr(index_reg, ptr(reg_indexes, static_cast(j * sizeof(size_t)))); @@ -72,7 +77,8 @@ void jit_uni_eltwise_generic::generate() { mov(reg_oc_off, 0); - ldr(reg_work_amount, ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, work_amount)))); + ldr(reg_work_amount, + ptr(reg_const_params, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, work_amount)))); } else { auto init_ptrs_with_offsets = [this, offset_count, param2](XReg pointer, const std::vector& offsets) { for (int j = 0; j < offset_count; j++) { @@ -88,7 +94,9 @@ void jit_uni_eltwise_generic::generate() { }; for (size_t i = 0; i < jep.inputs_number; i++) { - ldr(get_src_reg(i), ptr(param1, static_cast(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof(size_t)))); + ldr(get_src_reg(i), + ptr(param1, + static_cast(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof(size_t)))); init_ptrs_with_offsets(get_src_reg(i), jep.src_offsets[i]); } @@ -149,7 +157,12 @@ void jit_uni_eltwise_generic::generate() { for (size_t j = 0; j < min_src_size / vec_step; j++) { for (size_t i = 0; i < jep.inputs_number; i++) { if (jep.src_size[i] != 1) { - load_vector(get_vmm_reg(i), get_src_reg(i), jep.src_prc[i], exec_prc, false, j * vec_step * jep.src_prc[i].size()); + load_vector(get_vmm_reg(i), + get_src_reg(i), + jep.src_prc[i], + exec_prc, + false, + j * vec_step * jep.src_prc[i].size()); } } @@ -164,7 +177,11 @@ void jit_uni_eltwise_generic::generate() { for (size_t j = tail_start; j < min_src_size; j++) { for (size_t i = 0; i < jep.inputs_number; i++) { if (jep.src_size[i] != 1) { - load_scalar(get_scl_reg(i), get_src_reg(i), jep.src_prc[i], exec_prc, j * jep.src_prc[i].size()); + load_scalar(get_scl_reg(i), + get_src_reg(i), + jep.src_prc[i], + exec_prc, + j * jep.src_prc[i].size()); } } @@ -276,7 +293,7 @@ namespace utils { template void load_vector(const T1& data_lane, const T2& data_lanes, - const Xbyak_aarch64::XReg &ptr_reg, + const Xbyak_aarch64::XReg& ptr_reg, const int64_t offset, const bool broadcast, jit_generator* h) { @@ -296,7 +313,7 @@ void load_vector(const T1& data_lane, } } } -} // namespace utils +} // namespace utils template void jit_uni_eltwise_generic::load_vector(const TReg& data, @@ -306,62 +323,63 @@ void jit_uni_eltwise_generic::load_vector(const TReg& data, const bool broadcast, const int32_t ptr_offset) { switch (src_prc) { - case ov::element::f16: { - utils::load_vector(data.h, data.h4, ptr_reg, ptr_offset, broadcast, this); - break; - } - case ov::element::f32: - case ov::element::i32: { - if (broadcast) { - jit_generator::uni_ld1rw(data.s, ptr_reg, ptr_offset); - } else { - jit_generator::uni_ldr(data, ptr_reg, ptr_offset); - } - break; - } - case ov::element::i8: { - utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this); - sshll(data.h8, data.b8, 0); - sshll(data.s4, data.h4, 0); - break; - } - case ov::element::u8: { - utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this); - ushll(data.h8, data.b8, 0); - ushll(data.s4, data.h4, 0); - break; - } - default: { - OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); + case ov::element::f16: { + utils::load_vector(data.h, data.h4, ptr_reg, ptr_offset, broadcast, this); + break; + } + case ov::element::f32: + case ov::element::i32: { + if (broadcast) { + jit_generator::uni_ld1rw(data.s, ptr_reg, ptr_offset); + } else { + jit_generator::uni_ldr(data, ptr_reg, ptr_offset); } + break; + } + case ov::element::i8: { + utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this); + sshll(data.h8, data.b8, 0); + sshll(data.s4, data.h4, 0); + break; + } + case ov::element::u8: { + utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this); + ushll(data.h8, data.b8, 0); + ushll(data.s4, data.h4, 0); + break; + } + default: { + OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); + } } if (dst_prc != src_prc) { switch (dst_prc) { - case ov::element::f32: - switch (src_prc) { - case ov::element::f16: { - fcvtl(data.s4, data.h4); - break; - } - case ov::element::i32: { - scvtf(data.s, data.s); - break; - } - case ov::element::i8: { - scvtf(data.s, data.s); - break; - } - case ov::element::u8: { - ucvtf(data.s, data.s); - break; - } - default: - OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); - } + case ov::element::f32: + switch (src_prc) { + case ov::element::f16: { + fcvtl(data.s4, data.h4); + break; + } + case ov::element::i32: { + scvtf(data.s, data.s); break; + } + case ov::element::i8: { + scvtf(data.s, data.s); + break; + } + case ov::element::u8: { + ucvtf(data.s, data.s); + break; + } default: - OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); + OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + + dst_prc.to_string()); + } + break; + default: + OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); } } } @@ -373,61 +391,62 @@ void jit_uni_eltwise_generic::load_scalar(const SReg& data, const ov::element::Type& dst_prc, const int32_t ptr_offset) { switch (src_prc) { - case ov::element::f16: { - ldr(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); - break; - } - case ov::element::f32: - case ov::element::i32: { - ldr(data, Xbyak_aarch64::ptr(ptr, ptr_offset)); - break; - } - case ov::element::i8: { - ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + case ov::element::f16: { + ldr(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + case ov::element::f32: + case ov::element::i32: { + ldr(data, Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + case ov::element::i8: { + ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); - // scalar is loaded, operates with vector - TReg vec(data.getIdx()); - sshll(vec.h8, vec.b8, 0); - sshll(vec.s4, vec.h4, 0); - break; - } - case ov::element::u8: { - ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + // scalar is loaded, operates with vector + TReg vec(data.getIdx()); + sshll(vec.h8, vec.b8, 0); + sshll(vec.s4, vec.h4, 0); + break; + } + case ov::element::u8: { + ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); - // scalar is loaded, operates with vector - TReg vec(data.getIdx()); - ushll(vec.h8, vec.b8, 0); - ushll(vec.s4, vec.h4, 0); - break; - } - default: { - OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); - } + // scalar is loaded, operates with vector + TReg vec(data.getIdx()); + ushll(vec.h8, vec.b8, 0); + ushll(vec.s4, vec.h4, 0); + break; + } + default: { + OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); + } } if (dst_prc != src_prc) { switch (dst_prc) { - case ov::element::f32: - switch (src_prc) { - case ov::element::f16: { - fcvt(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::HReg(data.getIdx())); - break; - } - case ov::element::i32: - case ov::element::i8: { - scvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx())); - break; - } - case ov::element::u8: { - ucvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx())); - break; - } - default: - OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); - } + case ov::element::f32: + switch (src_prc) { + case ov::element::f16: { + fcvt(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::HReg(data.getIdx())); + break; + } + case ov::element::i32: + case ov::element::i8: { + scvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx())); + break; + } + case ov::element::u8: { + ucvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx())); break; + } default: - OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); + OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + + dst_prc.to_string()); + } + break; + default: + OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); } } } @@ -440,58 +459,59 @@ void jit_uni_eltwise_generic::store_vector(const XReg& ptr, const int32_t ptr_offset) { if (src_prc != dst_prc) { switch (src_prc) { - case ov::element::f32: { - switch (dst_prc) { - case ov::element::f16: { - fcvtn(data.h4, data.s4); - break; - } - case ov::element::i32: { - fcvtns(data.s, data.s); - break; - } - case ov::element::i8: { - fcvtms(data.s, data.s); - xtn(data.h4, data.s4); - xtn(data.b8, data.h8); - break; - } - case ov::element::u8: { - fcvtmu(data.s, data.s); - xtn(data.h4, data.s4); - xtn(data.b8, data.h8); - break; - } - default: { - OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); - } - } + case ov::element::f32: { + switch (dst_prc) { + case ov::element::f16: { + fcvtn(data.h4, data.s4); + break; + } + case ov::element::i32: { + fcvtns(data.s, data.s); + break; + } + case ov::element::i8: { + fcvtms(data.s, data.s); + xtn(data.h4, data.s4); + xtn(data.b8, data.h8); + break; + } + case ov::element::u8: { + fcvtmu(data.s, data.s); + xtn(data.h4, data.s4); + xtn(data.b8, data.h8); break; } default: { - OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); + OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + + src_prc.to_string()); + } } - } - } - - switch (dst_prc) { - case ov::element::f16: { - str(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); - break; - } - case ov::element::f32: - case ov::element::i32: { - str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); - break; - } - case ov::element::i8: - case ov::element::u8: { - str(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); break; } default: { - OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_ptr is " + src_prc.to_string()); + OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); } + } + } + + switch (dst_prc) { + case ov::element::f16: { + str(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + case ov::element::f32: + case ov::element::i32: { + str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + case ov::element::i8: + case ov::element::u8: { + str(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + default: { + OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_ptr is " + src_prc.to_string()); + } } } @@ -503,99 +523,94 @@ void jit_uni_eltwise_generic::store_scalar(const XReg& ptr, const int32_t ptr_offset) { if (src_prc != dst_prc) { switch (src_prc) { - case ov::element::f32: { - switch (dst_prc) { - case ov::element::f16: { - fcvt(Xbyak_aarch64::HReg(data.getIdx()), data); - break; - } - case ov::element::i32: { - fcvtns(data, data); - break; - } - case ov::element::i8: { - TReg vec_data(data.getIdx()); - fcvtms(vec_data.s, vec_data.s); - xtn(vec_data.h4, vec_data.s4); - xtn(vec_data.b8, vec_data.h8); - break; - } - case ov::element::u8: { - TReg vec_data(data.getIdx()); - fcvtmu(vec_data.s, vec_data.s); - xtn(vec_data.h4, vec_data.s4); - xtn(vec_data.b8, vec_data.h8); - break; - } - default: { - OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); - } - } + case ov::element::f32: { + switch (dst_prc) { + case ov::element::f16: { + fcvt(Xbyak_aarch64::HReg(data.getIdx()), data); + break; + } + case ov::element::i32: { + fcvtns(data, data); + break; + } + case ov::element::i8: { + TReg vec_data(data.getIdx()); + fcvtms(vec_data.s, vec_data.s); + xtn(vec_data.h4, vec_data.s4); + xtn(vec_data.b8, vec_data.h8); + break; + } + case ov::element::u8: { + TReg vec_data(data.getIdx()); + fcvtmu(vec_data.s, vec_data.s); + xtn(vec_data.h4, vec_data.s4); + xtn(vec_data.b8, vec_data.h8); break; } default: { - OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); + OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + + src_prc.to_string()); + } } - } - } - - switch (dst_prc) { - case ov::element::f16: { - str(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); - break; - } - case ov::element::i32: - case ov::element::f32: { - str(data, Xbyak_aarch64::ptr(ptr, ptr_offset)); - break; - } - case ov::element::i8: - case ov::element::u8: { - str(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); break; } default: { - OPENVINO_THROW("dst_prc " + src_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); + OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string()); } + } + } + + switch (dst_prc) { + case ov::element::f16: { + str(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + case ov::element::i32: + case ov::element::f32: { + str(data, Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + case ov::element::i8: + case ov::element::u8: { + str(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset)); + break; + } + default: { + OPENVINO_THROW("dst_prc " + src_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string()); + } } } struct EltwiseEmitterContext { std::shared_ptr emitter; - dnnl::impl::cpu::aarch64::jit_generator *host; + dnnl::impl::cpu::aarch64::jit_generator* host; dnnl::impl::cpu::aarch64::cpu_isa_t host_isa; const EltwiseData& opData; ov::element::Type exec_prc; }; -template +template struct EltwiseEmitter { void operator()(EltwiseEmitterContext& ctx) { ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, ctx.exec_prc); } }; -template<> +template <> struct EltwiseEmitter { void operator()(EltwiseEmitterContext& ctx) { - ctx.emitter = std::make_shared(ctx.host, - ctx.host_isa, - ctx.opData.alpha, - ctx.exec_prc); + ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, ctx.opData.alpha, ctx.exec_prc); } }; -template<> +template <> struct EltwiseEmitter { void operator()(EltwiseEmitterContext& ctx) { - ctx.emitter = std::make_shared(ctx.host, - ctx.host_isa, - ctx.opData.alpha, - ctx.exec_prc); + ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, ctx.opData.alpha, ctx.exec_prc); } }; -template<> +template <> struct EltwiseEmitter { void operator()(EltwiseEmitterContext& ctx) { ctx.emitter = std::make_shared(ctx.host, @@ -606,7 +621,7 @@ struct EltwiseEmitter { } }; -template<> +template <> struct EltwiseEmitter { void operator()(EltwiseEmitterContext& ctx) { ctx.emitter = std::make_shared(ctx.host, @@ -618,7 +633,7 @@ struct EltwiseEmitter { } }; -template<> +template <> struct EltwiseEmitter { void operator()(EltwiseEmitterContext& ctx) { ctx.emitter = std::make_shared(ctx.host, @@ -630,57 +645,56 @@ struct EltwiseEmitter { }; template -std::shared_ptr jit_uni_eltwise_generic::create_eltwise_emitter(const EltwiseData& data, const ov::element::Type& exec_prec) { - EltwiseEmitterContext ctx = { - nullptr, - this, - isa, - data, - exec_prec - }; - - OV_SWITCH(intel_cpu, EltwiseEmitter, ctx, data.algo, - OV_CASE(Algorithm::EltwiseAbs, ov::intel_cpu::aarch64::jit_abs_emitter), - OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter), - OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter), - OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter), - OV_CASE(Algorithm::EltwiseElu, ov::intel_cpu::aarch64::jit_elu_emitter), - OV_CASE(Algorithm::EltwiseEqual, ov::intel_cpu::aarch64::jit_equal_emitter), - OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter), - OV_CASE(Algorithm::EltwiseFloor, ov::intel_cpu::aarch64::jit_floor_emitter), - OV_CASE(Algorithm::EltwiseFloorMod, ov::intel_cpu::aarch64::jit_floor_mod_emitter), - OV_CASE(Algorithm::EltwiseCeiling, ov::intel_cpu::aarch64::jit_ceiling_emitter), - OV_CASE(Algorithm::EltwiseHswish, ov::intel_cpu::aarch64::jit_hswish_emitter), - OV_CASE(Algorithm::EltwiseIsFinite, ov::intel_cpu::aarch64::jit_is_finite_emitter), - OV_CASE(Algorithm::EltwiseIsInf, ov::intel_cpu::aarch64::jit_is_inf_emitter), - OV_CASE(Algorithm::EltwiseLessEqual, ov::intel_cpu::aarch64::jit_less_equal_emitter), - OV_CASE(Algorithm::EltwiseLogicalAnd, ov::intel_cpu::aarch64::jit_logical_and_emitter), - OV_CASE(Algorithm::EltwiseLogicalOr, ov::intel_cpu::aarch64::jit_logical_or_emitter), - OV_CASE(Algorithm::EltwiseLogicalNot, ov::intel_cpu::aarch64::jit_logical_not_emitter), - OV_CASE(Algorithm::EltwiseLogicalXor, ov::intel_cpu::aarch64::jit_logical_xor_emitter), - OV_CASE(Algorithm::EltwiseIsNaN, ov::intel_cpu::aarch64::jit_is_nan_emitter), - OV_CASE(Algorithm::EltwiseMaximum, ov::intel_cpu::aarch64::jit_maximum_emitter), - OV_CASE(Algorithm::EltwiseMinimum, ov::intel_cpu::aarch64::jit_minimum_emitter), - OV_CASE(Algorithm::EltwiseMish, ov::intel_cpu::aarch64::jit_mish_emitter), - OV_CASE(Algorithm::EltwiseGeluErf, ov::intel_cpu::aarch64::jit_gelu_erf_emitter), - OV_CASE(Algorithm::EltwiseGeluTanh, ov::intel_cpu::aarch64::jit_gelu_tanh_emitter), - OV_CASE(Algorithm::EltwiseGreater, ov::intel_cpu::aarch64::jit_greater_emitter), - OV_CASE(Algorithm::EltwiseGreaterEqual, ov::intel_cpu::aarch64::jit_greater_equal_emitter), - OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter), - OV_CASE(Algorithm::EltwiseMod, ov::intel_cpu::aarch64::jit_mod_emitter), - OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter), - OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter), - OV_CASE(Algorithm::EltwisePrelu, ov::intel_cpu::aarch64::jit_prelu_emitter), - OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, ov::intel_cpu::aarch64::jit_round_half_away_from_zero_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfToEven, ov::intel_cpu::aarch64::jit_round_half_to_even_emitter), - OV_CASE(Algorithm::EltwiseSelect, ov::intel_cpu::aarch64::jit_select_emitter), - OV_CASE(Algorithm::EltwiseSigmoid, ov::intel_cpu::aarch64::jit_sigmoid_emitter), - OV_CASE(Algorithm::EltwiseSoftSign, ov::intel_cpu::aarch64::jit_soft_sign_emitter), - OV_CASE(Algorithm::EltwiseSqrt, ov::intel_cpu::aarch64::jit_sqrt_emitter), - OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter), - OV_CASE(Algorithm::EltwiseSwish, ov::intel_cpu::aarch64::jit_swish_emitter), - OV_CASE(Algorithm::EltwiseTanh, ov::intel_cpu::aarch64::jit_tanh_emitter)); +std::shared_ptr jit_uni_eltwise_generic::create_eltwise_emitter(const EltwiseData& data, + const ov::element::Type& exec_prec) { + EltwiseEmitterContext ctx = {nullptr, this, isa, data, exec_prec}; + + OV_SWITCH( + intel_cpu, + EltwiseEmitter, + ctx, + data.algo, + OV_CASE(Algorithm::EltwiseAbs, ov::intel_cpu::aarch64::jit_abs_emitter), + OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter), + OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter), + OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter), + OV_CASE(Algorithm::EltwiseElu, ov::intel_cpu::aarch64::jit_elu_emitter), + OV_CASE(Algorithm::EltwiseEqual, ov::intel_cpu::aarch64::jit_equal_emitter), + OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter), + OV_CASE(Algorithm::EltwiseFloor, ov::intel_cpu::aarch64::jit_floor_emitter), + OV_CASE(Algorithm::EltwiseFloorMod, ov::intel_cpu::aarch64::jit_floor_mod_emitter), + OV_CASE(Algorithm::EltwiseCeiling, ov::intel_cpu::aarch64::jit_ceiling_emitter), + OV_CASE(Algorithm::EltwiseHswish, ov::intel_cpu::aarch64::jit_hswish_emitter), + OV_CASE(Algorithm::EltwiseIsFinite, ov::intel_cpu::aarch64::jit_is_finite_emitter), + OV_CASE(Algorithm::EltwiseIsInf, ov::intel_cpu::aarch64::jit_is_inf_emitter), + OV_CASE(Algorithm::EltwiseLessEqual, ov::intel_cpu::aarch64::jit_less_equal_emitter), + OV_CASE(Algorithm::EltwiseLogicalAnd, ov::intel_cpu::aarch64::jit_logical_and_emitter), + OV_CASE(Algorithm::EltwiseLogicalOr, ov::intel_cpu::aarch64::jit_logical_or_emitter), + OV_CASE(Algorithm::EltwiseLogicalNot, ov::intel_cpu::aarch64::jit_logical_not_emitter), + OV_CASE(Algorithm::EltwiseLogicalXor, ov::intel_cpu::aarch64::jit_logical_xor_emitter), + OV_CASE(Algorithm::EltwiseIsNaN, ov::intel_cpu::aarch64::jit_is_nan_emitter), + OV_CASE(Algorithm::EltwiseMaximum, ov::intel_cpu::aarch64::jit_maximum_emitter), + OV_CASE(Algorithm::EltwiseMinimum, ov::intel_cpu::aarch64::jit_minimum_emitter), + OV_CASE(Algorithm::EltwiseMish, ov::intel_cpu::aarch64::jit_mish_emitter), + OV_CASE(Algorithm::EltwiseGeluErf, ov::intel_cpu::aarch64::jit_gelu_erf_emitter), + OV_CASE(Algorithm::EltwiseGeluTanh, ov::intel_cpu::aarch64::jit_gelu_tanh_emitter), + OV_CASE(Algorithm::EltwiseGreater, ov::intel_cpu::aarch64::jit_greater_emitter), + OV_CASE(Algorithm::EltwiseGreaterEqual, ov::intel_cpu::aarch64::jit_greater_equal_emitter), + OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter), + OV_CASE(Algorithm::EltwiseMod, ov::intel_cpu::aarch64::jit_mod_emitter), + OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter), + OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter), + OV_CASE(Algorithm::EltwisePrelu, ov::intel_cpu::aarch64::jit_prelu_emitter), + OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, ov::intel_cpu::aarch64::jit_round_half_away_from_zero_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfToEven, ov::intel_cpu::aarch64::jit_round_half_to_even_emitter), + OV_CASE(Algorithm::EltwiseSelect, ov::intel_cpu::aarch64::jit_select_emitter), + OV_CASE(Algorithm::EltwiseSigmoid, ov::intel_cpu::aarch64::jit_sigmoid_emitter), + OV_CASE(Algorithm::EltwiseSoftSign, ov::intel_cpu::aarch64::jit_soft_sign_emitter), + OV_CASE(Algorithm::EltwiseSqrt, ov::intel_cpu::aarch64::jit_sqrt_emitter), + OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter), + OV_CASE(Algorithm::EltwiseSwish, ov::intel_cpu::aarch64::jit_swish_emitter), + OV_CASE(Algorithm::EltwiseTanh, ov::intel_cpu::aarch64::jit_tanh_emitter)); if (!ctx.emitter) OPENVINO_THROW("Unsupported operation type '" + algToString(data.algo) + "' for Eltwise emitter"); @@ -746,16 +760,16 @@ void jit_uni_eltwise_generic::apply_post_ops() { namespace { -template +template struct SupportedPrecisions { - void operator()(std::set> &precisions) { + void operator()(std::set>& precisions) { precisions = T::get_supported_precisions(); } }; static void set_intersection(const std::set>& precisions1, - const std::set>& precisions2, - std::set>& intersection) { + const std::set>& precisions2, + std::set>& intersection) { std::map intersection_types; for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) { @@ -773,7 +787,7 @@ static void set_intersection(const std::set>& precisi intersection.insert(std::vector(it->second, it->first)); } } -} // namespace +} // namespace ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_number, const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS], @@ -792,16 +806,14 @@ ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_nu supported_precision_intersection = prcs_intersect; } - static const element::Type exec_precisions_priority[] = { - element::f16, - element::f32 - }; + static const element::Type exec_precisions_priority[] = {element::f16, element::f32}; for (const auto prc : exec_precisions_priority) { - if (std::any_of( - supported_precision_intersection.begin(), - supported_precision_intersection.end(), - [&prc](const std::vector& precisions) { return std::find(precisions.begin(), precisions.end(), prc) != precisions.end(); })) { + if (std::any_of(supported_precision_intersection.begin(), + supported_precision_intersection.end(), + [&prc](const std::vector& precisions) { + return std::find(precisions.begin(), precisions.end(), prc) != precisions.end(); + })) { exec_prc = prc; break; } @@ -824,48 +836,51 @@ ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_nu std::set> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) { std::set> precisions; - OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo, - OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter), - OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter), - OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), - OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter), - OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), - OV_CASE(Algorithm::EltwiseElu, jit_elu_emitter), - OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), - OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), - OV_CASE(Algorithm::EltwiseFloor, jit_floor_emitter), - OV_CASE(Algorithm::EltwiseFloorMod, jit_floor_mod_emitter), - OV_CASE(Algorithm::EltwiseCeiling, jit_ceiling_emitter), - OV_CASE(Algorithm::EltwiseGeluErf, jit_gelu_erf_emitter), - OV_CASE(Algorithm::EltwiseGeluTanh, jit_gelu_tanh_emitter), - OV_CASE(Algorithm::EltwiseGreater, jit_greater_emitter), - OV_CASE(Algorithm::EltwiseGreaterEqual, jit_greater_equal_emitter), - OV_CASE(Algorithm::EltwiseHswish, jit_hswish_emitter), - OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter), - OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter), - OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter), - OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter), - OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter), - OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter), - OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), - OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter), - OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter), - OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter), - OV_CASE(Algorithm::EltwiseMish, jit_mish_emitter), - OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter), - OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), - OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), - OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), - OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, jit_round_half_away_from_zero_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfToEven, jit_round_half_to_even_emitter), - OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), - OV_CASE(Algorithm::EltwiseSigmoid, jit_sigmoid_emitter), - OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter), - OV_CASE(Algorithm::EltwiseSqrt, jit_sqrt_emitter), - OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter), - OV_CASE(Algorithm::EltwiseSwish, jit_swish_emitter), - OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter)); + OV_SWITCH(intel_cpu, + SupportedPrecisions, + precisions, + algo, + OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter), + OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter), + OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), + OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter), + OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), + OV_CASE(Algorithm::EltwiseElu, jit_elu_emitter), + OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), + OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), + OV_CASE(Algorithm::EltwiseFloor, jit_floor_emitter), + OV_CASE(Algorithm::EltwiseFloorMod, jit_floor_mod_emitter), + OV_CASE(Algorithm::EltwiseCeiling, jit_ceiling_emitter), + OV_CASE(Algorithm::EltwiseGeluErf, jit_gelu_erf_emitter), + OV_CASE(Algorithm::EltwiseGeluTanh, jit_gelu_tanh_emitter), + OV_CASE(Algorithm::EltwiseGreater, jit_greater_emitter), + OV_CASE(Algorithm::EltwiseGreaterEqual, jit_greater_equal_emitter), + OV_CASE(Algorithm::EltwiseHswish, jit_hswish_emitter), + OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter), + OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter), + OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter), + OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter), + OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter), + OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter), + OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), + OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter), + OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter), + OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter), + OV_CASE(Algorithm::EltwiseMish, jit_mish_emitter), + OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter), + OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), + OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), + OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), + OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, jit_round_half_away_from_zero_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfToEven, jit_round_half_to_even_emitter), + OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), + OV_CASE(Algorithm::EltwiseSigmoid, jit_sigmoid_emitter), + OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter), + OV_CASE(Algorithm::EltwiseSqrt, jit_sqrt_emitter), + OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter), + OV_CASE(Algorithm::EltwiseSwish, jit_swish_emitter), + OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter)); if (precisions.empty()) OPENVINO_THROW("Unsupported operation type for Eltwise emitter"); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index 8f18a9815b4fe4..1bf64d096e4a84 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -4,10 +4,11 @@ #pragma once -#include - #include + #include +#include + #include "nodes/executors/eltwise.hpp" // TODO: handle x64 headers more accurate and remove undef later @@ -24,12 +25,11 @@ #include -#include "utils/general_utils.h" -#include "utils/cpu_utils.hpp" - -#include "emitters/plugin/aarch64/jit_emitter.hpp" #include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp" +#include "emitters/plugin/aarch64/jit_emitter.hpp" #include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp" +#include "utils/cpu_utils.hpp" +#include "utils/general_utils.h" namespace ov { namespace intel_cpu { @@ -154,7 +154,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { OPENVINO_THROW("source vector ptr register " + std::to_string(idx) + " is not supported"); } - static const std::vector src_gprs = { 19, 20, 21, 22, 25, 26, 27 }; + static const std::vector src_gprs = {19, 20, 21, 22, 25, 26, 27}; return XReg(src_gprs[idx]); } @@ -192,8 +192,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // 24 | src // 25-31 | [not used] - - TReg vmm_dst {9}; + TReg vmm_dst{9}; inline TReg get_vmm_reg(const uint32_t idx) { if (idx > MAX_ELTWISE_INPUTS) { @@ -230,10 +229,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { const int32_t ptr_offset = 0); void store_vector(const XReg& ptr, - const TReg& data, - const ov::element::Type& src_prc, - const ov::element::Type& dst_prc, - const int32_t ptr_offset = 0); + const TReg& data, + const ov::element::Type& src_prc, + const ov::element::Type& dst_prc, + const int32_t ptr_offset = 0); void store_scalar(const XReg& ptr, const SReg& data, @@ -264,6 +263,6 @@ class eltwise_precision_helper { static std::set> get_supported_precisions(const Algorithm& algo); }; -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp index 0ae6a4ad2c45ff..28e17854f46b08 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp @@ -6,102 +6,95 @@ namespace ov { namespace intel_cpu { - GemmKernel::GemmKernel(size_t M, - size_t N, - size_t K, - bool b_transposed, - ov::element::Type inType) +GemmKernel::GemmKernel(size_t M, size_t N, size_t K, bool b_transposed, ov::element::Type inType) : M(M), N(N), K(K), b_transposed(b_transposed) { - if (!one_of(inType, ov::element::f32, ov::element::f16, ov::element::bf16)) - THROW_ERROR("brgemm kernel only supports bf16, f16 and f32"); - - if (inType == ov::element::f32) { - format = arm_compute::Format::F32; - } else if (inType == ov::element::f16) { - format = arm_compute::Format::F16; - } else if (inType == ov::element::bf16) { - format = arm_compute::Format::BFLOAT16; - } - + if (!one_of(inType, ov::element::f32, ov::element::f16, ov::element::bf16)) + THROW_ERROR("brgemm kernel only supports bf16, f16 and f32"); + + if (inType == ov::element::f32) { + format = arm_compute::Format::F32; + } else if (inType == ov::element::f16) { + format = arm_compute::Format::F16; + } else if (inType == ov::element::bf16) { + format = arm_compute::Format::BFLOAT16; + } - aclGemmKernel = std::make_unique(); + aclGemmKernel = std::make_unique(); +} + +arm_compute::Status GemmKernel::executeGemm(void* a, + void* b, + arm_compute::TensorInfo& dstInfo, + arm_compute::Tensor& dstTensor, + arm_compute::Strides aStrides, + arm_compute::Strides bStrides, + void* c, + float alpha, + float beta, + arm_compute::Strides* outStrides, + void* out) { + aInfo.init(shapeCast({M, N}), + format, + aStrides, + size_t(0), + (size_t)(M * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); + + arm_compute::TensorShape bShape; + if (b_transposed) + bShape = shapeCast({K, N}); + else + bShape = shapeCast({N, K}); + + bInfo.init(bShape, + format, + bStrides, + size_t(0), + (size_t)(K * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); + + aTensor.allocator()->init(aInfo); + bTensor.allocator()->init(bInfo); + + if (c != nullptr) { + cInfo.init(shapeCast({M, K}), format); + cTensor.allocator()->init(cInfo); } - arm_compute::Status GemmKernel::executeGemm(void *a, - void *b, - arm_compute::TensorInfo& dstInfo, - arm_compute::Tensor& dstTensor, - arm_compute::Strides aStrides, - arm_compute::Strides bStrides, - void *c, - float alpha, - float beta, - arm_compute::Strides* outStrides, - void* out) { - aInfo.init( - shapeCast({M, N}), + if (outStrides != nullptr) + dstInfo.init( + shapeCast({M, K}), format, - aStrides, + *outStrides, size_t(0), - (size_t)(M * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); - - arm_compute::TensorShape bShape; - if (b_transposed) - bShape = shapeCast({K, N}); - else - bShape = shapeCast({N, K}); - - bInfo.init( - bShape, - format, - bStrides, - size_t(0), - (size_t)(K * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); - - aTensor.allocator()->init(aInfo); - bTensor.allocator()->init(bInfo); - - if (c != nullptr) { - cInfo.init(shapeCast({M, K}), format); - cTensor.allocator()->init(cInfo); - } - - if (outStrides != nullptr) - dstInfo.init( - shapeCast({M, K}), - format, - *outStrides, - size_t(0), - (size_t)(M * K * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); - else - dstInfo.init(shapeCast({M, K}), format); - - dstTensor.allocator()->init(dstInfo); - - aTensor.allocator()->import_memory(reinterpret_cast(a)); - bTensor.allocator()->import_memory(reinterpret_cast(b)); - cTensor.allocator()->import_memory(reinterpret_cast(c)); - - if (out == nullptr) - dstTensor.allocator()->allocate(); - else - dstTensor.allocator()->import_memory(out); - - if (b_transposed) - aclGemmInfo.set_pretranspose_B(true); - - auto status = aclGemmKernel->validate(&aInfo, &bInfo, &cInfo, &dstInfo, 1.0, 0.0, aclGemmInfo); - - if (c == nullptr) - aclGemmKernel->configure(&aTensor, &bTensor, nullptr, &dstTensor, alpha, beta, aclGemmInfo); - else - aclGemmKernel->configure(&aTensor, &bTensor, &cTensor, &dstTensor, alpha, beta, aclGemmInfo); - aclGemmKernel->run(); - - return status; - } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file + (size_t)(M * K * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); + else + dstInfo.init(shapeCast({M, K}), format); + + dstTensor.allocator()->init(dstInfo); + + aTensor.allocator()->import_memory(reinterpret_cast(a)); + bTensor.allocator()->import_memory(reinterpret_cast(b)); + cTensor.allocator()->import_memory(reinterpret_cast(c)); + + if (out == nullptr) + dstTensor.allocator()->allocate(); + else + dstTensor.allocator()->import_memory(out); + + if (b_transposed) + aclGemmInfo.set_pretranspose_B(true); + + auto status = aclGemmKernel->validate(&aInfo, &bInfo, &cInfo, &dstInfo, 1.0, 0.0, aclGemmInfo); + + if (c == nullptr) + aclGemmKernel->configure(&aTensor, &bTensor, nullptr, &dstTensor, alpha, beta, aclGemmInfo); + else + aclGemmKernel->configure(&aTensor, &bTensor, &cTensor, &dstTensor, alpha, beta, aclGemmInfo); + aclGemmKernel->run(); + + return status; +} +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp index 620f42f239cbbb..06a26743e0b2a4 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp @@ -4,21 +4,17 @@ #pragma once #include #include -#include "nodes/executors/acl/acl_utils.hpp" -#include "utils/general_utils.h" -#include "arm_compute/runtime/NEON/NEFunctions.h" #include "arm_compute/core/Types.h" +#include "arm_compute/runtime/NEON/NEFunctions.h" +#include "nodes/executors/acl/acl_utils.hpp" +#include "utils/general_utils.h" namespace ov { namespace intel_cpu { class GemmKernel { public: - GemmKernel(size_t M, - size_t N, - size_t K, - bool b_transposed = false, - ov::element::Type inType = ov::element::f32); + GemmKernel(size_t M, size_t N, size_t K, bool b_transposed = false, ov::element::Type inType = ov::element::f32); arm_compute::Status executeGemm(void* a, void* b, @@ -48,5 +44,5 @@ class GemmKernel { arm_compute::GEMMInfo aclGemmInfo; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.cpp index 746b556314ce3a..f63981078616d5 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.cpp @@ -1,20 +1,19 @@ // Copyright (C) 2020-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 - #include "convert_group_conv.hpp" #include - -#include "openvino/opsets/opset1.hpp" #include + #include "openvino/core/rt_info.hpp" +#include "openvino/opsets/opset1.hpp" ov::intel_cpu::ConvertGroupConvolution::ConvertGroupConvolution() { auto gconv = ov::pass::pattern::wrap_type(); ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { - enum Inputs {Data, Weights}; + enum Inputs { Data, Weights }; auto gconv = std::dynamic_pointer_cast(m.get_match_root()); if (!gconv) { return false; @@ -26,31 +25,32 @@ ov::intel_cpu::ConvertGroupConvolution::ConvertGroupConvolution() { // Weights layout GOIYX int64_t groups = gconv->get_input_shape(Inputs::Weights)[0]; - if (data_shape[channel_axis].is_dynamic() || - output_shape[channel_axis].is_dynamic()) { + if (data_shape[channel_axis].is_dynamic() || output_shape[channel_axis].is_dynamic()) { return false; } if (groups == data_shape[channel_axis].get_length() && - groups == output_shape[channel_axis].get_length()) { // depthwise case + groups == output_shape[channel_axis].get_length()) { // depthwise case return false; } ov::NodeVector replace_nodes; - auto split_weights = std::make_shared(gconv->input_value(Inputs::Weights), - ov::opset8::Constant::create(ov::element::i64, ov::Shape{}, {0}), - groups); + auto split_weights = std::make_shared( + gconv->input_value(Inputs::Weights), + ov::opset8::Constant::create(ov::element::i64, ov::Shape{}, {0}), + groups); replace_nodes.push_back(split_weights); - auto axis = ov::opset8::Constant::create(ov::element::i64, ov::Shape{}, {channel_axis}); + auto axis = ov::opset8::Constant::create(ov::element::i64, ov::Shape{}, {channel_axis}); auto split = std::make_shared(gconv->input_value(Inputs::Data), axis, groups); replace_nodes.push_back(split); ov::NodeVector concat_inputs; for (int64_t g = 0; g < groups; g++) { auto out = split->output(g); - auto filter = std::make_shared(split_weights->output(g), - ov::opset8::Constant::create(ov::element::i64, ov::Shape{}, {0})); + auto filter = std::make_shared( + split_weights->output(g), + ov::opset8::Constant::create(ov::element::i64, ov::Shape{}, {0})); auto conv = std::make_shared(out, filter, gconv->get_strides(), diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.hpp index eadbf5ff9cd7e2..55c1ecde2aae10 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.hpp @@ -3,8 +3,8 @@ #pragma once -#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" /* * Description: @@ -13,7 +13,7 @@ * equals to number of groups. * * Before: - * + * * +--------------+ +---------------+ * | Input tensor | | Kernel tensor | * +-----------+--+ +-+-------------+ @@ -25,9 +25,9 @@ * +------v------+ * | Result | * +-------------+ - * + * * After: - * + * * +--------------+ +--------------+ +---------------+ +--------------+ * | Input tensor | | Constant (1) | | Kernel tensor | | Constant (0) | * +-----------+--+ +-+------------+ +-----------+---+ +-+------------+ @@ -53,13 +53,13 @@ * +----------v------------+ * | Result | * +-----------------------+ - * + * */ namespace ov { namespace intel_cpu { -class ConvertGroupConvolution: public ov::pass::MatcherPass { +class ConvertGroupConvolution : public ov::pass::MatcherPass { public: OPENVINO_RTTI("ConvertGroupConvolution", "0"); ConvertGroupConvolution(); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.cpp index 72228d89cb6ab2..94771fa57ed2cd 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.cpp @@ -1,14 +1,13 @@ // Copyright (C) 2020-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 - #include "convert_group_conv1d.hpp" #include - -#include "openvino/opsets/opset1.hpp" #include + #include "openvino/core/rt_info.hpp" +#include "openvino/opsets/opset1.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" template @@ -26,17 +25,20 @@ ov::matcher_pass_callback ov::intel_cpu::ConvertConv1DBase::convert_conv1d_to_co return false; } - auto input = conv->input_value(0); + auto input = conv->input_value(0); auto weights = conv->input_value(1); auto weights2d_shape = weights.get_shape(); weights2d_shape.push_back(1); - auto w_shape = std::make_shared(ov::element::i64, ov::Shape{weights2d_shape.size()}, weights2d_shape); + auto w_shape = std::make_shared(ov::element::i64, + ov::Shape{weights2d_shape.size()}, + weights2d_shape); auto getUnsqueeze = [&](const ov::Output& node) { auto rank = node.get_partial_shape().rank().get_length(); - return std::make_shared(node, - ov::opset1::Constant::create(ov::element::i64, ov::Shape{1}, {rank})); + return std::make_shared( + node, + ov::opset1::Constant::create(ov::element::i64, ov::Shape{1}, {rank})); }; auto input2d = getUnsqueeze(input); @@ -63,16 +65,16 @@ ov::matcher_pass_callback ov::intel_cpu::ConvertConv1DBase::convert_conv1d_to_co ov::intel_cpu::ConvertConv1D::ConvertConv1D() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::any_input()}), - "ConvertConvolutionToArm"); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::any_input()}), + "ConvertConvolutionToArm"); register_matcher(m, convert_conv1d_to_conv2d()); } ov::intel_cpu::ConvertGroupConv1D::ConvertGroupConv1D() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::any_input()}), - "ConvertGroupConvolutionToArm"); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::any_input()}), + "ConvertGroupConvolutionToArm"); register_matcher(m, convert_conv1d_to_conv2d()); } \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp index f1edd9d363df10..5674514eeb8e64 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp @@ -9,11 +9,11 @@ * Description: * ConvertConv1DBase detects 1D Convolution / GroupConvolution and replaces * it with the sequence Unsqueeze - 2D Convolution / GroupConvolution - Squeeze. - * Unsqueeze adds the additional dimension to Convolution inputs and Squeeze + * Unsqueeze adds the additional dimension to Convolution inputs and Squeeze * removes the additional dimension from the Convolution output. * * Before: - * + * * +--------------+ +---------------+ * | Input tensor | | Kernel tensor | * +-----------+--+ +-+-------------+ @@ -25,9 +25,9 @@ * +------v------+ * | Result | * +-------------+ - * + * * After: - * + * * +--------------+ +--------------+ +---------------+ +--------------+ * | Input tensor | | Constant (1) | | Kernel tensor | | Constant (1) | * +-----------+--+ +-+------------+ +-----------+---+ +-+------------+ @@ -47,25 +47,25 @@ * +------v------+ * | Result | * +-------------+ - * + * */ namespace ov { namespace intel_cpu { -class ConvertConv1DBase: public ov::pass::MatcherPass { +class ConvertConv1DBase : public ov::pass::MatcherPass { protected: OPENVINO_RTTI("ConvertConv1DBase", "0"); template ov::matcher_pass_callback convert_conv1d_to_conv2d(); }; -class ConvertConv1D: public ConvertConv1DBase { +class ConvertConv1D : public ConvertConv1DBase { public: OPENVINO_RTTI("ConvertConv1D", "0"); ConvertConv1D(); }; -class ConvertGroupConv1D: public ConvertConv1DBase { +class ConvertGroupConv1D : public ConvertConv1DBase { public: OPENVINO_RTTI("ConvertGroupConv1D", "0"); ConvertGroupConv1D(); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp index ff5632cb0a5e8f..287aad96b14c08 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp @@ -1,7 +1,6 @@ // Copyright (C) 2020-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 - #include "convert_reduce_multi_axis.hpp" #include "openvino/core/rt_info.hpp" @@ -41,7 +40,7 @@ ov::matcher_pass_callback ov::intel_cpu::ConvertReduceMultiAxisBase::convert_red std::shared_ptr node = input0.get_node_shared_ptr(); auto output = input0; bool keepDims = reduce->get_keep_dims(); - //axes should be sorted in descending order if keepDims is false to be keep axis within data shape + // axes should be sorted in descending order if keepDims is false to be keep axis within data shape if (!keepDims) { sort(axes.begin(), axes.end(), std::greater()); } @@ -61,28 +60,32 @@ ov::matcher_pass_callback ov::intel_cpu::ConvertReduceMultiAxisBase::convert_red ov::intel_cpu::ConvertReduceProd::ConvertReduceProd() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::wrap_type()}), "ConvertReduceProd"); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::wrap_type()}), + "ConvertReduceProd"); register_matcher(m, convert_reduce()); } ov::intel_cpu::ConvertReduceMin::ConvertReduceMin() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::wrap_type()}), "ConvertReduceMin"); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::wrap_type()}), + "ConvertReduceMin"); register_matcher(m, convert_reduce()); } ov::intel_cpu::ConvertReduceMax::ConvertReduceMax() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::wrap_type()}), "ConvertReduceMax"); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::wrap_type()}), + "ConvertReduceMax"); register_matcher(m, convert_reduce()); } ov::intel_cpu::ConvertReduceSum::ConvertReduceSum() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::wrap_type()}), "ConvertReduceSum"); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::wrap_type()}), + "ConvertReduceSum"); register_matcher(m, convert_reduce()); } diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp index 38a1f8a9c0c601..8e5fd1e38b605a 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp @@ -3,8 +3,8 @@ #pragma once -#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" /* * Description: @@ -13,7 +13,7 @@ * is replaced with a sequence of single-axe Reduce operations. * * Before: - * + * * +--------------+ +-------------------+ * | Data | | Axes tensor [A,B] | * +-----------+--+ +-+-----------------+ @@ -25,9 +25,9 @@ * +------v------+ * | Result | * +-------------+ - * + * * After: - * + * * +-------------+ +---------------+ * | Data | | Axes scalar A | * +---------+---+ +----+----------+ @@ -43,44 +43,44 @@ * +-------v---------+ * | Result | * +-----------------+ - * + * */ namespace ov { namespace intel_cpu { -class ConvertReduceMultiAxisBase: public ov::pass::MatcherPass { +class ConvertReduceMultiAxisBase : public ov::pass::MatcherPass { public: OPENVINO_RTTI("ConvertReduceMultiAxisBase", "0"); template ov::matcher_pass_callback convert_reduce(); }; -class ConvertReduceProd: public ConvertReduceMultiAxisBase { +class ConvertReduceProd : public ConvertReduceMultiAxisBase { public: OPENVINO_RTTI("ConvertReduceProd", "0"); ConvertReduceProd(); }; -class ConvertReduceMin: public ConvertReduceMultiAxisBase { +class ConvertReduceMin : public ConvertReduceMultiAxisBase { public: OPENVINO_RTTI("ConvertReduceMin", "0"); ConvertReduceMin(); }; -class ConvertReduceMax: public ConvertReduceMultiAxisBase { +class ConvertReduceMax : public ConvertReduceMultiAxisBase { public: OPENVINO_RTTI("ConvertReduceMax", "0"); ConvertReduceMax(); }; -class ConvertReduceSum: public ConvertReduceMultiAxisBase { +class ConvertReduceSum : public ConvertReduceMultiAxisBase { public: OPENVINO_RTTI("ConvertReduceSum", "0"); ConvertReduceSum(); }; -class ConvertReduceMultiAxis: public ov::pass::GraphRewrite { +class ConvertReduceMultiAxis : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("ConvertReduceMultiAxis", "0"); ConvertReduceMultiAxis() { @@ -91,5 +91,5 @@ class ConvertReduceMultiAxis: public ov::pass::GraphRewrite { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.cpp index def7250dd5b938..50530159e0875e 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.cpp @@ -1,7 +1,6 @@ // Copyright (C) 2020-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 - #include "convert_reduce_no_keep_dims.hpp" #include "openvino/core/rt_info.hpp" @@ -29,9 +28,10 @@ ov::matcher_pass_callback ov::intel_cpu::ConvertReduceNoKeepDimsBase::convert_re template ov::intel_cpu::ConvertReduction::ConvertReduction() { auto m = std::make_shared( - ov::pass::pattern::wrap_type({ov::pass::pattern::any_input(), - ov::pass::pattern::wrap_type()}), "ConvertReduction"); - register_matcher(m, convert_reduce()); + ov::pass::pattern::wrap_type( + {ov::pass::pattern::any_input(), ov::pass::pattern::wrap_type()}), + "ConvertReduction"); + register_matcher(m, convert_reduce()); } template class ov::intel_cpu::ConvertReduction; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.hpp index 2f3c7d19726513..ea4128ea265e42 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.hpp @@ -3,10 +3,10 @@ #pragma once -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "openvino/pass/graph_rewrite.hpp" #include "openvino/op/util/arithmetic_reductions_keep_dims.hpp" #include "openvino/op/util/logical_reduction_keep_dims.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" /* * Description: @@ -15,7 +15,7 @@ * which removes undesired dimensions. * * Before: - * + * * +--------------+ +-----------------+ * | Data | | Axes tensor | * +-----------+--+ +-+---------------+ @@ -23,9 +23,9 @@ * +---------------------------+ * | Reduce (keepDims = false) | * +---------------------------+ - * + * * After: - * + * * +--------------+ +-----------------+ * | Data | | Axes tensor | * +-----------+--+ +-+------------+--+ @@ -37,13 +37,13 @@ * +--------v------v-+ * | Squeeze | * +-----------------+ - * + * */ namespace ov { namespace intel_cpu { -class ConvertReduceNoKeepDimsBase: public ov::pass::MatcherPass { +class ConvertReduceNoKeepDimsBase : public ov::pass::MatcherPass { public: OPENVINO_RTTI("ConvertReduceNoKeepDims", "0"); template @@ -51,14 +51,13 @@ class ConvertReduceNoKeepDimsBase: public ov::pass::MatcherPass { }; template -class ConvertReduction: public ConvertReduceNoKeepDimsBase { +class ConvertReduction : public ConvertReduceNoKeepDimsBase { public: OPENVINO_RTTI("ConvertReduction", "0"); ConvertReduction(); }; - -class ConvertReduceNoKeepDims: public ov::pass::GraphRewrite { +class ConvertReduceNoKeepDims : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("ConvertReduceNoKeepDims", "0"); ConvertReduceNoKeepDims() { @@ -67,5 +66,5 @@ class ConvertReduceNoKeepDims: public ov::pass::GraphRewrite { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.cpp index 76bc2c4f7c1df8..338bb1820fa7f4 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.cpp @@ -1,11 +1,10 @@ // Copyright (C) 2020-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 - #include "mish_decomposition.hpp" -#include "openvino/opsets/opset4.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/opsets/opset4.hpp" ov::intel_cpu::MishDecomposition::MishDecomposition() { auto mish = ov::pass::pattern::wrap_type(); @@ -17,7 +16,9 @@ ov::intel_cpu::MishDecomposition::MishDecomposition() { } auto exp = std::make_shared(mish->input_value(0)); - auto add = std::make_shared(exp, opset4::Constant::create(mish->get_output_element_type(0), ov::Shape{}, {1.0f})); + auto add = std::make_shared( + exp, + opset4::Constant::create(mish->get_output_element_type(0), ov::Shape{}, {1.0f})); auto log = std::make_shared(add); auto tanh = std::make_shared(log); auto mul = std::make_shared(mish->input_value(0), tanh); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.hpp index 0b93cea24053f5..75b45dca468dc7 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/mish_decomposition.hpp @@ -3,17 +3,17 @@ #pragma once -#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" namespace ov { namespace intel_cpu { -class MishDecomposition: public ov::pass::MatcherPass { +class MishDecomposition : public ov::pass::MatcherPass { public: OPENVINO_RTTI("MishDecomposition", "0"); MishDecomposition(); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp index c38d088ef95e7b..25b10d55ca8165 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp @@ -3,35 +3,32 @@ // #include "snippets_mark_skipped.hpp" -#include "snippets/pass/tokenization.hpp" +#include "itt.hpp" #include "snippets/op/subgraph.hpp" +#include "snippets/pass/tokenization.hpp" #include "snippets/utils/utils.hpp" - -#include "transformations/utils/utils.hpp" #include "transformations/utils.hpp" -#include "utils/general_utils.h" +#include "transformations/utils/utils.hpp" #include "utils/cpu_utils.hpp" - -#include "itt.hpp" - +#include "utils/general_utils.h" namespace ov { namespace intel_cpu { namespace { static const int DEFAULT_AXIS = 1; -NodeFusingType GetNodeFusingType(const std::shared_ptr &node) { - auto &rt = node->get_rt_info(); +NodeFusingType GetNodeFusingType(const std::shared_ptr& node) { + auto& rt = node->get_rt_info(); const auto rinfo = rt.find("MayBeFusedInPlugin"); if (rinfo == rt.end()) return NodeFusingType::NotSet; return rinfo->second.as(); } -void SetNodeFusingType(const std::shared_ptr &node, NodeFusingType nodeType) { - auto &rt = node->get_rt_info(); +void SetNodeFusingType(const std::shared_ptr& node, NodeFusingType nodeType) { + auto& rt = node->get_rt_info(); rt["MayBeFusedInPlugin"] = nodeType; } -std::vector getContinuableChains(const std::shared_ptr &node) { +std::vector getContinuableChains(const std::shared_ptr& node) { std::vector result; for (const auto& input : node->inputs()) { const auto parent = input.get_source_output().get_node_shared_ptr(); @@ -42,12 +39,12 @@ std::vector getContinuableChains(const std::shared_ptr &node) { +int getNumNonConstInputs(const std::shared_ptr& node) { int num_non_const_inputs = 0; - for (const auto &parent_out : node->input_values()) { + for (const auto& parent_out : node->input_values()) { const auto parent = parent_out.get_node_shared_ptr(); if (ov::is_type(parent)) { - for (const auto &grandparent_out : parent->input_values()) { + for (const auto& grandparent_out : parent->input_values()) { const auto grandparent = grandparent_out.get_node_shared_ptr(); if (!ov::is_type(grandparent)) num_non_const_inputs++; @@ -65,40 +62,35 @@ bool isFullyConnected(const std::shared_ptr& node) { const auto out_weights = node->input_value(1); const auto rank_a = out_activations.get_partial_shape().rank(); const auto rank_w = out_weights.get_partial_shape().rank(); - return out_weights.get_partial_shape().is_static() && - rank_a.is_static() && rank_w.is_static() && - rank_a.get_length() != 1 && rank_w.get_length() != 1 && - rank_w.get_length() <= 3 && + return out_weights.get_partial_shape().is_static() && rank_a.is_static() && rank_w.is_static() && + rank_a.get_length() != 1 && rank_w.get_length() != 1 && rank_w.get_length() <= 3 && ov::op::util::is_on_constant_path(out_weights); } -bool SupportsFusingWithConvolution_Simple(const std::shared_ptr &node) { +bool SupportsFusingWithConvolution_Simple(const std::shared_ptr& node) { // Note: some other operations support this fusing (SoftPlus, Sqrt). // Skip them here, when they are supported by Snippets ARM. Ticket: 141170. - return ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node); + return ov::is_type(node) || ov::is_type(node) || + ov::is_type(node) || ov::is_type(node) || + ov::is_type(node) || ov::is_type(node); } // Convolution is a special case, since it supports peculiar fusings -bool isSuitableConvolutionParent(const std::shared_ptr &node) { - const bool is_suitable_node = ov::is_type(node) || - ov::is_type(node); +bool isSuitableConvolutionParent(const std::shared_ptr& node) { + const bool is_suitable_node = + ov::is_type(node) || ov::is_type(node); // has a single output, connected to a single child const auto out = node->outputs(); const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1); return is_suitable_node && has_only_child; } -bool isSuitableBinaryConvolutionParent(const std::shared_ptr &node) { +bool isSuitableBinaryConvolutionParent(const std::shared_ptr& node) { const bool is_suitable_node = ov::is_type(node); // has a single output, connected to a single child const auto out = node->outputs(); const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1); return is_suitable_node && has_only_child; } -bool isSuitableMiscParent(const std::shared_ptr &node) { +bool isSuitableMiscParent(const std::shared_ptr& node) { const bool is_suitable_node = ov::is_type(node) || ov::is_type(node) || ov::is_type(node); @@ -108,35 +100,34 @@ bool isSuitableMiscParent(const std::shared_ptr &node) { return is_suitable_node && has_only_child; } // Matmul is a special case, since it supports simple + bias fusings -bool isSuitableMatMulParent(const std::shared_ptr &node) { +bool isSuitableMatMulParent(const std::shared_ptr& node) { const bool is_suitable_node = ov::is_type(node); // has a single output, connected to a single child const auto out = node->outputs(); const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1); return is_suitable_node && has_only_child; } -bool isSuitablePoolChild(const std::shared_ptr &node) { +bool isSuitablePoolChild(const std::shared_ptr& node) { const bool is_suitable_node = ov::is_type(node); // has a single output, connected to a single child const auto out = node->outputs(); const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1); return is_suitable_node && has_only_child; } -bool isSuitableChildForFusingSimple(const std::shared_ptr &node) { +bool isSuitableChildForFusingSimple(const std::shared_ptr& node) { // Note: Fusing child is allowed to have several users, but that must be the end of the chain return SupportsFusingWithConvolution_Simple(node) && getNumNonConstInputs(node) == 1; } -bool isSuitableChildForFusingBias(const std::shared_ptr &node, int fusingAxis) { +bool isSuitableChildForFusingBias(const std::shared_ptr& node, int fusingAxis) { if (!ov::is_type(node)) return false; - auto is_suitable_parent = [](const std::shared_ptr &node) { - return (ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node)); + auto is_suitable_parent = [](const std::shared_ptr& node) { + return (ov::is_type(node) || ov::is_type(node) || + ov::is_type(node)); }; - for (const auto &in : node->inputs()) { + for (const auto& in : node->inputs()) { const auto& parent_out = in.get_source_output(); const auto& parent = parent_out.get_node_shared_ptr(); const auto& parent_pshape = parent_out.get_partial_shape(); @@ -151,7 +142,8 @@ bool isSuitableChildForFusingBias(const std::shared_ptr &node, int f if (bias_pshape.is_dynamic()) break; const auto bias_shape_norm = getNormalizedDimsBySize(bias_pshape.get_shape(), parent_pshape.size()); - if (fusingAxis >= static_cast(bias_shape_norm.size()) || fusingAxis >= static_cast(parent_pshape.size()) || + if (fusingAxis >= static_cast(bias_shape_norm.size()) || + fusingAxis >= static_cast(parent_pshape.size()) || bias_shape_norm.size() != parent_pshape.size() || bias_shape_norm.size() < 2) break; if (parent_pshape[fusingAxis].is_dynamic()) @@ -165,21 +157,23 @@ bool isSuitableChildForFusingBias(const std::shared_ptr &node, int f } // Continue fusing chain of the passed type if the node has one child // Otherwise mark node as FusedTerminator (Fused, but fusing chain is interrupted) -void PropagateIfHasOnlyChild(const std::shared_ptr &node, NodeFusingType nodeType) { +void PropagateIfHasOnlyChild(const std::shared_ptr& node, NodeFusingType nodeType) { const auto out = node->outputs(); const bool has_only_child = out.size() == 1 && out[0].get_target_inputs().size() == 1; SetNodeFusingType(node, has_only_child ? nodeType : NodeFusingType::FusedTerminator); } -// todo: Skipping MultiSubGraphOp such as TensorIterator, Loop and If. Snippets might tokenize their bodies in the future. -// Note that the function is recurrent, since there might be multi-level MultiSubGraphOp, if(){if(){}}else{} for example. -void MarkSubgraphOpAsSkipped(const std::shared_ptr &node) { +// todo: Skipping MultiSubGraphOp such as TensorIterator, Loop and If. Snippets might tokenize their bodies in the +// future. +// Note that the function is recurrent, since there might be multi-level MultiSubGraphOp, if(){if(){}}else{} for +// example. +void MarkSubgraphOpAsSkipped(const std::shared_ptr& node) { if (ov::is_type(node)) { std::vector> models{}; // Covers TensorIterator and Loop if (auto s = ov::as_type_ptr(node)) { models.push_back(s->get_function()); - // Add new multi-body subgraph op here - } else if (auto if_op = ov::as_type_ptr(node)) { + // Add new multi-body subgraph op here + } else if (auto if_op = ov::as_type_ptr(node)) { models.push_back(if_op->get_then_body()); models.push_back(if_op->get_else_body()); } @@ -204,8 +198,8 @@ bool isSuitableConvert(const std::shared_ptr& node) { return true; }; auto isSuitableChild = [](const std::shared_ptr& node) { - for (const auto &out : node->outputs()) { - const auto &child = out.get_node_shared_ptr(); + for (const auto& out : node->outputs()) { + const auto& child = out.get_node_shared_ptr(); if (!ov::is_type(child)) return false; } @@ -215,16 +209,15 @@ bool isSuitableConvert(const std::shared_ptr& node) { } auto is_skipped_op(const std::shared_ptr& op) -> bool { - return ov::is_type(op) || - ov::is_type(op) || + return ov::is_type(op) || ov::is_type(op) || ov::is_type(op); } -} // namespace +} // namespace -bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { +bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr& m) { RUN_ON_MODEL_SCOPE(SnippetsMarkSkipped); int channelAxis = DEFAULT_AXIS; - for (auto &node : m->get_ordered_ops()) { + for (auto& node : m->get_ordered_ops()) { if (is_skipped_op(node)) continue; if (isSuitableConvolutionParent(node)) { @@ -255,14 +248,18 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { if (isSuitableChildForFusingBias(node, channelAxis)) { PropagateIfHasOnlyChild(node, fusingChainType); } else if (isSuitableChildForFusingSimple(node)) { -#if defined (OV_CPU_WITH_ACL) - if (one_of(fusingChainType, NodeFusingType::FusedWithConvolution, NodeFusingType::FusedWithBinaryConvolution)) { +#if defined(OV_CPU_WITH_ACL) + if (one_of(fusingChainType, + NodeFusingType::FusedWithConvolution, + NodeFusingType::FusedWithBinaryConvolution)) { PropagateIfHasOnlyChild(node, NodeFusingType::FusedTerminator); continue; } #endif PropagateIfHasOnlyChild(node, fusingChainType); - } else if (one_of(fusingChainType, NodeFusingType::FusedWithConvolution, NodeFusingType::FusedWithBinaryConvolution)) { + } else if (one_of(fusingChainType, + NodeFusingType::FusedWithConvolution, + NodeFusingType::FusedWithBinaryConvolution)) { if (isSuitablePoolChild(node)) { PropagateIfHasOnlyChild(node, fusingChainType); } @@ -279,5 +276,5 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { return true; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.hpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.hpp index 7fdc7244d21de2..2b17039d198bf3 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.hpp @@ -11,30 +11,34 @@ namespace intel_cpu { /** * @interface SnippetsMarkSkipped - * @brief Mark operations that should be ignored by snippets on tokenization stage. A typical example is eltwise operations - * that will be fused into convolutions on plugin side. + * @brief Mark operations that should be ignored by snippets on tokenization stage. A typical example is eltwise + * operations that will be fused into convolutions on plugin side. */ class SnippetsMarkSkipped : public ov::pass::ModelPass { public: OPENVINO_RTTI("SnippetsMarkSkipped", "0"); SnippetsMarkSkipped() : ModelPass() {} - bool run_on_model(const std::shared_ptr &) override; + bool run_on_model(const std::shared_ptr&) override; }; /* NotSet - not part of a fusing chain FusedTerminator - the node is fused, but the chain can't be continued FusedWithConvolution, FusedWithMisc - fusing chains with different continuation rules -IgnoredAfterInputs - node must be skipped, since can't be handled properly at this time. Also a continuable fusing chain. -Order of SnippetsNodeType is important!: +IgnoredAfterInputs - node must be skipped, since can't be handled properly at this time. Also a continuable fusing +chain. Order of SnippetsNodeType is important!: * SnippetsNodeType >= FusedTerminator is a Fused chain * SnippetsNodeType > FusedTerminator is a Fused chain that may be continued */ enum class NodeFusingType : int64_t { NotSet, FusedTerminator, - FusedWithConvolution, FusedWithBinaryConvolution, - FusedWithMatMul, FusedWithFC, FusedWithMisc}; + FusedWithConvolution, + FusedWithBinaryConvolution, + FusedWithMatMul, + FusedWithFC, + FusedWithMisc +}; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp index 967afe946e0793..a3c9a1c184d550 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp @@ -3,9 +3,10 @@ // #include "shape_inference.hpp" + #include "snippets/shape_inference/shape_infer_instances.hpp" -#include "transformations/snippets/common/op/fused_mul_add.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" +#include "transformations/snippets/common/op/fused_mul_add.hpp" namespace ov { namespace snippets { @@ -19,19 +20,31 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov return {}; } -#define SHAPE_INFER_PREDEFINED(OP, InferType) \ - { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared();} } -#define SHAPE_INFER_OP_SPECIFIC(OP) \ - { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } -#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \ - { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } +#define SHAPE_INFER_PREDEFINED(OP, InferType) \ + { \ + OP::get_type_info_static(), [](const std::shared_ptr& n) { \ + return std::make_shared(); \ + } \ + } +#define SHAPE_INFER_OP_SPECIFIC(OP) \ + { \ + OP::get_type_info_static(), [](const std::shared_ptr& n) { \ + return std::make_shared(n); \ + } \ + } +#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \ + { \ + OP::get_type_info_static(), [](const std::shared_ptr& n) { \ + return std::make_shared(n); \ + } \ + } -const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry { +const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry{ SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::SwishNode, PassThroughShapeInfer), }; #undef SHAPE_INFER_OP_SPECIFIC #undef SHAPE_INFER_PREDEFINED -} // namespace snippets -} // namespace ov +} // namespace snippets +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.hpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.hpp index 3e3a417bd9164b..1028674a081102 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.hpp @@ -18,11 +18,12 @@ class CPUShapeInferSnippetsFactory : public IShapeInferSnippetsFactory { protected: /** - * @brief get shape infer instances for operations from backend-specific opset - * @return register ShapeInferPtr - */ - ShapeInferPtr get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, const std::shared_ptr& op) const override; + * @brief get shape infer instances for operations from backend-specific opset + * @return register ShapeInferPtr + */ + ShapeInferPtr get_specific_op_shape_infer(const ov::DiscreteTypeInfo& key, + const std::shared_ptr& op) const override; }; -} // namespace snippets -} // namespace ov +} // namespace snippets +} // namespace ov