diff --git a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp index 5eb7d0ceea4..c2c794f26d1 100644 --- a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp @@ -492,7 +492,7 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; - const bool flat = one_of(jcp.ic, 1, 2, 3); + const bool flat = one_of(jcp.ic, 1, 2, 3) && (jcp.src_tag != dat_tag_nCx8c); const bool mimo = !flat; bool args_ok = true diff --git a/src/cpu/x64/jit_sse41_convolution.hpp b/src/cpu/x64/jit_sse41_convolution.hpp index 15511df7973..0184919259a 100644 --- a/src/cpu/x64/jit_sse41_convolution.hpp +++ b/src/cpu/x64/jit_sse41_convolution.hpp @@ -85,7 +85,7 @@ struct jit_sse41_convolution_fwd_t : public primitive_t { && IMPLICATION(curr_dst_tag != dat_tag_nxc, dst_d.format_kind() == format_kind::any) && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); - const bool flat = utils::one_of(IC(), 1, 2, 3); + const bool flat = utils::one_of(IC(), 1, 2, 3) && (curr_src_tag != dat_tag_nCx8c); auto src_tag = is_data_layout_nxc ? dat_tag_nxc : flat ? dat_tag_ncx : dat_tag_nCx8c;