From b6827c125ee00482bc81a5739363cdd1030c5d41 Mon Sep 17 00:00:00 2001 From: serge-sans-paille Date: Tue, 21 Nov 2023 23:56:14 +0100 Subject: [PATCH] Generic, simple implementation fox xsimd::expand Also provide a specialization for avx512f. Related to #975 --- .../arch/generic/xsimd_generic_memory.hpp | 25 +++++ include/xsimd/arch/xsimd_avx512f.hpp | 32 +++++++ include/xsimd/types/xsimd_api.hpp | 13 +++ test/test_shuffle.cpp | 96 +++++++++++++++++++ 4 files changed, 166 insertions(+) diff --git a/include/xsimd/arch/generic/xsimd_generic_memory.hpp b/include/xsimd/arch/generic/xsimd_generic_memory.hpp index ee65b0d35..33daffe6d 100644 --- a/include/xsimd/arch/generic/xsimd_generic_memory.hpp +++ b/include/xsimd/arch/generic/xsimd_generic_memory.hpp @@ -62,6 +62,31 @@ namespace xsimd return swizzle(z, compress_mask); } + // expand + namespace detail + { + template + batch create_expand_swizzle_mask(I bitmask, ::xsimd::detail::index_sequence) + { + batch swizzle_mask(IT(0)); + IT j = 0; + (void)std::initializer_list { ((swizzle_mask = insert(swizzle_mask, j, index())), (j += ((bitmask >> Is) & 1u)), true)... }; + return swizzle_mask; + } + } + + template + inline batch + expand(batch const& x, batch_bool const& mask, + kernel::requires_arch) noexcept + { + constexpr std::size_t size = batch_bool::size; + auto bitmask = mask.mask(); + auto swizzle_mask = detail::create_expand_swizzle_mask, A>(bitmask, ::xsimd::detail::make_index_sequence()); + auto z = swizzle(x, swizzle_mask); + return select(mask, z, batch(T(0))); + } + // extract_pair template inline batch extract_pair(batch const& self, batch const& other, std::size_t i, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx512f.hpp b/include/xsimd/arch/xsimd_avx512f.hpp index b25280cb3..1243f7516 100644 --- a/include/xsimd/arch/xsimd_avx512f.hpp +++ b/include/xsimd/arch/xsimd_avx512f.hpp @@ -788,6 +788,38 @@ namespace xsimd return register_type(~self.data ^ other.data); } + // expand + template + inline batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_ps(self, mask.mask()); + } + template + inline batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_pd(self, mask.mask()); + } + template + inline batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi32(self, mask.mask()); + } + template + inline batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi32(self, mask.mask()); + } + template + inline batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi64(self, mask.mask()); + } + template + inline batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi64(self, mask.mask()); + } + // floor template inline batch floor(batch const& self, requires_arch) noexcept diff --git a/include/xsimd/types/xsimd_api.hpp b/include/xsimd/types/xsimd_api.hpp index b4a84a66d..0420f0a09 100644 --- a/include/xsimd/types/xsimd_api.hpp +++ b/include/xsimd/types/xsimd_api.hpp @@ -718,6 +718,19 @@ namespace xsimd return kernel::exp2(x, A {}); } + /** + * @ingroup batch_data_transfer + * + * Load contiguous elements from \c x and place them in slots selected by \c + * mask, zeroing the other slots + */ + template + inline batch expand(batch const& x, batch_bool const& mask) noexcept + { + detail::static_check_supported_config(); + return kernel::expand(x, mask, A {}); + } + /** * @ingroup batch_math * diff --git a/test/test_shuffle.cpp b/test/test_shuffle.cpp index beafeb1f9..d7e8458fa 100644 --- a/test/test_shuffle.cpp +++ b/test/test_shuffle.cpp @@ -368,6 +368,102 @@ TEST_CASE_TEMPLATE("[compress]", B, BATCH_FLOAT_TYPES, xsimd::batch, x // } } +template +struct expand_test +{ + using batch_type = B; + using value_type = typename B::value_type; + using mask_batch_type = typename B::batch_bool_type; + + static constexpr size_t size = B::size; + std::array input; + std::array mask; + std::array expected; + + expand_test() + { + for (size_t i = 0; i < size; ++i) + { + input[i] = i; + } + } + + void full() + { + std::fill(mask.begin(), mask.end(), true); + + for (size_t i = 0; i < size; ++i) + expected[i] = input[i]; + + auto b = xsimd::expand( + batch_type::load_unaligned(input.data()), + mask_batch_type::load_unaligned(mask.data())); + CHECK_BATCH_EQ(b, expected); + } + + void empty() + { + std::fill(mask.begin(), mask.end(), false); + + for (size_t i = 0; i < size; ++i) + expected[i] = 0; + + auto b = xsimd::expand( + batch_type::load_unaligned(input.data()), + mask_batch_type::load_unaligned(mask.data())); + CHECK_BATCH_EQ(b, expected); + } + + void interleave() + { + for (size_t i = 0; i < size; ++i) + mask[i] = i % 2 == 0; + + for (size_t i = 0, j = 0; i < size; ++i) + expected[i] = mask[i] ? input[j++] : 0; + + auto b = xsimd::expand( + batch_type::load_unaligned(input.data()), + mask_batch_type::load_unaligned(mask.data())); + CHECK_BATCH_EQ(b, expected); + } + + void generic() + { + for (size_t i = 0; i < size; ++i) + mask[i] = i % 3 == 0; + + for (size_t i = 0, j = 0; i < size; ++i) + expected[i] = mask[i] ? input[j++] : 0; + + auto b = xsimd::expand( + batch_type::load_unaligned(input.data()), + mask_batch_type::load_unaligned(mask.data())); + CHECK_BATCH_EQ(b, expected); + } +}; + +TEST_CASE_TEMPLATE("[expand]", B, BATCH_FLOAT_TYPES, xsimd::batch, xsimd::batch, xsimd::batch, xsimd::batch) +{ + expand_test Test; + SUBCASE("empty") + { + Test.empty(); + } + SUBCASE("full") + { + Test.full(); + } + SUBCASE("interleave") + { + Test.interleave(); + } + SUBCASE("generic") + { + Test.generic(); + } +} + template struct shuffle_test {