From 888856902ac79eb1ab58b405d79d96f7a6b43037 Mon Sep 17 00:00:00 2001 From: ashwins990 Date: Fri, 10 Jan 2025 12:00:39 +0530 Subject: [PATCH] Aarch64 paged attention enablement (#27841) This development is related to Feature Request : https://github.com/openvinotoolkit/openvino/issues/26422 ## Benchmarking Results Machine : Graviton 3 - 64 cores #### vLLM serving benchmark on ShareGPT dataset | Requests / sec | avg TTFT ( sec ) | avg TPOT ( sec ) | output / Total - Throughput ( tokens/ sec ) | | ----------------- | ------------------ | -------------------|------------------------------------------------| | 0.2 | 1.153 | 0.186 | 38.73 / 77.79 | | 0.5 | 2.083 | 0.482 | 94.10 / 187.92 | #### vLLM Throughput benchmark on ShareGPT dataset ![plot_aarch64_sve](https://github.com/user-attachments/assets/54c49326-2500-4744-9b4b-72c61f1db2af) ## vLLM with openvino backend Clone the [vLLM repo](https://github.com/vllm-project/vllm) Set inference precision as f32 before model compilation by setting ```Execution Mode``` to ```ACCURACY``` ``` // file_path : vllm/model_executor/model_loader/openvino.py import openvino.properties.hint as hints ov_core.set_property( "CPU", {hints.execution_mode: hints.ExecutionMode.ACCURACY}, ) ov_compiled = ov_core.compile_model(pt_model.model, ov_device) self.ov_request = ov_compiled.create_infer_request() ``` Note : If we don't set the inference precision as f32 it will take the f16 precision path. This can lead to Segmentation Fault [ in Aarch64 ] due to the presence of [ Optional Variable - Alibi param ] in transformation graph. Optional variables are graph nodes with empty shapes. After the above change, Follow this [link](https://docs.vllm.ai/en/v0.6.3.post1/getting_started/openvino-installation.html) and install vLLM from source. --- .../compile_flags/os_flags.cmake | 2 +- src/plugins/intel_cpu/CMakeLists.txt | 6 +- .../nodes/kernels/aarch64/brgemm_kernel.cpp | 333 ++++++++++++++++++ .../nodes/kernels/aarch64/brgemm_kernel.hpp | 105 ++++++ .../kernels/scaled_attn/attn_quant_kernel.hpp | 28 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 37 +- .../kernels/scaled_attn/mha_single_token.cpp | 4 +- .../kernels/scaled_attn/transpose_kernel.hpp | 114 ++++++ .../intel_cpu/src/nodes/paged_attn.cpp | 2 +- src/plugins/intel_cpu/src/nodes_factory.cpp | 2 + 10 files changed, 620 insertions(+), 13 deletions(-) create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.hpp diff --git a/cmake/developer_package/compile_flags/os_flags.cmake b/cmake/developer_package/compile_flags/os_flags.cmake index 7c50f6f6e7eb6b..759a9080188639 100644 --- a/cmake/developer_package/compile_flags/os_flags.cmake +++ b/cmake/developer_package/compile_flags/os_flags.cmake @@ -104,6 +104,7 @@ macro(ov_check_compiler_supports_sve flags) int main() { svfloat64_t a; a = svdup_n_f64(0); + (void)a; // to avoid warnings return 0; }") @@ -259,7 +260,6 @@ endmacro() macro(ov_arm_sve_optimization_flags flags) # Check for compiler SVE support ov_check_compiler_supports_sve("-march=armv8-a+sve") - if(OV_COMPILER_IS_INTEL_LLVM) message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index 2eebfe88a2c803..c6ccbcf375746a 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -298,21 +298,21 @@ cross_compiled_file(${TARGET_NAME} NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 ANY + ARCH AVX512F AVX2 SVE ANY src/nodes/kernels/scaled_attn/executor_pa.cpp API src/nodes/kernels/scaled_attn/executor_pa.hpp NAME make_pa_executor NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 ANY + ARCH AVX512F AVX2 SVE ANY src/nodes/kernels/scaled_attn/attn_memcpy.cpp API src/nodes/kernels/scaled_attn/attn_memcpy.hpp NAME attn_memcpy paged_attn_memcpy attn_memcpy2d_kernel NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 ANY + ARCH AVX512F AVX2 SVE ANY src/nodes/kernels/scaled_attn/attn_quant.cpp API src/nodes/kernels/scaled_attn/attn_quant.hpp NAME attn_quantkv paged_attn_quantkv attn_quant_u8 attn_dequant_u8 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.cpp new file mode 100644 index 00000000000000..59b54f47024adf --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.cpp @@ -0,0 +1,333 @@ +// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2024 FUJITSU LIMITED +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm_kernel.hpp" + +#include +#include + +#include "dnnl_extension_utils.h" +#include "utils/cpu_utils.hpp" + +using namespace dnnl::impl::cpu::aarch64; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::aarch64::matmul; + +#define THROW_ERROR(...) OPENVINO_THROW("brgemm executor Init Failure '", __VA_ARGS__) +namespace ov { +namespace intel_cpu { + +BrgemmKernel::BrgemmKernel(size_t M, + size_t N, + size_t K, + size_t lda, + size_t ldb, + size_t ldc, + bool b_transposed, + ov::element::Type inType, + bool b_accumulate) + : M(M), + K(K), + N(N), + lda(lda), + ldb(ldb), + ldc(ldc), + b_transposed(b_transposed), + inType(inType) { + // blocking M + M_blk = matmulOptimalM; + M_tail = M % M_blk; + kBlkStep = 4 / inType.size(); + size_t vlen; + vlen = mayiuse(sve_512) ? cpu_isa_traits::vlen + : mayiuse(sve_256) ? cpu_isa_traits::vlen + : cpu_isa_traits::vlen; + // blocking N + N_blk = std::max(N, vlen / inType.size()); + N_tail = N % N_blk; + + // blocking K + K_blk = K; + K_tail = K % K_blk; + // copied K must be round up by vlen / inType.size(), otherwise copy B kernel may access wrong memory + packedBSize = rnd_up(K, vlen / inType.size()) * rnd_up(N, N_blk) * inType.size(); + size_t brg0BaseIdx = std::numeric_limits::max(); + for (size_t m = 0; m < 2; m++) { + for (size_t k = 0; k < 2; k++) { + for (size_t n = 0; n < 2; n++) { + auto& brgemmCtx = brgCtxs[getBrgIdx(m, k, n)]; + + auto M_ = m ? M_tail : M < M_blk ? 0 : M_blk; + auto N_ = n ? N_tail : N - N_tail; + auto K_ = k ? K_tail : K - K % K_blk; + auto beta = (b_accumulate || (k && brgCtxs[getBrgIdx(m, 0, n)].K != 0)) ? 1.0f : 0.0f; + + brgemmCtx.M = M_; + brgemmCtx.N = N_; + brgemmCtx.K = K_; + brgemmCtx.LDA = k ? K_blk : lda; + brgemmCtx.LDB = b_transposed ? rnd_up(N, N_blk) : ldb; // b_transposed needs copy + brgemmCtx.LDC = ldc; + brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::ElementTypeToDataType(inType)); + brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::ElementTypeToDataType(inType)); + brgemmCtx.beta = beta; + + // don't create brgemm kernels for empty tiles + if (M_ != 0 && K_ != 0 && N_ != 0) { + if (brg0BaseIdx == std::numeric_limits::max()) + brg0BaseIdx = getBrgIdx(m, k, n); + init_brgemm(brgemmCtx, brgKernels[getBrgIdx(m, k, n)]); + } + } + } + } + + auto& brgemmCtx0 = brgCtxs[brg0BaseIdx]; + if (b_transposed) { + size_t b_stride = 0; + b_stride = ldb * inType.size(); + // K should use the original K + init_brgemm_copy_b(brgCopyBKernel, + N, + N_blk, + N_tail, + brgemmCtx0.LDB, + K, + brgemmCtx0.dt_in0, + brgemmCtx0.dt_in1, + b_transposed, + b_stride); + } +} + +const size_t BrgemmKernel::get_scratch_a_size() const { + return packedASize; +} + +const size_t BrgemmKernel::get_scratch_b_size() const { + return packedBSize; +} + +void BrgemmKernel::init_brgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel) { + brgemm_t brgDesc; + cpu_isa_t isa; + isa = mayiuse(sve_512) ? cpu_isa_t::sve_512 : mayiuse(sve_256) ? cpu_isa_t::sve_256 : cpu_isa_t::sve_128; + auto status = brgemm_desc_init(&brgDesc, + isa, + brgemm_addr, + ctx.dt_in0, + ctx.dt_in1, + ctx.transpose_a, + ctx.transpose_b, + brgemm_row_major, + 1.f, + ctx.beta, + ctx.LDA, + ctx.LDB, + ctx.LDC, + ctx.M, + ctx.N, + ctx.K, + nullptr); + if (status != dnnl_success) { + THROW_ERROR("cannot be executed due to invalid brgconv params"); + } + + brgemm_kernel_t* brgKernel_ = nullptr; + status = brgemm_kernel_create(&brgKernel_, brgDesc); + if (status != dnnl_success) { + THROW_ERROR("cannot be executed due to invalid brgconv params"); + } + brgKernel.reset(brgKernel_); +} +void BrgemmKernel::init_brgemm_copy_a( + std::unique_ptr& brgCopyKernel, + size_t K, + size_t K_blk, + size_t K_tail, + size_t LDA, + dnnl_data_type_t dt_in0, + bool transpose, + size_t copy_A_src_stride) { + brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_tag = dnnl_abcd; + brgCopyKernelConf.K = K; + brgCopyKernelConf.K_tail = K_tail; + brgCopyKernelConf.K_blk = K_blk; + brgCopyKernelConf.use_buffer_a_tail_only = false; + // padding K tail to K_blk, LDA is the stride for target tensor + brgCopyKernelConf.LDA = LDA; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.s8s8_compensation_required = false; + brgCopyKernelConf.wei_zp_type = dnnl::impl::cpu::aarch64::none; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::aarch64::none; + brgCopyKernelConf.src_dt = dt_in0; + brgCopyKernelConf.copy_A_src_stride = copy_A_src_stride; + brgCopyKernelConf.a_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); + // copied A has the same precision of original + brgCopyKernelConf.tr_a_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); + brgCopyKernelConf.transposed_A = transpose; + brgCopyKernelConf.isa = mayiuse(sve_512) ? cpu_isa_t::sve_512 + : mayiuse(sve_256) ? cpu_isa_t::sve_256 + : cpu_isa_t::sve_128; + + create_brgemm_matmul_copy_a(brgCopyKernel, &brgCopyKernelConf); +} + +void BrgemmKernel::init_brgemm_copy_b( + std::unique_ptr& brgCopyKernel, + size_t N, + size_t N_blk, + size_t N_tail, + size_t LDB, + size_t K, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + bool transpose, + size_t copy_B_wei_stride) { + brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_dt = dt_in0; + brgCopyKernelConf.wei_dt = dt_in1; + brgCopyKernelConf.wei_n_blk = N_blk; + brgCopyKernelConf.wei_tag = transpose ? dnnl_ba : dnnl_ab; + brgCopyKernelConf.copy_B_wei_stride = copy_B_wei_stride; + + // LDB here is for the target tensor, not source tensor + brgCopyKernelConf.LDB = LDB; + brgCopyKernelConf.N = N; + brgCopyKernelConf.N_tail = N_tail; + brgCopyKernelConf.N_blk = N_blk; + brgCopyKernelConf.K = K; + brgCopyKernelConf.K_blk = K; + brgCopyKernelConf.K_tail = 0; + brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; + brgCopyKernelConf.b_dt_sz = + DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.tr_b_dt_sz = + DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.req_wei_vnni_downconvert = false; + brgCopyKernelConf.isa = mayiuse(sve_512) ? cpu_isa_t::sve_512 + : mayiuse(sve_256) ? cpu_isa_t::sve_256 + : cpu_isa_t::sve_128; + + brgCopyKernelConf.has_zero_point_a = false; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::aarch64::none; + auto ret = create_brgemm_matmul_copy_b(brgCopyKernel, &brgCopyKernelConf); + if (ret != dnnl::impl::status_t::dnnl_success) + THROW_ERROR("cannot create_brgemm_matmul_copy_b kernel"); +} + +void BrgemmKernel::copy_buffer_b(void* b, void* scratch_b) { + auto ptr_b = reinterpret_cast(b); + auto ptr_scartch_b = reinterpret_cast(scratch_b); + if (brgCopyBKernel) { + for (size_t nb = 0; nb < div_up(N, N_blk); nb++) { + auto N_stride = b_transposed ? ldb : 1; + auto pCopyKernel0In = ptr_b + nb * N_blk * inType.size() * N_stride; + auto pCopyKernel0Out = ptr_scartch_b + nb * N_blk * kBlkStep * inType.size(); + + auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t(); + + const bool is_N_tail = (N - nb * N_blk < N_blk); + ctx.current_N_blk = is_N_tail ? N_tail : N_blk; + ctx.src = pCopyKernel0In; + ctx.tr_src = pCopyKernel0Out; + ctx.compensation_ptr = nullptr; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K; + (*brgCopyBKernel)(&ctx); + } + } +} + +void BrgemmKernel::executeGemm(bool is_M_tail, void* a, void* b, void* c, void* wsp, void* scratch_a) { + auto ptr_A = reinterpret_cast(a); + auto ptr_C = reinterpret_cast(c); + auto ptr_scartch_a = reinterpret_cast(scratch_a); + auto ptr_scartch_b = reinterpret_cast(b); + uint8_t* ptr_a_tail = nullptr; + + size_t brgIdx0 = getBrgIdx(0, 0, 0); + // The step for matrix A over main K dimension + size_t K0_step0 = brgCtxs[brgIdx0].K; + auto cur_M_blk = is_M_tail ? M_tail : M_blk; + if (brgCopyAKernel) { + // only copy tailed data; + size_t K_offset = K < K_blk ? 0 : K0_step0 * inType.size(); + auto pCopyKernelIn = ptr_A + K_offset; + auto pCopyKernelOut = ptr_scartch_a; + + auto ctx = jit_brgemm_matmul_copy_a_t::ctx_t(); + + ctx.current_M_blk = cur_M_blk; + ctx.zp_b_compensation_buffer_ptr = nullptr; + ctx.zp_a_compensation_result_ptr = nullptr; + ctx.zp_b_neg_value_ptr = nullptr; + ctx.zp_ab_comp_ptr = nullptr; + ctx.src = pCopyKernelIn; + ctx.tr_src = pCopyKernelOut; + ctx.current_K_start = 0; + ctx.current_K_blk = K % K_blk; + + (*brgCopyAKernel)(&ctx); + + ptr_a_tail = pCopyKernelOut; + } + size_t count_N = 0; + for (size_t n = 0; n < 2; n++) { + size_t count_K = 0; + for (size_t k = 0; k < 2; k++) { + size_t mIdx = is_M_tail ? 1 : 0; + auto& brgemmCtx = brgCtxs[getBrgIdx(mIdx, k, n)]; + if (brgemmCtx.K != 0 && brgemmCtx.N != 0 && brgemmCtx.M != 0) { + auto local_a_ptr = k > 0 ? ptr_a_tail : ptr_A; + auto B_stride = (k * count_K + n * count_N * kBlkStep) * inType.size(); + auto weight_ptr = ptr_scartch_b + B_stride; + auto C_stride = n * count_N * ov::element::f32.size(); + auto out_ptr = ptr_C + C_stride; + callBrgemm(brgemmCtx, brgKernels[getBrgIdx(mIdx, k, n)], local_a_ptr, weight_ptr, out_ptr, wsp); + // stride K, N if body kernel is executed. + if (k == 0) { + count_K = brgemmCtx.K * brgemmCtx.LDB; + } + if (n == 0) { + count_N = brgemmCtx.N; + } + } + } + } +} + +void BrgemmKernel::executeGemm(void* a, void* b, void* c, void* wsp, void* scratch_a, void* scratch_b) { + auto ptr_A = reinterpret_cast(a); + auto ptr_B = reinterpret_cast(b); + auto ptr_C = reinterpret_cast(c); + + copy_buffer_b(ptr_B, scratch_b); + + for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { + const bool is_M_tail = (M - mb * M_blk < M_blk); + auto ptr_a = ptr_A + (mb * M_blk * lda) * inType.size(); + auto ptr_c = ptr_C + (mb * M_blk * ldc) * ov::element::f32.size(); + executeGemm(is_M_tail, ptr_a, scratch_b, wsp, ptr_c, scratch_a); + } +} +void BrgemmKernel::callBrgemm(brgemmCtx& ctx, + std::unique_ptr& brgKernel, + const void* pin0, + const void* pin1, + void* pout, + void* wsp) { + brgemm_batch_element_t addr_batch; + addr_batch.ptr.A = pin0; + addr_batch.ptr.B = pin1; + brgemm_kernel_execute(brgKernel.get(), 1, &addr_batch, pout, wsp); +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.hpp new file mode 100644 index 00000000000000..06236ec1a9b775 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/brgemm_kernel.hpp @@ -0,0 +1,105 @@ +// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2024 FUJITSU LIMITED +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include + +namespace ov { +namespace intel_cpu { + +class BrgemmKernel { +public: + // Construct brgemm kernel for matmul (M, K) * (K, N)/(N, K)^T + // FP32 * FP32 -> FP32 + // lda is the leading dimension for A matrix + // ldb is the leading dimension for B matrix + // ldc is the leading dimension for C matrix + // b_transpose indicates wheter B matrix is transposed. + BrgemmKernel(size_t M, + size_t N, + size_t K, + size_t lda, + size_t ldb, + size_t ldc, + bool b_transposed = false, + ov::element::Type inType = ov::element::f32, + bool b_accumulate = false); + // execute all M + void executeGemm(void* a, void* b, void* c, void* wsp, void* scratch_a, void* scratch_b); + // execute by m_blk + void executeGemm(bool is_M_tail, void* a, void* b, void* c, void* wsp, void* scratch_a); + + void copy_buffer_b(void* b, void* scratch_b); + // bytes needed to place scratch buffer a + const size_t get_scratch_a_size() const; + // bytes needed to place scratch buffer b + const size_t get_scratch_b_size() const; + const size_t get_wsp_size() const { + return 4 * 1024; + } + +private: + size_t M = 0, M_blk = 0, M_tail = 0; + size_t K = 0, K_blk = 0, K_tail = 0, N = 0, N_blk = 0, N_tail = 0; + size_t lda = 0, ldb = 0, ldc = 0; + bool b_transposed = false; + size_t kBlkStep = 0; + size_t packedBSize = 0; + size_t packedASize = 0; + ov::element::Type inType; + static constexpr size_t MHA_BRGEMM_KERNELS_NUM = 8; + static constexpr size_t matmulOptimalM = 32; + struct brgemmCtx { + size_t M = 0, N = 0, K = 0, LDA = 0, LDB = 0, LDC = 0; + dnnl_data_type_t dt_in0 = dnnl_data_type_undef; + dnnl_data_type_t dt_in1 = dnnl_data_type_undef; + bool transpose_a = false; + bool transpose_b = false; + float beta = 0.0f; + }; + brgemmCtx brgCtxs[MHA_BRGEMM_KERNELS_NUM]; + std::unique_ptr brgKernels[MHA_BRGEMM_KERNELS_NUM]; + std::unique_ptr brgCopyAKernel; + std::unique_ptr brgCopyBKernel; + size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) { + return mIdx * 4 + kIdx * 2 + nIdx; + } + void init_brgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel); + // LDA, LDB is used for stride of target memory + void init_brgemm_copy_a( + std::unique_ptr& brgCopyKernel, + size_t K, + size_t K_blk, + size_t K_tail, + size_t LDA, + dnnl_data_type_t dt_in0, + bool transpose = false, + size_t copy_A_src_stride = 0); + + void init_brgemm_copy_b( + std::unique_ptr& brgCopyKernel, + size_t N, + size_t N_blk, + size_t N_tail, + size_t LDB, + size_t K, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + bool transpose = false, + size_t copy_B_wei_stride = 0); + + void callBrgemm(brgemmCtx& ctx, + std::unique_ptr& brgKernel, + const void* pin0, + const void* pin1, + void* pout, + void* wsp); +}; +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 761a136eda2997..c7b2b13123c7d5 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -11,6 +11,9 @@ #include #include +#if defined(HAVE_SVE) +# include "arm_sve.h" +#endif namespace ov { namespace Extensions { @@ -138,7 +141,30 @@ void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, f } } +#if defined(HAVE_SVE) +void inline attn_dequant_u8_kernel(const uint8_t* src, float* dst, size_t n, float scale, float zp) { + size_t i = 0; + uint8_t* src_nc = const_cast(src); + size_t nvec = n / svcntw(); + size_t lvec = svcntw(); + auto sve_pg = svptrue_b32(); + for (size_t j = 0; j < nvec; ++j) { + svuint32_t reg1 = svld1ub_u32(sve_pg, src_nc + j * lvec); + svfloat32_t reg2 = svcvt_f32_u32_z(sve_pg, reg1); + svfloat32_t reg3 = svsub_f32_z(sve_pg, reg2, svdup_n_f32(zp)); + svfloat32_t reg4 = svmul_f32_z(sve_pg, reg3, svdup_n_f32(scale)); + svst1_f32(sve_pg, dst + j * lvec, reg4); + } + i = n - n % svcntw(); + for (; i < n; ++i) { + float tmp = src_nc[i]; + tmp = (tmp - zp) * scale; + dst[i] = tmp; + } +} +#endif + } // namespace XARCH } // namespace Cpu } // namespace Extensions -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index ce95d825d44f50..dec4650dc548c1 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -20,13 +20,18 @@ #include "common.hpp" #include "executor_pa.hpp" #include "executor_pa_common.hpp" -#include "nodes/kernels/x64/brgemm_kernel.hpp" #include "openvino/core/parallel.hpp" #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/type/float16.hpp" #include "softmax_kernel.hpp" #include "transpose_kernel.hpp" #include "utils/plain_tensor.hpp" +#if defined(OPENVINO_ARCH_X86_64) +# include "nodes/kernels/x64/brgemm_kernel.hpp" +#elif defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE) +# include "arm_sve.h" +# include "nodes/kernels/aarch64/brgemm_kernel.hpp" +#endif namespace ov { namespace Extensions { @@ -37,7 +42,7 @@ using namespace ov; using namespace ov::intel_cpu; // currently depends on brgemm which only support x64 -#ifdef OPENVINO_ARCH_X86_64 +#if defined(OPENVINO_ARCH_X86_64) || (defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE)) # if defined(HAVE_AVX2) || defined(HAVE_AVX512F) @@ -1241,8 +1246,10 @@ struct MHAHelper { std::vector> _wv_gemm; // will accumulate C buffer std::vector> _wv_gemm_acc; - // second token +// second token +# if defined(OPENVINO_ARCH_X86_64) std::shared_ptr _gemv; +# endif ov::element::Type _fastpath_valid_prec = ov::element::undefined; // second token for bhl loop PlainTensor _weight_bhl; @@ -1347,6 +1354,7 @@ struct MHAHelper { _wv_scratch_a.resize( {_nthr, _wv_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); +# if defined(OPENVINO_ARCH_X86_64) if ((S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && precision_of::value == ov::element::bf16 && @@ -1363,6 +1371,7 @@ struct MHAHelper { static_cast(block_size), _fastpath_valid_prec); } +# endif } if (init_alibi_lookup && (!_alibi_lookup || _alibi_lookup.m_dims[0] < kv_len)) { @@ -1564,6 +1573,7 @@ struct MHAHelper { size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { +# if defined(OPENVINO_ARCH_X86_64) if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) { _gemv->tile_config(); for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { @@ -1578,6 +1588,7 @@ struct MHAHelper { } _gemv->tile_release(); } else { +# endif for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { @@ -1591,7 +1602,9 @@ struct MHAHelper { } } } +# if defined(OPENVINO_ARCH_X86_64) } +# endif for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { @@ -1711,6 +1724,7 @@ struct MHAHelper { auto pk = pk_in_blocks * _block_size; if (pk < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pk_in_blocks]; +# if defined(OPENVINO_ARCH_X86_64) if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) { _gemv->tile_config(); for (size_t pq = 0; pq < q_len; pq++) { @@ -1722,6 +1736,7 @@ struct MHAHelper { } _gemv->tile_release(); } else { +# endif for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(b, h, pq), @@ -1732,7 +1747,9 @@ struct MHAHelper { _key_group_size); } } +# if defined(OPENVINO_ARCH_X86_64) } +# endif } }; @@ -2409,7 +2426,7 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ size_t value_group_size) { std::shared_ptr executor; -#ifdef OPENVINO_ARCH_X86_64 +#if defined(OPENVINO_ARCH_X86_64) if (data_type == ov::element::bf16) { # if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { @@ -2479,8 +2496,18 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ } else { OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); } +#elif (defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE)) + if (data_type == ov::element::f32) { + if (key_cache_type == ov::element::u8 && value_cache_type == ov::element::u8) { + executor = + std::make_shared>(key_group_size, value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type and value_cache_type of u8 is only support"); + } + } + #else - OPENVINO_THROW("make_pa_executor: only support x64 platform"); + OPENVINO_THROW("make_pa_executor: only support x64 platform or ARM with SVE support"); #endif return executor; } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index f42f15ce1e065a..27782970323bdd 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -50,7 +50,7 @@ using namespace ov; #endif template -void cvt_copy(TA* dst, TB* src, size_t n) { +static void cvt_copy(TA* dst, TB* src, size_t n) { size_t i = 0; #if defined(HAVE_AVX512F) for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { @@ -1561,4 +1561,4 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, } // namespace XARCH } // namespace Cpu } // namespace Extensions -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp index 93d7db55107951..c89e807bae7fa5 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp @@ -11,6 +11,10 @@ #include "common.hpp" #include "openvino/core/type/element_type.hpp" +#if defined(HAVE_SVE) +# include "arm_sve.h" +#endif + namespace ov { namespace Extensions { namespace Cpu { @@ -593,6 +597,116 @@ inline void transpose_16xK_kernel(float* dst, T* src, size_t K, size_t dst_strid } } +#elif defined(HAVE_SVE) +template +inline void transpose_16x16_kernel(TDST* dst, TSRC* src, size_t dst_stride, size_t src_stride) { + for (size_t i = 0; i < 16; i++) { + for (size_t j = 0; j < 16; j++) { + dst[i * dst_stride + j] = static_cast(src[i + j * src_stride]); + } + } +} + +template +inline void transpose_16xK_kernel(TDST* dst, TSRC* src, size_t K, size_t dst_stride, size_t src_stride) { + for (size_t i = 0; i < K; i++) { + for (size_t j = 0; j < 16; j++) { + dst[i * dst_stride + j] = static_cast(src[i + j * src_stride]); + } + } +} + +inline void transpose_8x8_kernel(float* src, size_t ld_src, float* dst, size_t ld_dst) { + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 + // b: b0 b1 b2 b3 b4 b5 b6 b7 + // c: c0 c1 c2 c3 c4 c5 c6 c7 + // d: d0 d1 d2 d3 d4 d5 d6 d7 + // e: e0 e1 e2 e3 e4 e5 e6 e7 + // f: f0 f1 f2 f3 f4 f5 f6 f7 + // g: g0 g1 g2 g3 g4 g5 g6 g7 + // h: h0 h1 h2 h3 h4 h5 h6 h7 + svfloat32_t a = svld1_f32(svptrue_b8(), &src[0 * ld_src]); + svfloat32_t b = svld1_f32(svptrue_b8(), &src[1 * ld_src]); + svfloat32_t c = svld1_f32(svptrue_b8(), &src[2 * ld_src]); + svfloat32_t d = svld1_f32(svptrue_b8(), &src[3 * ld_src]); + svfloat32_t e = svld1_f32(svptrue_b8(), &src[4 * ld_src]); + svfloat32_t f = svld1_f32(svptrue_b8(), &src[5 * ld_src]); + svfloat32_t g = svld1_f32(svptrue_b8(), &src[6 * ld_src]); + svfloat32_t h = svld1_f32(svptrue_b8(), &src[7 * ld_src]); + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 + // a2 b2 a3 b3 a6 b6 a7 b7 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + svfloat32_t ta = svtrn1_f32(a, b); + svfloat32_t tb = svtrn2_f32(a, b); + svfloat32_t tc = svtrn1_f32(c, d); + svfloat32_t td = svtrn2_f32(c, d); + svfloat32_t te = svtrn1_f32(e, f); + svfloat32_t tf = svtrn2_f32(e, f); + svfloat32_t tg = svtrn1_f32(g, h); + svfloat32_t th = svtrn2_f32(g, h); + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + a = svreinterpret_f32_f64(svtrn1_f64(svreinterpret_f64_f32(ta), svreinterpret_f64_f32(tc))); + b = svreinterpret_f32_f64(svtrn2_f64(svreinterpret_f64_f32(ta), svreinterpret_f64_f32(tc))); + c = svreinterpret_f32_f64(svtrn1_f64(svreinterpret_f64_f32(tb), svreinterpret_f64_f32(td))); + d = svreinterpret_f32_f64(svtrn2_f64(svreinterpret_f64_f32(tb), svreinterpret_f64_f32(td))); + e = svreinterpret_f32_f64(svtrn1_f64(svreinterpret_f64_f32(te), svreinterpret_f64_f32(tg))); + f = svreinterpret_f32_f64(svtrn2_f64(svreinterpret_f64_f32(te), svreinterpret_f64_f32(tg))); + g = svreinterpret_f32_f64(svtrn1_f64(svreinterpret_f64_f32(tf), svreinterpret_f64_f32(th))); + h = svreinterpret_f32_f64(svtrn2_f64(svreinterpret_f64_f32(tf), svreinterpret_f64_f32(th))); + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 e0 f0 g0 h0 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + svfloat32_t t1a = svext_f32(a, a, 4); + svfloat32_t t1b = svext_f32(b, b, 4); + svfloat32_t t1c = svext_f32(c, c, 4); + svfloat32_t t1d = svext_f32(d, d, 4); + ta = svext_f32(t1a, e, 4); + tb = svext_f32(t1b, f, 4); + tc = svext_f32(t1c, g, 4); + td = svext_f32(t1d, h, 4); + te = svsel_f32(svptrue_pat_b32(SV_VL4), t1a, e); + tf = svsel_f32(svptrue_pat_b32(SV_VL4), t1b, f); + tg = svsel_f32(svptrue_pat_b32(SV_VL4), t1c, g); + th = svsel_f32(svptrue_pat_b32(SV_VL4), t1d, h); + // Store the transposed result in destination + svst1_f32(svptrue_b8(), &dst[0 * ld_dst], ta); + svst1_f32(svptrue_b8(), &dst[1 * ld_dst], tc); + svst1_f32(svptrue_b8(), &dst[2 * ld_dst], tb); + svst1_f32(svptrue_b8(), &dst[3 * ld_dst], td); + svst1_f32(svptrue_b8(), &dst[4 * ld_dst], te); + svst1_f32(svptrue_b8(), &dst[5 * ld_dst], tg); + svst1_f32(svptrue_b8(), &dst[6 * ld_dst], tf); + svst1_f32(svptrue_b8(), &dst[7 * ld_dst], th); +} +template <> +inline void transpose_16x16_kernel(float* dst, float* src, size_t dst_stride, size_t src_stride) { + transpose_8x8_kernel(src, src_stride, dst, dst_stride); + transpose_8x8_kernel(src + 8, src_stride, dst + 8 * dst_stride, dst_stride); + transpose_8x8_kernel(src + 8 * src_stride, src_stride, dst + 8, dst_stride); + transpose_8x8_kernel(src + 8 * src_stride + 8, src_stride, dst + 8 * dst_stride + 8, dst_stride); +} + #else template diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index b1632c34ff6fa2..98be1a12517441 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -158,7 +158,7 @@ void PagedAttention::createPrimitive() { PagedAttentionKey key = {rtPrecision}; auto builder = [&](const PagedAttentionKey& key) -> std::shared_ptr { -#ifdef OPENVINO_ARCH_X86_64 +#if defined(OPENVINO_ARCH_X86_64) || (defined(OPENVINO_ARCH_ARM64)) // Since we are quantize only last dim it's safe to use the last dim of KV. auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index ff27a0e4246baf..400e4946312330 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -232,6 +232,8 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") { INTEL_CPU_NODE(MHA, Type::MHA); INTEL_CPU_NODE(PagedAttention, Type::PagedAttention); INTEL_CPU_NODE(RMSNorm, Type::RMS); +#elif defined(OPENVINO_ARCH_ARM64) + INTEL_CPU_NODE(PagedAttention, Type::PagedAttention); #endif }