Skip to content

Commit

Permalink
Adding synchronous collective operations
Browse files Browse the repository at this point in the history
- adding predefined world_comunicator
  • Loading branch information
hkaiser committed Jan 5, 2025
1 parent 64b1c0d commit b6da4e2
Show file tree
Hide file tree
Showing 7 changed files with 833 additions and 2 deletions.
88 changes: 87 additions & 1 deletion libs/full/collectives/include/hpx/collectives/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ namespace hpx { namespace collectives {
#include <hpx/assert.hpp>
#include <hpx/async_base/launch_policy.hpp>
#include <hpx/async_distributed/async.hpp>
#include <hpx/async_local/dataflow.hpp>
#include <hpx/collectives/argument_types.hpp>
#include <hpx/collectives/create_communicator.hpp>
#include <hpx/components_base/agas_interface.hpp>
Expand Down Expand Up @@ -334,6 +333,39 @@ namespace hpx::collectives {
HPX_FORWARD(T, local_result), this_site);
}

////////////////////////////////////////////////////////////////////////////
template <typename T>
decltype(auto) broadcast_to(hpx::launch::sync_policy, communicator fid,
T&& local_result, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return broadcast_to(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
decltype(auto) broadcast_to(hpx::launch::sync_policy, communicator fid,
T&& local_result, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return broadcast_to(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
decltype(auto) broadcast_to(hpx::launch::sync_policy, char const* basename,
T&& local_result, num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return broadcast_to(hpx::launch::sync,
create_communicator(basename, num_sites, this_site, generation,
root_site_arg(this_site.argument_)),
HPX_FORWARD(T, local_result), this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
hpx::future<T> broadcast_from(communicator fid,
Expand Down Expand Up @@ -392,6 +424,60 @@ namespace hpx::collectives {
this_site, generation, root_site),
this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
T broadcast_from(hpx::launch::sync_policy, communicator fid,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return broadcast_from<T>(HPX_MOVE(fid), this_site, generation).get();
}

template <typename T>
T broadcast_from(hpx::launch::sync_policy, communicator fid,
generation_arg generation, this_site_arg this_site = this_site_arg())
{
return broadcast_from<T>(HPX_MOVE(fid), this_site, generation).get();
}

template <typename T>
T broadcast_from(hpx::launch::sync_policy, char const* basename,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
HPX_ASSERT(this_site != root_site);
return broadcast_from<T>(create_communicator(basename, num_sites_arg(),
this_site, generation, root_site),
this_site)
.get();
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
void broadcast(communicator fid, T& value,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
if (this_site == static_cast<std::size_t>(-1))
{
this_site = static_cast<std::size_t>(agas::get_locality_id());
}

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
{
broadcast_to(
hpx::launch::sync, HPX_MOVE(fid), value, this_site, generation);
}
else
{
value = broadcast_from<T>(
hpx::launch::sync, HPX_MOVE(fid), this_site, generation);
}
}
} // namespace hpx::collectives

////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ namespace hpx::collectives {
}
};

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
HPX_EXPORT communicator get_world_communicator();

///////////////////////////////////////////////////////////////////////////
HPX_EXPORT communicator create_communicator(char const* basename,
num_sites_arg num_sites = num_sites_arg(),
Expand Down
93 changes: 93 additions & 0 deletions libs/full/collectives/include/hpx/collectives/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,39 @@ namespace hpx::collectives {
HPX_FORWARD(T, result), HPX_FORWARD(F, op), this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T, typename F>
decltype(auto) reduce_here(hpx::launch::sync_policy, communicator fid,
T&& local_result, F&& op, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return reduce_here(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
HPX_FORWARD(F, op), this_site, generation)
.get();
}

template <typename T, typename F>
decltype(auto) reduce_here(hpx::launch::sync_policy, communicator fid,
T&& local_result, F&& op, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return reduce_here(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
HPX_FORWARD(F, op), this_site, generation)
.get();
}

template <typename T, typename F>
decltype(auto) reduce_here(hpx::launch::sync_policy, char const* basename,
T&& result, F&& op, num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return reduce_here(create_communicator(basename, num_sites, this_site,
generation, root_site_arg(this_site.argument_)),
HPX_FORWARD(T, result), HPX_FORWARD(F, op), this_site)
.get();
}

///////////////////////////////////////////////////////////////////////////
// reduce plain values
template <typename T>
Expand Down Expand Up @@ -443,6 +476,66 @@ namespace hpx::collectives {
this_site, generation, root_site),
HPX_FORWARD(T, local_result), this_site);
}

////////////////////////////////////////////////////////////////////////////
template <typename T>
void reduce_there(hpx::launch::sync_policy, communicator fid,
T&& local_result, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
reduce_there(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
void reduce_there(hpx::launch::sync_policy, communicator fid,
T&& local_result, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
reduce_there(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
void reduce_there(hpx::launch::sync_policy, char const* basename,
T&& local_result, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
HPX_ASSERT(this_site != root_site);
reduce_there(create_communicator(basename, num_sites_arg(), this_site,
generation, root_site),
HPX_FORWARD(T, local_result), this_site)
.get();
}

////////////////////////////////////////////////////////////////////////////
template <typename T, typename F>
void reduce(communicator fid, T&& local_result, F&& op,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
if (this_site == static_cast<std::size_t>(-1))
{
this_site = static_cast<std::size_t>(agas::get_locality_id());
}

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
{
local_result = reduce_here(hpx::launch::sync, HPX_MOVE(fid),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site,
generation);
}
else
{
reduce_there(hpx::launch::sync, HPX_MOVE(fid),
HPX_FORWARD(T, local_result), this_site, generation);
}
}
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
Expand Down
18 changes: 18 additions & 0 deletions libs/full/collectives/src/create_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,24 @@ namespace hpx::collectives {
// find existing communicator
return hpx::find_from_basename<communicator>(HPX_MOVE(name), root_site);
}

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
namespace {
communicator world_communicator;
hpx::mutex world_communicator_mtx;
} // namespace

communicator get_world_communicator()
{
{
std::lock_guard<hpx::mutex> l(world_communicator_mtx);
if (!world_communicator)
world_communicator =
create_communicator("hpx::collectives::world_communicator");
}
return world_communicator;
}
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
4 changes: 3 additions & 1 deletion libs/full/collectives/tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2024 Hartmut Kaiser
# Copyright (c) 2019-2025 Hartmut Kaiser
#
# SPDX-License-Identifier: BSL-1.0
# Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand All @@ -12,6 +12,7 @@ set(tests
broadcast
broadcast_component
broadcast_post
broadcast_sync
channel_communicator
fold
global_spmd_block
Expand All @@ -28,6 +29,7 @@ if(HPX_WITH_NETWORKING)
gather
inclusive_scan_
reduce
reduce_sync
scatter
)

Expand Down
Loading

0 comments on commit b6da4e2

Please sign in to comment.