Skip to content

Commit

Permalink
Implement [u]int8 and [u]int16 matrix transpose for 128 bit registers
Browse files Browse the repository at this point in the history
Related to #1054
  • Loading branch information
serge-sans-paille committed Dec 27, 2024
1 parent 5a9fae3 commit 1720471
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 1 deletion.
137 changes: 136 additions & 1 deletion include/xsimd/arch/generic/xsimd_generic_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ namespace xsimd
hi.store_aligned(buffer + real_batch::size);
}

// store_compelx_unaligned
// store_complex_unaligned
template <class A, class T_out, class T_in>
XSIMD_INLINE void store_complex_unaligned(std::complex<T_out>* dst, batch<std::complex<T_in>, A> const& src, requires_arch<generic>) noexcept
{
Expand Down Expand Up @@ -665,6 +665,141 @@ namespace xsimd
}
}

// transpose
template <class A, class = typename std::enable_if<batch<int16_t, A>::size == 8, void>::type>
XSIMD_INLINE void transpose(batch<int16_t, A>* matrix_begin, batch<int16_t, A>* matrix_end, requires_arch<generic>) noexcept
{
assert((matrix_end - matrix_begin == batch<int16_t, A>::size) && "correctly sized matrix");
(void)matrix_end;
auto l0 = zip_lo(matrix_begin[0], matrix_begin[1]);
auto l1 = zip_lo(matrix_begin[2], matrix_begin[3]);
auto l2 = zip_lo(matrix_begin[4], matrix_begin[5]);
auto l3 = zip_lo(matrix_begin[6], matrix_begin[7]);

auto l4 = zip_lo(bit_cast<batch<int32_t, A>>(l0), bit_cast<batch<int32_t, A>>(l1));
auto l5 = zip_lo(bit_cast<batch<int32_t, A>>(l2), bit_cast<batch<int32_t, A>>(l3));

auto l6 = zip_hi(bit_cast<batch<int32_t, A>>(l0), bit_cast<batch<int32_t, A>>(l1));
auto l7 = zip_hi(bit_cast<batch<int32_t, A>>(l2), bit_cast<batch<int32_t, A>>(l3));

auto h0 = zip_hi(matrix_begin[0], matrix_begin[1]);
auto h1 = zip_hi(matrix_begin[2], matrix_begin[3]);
auto h2 = zip_hi(matrix_begin[4], matrix_begin[5]);
auto h3 = zip_hi(matrix_begin[6], matrix_begin[7]);

auto h4 = zip_lo(bit_cast<batch<int32_t, A>>(h0), bit_cast<batch<int32_t, A>>(h1));
auto h5 = zip_lo(bit_cast<batch<int32_t, A>>(h2), bit_cast<batch<int32_t, A>>(h3));

auto h6 = zip_hi(bit_cast<batch<int32_t, A>>(h0), bit_cast<batch<int32_t, A>>(h1));
auto h7 = zip_hi(bit_cast<batch<int32_t, A>>(h2), bit_cast<batch<int32_t, A>>(h3));

matrix_begin[0] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(l4), bit_cast<batch<int64_t, A>>(l5)));
matrix_begin[1] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(l4), bit_cast<batch<int64_t, A>>(l5)));
matrix_begin[2] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(l6), bit_cast<batch<int64_t, A>>(l7)));
matrix_begin[3] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(l6), bit_cast<batch<int64_t, A>>(l7)));

matrix_begin[4] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(h4), bit_cast<batch<int64_t, A>>(h5)));
matrix_begin[5] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(h4), bit_cast<batch<int64_t, A>>(h5)));
matrix_begin[6] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(h6), bit_cast<batch<int64_t, A>>(h7)));
matrix_begin[7] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(h6), bit_cast<batch<int64_t, A>>(h7)));
}

template <class A, class = typename std::enable_if<batch<uint16_t, A>::size == 8, void>::type>
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<generic>) noexcept
{
transpose(reinterpret_cast<batch<int16_t, A>*>(matrix_begin), reinterpret_cast<batch<int16_t, A>*>(matrix_end), A {});
}

