Skip to content

Commit

Permalink
Refactor action sequence into action group plus extras (#1410)
Browse files Browse the repository at this point in the history
* Move action sequence out of detail namespace
* Fix tests to work with future geant4
* Refactor action sequence into action group plus Celeritas detail
* Add typedef and todos
* Fix accidental include of upstream header
  • Loading branch information
sethrj authored Sep 19, 2024
1 parent 417ee2f commit a7dc59d
Show file tree
Hide file tree
Showing 22 changed files with 299 additions and 166 deletions.
2 changes: 2 additions & 0 deletions app/celer-sim/Runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ size_type Runner::num_events() const
* Get the accumulated action times.
*
* This is a *mean* value over all streams.
*
* \todo Refactor action times gathering: see celeritas::ActionSequence .
*/
auto Runner::get_action_times() const -> MapStrDouble
{
Expand Down
8 changes: 5 additions & 3 deletions app/celer-sim/Transporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include "corecel/sys/Counter.hh"
#include "corecel/sys/ScopedSignalHandler.hh"
#include "celeritas/Types.hh"
#include "celeritas/global/ActionSequence.hh"
#include "celeritas/global/CoreParams.hh"
#include "celeritas/global/Stepper.hh"
#include "celeritas/global/detail/ActionSequence.hh"
#include "celeritas/phys/Model.hh"

#include "StepTimer.hh"
Expand Down Expand Up @@ -180,7 +180,9 @@ auto Transporter<M>::operator()(SpanConstPrimary primaries) -> TransporterResult

//---------------------------------------------------------------------------//
/*!
* Transport the input primaries and all secondaries produced.
* Merge times across all threads.
*
* \todo Action times are to be refactored as aux data.
*/
template<MemSpace M>
void Transporter<M>::accum_action_times(MapStrDouble* result) const
Expand All @@ -191,7 +193,7 @@ void Transporter<M>::accum_action_times(MapStrDouble* result) const
auto const& action_seq = step.actions();
if (action_seq.action_times())
{
auto const& action_ptrs = action_seq.actions();
auto const& action_ptrs = action_seq.actions().step();
auto const& times = action_seq.accum_time();

CELER_ASSERT(action_ptrs.size() == times.size());
Expand Down
2 changes: 1 addition & 1 deletion src/accel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ endif()
celeritas_get_g4libs(_g4_private digits_hits run)
list(APPEND PRIVATE_DEPS ${_g4_private})

celeritas_get_g4libs(_g4_public event global track tracking intercoms geometry)
celeritas_get_g4libs(_g4_public event intercoms geometry global track tracking)
list(APPEND PUBLIC_DEPS ${_g4_public})

#-----------------------------------------------------------------------------#
Expand Down
4 changes: 2 additions & 2 deletions src/accel/LocalTransporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "geocel/g4/Convert.geant.hh"
#include "celeritas/Quantities.hh"
#include "celeritas/ext/GeantUnits.hh"
#include "celeritas/global/detail/ActionSequence.hh"
#include "celeritas/global/ActionSequence.hh"
#include "celeritas/io/EventWriter.hh"
#include "celeritas/io/RootEventWriter.hh"
#include "celeritas/phys/PDGNumber.hh"
Expand Down Expand Up @@ -299,7 +299,7 @@ auto LocalTransporter::GetActionTime() const -> MapStrReal
if (action_seq.action_times())
{
// Save kernel timing if synchronization is enabled
auto const& action_ptrs = action_seq.actions();
auto const& action_ptrs = action_seq.actions().step();
auto const& time = action_seq.accum_time();

CELER_ASSERT(action_ptrs.size() == time.size());
Expand Down
8 changes: 5 additions & 3 deletions src/celeritas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ list(APPEND SOURCES
field/RZMapFieldInputIO.json.cc
field/RZMapFieldParams.cc
geo/GeoMaterialParams.cc
global/ActionGroups.cc
global/ActionSequence.cc
global/CoreParams.cc
global/CoreState.cc
global/CoreTrackData.cc
global/Debug.cc
global/DebugIO.json.cc
global/KernelContextException.cc
global/Stepper.cc
global/detail/ActionSequence.cc
global/detail/PinnedAllocator.cc
grid/GenericGridBuilder.cc
grid/TwodGridBuilder.cc
Expand Down Expand Up @@ -140,8 +141,9 @@ if(CELERITAS_USE_Geant4)
ext/detail/GeantProcessImporter.cc
ext/detail/MuHadEmStandardPhysics.cc
)
celeritas_get_g4libs(_cg4_libs global geometry materials processes run
physicslists tasking)
celeritas_get_g4libs(_cg4_libs
global geometry materials processes run physicslists tasking
)
list(APPEND _cg4_libs Celeritas::corecel XercesC::XercesC)

celeritas_add_object_library(celeritas_geant4 ${_cg4_sources})
Expand Down
22 changes: 22 additions & 0 deletions src/celeritas/global/ActionGroups.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/ActionGroups.cc
//---------------------------------------------------------------------------//
#include "ActionGroups.hh"

#include "corecel/sys/ActionGroups.t.hh"

#include "CoreParams.hh"
#include "CoreState.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//

template class ActionGroups<CoreParams, CoreState>;

//---------------------------------------------------------------------------//
} // namespace celeritas
21 changes: 21 additions & 0 deletions src/celeritas/global/ActionGroups.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/ActionGroups.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/sys/ActionGroups.hh"

#include "ActionInterface.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//

extern template class ActionGroups<CoreParams, CoreState>;

//---------------------------------------------------------------------------//
} // namespace celeritas
22 changes: 1 addition & 21 deletions src/celeritas/global/ActionInterface.hh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CoreState;
//---------------------------------------------------------------------------//
// TYPE ALIASES
//---------------------------------------------------------------------------//
//! Action interface for core stepping loop
//! Interface called at beginning of the core stepping loop
using CoreBeginRunActionInterface
= BeginRunActionInterface<CoreParams, CoreState>;

Expand Down Expand Up @@ -53,26 +53,6 @@ class [[deprecated]] ExplicitCoreActionInterface
virtual void execute(CoreParams const&, CoreStateDevice&) const = 0;
};

//---------------------------------------------------------------------------//
// HELPER STRUCTS
//---------------------------------------------------------------------------//
//! Action order/ID tuple for comparison in sorting
struct OrderedAction
{
StepActionOrder order;
ActionId id;

//! Ordering comparison for an action/ID
CELER_CONSTEXPR_FUNCTION bool operator<(OrderedAction const& other) const
{
if (this->order < other.order)
return true;
if (this->order > other.order)
return false;
return this->id < other.id;
}
};

//---------------------------------------------------------------------------//
// HELPER FUNCTIONS
//---------------------------------------------------------------------------//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/ActionSequence.cc
//! \file celeritas/global/ActionSequence.cc
//---------------------------------------------------------------------------//
#include "ActionSequence.hh"

Expand All @@ -21,65 +21,27 @@
#include "corecel/sys/ScopedProfiling.hh"
#include "corecel/sys/Stopwatch.hh"
#include "corecel/sys/Stream.hh"
#include "celeritas/global/CoreParams.hh"
#include "celeritas/global/CoreState.hh"
#include "celeritas/track/StatusChecker.hh"

#include "../ActionInterface.hh"
#include "../CoreState.hh"
#include "../Debug.hh"
#include "ActionInterface.hh"
#include "CoreParams.hh"
#include "CoreState.hh"
#include "Debug.hh"

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
/*!
* Construct from an action registry and sequence options.
*/
template<class P, template<MemSpace M> class S>
ActionSequence<P, S>::ActionSequence(ActionRegistry const& reg, Options options)
: options_(std::move(options))
ActionSequence::ActionSequence(ActionRegistry const& reg, Options options)
: actions_{reg}, options_{std::move(options)}
{
actions_.reserve(reg.num_actions());
// Loop over all action IDs
for (auto aidx : range(reg.num_actions()))
{
// Get abstract action shared pointer to determine type
auto const& base = reg.action(ActionId{aidx});
static_assert(std::is_same_v<StepActionT, CoreStepActionInterface>);
if (auto step_act = std::dynamic_pointer_cast<StepActionT const>(base))
{
// Add stepping action to our array
actions_.push_back(std::move(step_act));
}
}

begin_run_.reserve(reg.mutable_actions().size());
// Loop over all mutable actions
for (auto const& base : reg.mutable_actions())
{
if (auto brun
= std::dynamic_pointer_cast<CoreBeginRunActionInterface>(base))
{
// Add beginning-of-run to the array
begin_run_.emplace_back(std::move(brun));
}
}

// Sort actions by increasing order (and secondarily, increasing IDs)
std::sort(actions_.begin(),
actions_.end(),
[](SPConstStepAction const& a, SPConstStepAction const& b) {
return OrderedAction{a->order(), a->action_id()}
< OrderedAction{b->order(), b->action_id()};
});

// Initialize timing
accum_time_.resize(actions_.size());
accum_time_.resize(actions_.step().size());

// Get status checker if available
for (auto const& brun_sp : begin_run_)
for (auto const& brun_sp : actions_.begin_run())
{
if (auto sc = std::dynamic_pointer_cast<StatusChecker>(brun_sp))
{
Expand All @@ -91,18 +53,17 @@ ActionSequence<P, S>::ActionSequence(ActionRegistry const& reg, Options options)
}
}

CELER_ENSURE(actions_.size() == accum_time_.size());
CELER_ENSURE(actions_.step().size() == accum_time_.size());
}

//---------------------------------------------------------------------------//
/*!
* Initialize actions and states.
*/
template<class P, template<MemSpace M> class S>
template<MemSpace M>
void ActionSequence<P, S>::begin_run(P const& params, S<M>& state)
void ActionSequence::begin_run(CoreParams const& params, CoreState<M>& state)
{
for (auto const& sp_action : begin_run_)
for (auto const& sp_action : actions_.begin_run())
{
ScopedProfiling profile_this{sp_action->label()};
sp_action->begin_run(params, state);
Expand All @@ -113,26 +74,25 @@ void ActionSequence<P, S>::begin_run(P const& params, S<M>& state)
/*!
* Call all explicit actions with host or device data.
*/
template<class P, template<MemSpace M> class S>
template<MemSpace M>
void ActionSequence<P, S>::step(P const& params, S<M>& state)
void ActionSequence::step(CoreParams const& params, CoreState<M>& state)
{
[[maybe_unused]] Stream::StreamT stream = nullptr;
if (M == MemSpace::device && options_.action_times)
{
stream = celeritas::device().stream(state.stream_id()).get();
}

if constexpr (M == MemSpace::host && std::is_same_v<CoreParams, P>)
if constexpr (M == MemSpace::host)
{
if (status_checker_)
{
g_debug_executing_params = &params;
}
}

// Running a single track slot on host:
// Skip inapplicable post-step action
// When running a single track slot on host, we can preemptively skip
// inapplicable post-step actions
auto const skip_post_action = [&](auto const& action) {
if constexpr (M != MemSpace::host)
{
Expand All @@ -143,12 +103,14 @@ void ActionSequence<P, S>::step(P const& params, S<M>& state)
!= state.ref().sim.post_step_action[TrackSlotId{0}];
};

auto step_actions = make_span(actions_.step());
if (options_.action_times && !state.warming_up())
{
// Execute all actions and record the time elapsed
for (auto i : range(actions_.size()))
for (auto i : range(step_actions.size()))
{
if (auto const& action = *actions_[i]; !skip_post_action(action))
if (auto const& action = *step_actions[i];
!skip_post_action(action))
{
ScopedProfiling profile_this{action.label()};
Stopwatch get_time;
Expand All @@ -168,7 +130,7 @@ void ActionSequence<P, S>::step(P const& params, S<M>& state)
else
{
// Just loop over the actions
for (auto const& sp_action : actions_)
for (auto const& sp_action : actions_.step())
{
if (auto const& action = *sp_action; !skip_post_action(action))
{
Expand All @@ -182,34 +144,23 @@ void ActionSequence<P, S>::step(P const& params, S<M>& state)
}
}

if (M == MemSpace::host && std::is_same_v<CoreParams, P> && status_checker_)
if (M == MemSpace::host && status_checker_)
{
g_debug_executing_params = nullptr;
}
}

//---------------------------------------------------------------------------//
// Explicit template instantiation
//---------------------------------------------------------------------------//

template class ActionSequence<CoreParams, CoreState>;

template void
ActionSequence<CoreParams, CoreState>::begin_run(CoreParams const&,
CoreState<MemSpace::host>&);
ActionSequence::begin_run(CoreParams const&, CoreState<MemSpace::host>&);
template void
ActionSequence<CoreParams, CoreState>::begin_run(CoreParams const&,
CoreState<MemSpace::device>&);
ActionSequence::begin_run(CoreParams const&, CoreState<MemSpace::device>&);

template void
ActionSequence<CoreParams, CoreState>::step(CoreParams const&,
CoreState<MemSpace::host>&);
ActionSequence::step(CoreParams const&, CoreState<MemSpace::host>&);
template void
ActionSequence<CoreParams, CoreState>::step(CoreParams const&,
CoreState<MemSpace::device>&);

// TODO: add explicit template instantiation of execute for optical data
ActionSequence::step(CoreParams const&, CoreState<MemSpace::device>&);

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
Loading

0 comments on commit a7dc59d

Please sign in to comment.