diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index d5f2e0e6bae..e2bfc23c13b 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -2499,13 +2499,6 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, mov(reg_rdb_loop, brg.rdb); L_aligned(rdb_loop_label, 64); { - const bool is_rd_tail = false; - gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, - is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); - - add(reg_aux_A, rdb_A_offset()); - add(reg_aux_B, rdb_B_offset()); - if (brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0)) { auto reg_local_ic = reg_aux_D; @@ -2529,10 +2522,6 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv - mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); - add(reg_local_ic, brg.rd_block); - mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic); - if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_, brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * sizeof(float)); @@ -2548,12 +2537,23 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, brg.src_scales_group_size, sizeof(float)); } + mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); + add(reg_local_ic, brg.rd_block); + mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic); + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]); mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]); } + const bool is_rd_tail = false; + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, + is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); + + add(reg_aux_A, rdb_A_offset()); + add(reg_aux_B, rdb_B_offset()); + dec(reg_rdb_loop); cmp(reg_rdb_loop, 0); } diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 3c4a19e6764..99c66c86569 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -397,9 +397,9 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_zero_points_offset = 0; int src_scales_offset = 0; if (jbgp.weights_decompression) { - wei_scales_offset = (ic / jbgp.wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc; - wei_zero_points_offset = ((ic / jbgp.wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc) * wei_zero_points_dt_size; - src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size) + (ic / jbgp.src_quant_group_size); + wei_scales_offset = wei_scales_oc_stride * oc; + wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; + src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); } auto ptr_D = dst + dst_off; @@ -423,10 +423,10 @@ status_t brgemm_inner_product_fwd_t::execute_forward( brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data, - scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0); + scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); } else { brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); } } @@ -500,9 +500,9 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_zero_points_offset = 0; int src_scales_offset = 0; if (jbgp.weights_decompression) { - wei_scales_offset = (ic / jbgp.wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc; - wei_zero_points_offset = ((ic / jbgp.wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc) * wei_zero_points_dt_size; - src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size) + (ic / jbgp.src_quant_group_size); + wei_scales_offset = wei_scales_oc_stride * oc; + wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; + src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); } auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get(); @@ -524,10 +524,10 @@ status_t brgemm_inner_product_fwd_t::execute_forward( nullptr, false, 1, false, false, dst_scales}; brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0); + (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); } else { brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); } } };