template <class A, class = typename std::enable_if<batch<int8_t, A>::size == 8, void>::type>
XSIMD_INLINE void transpose(batch<int8_t, A>* matrix_begin, batch<int8_t, A>* matrix_end, requires_arch<generic>) noexcept
{
assert((matrix_end - matrix_begin == batch<int8_t, A>::size) && "correctly sized matrix");
(void)matrix_end;
auto l0 = zip_lo(matrix_begin[0], matrix_begin[1]);
auto l1 = zip_lo(matrix_begin[2], matrix_begin[3]);
auto l2 = zip_lo(matrix_begin[4], matrix_begin[5]);
auto l3 = zip_lo(matrix_begin[6], matrix_begin[7]);
auto l4 = zip_lo(matrix_begin[8], matrix_begin[9]);
auto l5 = zip_lo(matrix_begin[10], matrix_begin[11]);
auto l6 = zip_lo(matrix_begin[12], matrix_begin[13]);
auto l7 = zip_lo(matrix_begin[14], matrix_begin[15]);

auto L0 = zip_lo(bit_cast<batch<int16_t, A>>(l0), bit_cast<batch<int16_t, A>>(l1));
auto L1 = zip_lo(bit_cast<batch<int16_t, A>>(l2), bit_cast<batch<int16_t, A>>(l3));
auto L2 = zip_lo(bit_cast<batch<int16_t, A>>(l4), bit_cast<batch<int16_t, A>>(l5));
auto L3 = zip_lo(bit_cast<batch<int16_t, A>>(l6), bit_cast<batch<int16_t, A>>(l7));

auto m0 = zip_lo(bit_cast<batch<int32_t, A>>(L0), bit_cast<batch<int32_t, A>>(L1));
auto m1 = zip_lo(bit_cast<batch<int32_t, A>>(L2), bit_cast<batch<int32_t, A>>(L3));
auto m2 = zip_hi(bit_cast<batch<int32_t, A>>(L0), bit_cast<batch<int32_t, A>>(L1));
auto m3 = zip_hi(bit_cast<batch<int32_t, A>>(L2), bit_cast<batch<int32_t, A>>(L3));

matrix_begin[0] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m0), bit_cast<batch<int64_t, A>>(m1)));
matrix_begin[1] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m0), bit_cast<batch<int64_t, A>>(m1)));
matrix_begin[2] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m2), bit_cast<batch<int64_t, A>>(m3)));
matrix_begin[3] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m2), bit_cast<batch<int64_t, A>>(m3)));

auto L4 = zip_hi(bit_cast<batch<int16_t, A>>(l0), bit_cast<batch<int16_t, A>>(l1));
auto L5 = zip_hi(bit_cast<batch<int16_t, A>>(l2), bit_cast<batch<int16_t, A>>(l3));
auto L6 = zip_hi(bit_cast<batch<int16_t, A>>(l4), bit_cast<batch<int16_t, A>>(l5));
auto L7 = zip_hi(bit_cast<batch<int16_t, A>>(l6), bit_cast<batch<int16_t, A>>(l7));

auto m4 = zip_lo(bit_cast<batch<int32_t, A>>(L4), bit_cast<batch<int32_t, A>>(L5));
auto m5 = zip_lo(bit_cast<batch<int32_t, A>>(L6), bit_cast<batch<int32_t, A>>(L7));
auto m6 = zip_hi(bit_cast<batch<int32_t, A>>(L4), bit_cast<batch<int32_t, A>>(L5));
auto m7 = zip_hi(bit_cast<batch<int32_t, A>>(L6), bit_cast<batch<int32_t, A>>(L7));

matrix_begin[4] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m4), bit_cast<batch<int64_t, A>>(m5)));
matrix_begin[5] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m4), bit_cast<batch<int64_t, A>>(m5)));
matrix_begin[6] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m6), bit_cast<batch<int64_t, A>>(m7)));
matrix_begin[7] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m6), bit_cast<batch<int64_t, A>>(m7)));

