diff --git a/include/xsimd/arch/xsimd_avx512f.hpp b/include/xsimd/arch/xsimd_avx512f.hpp index c2b485a30..cc1394578 100644 --- a/include/xsimd/arch/xsimd_avx512f.hpp +++ b/include/xsimd/arch/xsimd_avx512f.hpp @@ -25,6 +25,12 @@ namespace xsimd { using namespace types; + // fwd + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept; + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept; + namespace detail { XSIMD_INLINE void split_avx512(__m512 val, __m256& low, __m256& high) noexcept @@ -2010,6 +2016,79 @@ namespace xsimd return bitwise_cast(swizzle(bitwise_cast(self), mask, avx512f {})); } + // transpose + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + batch tmp_lo0[16]; + for (int i = 0; i < 16; ++i) + tmp_lo0[i] = _mm512_castsi512_si256(matrix_begin[i]); + transpose(tmp_lo0 + 0, tmp_lo0 + 16, avx2 {}); + + batch tmp_hi0[16]; + for (int i = 0; i < 16; ++i) + tmp_hi0[i] = _mm512_castsi512_si256(matrix_begin[16 + i]); + transpose(tmp_hi0 + 0, tmp_hi0 + 16, avx2 {}); + + batch tmp_lo1[16]; + for (int i = 0; i < 16; ++i) + tmp_lo1[i] = _mm512_extracti64x4_epi64(matrix_begin[i], 1); + transpose(tmp_lo1 + 0, tmp_lo1 + 16, avx2 {}); + + batch tmp_hi1[16]; + for (int i = 0; i < 16; ++i) + tmp_hi1[i] = _mm512_extracti64x4_epi64(matrix_begin[16 + i], 1); + transpose(tmp_hi1 + 0, tmp_hi1 + 16, avx2 {}); + + for (int i = 0; i < 16; ++i) + matrix_begin[i] = detail::merge_avx(tmp_lo0[i], tmp_hi0[i]); + for (int i = 0; i < 16; ++i) + matrix_begin[i + 16] = detail::merge_avx(tmp_lo1[i], tmp_hi1[i]); + } + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + return transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + batch tmp_lo0[32]; + for (int i = 0; i < 32; ++i) + tmp_lo0[i] = _mm512_castsi512_si256(matrix_begin[i]); + transpose(tmp_lo0 + 0, tmp_lo0 + 32, avx2 {}); + + batch tmp_hi0[32]; + for (int i = 0; i < 32; ++i) + tmp_hi0[i] = _mm512_castsi512_si256(matrix_begin[32 + i]); + transpose(tmp_hi0 + 0, tmp_hi0 + 32, avx2 {}); + + batch tmp_lo1[32]; + for (int i = 0; i < 32; ++i) + tmp_lo1[i] = _mm512_extracti64x4_epi64(matrix_begin[i], 1); + transpose(tmp_lo1 + 0, tmp_lo1 + 32, avx2 {}); + + batch tmp_hi1[32]; + for (int i = 0; i < 32; ++i) + tmp_hi1[i] = _mm512_extracti64x4_epi64(matrix_begin[32 + i], 1); + transpose(tmp_hi1 + 0, tmp_hi1 + 32, avx2 {}); + + for (int i = 0; i < 32; ++i) + matrix_begin[i] = detail::merge_avx(tmp_lo0[i], tmp_hi0[i]); + for (int i = 0; i < 32; ++i) + matrix_begin[i + 32] = detail::merge_avx(tmp_lo1[i], tmp_hi1[i]); + } + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + return transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + // trunc template XSIMD_INLINE batch