Skip to content

Commit

Permalink
Generic, simple implementation fox xsimd::expand
Browse files Browse the repository at this point in the history
Also provide a specialization for avx512f.

Related to #975
  • Loading branch information
serge-sans-paille committed Nov 28, 2023
1 parent 8d3c67c commit b6827c1
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/xsimd/arch/generic/xsimd_generic_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,31 @@ namespace xsimd
return swizzle(z, compress_mask);
}

// expand
namespace detail
{
template <class IT, class A, class I, size_t... Is>
batch<IT, A> create_expand_swizzle_mask(I bitmask, ::xsimd::detail::index_sequence<Is...>)
{
batch<IT, A> swizzle_mask(IT(0));
IT j = 0;
(void)std::initializer_list<bool> { ((swizzle_mask = insert(swizzle_mask, j, index<Is>())), (j += ((bitmask >> Is) & 1u)), true)... };
return swizzle_mask;
}
}

template <typename A, typename T>
inline batch<T, A>
expand(batch<T, A> const& x, batch_bool<T, A> const& mask,
kernel::requires_arch<generic>) noexcept
{
constexpr std::size_t size = batch_bool<T, A>::size;
auto bitmask = mask.mask();
auto swizzle_mask = detail::create_expand_swizzle_mask<as_unsigned_integer_t<T>, A>(bitmask, ::xsimd::detail::make_index_sequence<size>());
auto z = swizzle(x, swizzle_mask);
return select(mask, z, batch<T, A>(T(0)));
}

// extract_pair
template <class A, class T>
inline batch<T, A> extract_pair(batch<T, A> const& self, batch<T, A> const& other, std::size_t i, requires_arch<generic>) noexcept
Expand Down
32 changes: 32 additions & 0 deletions include/xsimd/arch/xsimd_avx512f.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,38 @@ namespace xsimd
return register_type(~self.data ^ other.data);
}

// expand
template <class A>
inline batch<float, A> expand(batch<float, A> const& self, batch_bool<float, A> const& mask, requires_arch<avx512f>) noexcept
{
return _mm512_maskz_expand_ps(self, mask.mask());
}
template <class A>
inline batch<double, A> expand(batch<double, A> const& self, batch_bool<double, A> const& mask, requires_arch<avx512f>) noexcept
{
return _mm512_maskz_expand_pd(self, mask.mask());
}
template <class A>
inline batch<int32_t, A> expand(batch<int32_t, A> const& self, batch_bool<int32_t, A> const& mask, requires_arch<avx512f>) noexcept
{
return _mm512_maskz_expand_epi32(self, mask.mask());
}
template <class A>
inline batch<uint32_t, A> expand(batch<uint32_t, A> const& self, batch_bool<uint32_t, A> const& mask, requires_arch<avx512f>) noexcept
{
return _mm512_maskz_expand_epi32(self, mask.mask());
}
template <class A>
inline batch<int64_t, A> expand(batch<int64_t, A> const& self, batch_bool<int64_t, A> const& mask, requires_arch<avx512f>) noexcept
{
return _mm512_maskz_expand_epi64(self, mask.mask());
}
template <class A>
inline batch<uint64_t, A> expand(batch<uint64_t, A> const& self, batch_bool<uint64_t, A> const& mask, requires_arch<avx512f>) noexcept
{
return _mm512_maskz_expand_epi64(self, mask.mask());
}

// floor
template <class A>
inline batch<float, A> floor(batch<float, A> const& self, requires_arch<avx512f>) noexcept
Expand Down
13 changes: 13 additions & 0 deletions include/xsimd/types/xsimd_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,19 @@ namespace xsimd
return kernel::exp2<A>(x, A {});
}

/**
* @ingroup batch_data_transfer
*
* Load contiguous elements from \c x and place them in slots selected by \c
* mask, zeroing the other slots
*/
template <class T, class A>
inline batch<T, A> expand(batch<T, A> const& x, batch_bool<T, A> const& mask) noexcept
{
detail::static_check_supported_config<T, A>();
return kernel::expand<A>(x, mask, A {});
}

/**
* @ingroup batch_math
*
Expand Down
96 changes: 96 additions & 0 deletions test/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,102 @@ TEST_CASE_TEMPLATE("[compress]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, x
// }
}

template <class B>
struct expand_test
{
using batch_type = B;
using value_type = typename B::value_type;
using mask_batch_type = typename B::batch_bool_type;

static constexpr size_t size = B::size;
std::array<value_type, size> input;
std::array<bool, size> mask;
std::array<value_type, size> expected;

expand_test()
{
for (size_t i = 0; i < size; ++i)
{
input[i] = i;
}
}

void full()
{
std::fill(mask.begin(), mask.end(), true);

for (size_t i = 0; i < size; ++i)
expected[i] = input[i];

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}

void empty()
{
std::fill(mask.begin(), mask.end(), false);

for (size_t i = 0; i < size; ++i)
expected[i] = 0;

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}

void interleave()
{
for (size_t i = 0; i < size; ++i)
mask[i] = i % 2 == 0;

for (size_t i = 0, j = 0; i < size; ++i)
expected[i] = mask[i] ? input[j++] : 0;

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}

void generic()
{
for (size_t i = 0; i < size; ++i)
mask[i] = i % 3 == 0;

for (size_t i = 0, j = 0; i < size; ++i)
expected[i] = mask[i] ? input[j++] : 0;

auto b = xsimd::expand(
batch_type::load_unaligned(input.data()),
mask_batch_type::load_unaligned(mask.data()));
CHECK_BATCH_EQ(b, expected);
}
};

TEST_CASE_TEMPLATE("[expand]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, xsimd::batch<int32_t>, xsimd::batch<uint64_t>, xsimd::batch<int64_t>)
{
expand_test<B> Test;
SUBCASE("empty")
{
Test.empty();
}
SUBCASE("full")
{
Test.full();
}
SUBCASE("interleave")
{
Test.interleave();
}
SUBCASE("generic")
{
Test.generic();
}
}

template <class B>
struct shuffle_test
{
Expand Down

0 comments on commit b6827c1

Please sign in to comment.