Skip to content

Commit

Permalink
several fixes for mpi-openshmem and ucx-openshmem
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Taylor committed May 8, 2024
1 parent d26b61d commit ab462c3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 23 deletions.
1 change: 0 additions & 1 deletion cmake/HPX_SetupOpenSHMEM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,6 @@ separate_arguments(OPENSHMEM_CFLAGS UNIX_COMMAND "${OPENSHMEM_CFLAGS}")
separate_arguments(OPENSHMEM_LDFLAGS UNIX_COMMAND "${OPENSHMEM_LDFLAGS}")
separate_arguments(OPENSHMEM_LIBRARY_DIRS UNIX_COMMAND "${OPENSHMEM_LIBRARY_DIRS}")


set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS
${OPENSHMEM_CFLAGS}
Expand Down
86 changes: 64 additions & 22 deletions libs/core/openshmem_base/src/openshmem_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <hpx/modules/threading_base.hpp>
#include <hpx/modules/util.hpp>

#include <shmemx.h>

#include <atomic>
#include <unistd.h>
#include <cstddef>
Expand Down Expand Up @@ -340,8 +342,8 @@ namespace hpx::util {
using arg_type = std::conditional<
std::is_same<typename signature<decltype(shmem_int_wait_until)>::type,
std::tuple<volatile int*, int, int>>::value,
volatile int *,
int *
volatile unsigned int *,
unsigned int *
>::type;

union {
Expand All @@ -350,33 +352,73 @@ namespace hpx::util {
} tmp;
tmp.uaddr = sigaddr;

shmem_int_wait_until(tmp.iaddr, SHMEM_CMP_EQ, static_cast<int>(value));
shmem_uint_wait_until(tmp.iaddr, SHMEM_CMP_EQ, static_cast<int>(value));
}

std::size_t openshmem_environment::wait_until_any(const std::uint8_t value, std::uint8_t * sigaddr, const std::size_t count) {
constexpr bool vendor_strings_equal(char const * a, char const * b) {
return *a == *b && (*a == '\0' || vendor_strings_equal(a + 1, b + 1));
}

#if defined(SHMEM_MAJOR_VERSION) && defined(SHMEM_MINOR_VERSION) && \
defined(SHMEM_VENDOR_STRING) && defined(SHMEM_MAX_NAME_LEN) && \
SHMEM_MAX_NAME_LEN == 256
struct ucx {};
struct mpi {};
struct err {};

int rc = 0;
for(std::size_t i = 0; i < count; ++i) {
rc = shmem_uint_test(reinterpret_cast<unsigned int*>(sigaddr+i), SHMEM_CMP_EQ, value);
if(rc) { return i; }
}
template<typename tag>
struct wait_until_any_wrapper {
std::size_t operator()(unsigned int * addr, const std::size_t count, const std::uint8_t value) { return -1; }
};

#if defined(OSHMEM_SHMEMX_H)
template<>
struct wait_until_any_wrapper<mpi> {
std::size_t operator()(unsigned int * addr, const std::size_t count, const std::uint8_t value) {
int rc = 0;
for(std::size_t i = 0; i < count; ++i) {
rc = shmem_uint_test(addr+i, SHMEM_CMP_EQ, value);
if(rc) { return i; }
}

return -1;
}
};
#endif

#if defined(_SHMEM_H)
template<>
struct wait_until_any_wrapper<ucx> {
std::size_t operator()(unsigned int * addr, const std::size_t count, const std::uint8_t value) {
const std::size_t sig_idx =
#if defined(_SHMEMX_H) && _SHMEMX_H == 1
shmemx_uint_wait_until_any
#else
shmem_uint_wait_until_any
#endif
(
addr,
count,
nullptr,
SHMEM_CMP_EQ,
value
);

return sig_idx;
}
};
#endif

return -1;
std::size_t openshmem_environment::wait_until_any(const std::uint8_t value, std::uint8_t * sigaddr, const std::size_t count) {
#if defined(SHMEM_VENDOR_STRING)
using tag = std::conditional< std::integral_constant<bool, vendor_strings_equal(SHMEM_VENDOR_STRING,"osss-ucx")>::value, ucx,
std::conditional< std::integral_constant<bool, vendor_strings_equal(SHMEM_VENDOR_STRING,"http://www.open-mpi.org/")>::value, mpi, err>::type
>::type;
using wait_until_any_type = wait_until_any_wrapper<tag>;
#else
const std::size_t sig_idx = shmem_wait_until_any(
reinterpret_cast<unsigned int *>(sigaddr),
count,
nullptr,
SHMEM_CMP_EQ,
value
);

return sig_idx;
#define SHMEM_VENDOR_STRING "SHMEM_VENDOR_STRING not defined"
using tag = err;
using wait_until_any_type = wait_until_any_wrapper<tag>;
#endif
wait_until_any_type t{};
return t(reinterpret_cast<unsigned int *>(sigaddr), count, value);
}

void openshmem_environment::get(std::uint8_t* addr, const int node,
Expand Down

0 comments on commit ab462c3

Please sign in to comment.