Skip to content

Commit

Permalink
Added Updated the the source_medium load_on_device and store_on_devic…
Browse files Browse the repository at this point in the history
…e functions
  • Loading branch information
lsawade committed Jan 24, 2025
1 parent 468b9ef commit f7d8c2f
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 53 deletions.
42 changes: 24 additions & 18 deletions include/compute/sources/source_medium.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,36 +95,42 @@ struct source_medium {
SourceArrayView source_array; ///< Lagrange interpolants for every source
SourceArrayView::HostMirror h_source_array; ///< Host mirror of source_array

template <typename IndexType, typename PointSourceType>
template <typename IteratorIndexType, typename PointSourceType>
KOKKOS_INLINE_FUNCTION void
load_on_device(const int timestep, const IndexType index,
load_on_device(const int timestep, const IteratorIndexType &iterator_index,
PointSourceType &point_source) const {
// For the source it is important to remember that we are using the
// mapped index to access the element and source indices
// that means that index actually is a mapped_chunk_index
// and we need to use index.ispec to access the element index
// and index.imap to access the source index
/* For the source it is important to remember that we are using the
* mapped index to access the element and source indices
* that means that index actually is a mapped_chunk_index
* and we need to use index.ispec to access the element index
* and index.imap to access the source index
*/
const auto index = iterator_index.index;
const auto isource = iterator_index.imap;
for (int component = 0; component < components; component++) {
point_source.stf(component) =
source_time_function(timestep, index.imap, component);
source_time_function(timestep, isource, component);
point_source.lagrange_interpolant(component) =
source_array(index.imap, component, index.iz, index.ix);
source_array(isource, component, index.iz, index.ix);
}
}

template <typename IndexType, typename PointSourceType>
template <typename IteratorIndexType, typename PointSourceType>
KOKKOS_INLINE_FUNCTION void
// For the source it is important to remember that we are using the
// mapped index to access the element and source indices
// that means that index actually is a mapped_chunk_index
// and we need to use index.ispec to access the element index
// and index.imap to access the source index
store_on_device(const int timestep, const IndexType index,
store_on_device(const int timestep, const IteratorIndexType &iterator_index,
const PointSourceType &point_source) const {
/* For the source it is important to remember that we are using the
* mapped index to access the element and source indices
* that means that index actually is a mapped_chunk_index
* and we need to use index.ispec to access the element index
* and index.imap to access the source index
*/
const auto index = iterator_index.index;
const auto isource = iterator_index.imap;
for (int component = 0; component < components; component++) {
source_time_function(timestep, index.imap, component) =
source_time_function(timestep, isource, component) =
point_source.stf(component);
source_array(index.imap, component, index.iz, index.ix) =
source_array(isource, component, index.iz, index.ix) =
point_source.lagrange_interpolant(component);
}
}
Expand Down
32 changes: 16 additions & 16 deletions include/compute/sources/sources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,15 @@ struct sources {
* @param sources Source information for the domain
* @param point_source Point source object to load source information into
*/
template <typename IndexType, typename PointSourceType>
template <typename IteratorIndexType, typename PointSourceType>
KOKKOS_INLINE_FUNCTION void
load_on_device(const IndexType index, const specfem::compute::sources &sources,
load_on_device(const IteratorIndexType iterator_index,
const specfem::compute::sources &sources,
PointSourceType &point_source) {

static_assert(IndexType::using_simd == false,
const auto index = iterator_index.index;

static_assert(index.using_simd == false,
"IndexType must not use SIMD when loading sources");

static_assert(
Expand All @@ -265,7 +268,7 @@ load_on_device(const IndexType index, const specfem::compute::sources &sources,
static_assert(PointSourceType::dimension == specfem::dimension::type::dim2,
"PointSourceType must be a 2D point source type");

static_assert(IndexType::dimension == specfem::dimension::type::dim2,
static_assert(index.dimension == specfem::dimension::type::dim2,
"IndexType must be a 2D index type");

static_assert(
Expand Down Expand Up @@ -295,16 +298,13 @@ load_on_device(const IndexType index, const specfem::compute::sources &sources,
}
#endif

IndexType lcoord = index;
// lcoord.ispec = index.imap;

#define SOURCE_MEDIUM_LOAD_ON_DEVICE(DIMENSION_TAG, MEDIUM_TAG) \
if constexpr (GET_TAG(DIMENSION_TAG) == specfem::dimension::type::dim2) { \
if constexpr (GET_TAG(MEDIUM_TAG) == PointSourceType::medium_tag) { \
sources \
.CREATE_VARIABLE_NAME(source, GET_NAME(DIMENSION_TAG), \
GET_NAME(MEDIUM_TAG)) \
.load_on_device(sources.timestep, lcoord, point_source); \
.load_on_device(sources.timestep, iterator_index, point_source); \
} \
}

Expand Down Expand Up @@ -412,12 +412,15 @@ void load_on_host(const IndexType index,
* @param point_source Point source object to load source information into
* @param sources Source information for the domain
*/
template <typename IndexType, typename PointSourceType>
template <typename IteratorIndexType, typename PointSourceType>
KOKKOS_INLINE_FUNCTION void
store_on_device(const IndexType index, const PointSourceType &point_source,
store_on_device(const IteratorIndexType iterator_index,
const PointSourceType &point_source,
const specfem::compute::sources &sources) {

static_assert(IndexType::using_simd == false,
const auto index = iterator_index.index;

static_assert(index.using_simd == false,
"IndexType must not use SIMD when storing sources");

static_assert(
Expand All @@ -427,7 +430,7 @@ store_on_device(const IndexType index, const PointSourceType &point_source,
static_assert(PointSourceType::dimension == specfem::dimension::type::dim2,
"PointSourceType must be a 2D point source type");

static_assert(IndexType::dimension == specfem::dimension::type::dim2,
static_assert(index.dimension == specfem::dimension::type::dim2,
"IndexType must be a 2D index type");

static_assert(
Expand All @@ -454,16 +457,13 @@ store_on_device(const IndexType index, const PointSourceType &point_source,
}
#endif

IndexType lcoord = index;
lcoord.ispec = sources.source_domain_index_mapping(index.ispec);

#define SOURCE_MEDIUM_STORE_ON_DEVICE(DIMENSION_TAG, MEDIUM_TAG) \
if constexpr (GET_TAG(DIMENSION_TAG) == specfem::dimension::type::dim2) { \
if constexpr (GET_TAG(MEDIUM_TAG) == PointSourceType::medium_tag) { \
sources \
.CREATE_VARIABLE_NAME(source, GET_NAME(DIMENSION_TAG), \
GET_NAME(MEDIUM_TAG)) \
.store_on_device(sources.timestep, lcoord, point_source); \
.store_on_device(sources.timestep, iterator_index, point_source); \
} \
}

Expand Down
15 changes: 10 additions & 5 deletions include/kokkos_kernels/impl/compute_source_interaction.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,23 @@ Kokkos::parallel_for(
break;
}

const auto iterator =
// This is a mapped_chunk iterator
const auto mapped_chunk_iterator =
policy.mapped_league_iterator(starting_element_index);

Kokkos::parallel_for(
Kokkos::TeamThreadRange(team, iterator.chunk_size()),
Kokkos::TeamThreadRange(team, mapped_chunk_iterator.chunk_size()),
[&](const int i) {
const auto mapped_iterator_index = iterator(i);
const auto mapped_index = mapped_iterator_index.index;
// mapped_chunk_index_type
const auto mapped_chunked_index = mapped_chunk_iterator(i);

// mapped_index is specfem::point::index
const auto mapped_index = mapped_chunked_index.index;

// need mapped_chunk_index here to get the imap=isource
PointSourcesType point_source;
specfem::compute::load_on_device(mapped_iterator_index, sources, point_source);
specfem::compute::load_on_device(mapped_chunked_index, sources,
point_source);

PointPropertiesType point_property;
specfem::compute::load_on_device(mapped_index, properties,
Expand Down
37 changes: 27 additions & 10 deletions include/policies/chunk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ struct chunk_index_type<false, DimensionType> {
const specfem::point::index<dimension> index)
: ielement(ielement), index(index){};
};

/**
* @brief Struct to store the index of a quadrature point generated by chunk
* policy.
*
* @tparam UseSIMD Indicates whether SIMD is used or not.
* @tparam DimensionType Dimension type of the elements within this iterator.
*/
template <bool UseSIMD, specfem::dimension::type DimensionType>
struct mapped_chunk_index_type
: public chunk_index_type<UseSIMD, DimensionType> {
using Base = chunk_index_type<UseSIMD, DimensionType>;
int imap; ///< Index of the mapped element

mapped_chunk_index_type(const int ielement,
const specfem::point::index<DimensionType> index,
const int imap)
: Base(ielement, index), imap(imap) {}
};

} // namespace impl

/**
Expand Down Expand Up @@ -232,23 +252,20 @@ template <typename ViewType, typename SIMD>
class mapped_chunk<ViewType, specfem::dimension::type::dim2, SIMD>
: public chunk<ViewType, specfem::dimension::type::dim2, SIMD> {
using Base = chunk<ViewType, specfem::dimension::type::dim2, SIMD>;
using mapped_index_type =
typename impl::mapped_chunk_index_type<Base::using_simd,
Base::dimension>; ///< Index

public:
mapped_chunk(const ViewType &indices, const ViewType &mapping,
const int ngllz, const int ngllx)
: Base(indices, ngllz, ngllx), mapping(mapping) {}

KOKKOS_INLINE_FUNCTION
const int imap(const int i) const {
#ifdef KOKKOS_ENABLE_CUDA
const int ielement = i % num_elements;
return mapping(ielement);
#else
const int ix = i % Base::ngllx;
const int iz = (i / Base::ngllx) % Base::ngllz;
const int ielement = i / (Base::ngllz * Base::ngllx);
return mapping(ielement);
#endif
mapped_index_type operator()(const int i) const {
const auto base_index = Base::operator()(i);
return mapped_index_type(base_index.ielement, base_index.index,
mapping(base_index.ielement));
}

private:
Expand Down
16 changes: 12 additions & 4 deletions tests/unit-tests/assembly/sources/sources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "enumerations/medium.hpp"
#include "enumerations/wavefield.hpp"
#include "point/sources.hpp"
#include "policies/chunk.hpp"
#include "gtest/gtest.h"
#include <Kokkos_Core.hpp>

Expand Down Expand Up @@ -43,17 +44,23 @@ void check_store(specfem::compute::assembly &assembly) {

Kokkos::deep_copy(values_to_store, h_values_to_store);

using PointType = specfem::point::source<Dimension, MediumTag, WavefieldType>;

using PointSourceType =
specfem::point::source<Dimension, MediumTag, WavefieldType>;
using mapped_chunk_index_type =
specfem::iterator::impl::mapped_chunk_index_type<
false, specfem::dimension::type::dim2>;
Kokkos::parallel_for(
"check_store_on_device",
Kokkos::MDRangePolicy<Kokkos::DefaultExecutionSpace, Kokkos::Rank<3> >(
{ 0, 0, 0 }, { nelements, ngllz, ngllx }),
KOKKOS_LAMBDA(const int &i, const int &iz, const int &ix) {
const int ielement = element_indices(i);
const int isource = source_indices(i);

const auto index =
specfem::point::index<Dimension, false>(ielement, iz, ix);
const auto mapped_iterator_index =
mapped_chunk_index_type(ielement, index, isource);
specfem::datatype::ScalarPointViewType<type_real, num_components, false>
stf;
specfem::datatype::ScalarPointViewType<type_real, num_components, false>
Expand All @@ -62,8 +69,9 @@ void check_store(specfem::compute::assembly &assembly) {
stf(ic) = values_to_store(i);
lagrange_interpolant(ic) = values_to_store(i);
}
PointType point(stf, lagrange_interpolant);
specfem::compute::store_on_device(index, point, sources);
PointSourceType point(stf, lagrange_interpolant);
specfem::compute::store_on_device(mapped_iterator_index, point,
sources);
});

Kokkos::fence();
Expand Down

0 comments on commit f7d8c2f

Please sign in to comment.