diff --git a/include/xsimd/arch/generic/xsimd_generic_memory.hpp b/include/xsimd/arch/generic/xsimd_generic_memory.hpp index fbe1bbc13..1bc61a098 100644 --- a/include/xsimd/arch/generic/xsimd_generic_memory.hpp +++ b/include/xsimd/arch/generic/xsimd_generic_memory.hpp @@ -627,7 +627,7 @@ namespace xsimd hi.store_aligned(buffer + real_batch::size); } - // store_compelx_unaligned + // store_complex_unaligned template XSIMD_INLINE void store_complex_unaligned(std::complex* dst, batch, A> const& src, requires_arch) noexcept { @@ -665,6 +665,141 @@ namespace xsimd } } + // transpose + template ::size == 8, void>::type> + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + auto l0 = zip_lo(matrix_begin[0], matrix_begin[1]); + auto l1 = zip_lo(matrix_begin[2], matrix_begin[3]); + auto l2 = zip_lo(matrix_begin[4], matrix_begin[5]); + auto l3 = zip_lo(matrix_begin[6], matrix_begin[7]); + + auto l4 = zip_lo(bit_cast>(l0), bit_cast>(l1)); + auto l5 = zip_lo(bit_cast>(l2), bit_cast>(l3)); + + auto l6 = zip_hi(bit_cast>(l0), bit_cast>(l1)); + auto l7 = zip_hi(bit_cast>(l2), bit_cast>(l3)); + + auto h0 = zip_hi(matrix_begin[0], matrix_begin[1]); + auto h1 = zip_hi(matrix_begin[2], matrix_begin[3]); + auto h2 = zip_hi(matrix_begin[4], matrix_begin[5]); + auto h3 = zip_hi(matrix_begin[6], matrix_begin[7]); + + auto h4 = zip_lo(bit_cast>(h0), bit_cast>(h1)); + auto h5 = zip_lo(bit_cast>(h2), bit_cast>(h3)); + + auto h6 = zip_hi(bit_cast>(h0), bit_cast>(h1)); + auto h7 = zip_hi(bit_cast>(h2), bit_cast>(h3)); + + matrix_begin[0] = bit_cast>(zip_lo(bit_cast>(l4), bit_cast>(l5))); + matrix_begin[1] = bit_cast>(zip_hi(bit_cast>(l4), bit_cast>(l5))); + matrix_begin[2] = bit_cast>(zip_lo(bit_cast>(l6), bit_cast>(l7))); + matrix_begin[3] = bit_cast>(zip_hi(bit_cast>(l6), bit_cast>(l7))); + + matrix_begin[4] = bit_cast>(zip_lo(bit_cast>(h4), bit_cast>(h5))); + matrix_begin[5] = bit_cast>(zip_hi(bit_cast>(h4), bit_cast>(h5))); + matrix_begin[6] = bit_cast>(zip_lo(bit_cast>(h6), bit_cast>(h7))); + matrix_begin[7] = bit_cast>(zip_hi(bit_cast>(h6), bit_cast>(h7))); + } + + template ::size == 8, void>::type> + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + + template ::size == 8, void>::type> + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + auto l0 = zip_lo(matrix_begin[0], matrix_begin[1]); + auto l1 = zip_lo(matrix_begin[2], matrix_begin[3]); + auto l2 = zip_lo(matrix_begin[4], matrix_begin[5]); + auto l3 = zip_lo(matrix_begin[6], matrix_begin[7]); + auto l4 = zip_lo(matrix_begin[8], matrix_begin[9]); + auto l5 = zip_lo(matrix_begin[10], matrix_begin[11]); + auto l6 = zip_lo(matrix_begin[12], matrix_begin[13]); + auto l7 = zip_lo(matrix_begin[14], matrix_begin[15]); + + auto L0 = zip_lo(bit_cast>(l0), bit_cast>(l1)); + auto L1 = zip_lo(bit_cast>(l2), bit_cast>(l3)); + auto L2 = zip_lo(bit_cast>(l4), bit_cast>(l5)); + auto L3 = zip_lo(bit_cast>(l6), bit_cast>(l7)); + + auto m0 = zip_lo(bit_cast>(L0), bit_cast>(L1)); + auto m1 = zip_lo(bit_cast>(L2), bit_cast>(L3)); + auto m2 = zip_hi(bit_cast>(L0), bit_cast>(L1)); + auto m3 = zip_hi(bit_cast>(L2), bit_cast>(L3)); + + matrix_begin[0] = bit_cast>(zip_lo(bit_cast>(m0), bit_cast>(m1))); + matrix_begin[1] = bit_cast>(zip_hi(bit_cast>(m0), bit_cast>(m1))); + matrix_begin[2] = bit_cast>(zip_lo(bit_cast>(m2), bit_cast>(m3))); + matrix_begin[3] = bit_cast>(zip_hi(bit_cast>(m2), bit_cast>(m3))); + + auto L4 = zip_hi(bit_cast>(l0), bit_cast>(l1)); + auto L5 = zip_hi(bit_cast>(l2), bit_cast>(l3)); + auto L6 = zip_hi(bit_cast>(l4), bit_cast>(l5)); + auto L7 = zip_hi(bit_cast>(l6), bit_cast>(l7)); + + auto m4 = zip_lo(bit_cast>(L4), bit_cast>(L5)); + auto m5 = zip_lo(bit_cast>(L6), bit_cast>(L7)); + auto m6 = zip_hi(bit_cast>(L4), bit_cast>(L5)); + auto m7 = zip_hi(bit_cast>(L6), bit_cast>(L7)); + + matrix_begin[4] = bit_cast>(zip_lo(bit_cast>(m4), bit_cast>(m5))); + matrix_begin[5] = bit_cast>(zip_hi(bit_cast>(m4), bit_cast>(m5))); + matrix_begin[6] = bit_cast>(zip_lo(bit_cast>(m6), bit_cast>(m7))); + matrix_begin[7] = bit_cast>(zip_hi(bit_cast>(m6), bit_cast>(m7))); + + auto h0 = zip_hi(matrix_begin[0], matrix_begin[1]); + auto h1 = zip_hi(matrix_begin[2], matrix_begin[3]); + auto h2 = zip_hi(matrix_begin[4], matrix_begin[5]); + auto h3 = zip_hi(matrix_begin[6], matrix_begin[7]); + auto h4 = zip_hi(matrix_begin[8], matrix_begin[9]); + auto h5 = zip_hi(matrix_begin[10], matrix_begin[11]); + auto h6 = zip_hi(matrix_begin[12], matrix_begin[13]); + auto h7 = zip_hi(matrix_begin[14], matrix_begin[15]); + + auto H0 = zip_lo(bit_cast>(h0), bit_cast>(h1)); + auto H1 = zip_lo(bit_cast>(h2), bit_cast>(h3)); + auto H2 = zip_lo(bit_cast>(h4), bit_cast>(h5)); + auto H3 = zip_lo(bit_cast>(h6), bit_cast>(h7)); + + auto M0 = zip_lo(bit_cast>(H0), bit_cast>(H1)); + auto M1 = zip_lo(bit_cast>(H2), bit_cast>(H3)); + auto M2 = zip_hi(bit_cast>(H0), bit_cast>(H1)); + auto M3 = zip_hi(bit_cast>(H2), bit_cast>(H3)); + + matrix_begin[8] = bit_cast>(zip_lo(bit_cast>(M0), bit_cast>(M1))); + matrix_begin[9] = bit_cast>(zip_hi(bit_cast>(M0), bit_cast>(M1))); + matrix_begin[10] = bit_cast>(zip_lo(bit_cast>(M2), bit_cast>(M3))); + matrix_begin[11] = bit_cast>(zip_hi(bit_cast>(M2), bit_cast>(M3))); + + auto H4 = zip_hi(bit_cast>(h0), bit_cast>(h1)); + auto H5 = zip_hi(bit_cast>(h2), bit_cast>(h3)); + auto H6 = zip_hi(bit_cast>(h4), bit_cast>(h5)); + auto H7 = zip_hi(bit_cast>(h6), bit_cast>(h7)); + + auto M4 = zip_lo(bit_cast>(H4), bit_cast>(H5)); + auto M5 = zip_lo(bit_cast>(H6), bit_cast>(H7)); + auto M6 = zip_hi(bit_cast>(H4), bit_cast>(H5)); + auto M7 = zip_hi(bit_cast>(H6), bit_cast>(H7)); + + matrix_begin[12] = bit_cast>(zip_lo(bit_cast>(M4), bit_cast>(M5))); + matrix_begin[13] = bit_cast>(zip_hi(bit_cast>(M4), bit_cast>(M5))); + matrix_begin[14] = bit_cast>(zip_lo(bit_cast>(M6), bit_cast>(M7))); + matrix_begin[15] = bit_cast>(zip_hi(bit_cast>(M6), bit_cast>(M7))); + } + + template ::size == 8, void>::type> + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + } } diff --git a/test/test_shuffle.cpp b/test/test_shuffle.cpp index ed07095da..bcfc4aeb5 100644 --- a/test/test_shuffle.cpp +++ b/test/test_shuffle.cpp @@ -723,4 +723,13 @@ TEST_CASE_TEMPLATE("[shuffle]", B, BATCH_FLOAT_TYPES, xsimd::batch, xs } } +TEST_CASE_TEMPLATE("[small integer transpose]", B, xsimd::batch, xsimd::batch, xsimd::batch, xsimd::batch) +{ + shuffle_test Test; + SUBCASE("transpose") + { + Test.transpose(); + } +} + #endif