Skip to content

Commit

Permalink
Add avx512 support for tranpose operator
Browse files Browse the repository at this point in the history
  • Loading branch information
serge-sans-paille committed Dec 27, 2024
1 parent bbe95c1 commit 9bf4781
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions include/xsimd/arch/xsimd_avx512f.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ namespace xsimd
{
using namespace types;

// fwd
template <class A>
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<generic>) noexcept;
template <class A>
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<generic>) noexcept;

namespace detail
{
XSIMD_INLINE void split_avx512(__m512 val, __m256& low, __m256& high) noexcept
Expand Down Expand Up @@ -2010,6 +2016,79 @@ namespace xsimd
return bitwise_cast<int16_t>(swizzle(bitwise_cast<uint16_t>(self), mask, avx512f {}));
}

// transpose
template <class A>
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<avx512f>) noexcept
{
assert((matrix_end - matrix_begin == batch<uint16_t, A>::size) && "correctly sized matrix");
(void)matrix_end;
batch<uint16_t, avx2> tmp_lo0[16];
for (int i = 0; i < 16; ++i)
tmp_lo0[i] = _mm512_castsi512_si256(matrix_begin[i]);
transpose(tmp_lo0 + 0, tmp_lo0 + 16, avx2 {});

batch<uint16_t, avx2> tmp_hi0[16];
for (int i = 0; i < 16; ++i)
tmp_hi0[i] = _mm512_castsi512_si256(matrix_begin[16 + i]);
transpose(tmp_hi0 + 0, tmp_hi0 + 16, avx2 {});

batch<uint16_t, avx2> tmp_lo1[16];
for (int i = 0; i < 16; ++i)
tmp_lo1[i] = _mm512_extracti64x4_epi64(matrix_begin[i], 1);
transpose(tmp_lo1 + 0, tmp_lo1 + 16, avx2 {});

batch<uint16_t, avx2> tmp_hi1[16];
for (int i = 0; i < 16; ++i)
tmp_hi1[i] = _mm512_extracti64x4_epi64(matrix_begin[16 + i], 1);
transpose(tmp_hi1 + 0, tmp_hi1 + 16, avx2 {});

for (int i = 0; i < 16; ++i)
matrix_begin[i] = detail::merge_avx(tmp_lo0[i], tmp_hi0[i]);
for (int i = 0; i < 16; ++i)
matrix_begin[i + 16] = detail::merge_avx(tmp_lo1[i], tmp_hi1[i]);
}
template <class A>
XSIMD_INLINE void transpose(batch<int16_t, A>* matrix_begin, batch<int16_t, A>* matrix_end, requires_arch<avx512f>) noexcept
{
return transpose(reinterpret_cast<batch<uint16_t, A>*>(matrix_begin), reinterpret_cast<batch<uint16_t, A>*>(matrix_end), A {});
}

template <class A>
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<avx512F>) noexcept
{
assert((matrix_end - matrix_begin == batch<uint8_t, A>::size) && "correctly sized matrix");
(void)matrix_end;
batch<uint8_t, avx2> tmp_lo0[32];
for (int i = 0; i < 32; ++i)
tmp_lo0[i] = _mm512_castsi512_si256(matrix_begin[i]);
transpose(tmp_lo0 + 0, tmp_lo0 + 32, avx2 {});

batch<uint8_t, avx2> tmp_hi0[32];
for (int i = 0; i < 32; ++i)
tmp_hi0[i] = _mm512_castsi512_si256(matrix_begin[32 + i]);
transpose(tmp_hi0 + 0, tmp_hi0 + 32, avx2 {});

batch<uint8_t, avx2> tmp_lo1[32];
for (int i = 0; i < 32; ++i)
tmp_lo1[i] = _mm512_extracti64x4_epi64(matrix_begin[i], 1);
transpose(tmp_lo1 + 0, tmp_lo1 + 32, avx2 {});

batch<uint8_t, avx2> tmp_hi1[32];
for (int i = 0; i < 32; ++i)
tmp_hi1[i] = _mm512_extracti64x4_epi64(matrix_begin[32 + i], 1);
transpose(tmp_hi1 + 0, tmp_hi1 + 32, avx2 {});

for (int i = 0; i < 32; ++i)
matrix_begin[i] = detail::merge_avx(tmp_lo0[i], tmp_hi0[i]);
for (int i = 0; i < 32; ++i)
matrix_begin[i + 32] = detail::merge_avx(tmp_lo1[i], tmp_hi1[i]);
}
template <class A>
XSIMD_INLINE void transpose(batch<int8_t, A>* matrix_begin, batch<int8_t, A>* matrix_end, requires_arch<avx512f>) noexcept
{
return transpose(reinterpret_cast<batch<uint8_t, A>*>(matrix_begin), reinterpret_cast<batch<uint8_t, A>*>(matrix_end), A {});
}

// trunc
template <class A>
XSIMD_INLINE batch<float, A>
Expand Down

0 comments on commit 9bf4781

Please sign in to comment.