auto h0 = zip_hi(matrix_begin[0], matrix_begin[1]);
auto h1 = zip_hi(matrix_begin[2], matrix_begin[3]);
auto h2 = zip_hi(matrix_begin[4], matrix_begin[5]);
auto h3 = zip_hi(matrix_begin[6], matrix_begin[7]);
auto h4 = zip_hi(matrix_begin[8], matrix_begin[9]);
auto h5 = zip_hi(matrix_begin[10], matrix_begin[11]);
auto h6 = zip_hi(matrix_begin[12], matrix_begin[13]);
auto h7 = zip_hi(matrix_begin[14], matrix_begin[15]);

auto H0 = zip_lo(bit_cast<batch<int16_t, A>>(h0), bit_cast<batch<int16_t, A>>(h1));
auto H1 = zip_lo(bit_cast<batch<int16_t, A>>(h2), bit_cast<batch<int16_t, A>>(h3));
auto H2 = zip_lo(bit_cast<batch<int16_t, A>>(h4), bit_cast<batch<int16_t, A>>(h5));
auto H3 = zip_lo(bit_cast<batch<int16_t, A>>(h6), bit_cast<batch<int16_t, A>>(h7));

auto M0 = zip_lo(bit_cast<batch<int32_t, A>>(H0), bit_cast<batch<int32_t, A>>(H1));
auto M1 = zip_lo(bit_cast<batch<int32_t, A>>(H2), bit_cast<batch<int32_t, A>>(H3));
auto M2 = zip_hi(bit_cast<batch<int32_t, A>>(H0), bit_cast<batch<int32_t, A>>(H1));
auto M3 = zip_hi(bit_cast<batch<int32_t, A>>(H2), bit_cast<batch<int32_t, A>>(H3));

matrix_begin[8] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M0), bit_cast<batch<int64_t, A>>(M1)));
matrix_begin[9] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M0), bit_cast<batch<int64_t, A>>(M1)));
matrix_begin[10] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M2), bit_cast<batch<int64_t, A>>(M3)));
matrix_begin[11] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M2), bit_cast<batch<int64_t, A>>(M3)));

auto H4 = zip_hi(bit_cast<batch<int16_t, A>>(h0), bit_cast<batch<int16_t, A>>(h1));
auto H5 = zip_hi(bit_cast<batch<int16_t, A>>(h2), bit_cast<batch<int16_t, A>>(h3));
auto H6 = zip_hi(bit_cast<batch<int16_t, A>>(h4), bit_cast<batch<int16_t, A>>(h5));
auto H7 = zip_hi(bit_cast<batch<int16_t, A>>(h6), bit_cast<batch<int16_t, A>>(h7));

auto M4 = zip_lo(bit_cast<batch<int32_t, A>>(H4), bit_cast<batch<int32_t, A>>(H5));
auto M5 = zip_lo(bit_cast<batch<int32_t, A>>(H6), bit_cast<batch<int32_t, A>>(H7));
auto M6 = zip_hi(bit_cast<batch<int32_t, A>>(H4), bit_cast<batch<int32_t, A>>(H5));
auto M7 = zip_hi(bit_cast<batch<int32_t, A>>(H6), bit_cast<batch<int32_t, A>>(H7));

matrix_begin[12] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M4), bit_cast<batch<int64_t, A>>(M5)));
matrix_begin[13] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M4), bit_cast<batch<int64_t, A>>(M5)));
matrix_begin[14] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M6), bit_cast<batch<int64_t, A>>(M7)));
matrix_begin[15] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M6), bit_cast<batch<int64_t, A>>(M7)));
}

template <class A, class = typename std::enable_if<batch<uint8_t, A>::size == 8, void>::type>
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<generic>) noexcept
{
transpose(reinterpret_cast<batch<int8_t, A>*>(matrix_begin), reinterpret_cast<batch<int8_t, A>*>(matrix_end), A {});
}

}

}
Expand Down
9 changes: 9 additions & 0 deletions test/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,13 @@ TEST_CASE_TEMPLATE("[shuffle]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, xs
}
}

TEST_CASE_TEMPLATE("[small integer transpose]", B, xsimd::batch<uint16_t>, xsimd::batch<int16_t>, xsimd::batch<uint8_t>, xsimd::batch<int8_t>)
{
shuffle_test<B> Test;
SUBCASE("transpose")
{
Test.transpose();
}
}

#endif

0 comments on commit 1720471

Please sign in to comment.