diff --git a/CMakeLists.txt b/CMakeLists.txt index 59131fac4f8..c5d13a5c0f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,7 +94,13 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj") endif() if(MINGW OR CYGWIN) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj") + if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # Otherwise, dynamic_cast to the class marked by final will be failed. + # https://reviews.llvm.org/D154658 should be relevant + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-assume-unique-vtables") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj") + endif() endif() # For now, PGI/NVHPC nvc++ compiler doesn't seem to support diff --git a/common/cuda_hip/solver/batch_bicgstab_launch.hpp b/common/cuda_hip/solver/batch_bicgstab_launch.hpp index 696e11b5899..3886c33bcd5 100644 --- a/common/cuda_hip/solver/batch_bicgstab_launch.hpp +++ b/common/cuda_hip/solver/batch_bicgstab_launch.hpp @@ -11,6 +11,7 @@ #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_bicgstab_kernels.hpp" +#include "core/solver/batch_dispatch.hpp" namespace gko { @@ -50,32 +51,28 @@ void launch_apply_kernel( device_type<_vtype>* const __restrict__ workspace_data, \ const int& block_size, const size_t& shared_size) -#define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS( \ - GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_0_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 0, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 0, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_1_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 1, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 1, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_2_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 2, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 2, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_3_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 3, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 3, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_4_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 4, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 4, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_5_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 5, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 5, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_6_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 6, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 6, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_7_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 7, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 7, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_8_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 8, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 8, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_9_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 9, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 9, false) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_9_TRUE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 9, true) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 9, true) } // namespace batch_bicgstab diff --git a/common/cuda_hip/solver/batch_cg_launch.hpp b/common/cuda_hip/solver/batch_cg_launch.hpp index fe5d96c8a21..4306dc2bfab 100644 --- a/common/cuda_hip/solver/batch_cg_launch.hpp +++ b/common/cuda_hip/solver/batch_cg_launch.hpp @@ -11,6 +11,7 @@ #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_cg_kernels.hpp" +#include "core/solver/batch_dispatch.hpp" namespace gko { @@ -50,24 +51,20 @@ void launch_apply_kernel( device_type<_vtype>* const __restrict__ workspace_data, \ const int& block_size, const size_t& shared_size) -#define GKO_INSTANTIATE_BATCH_CG_LAUNCH(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS(GKO_DECLARE_BATCH_CG_LAUNCH, \ - __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_0_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 0, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 0, false) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_1_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 1, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 1, false) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_2_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 2, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 2, false) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_3_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 3, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 3, false) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_4_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 4, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 4, false) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_5_FALSE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 5, false) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 5, false) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_5_TRUE \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 5, true) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 5, true) } // namespace batch_cg diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index ddd8937c44f..801ba46d248 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -10,11 +10,13 @@ set(config_source if(GINKGO_BUILD_MPI) list(APPEND config_source config/schwarz_config.cpp) endif() -# MSVC: To solve LNK1189, we separate the library as a workaround +# MSVC: LNK1189 issue +# CLANG in MSYS2 (MINGW): too many exported symbols +# We separate the library as a workaround to solve this issue # To make ginkgo still be the major library, we make the original to ginkgo_core in MSVC/shared # TODO: should think another way to solve it like dllexport or def file set(ginkgo_core "ginkgo") -if(MSVC AND BUILD_SHARED_LIBS) +if((MSVC OR MINGW) AND BUILD_SHARED_LIBS) set(ginkgo_core "ginkgo_core") endif() @@ -142,8 +144,8 @@ if(GINKGO_BUILD_MPI) distributed/preconditioner/schwarz.cpp) endif() -# MSVC/shared: make ginkgo be the major library -if(MSVC AND BUILD_SHARED_LIBS) +# MSVC or CLANG/msys2 with shared: make ginkgo be the major library +if((MSVC OR MINGW) AND BUILD_SHARED_LIBS) add_library(ginkgo "") target_sources(ginkgo PRIVATE ${config_source}) ginkgo_compile_features(ginkgo) @@ -161,7 +163,7 @@ ginkgo_compile_features(${ginkgo_core}) # add a namespace alias so Ginkgo can always be included as Ginkgo::ginkgo # regardless of whether it is installed or added as a subdirectory add_library(Ginkgo::ginkgo ALIAS ginkgo) -if(MSVC AND BUILD_SHARED_LIBS) +if((MSVC OR MINGW) AND BUILD_SHARED_LIBS) target_link_libraries(ginkgo PUBLIC ${ginkgo_core}) endif() target_link_libraries(${ginkgo_core} diff --git a/core/base/batch_instantiation.hpp b/core/base/batch_instantiation.hpp new file mode 100644 index 00000000000..dbcccefb469 --- /dev/null +++ b/core/base/batch_instantiation.hpp @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_PUBLIC_CORE_BASE_BATCH_INSTANTIATION_HPP_ +#define GKO_PUBLIC_CORE_BASE_BATCH_INSTANTIATION_HPP_ + +#include +#include +#include +#include +#include +#include + + +namespace gko { +namespace batch { + + +// just make the call list more consistent +#define GKO_CALL(_macro, ...) GKO_INDIRECT(_macro(__VA_ARGS__)) + +#define GKO_BATCH_INSTANTIATE_PRECONDITIONER(_next, ...) \ + GKO_INDIRECT(_next(__VA_ARGS__, gko::batch::matrix::Identity)); \ + GKO_INDIRECT(_next(__VA_ARGS__, gko::batch::preconditioner::Jacobi)) + +#define GKO_BATCH_INSTANTIATE_MATRIX(_next, ...) \ + GKO_INDIRECT(_next(__VA_ARGS__, gko::batch::matrix::Ell)); \ + GKO_INDIRECT(_next(__VA_ARGS__, gko::batch::matrix::Dense)); \ + GKO_INDIRECT(_next(__VA_ARGS__, gko::batch::matrix::Csr)) + +/** + * Instantiates a template for each valid combination of value type, batch + * matrix type, and batch preconditioner type. This only allows batch matrix + * type and preconditioner type also uses the same value type. + * + * @param args the first should be a macro which expands the template + * instantiation (not including the leading `template` specifier). + * Should take three arguments, where the first is replaced by the + * value type, the second by the matrix, and the third by the + * preconditioner. + * + * @note the second and third arguments only accept the base type.s + */ +#define GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER(...) \ + GKO_CALL(GKO_BATCH_INSTANTIATE_MATRIX, \ + GKO_BATCH_INSTANTIATE_PRECONDITIONER, \ + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS, __VA_ARGS__) + + +} // namespace batch +} // namespace gko + +#endif // GKO_PUBLIC_CORE_BASE_BATCH_INSTANTIATION_HPP_ diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 26de8531741..290b5afd907 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -5,6 +5,7 @@ #include #include +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_multi_vector_kernels.hpp" #include "core/base/device_matrix_data_kernels.hpp" #include "core/base/index_set_kernels.hpp" @@ -168,6 +169,13 @@ _macro(ValueType, ValueTypeKrylovBases) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_CB_GMRES_CONST_TYPE(_macro) +#define GKO_STUB_BATCH_VALUE_MATRIX_PRECONDITIONER(_declare, _wrapper) \ + template \ + _declare(ValueType, BatchMatrixType, PrecType) \ + GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER(_wrapper) + + namespace gko { namespace kernels { namespace GKO_HOOK_MODULE { @@ -421,7 +429,9 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DIAGONAL_FILL_IN_MATRIX_DATA_KERNEL); namespace batch_bicgstab { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_STUB_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL, + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER); } // namespace batch_bicgstab @@ -430,7 +440,9 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); namespace batch_cg { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_STUB_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL, + GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER); } // namespace batch_cg diff --git a/core/solver/batch_bicgstab.cpp b/core/solver/batch_bicgstab.cpp index c22c712b411..73fc0a2c852 100644 --- a/core/solver/batch_bicgstab.cpp +++ b/core/solver/batch_bicgstab.cpp @@ -7,8 +7,14 @@ #include #include #include +#include +#include +#include +#include +#include #include "core/base/batch_multi_vector_kernels.hpp" +#include "core/base/dispatch_helper.hpp" #include "core/solver/batch_bicgstab_kernels.hpp" @@ -45,14 +51,19 @@ void Bicgstab::solver_apply( const MultiVector* b, MultiVector* x, log::detail::log_data>* log_data) const { - using MVec = MultiVector; const kernels::batch_bicgstab::settings> settings{ this->max_iterations_, static_cast(this->residual_tol_), parameters_.tolerance_type}; auto exec = this->get_executor(); - exec->run(bicgstab::make_apply(settings, this->system_matrix_.get(), - this->preconditioner_.get(), b, x, - *log_data)); + + run, matrix::Csr, + matrix::Ell>(this->system_matrix_.get(), [&](auto matrix) { + run, preconditioner::Jacobi>( + this->preconditioner_.get(), [&](auto preconditioner) { + exec->run(bicgstab::make_apply(settings, matrix, preconditioner, + b, x, *log_data)); + }); + }); } diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index 615ed472597..2485e7e454e 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -174,19 +174,25 @@ storage_config compute_shared_storage(const int available_shared_mem, } // namespace batch_bicgstab -#define GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL(_type) \ +#define GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL(_type, _matrix, _prec) \ void apply( \ std::shared_ptr exec, \ const gko::kernels::batch_bicgstab::settings>& \ options, \ - const batch::BatchLinOp* a, const batch::BatchLinOp* preconditioner, \ + const _matrix* a, const _prec* preconditioner, \ const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \ gko::batch::log::detail::log_data>& logdata) +#define GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER(_vtype, _matrix, \ + _precond) \ + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL(_vtype, _matrix<_vtype>, \ + _precond<_vtype>) -#define GKO_DECLARE_ALL_AS_TEMPLATES \ - template \ - GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL(ValueType) + +#define GKO_DECLARE_ALL_AS_TEMPLATES \ + template \ + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL(ValueType, BatchMatrixType, \ + PrecType) GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(batch_bicgstab, diff --git a/core/solver/batch_cg.cpp b/core/solver/batch_cg.cpp index 0ab1ca8564f..13a5afffcaa 100644 --- a/core/solver/batch_cg.cpp +++ b/core/solver/batch_cg.cpp @@ -7,11 +7,16 @@ #include #include #include +#include +#include +#include +#include +#include #include "core/base/batch_multi_vector_kernels.hpp" +#include "core/base/dispatch_helper.hpp" #include "core/solver/batch_cg_kernels.hpp" - namespace gko { namespace batch { namespace solver { @@ -49,8 +54,17 @@ void Cg::solver_apply( this->max_iterations_, static_cast(this->residual_tol_), parameters_.tolerance_type}; auto exec = this->get_executor(); - exec->run(cg::make_apply(settings, this->system_matrix_.get(), - this->preconditioner_.get(), b, x, *log_data)); + + run, batch::matrix::Csr, + batch::matrix::Ell>( + this->system_matrix_.get(), [&](auto matrix) { + run, + batch::preconditioner::Jacobi>( + this->preconditioner_.get(), [&](auto preconditioner) { + exec->run(cg::make_apply(settings, matrix, preconditioner, + b, x, *log_data)); + }); + }); } diff --git a/core/solver/batch_cg_kernels.hpp b/core/solver/batch_cg_kernels.hpp index b21a2c07d3e..79e5e6c397d 100644 --- a/core/solver/batch_cg_kernels.hpp +++ b/core/solver/batch_cg_kernels.hpp @@ -162,19 +162,22 @@ storage_config compute_shared_storage(const int available_shared_mem, } // namespace batch_cg -#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_type) \ - void apply( \ - std::shared_ptr exec, \ - const gko::kernels::batch_cg::settings>& \ - options, \ - const batch::BatchLinOp* mat, const batch::BatchLinOp* preconditioner, \ - const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \ +#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_type, _matrix, _prec) \ + void apply( \ + std::shared_ptr exec, \ + const gko::kernels::batch_cg::settings>& \ + options, \ + const _matrix* mat, const _prec* preconditioner, \ + const batch::MultiVector<_type>* b, batch::MultiVector<_type>* x, \ gko::batch::log::detail::log_data>& logdata) +#define GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER(_vtype, _matrix, _precond) \ + GKO_DECLARE_BATCH_CG_APPLY_KERNEL(_vtype, _matrix<_vtype>, _precond<_vtype>) -#define GKO_DECLARE_ALL_AS_TEMPLATES \ - template \ - GKO_DECLARE_BATCH_CG_APPLY_KERNEL(ValueType) + +#define GKO_DECLARE_ALL_AS_TEMPLATES \ + template \ + GKO_DECLARE_BATCH_CG_APPLY_KERNEL(ValueType, BatchMatrixType, PrecType) GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(batch_cg, GKO_DECLARE_ALL_AS_TEMPLATES); diff --git a/core/solver/batch_dispatch.hpp b/core/solver/batch_dispatch.hpp index 5a37b12cf11..d76bc72d489 100644 --- a/core/solver/batch_dispatch.hpp +++ b/core/solver/batch_dispatch.hpp @@ -17,6 +17,7 @@ #include #include +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" @@ -164,35 +165,34 @@ enum class log_type { simple_convergence_completion }; } // namespace log -#define GKO_BATCH_INSTANTIATE_STOP(macro, ...) \ - macro(__VA_ARGS__, \ - ::gko::batch::solver::device::batch_stop::SimpleAbsResidual); \ - macro(__VA_ARGS__, \ - ::gko::batch::solver::device::batch_stop::SimpleRelResidual) - -#define GKO_BATCH_INSTANTIATE_PRECONDITIONER(macro, ...) \ - GKO_BATCH_INSTANTIATE_STOP( \ - macro, __VA_ARGS__, \ - ::gko::batch::solver::device::batch_preconditioner::Identity); \ - GKO_BATCH_INSTANTIATE_STOP( \ - macro, __VA_ARGS__, \ - ::gko::batch::solver::device::batch_preconditioner::ScalarJacobi); \ - GKO_BATCH_INSTANTIATE_STOP( \ - macro, __VA_ARGS__, \ - ::gko::batch::solver::device::batch_preconditioner::BlockJacobi) - -#define GKO_BATCH_INSTANTIATE_LOGGER(macro, ...) \ - GKO_BATCH_INSTANTIATE_PRECONDITIONER( \ - macro, __VA_ARGS__, \ - ::gko::batch::solver::device::batch_log::SimpleFinalLogger) - -#define GKO_BATCH_INSTANTIATE_MATRIX_VARGS(macro, ...) \ - GKO_BATCH_INSTANTIATE_LOGGER(macro, __VA_ARGS__, \ - batch::matrix::ell::uniform_batch); \ - GKO_BATCH_INSTANTIATE_LOGGER(macro, __VA_ARGS__, \ - batch::matrix::dense::uniform_batch); \ - GKO_BATCH_INSTANTIATE_LOGGER(macro, __VA_ARGS__, \ - batch::matrix::csr::uniform_batch) +#define GKO_BATCH_INSTANTIATE_STOP(_next, ...) \ + GKO_INDIRECT( \ + _next(__VA_ARGS__, \ + ::gko::batch::solver::device::batch_stop::SimpleAbsResidual)); \ + GKO_INDIRECT( \ + _next(__VA_ARGS__, \ + ::gko::batch::solver::device::batch_stop::SimpleRelResidual)) + +#define GKO_BATCH_INSTANTIATE_DEVICE_PRECONDITIONER(_next, ...) \ + GKO_INDIRECT( \ + _next(__VA_ARGS__, \ + ::gko::batch::solver::device::batch_preconditioner::Identity)); \ + GKO_INDIRECT(_next( \ + __VA_ARGS__, \ + ::gko::batch::solver::device::batch_preconditioner::ScalarJacobi)); \ + GKO_INDIRECT(_next( \ + __VA_ARGS__, \ + ::gko::batch::solver::device::batch_preconditioner::BlockJacobi)) + +#define GKO_BATCH_INSTANTIATE_LOGGER(_next, ...) \ + GKO_INDIRECT( \ + _next(__VA_ARGS__, \ + ::gko::batch::solver::device::batch_log::SimpleFinalLogger)) + +#define GKO_BATCH_INSTANTIATE_MATRIX_BATCH(_next, ...) \ + GKO_INDIRECT(_next(__VA_ARGS__, batch::matrix::ell::uniform_batch)); \ + GKO_INDIRECT(_next(__VA_ARGS__, batch::matrix::dense::uniform_batch)); \ + GKO_INDIRECT(_next(__VA_ARGS__, batch::matrix::csr::uniform_batch)) /** * Passes each valid configuration of batch solver template parameter to a @@ -201,21 +201,11 @@ enum class log_type { simple_convergence_completion }; * GKO_BATCH_INSTANTIATE will be prepended to the batch solver template * parameters. */ -#define GKO_BATCH_INSTANTIATE_VARGS(macro, ...) \ - GKO_BATCH_INSTANTIATE_MATRIX_VARGS(macro, __VA_ARGS__) - - -/** - * Passes each valid configuration of batch solver template parameter to a - * macro. The order of template parameters is: macro(, , - * , ) - */ -#define GKO_BATCH_INSTANTIATE_MATRIX(macro, ...) \ - GKO_BATCH_INSTANTIATE_LOGGER(macro, batch::matrix::ell::uniform_batch); \ - GKO_BATCH_INSTANTIATE_LOGGER(macro, batch::matrix::dense::uniform_batch); \ - GKO_BATCH_INSTANTIATE_LOGGER(macro, batch::matrix::csr::uniform_batch) - -#define GKO_BATCH_INSTANTIATE(macro) GKO_BATCH_INSTANTIATE_MATRIX(macro) +#define GKO_BATCH_INSTANTIATE(...) \ + GKO_CALL(GKO_BATCH_INSTANTIATE_MATRIX_BATCH, GKO_BATCH_INSTANTIATE_LOGGER, \ + GKO_BATCH_INSTANTIATE_DEVICE_PRECONDITIONER, \ + GKO_BATCH_INSTANTIATE_STOP, \ + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS, __VA_ARGS__) /** @@ -229,7 +219,8 @@ enum class log_type { simple_convergence_completion }; * @tparam SettingsType Structure type of options for the particular solver to * be used. */ -template +template class batch_solver_dispatch { public: using value_type = ValueType; @@ -238,7 +229,8 @@ class batch_solver_dispatch { batch_solver_dispatch( const KernelCaller& kernel_caller, const SettingsType& settings, - const BatchLinOp* const matrix, const BatchLinOp* const preconditioner, + const BatchMatrixType* const matrix, + const PrecType* const preconditioner, const log::detail::log_type logger_type = log::detail::log_type::simple_convergence_completion) : caller_{kernel_caller}, @@ -248,21 +240,21 @@ class batch_solver_dispatch { logger_type_{logger_type} {} - template + template void dispatch_on_stop( - const LogType& logger, const BatchMatrixType& mat_item, - PrecType precond, + const LogType& logger, const BatchMatrixEntry& mat_item, + PrecEntry precond, const multi_vector::uniform_batch& b_item, const multi_vector::uniform_batch& x_item) { if (settings_.tol_type == stop::tolerance_type::absolute) { caller_.template call_kernel< - BatchMatrixType, PrecType, + BatchMatrixEntry, PrecEntry, device::batch_stop::SimpleAbsResidual, LogType>(logger, mat_item, precond, b_item, x_item); } else if (settings_.tol_type == stop::tolerance_type::relative) { caller_.template call_kernel< - BatchMatrixType, PrecType, + BatchMatrixEntry, PrecEntry, device::batch_stop::SimpleRelResidual, LogType>(logger, mat_item, precond, b_item, x_item); } else { @@ -270,37 +262,37 @@ class batch_solver_dispatch { } } - template + template void dispatch_on_preconditioner( - const LogType& logger, const BatchMatrixType& mat_item, + const LogType& logger, const BatchMatrixEntry& mat_item, const multi_vector::uniform_batch& b_item, const multi_vector::uniform_batch& x_item) { - if (!precond_ || - dynamic_cast*>(precond_)) { + if constexpr (std::is_same_v>) { dispatch_on_stop( logger, mat_item, device::batch_preconditioner::Identity(), b_item, x_item); - } else if (auto prec = dynamic_cast< - const batch::preconditioner::Jacobi*>( - precond_)) { - const auto max_block_size = prec->get_max_block_size(); + } else if constexpr (std::is_same_v< + PrecType, + batch::preconditioner::Jacobi>) { + const auto max_block_size = precond_->get_max_block_size(); if (max_block_size == 1) { dispatch_on_stop(logger, mat_item, device::batch_preconditioner::ScalarJacobi< device_value_type>(), b_item, x_item); } else { - const auto num_blocks = prec->get_num_blocks(); - const auto block_ptrs_arr = prec->get_const_block_pointers(); + const auto num_blocks = precond_->get_num_blocks(); + const auto block_ptrs_arr = + precond_->get_const_block_pointers(); const auto row_block_map_arr = - prec->get_const_map_block_to_row(); + precond_->get_const_map_block_to_row(); const auto blocks_arr = reinterpret_cast>( - prec->get_const_blocks()); + precond_->get_const_blocks()); const auto blocks_cumul_storage = - prec->get_const_blocks_cumulative_offsets(); + precond_->get_const_blocks_cumulative_offsets(); dispatch_on_stop( logger, mat_item, @@ -315,9 +307,9 @@ class batch_solver_dispatch { } } - template + template void dispatch_on_logger( - const BatchMatrixType& amat, + const BatchMatrixEntry& amat, const multi_vector::uniform_batch& b_item, const multi_vector::uniform_batch& x_item, batch::log::detail::log_data& log_data) @@ -337,23 +329,8 @@ class batch_solver_dispatch { const multi_vector::uniform_batch& x_item, batch::log::detail::log_data& log_data) { - if (auto batch_mat = - dynamic_cast*>( - mat_)) { - auto mat_item = device::get_batch_struct(batch_mat); - dispatch_on_logger(mat_item, b_item, x_item, log_data); - } else if (auto batch_mat = - dynamic_cast*>( - mat_)) { - auto mat_item = device::get_batch_struct(batch_mat); - dispatch_on_logger(mat_item, b_item, x_item, log_data); - } else if (auto batch_mat = dynamic_cast< - const batch::matrix::Csr*>(mat_)) { - auto mat_item = device::get_batch_struct(batch_mat); - dispatch_on_logger(mat_item, b_item, x_item, log_data); - } else { - GKO_NOT_SUPPORTED(mat_); - } + auto mat_item = device::get_batch_struct(mat_); + dispatch_on_logger(mat_item, b_item, x_item, log_data); } /** @@ -375,8 +352,8 @@ class batch_solver_dispatch { private: const KernelCaller caller_; const SettingsType settings_; - const BatchLinOp* mat_; - const BatchLinOp* precond_; + const BatchMatrixType* mat_; + const PrecType* precond_; const log::detail::log_type logger_type_; }; @@ -384,14 +361,19 @@ class batch_solver_dispatch { /** * Convenient function to create a dispatcher. Infers most template arguments. */ -template -batch_solver_dispatch create_dispatcher( - const KernelCaller& kernel_caller, const SettingsType& settings, - const BatchLinOp* const matrix, const BatchLinOp* const preconditioner, - const log::detail::log_type logger_type = - log::detail::log_type::simple_convergence_completion) +template +batch_solver_dispatch +create_dispatcher(const KernelCaller& kernel_caller, + const SettingsType& settings, + const BatchMatrixType* const matrix, + const PrecType* const preconditioner, + const log::detail::log_type logger_type = + log::detail::log_type::simple_convergence_completion) { - return batch_solver_dispatch( + return batch_solver_dispatch( kernel_caller, settings, matrix, preconditioner, logger_type); } diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index eff6626de31..15c4d7560d9 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "core/test/utils/assertions.hpp" @@ -334,7 +335,8 @@ ResultWithLogData solve_linear_system( if (precond_factory) { precond = precond_factory->generate(sys.matrix); } else { - precond = nullptr; + precond = gko::batch::matrix::Identity::create( + exec, sys.matrix->get_size()); } solve_lambda(settings, precond.get(), sys.matrix.get(), sys.rhs.get(), diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index e8052637763..74d312c95ef 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -9,6 +9,7 @@ #include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" #include "common/cuda_hip/matrix/batch_struct.hpp" #include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp" +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" #include "cuda/solver/batch_bicgstab_launch.cuh" @@ -138,20 +139,21 @@ private: }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precon, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precond); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER); } // namespace batch_bicgstab diff --git a/cuda/solver/batch_bicgstab_launch.cuh b/cuda/solver/batch_bicgstab_launch.cuh index 737f2a923b0..b4e8753ccca 100644 --- a/cuda/solver/batch_bicgstab_launch.cuh +++ b/cuda/solver/batch_bicgstab_launch.cuh @@ -31,7 +31,7 @@ template exec, const int num_rows); -#define GKO_DECLARE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK_( \ +#define GKO_DECLARE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK( \ _vtype, mat_t, log_t, pre_t, stop_t) \ int get_num_threads_per_block< \ stop_t>, pre_t>, \ @@ -39,34 +39,24 @@ int get_num_threads_per_block(std::shared_ptr exec, cuda_type<_vtype>>(std::shared_ptr exec, \ const int num_rows) -#define GKO_INSTANTIATE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK_(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS( \ - GKO_DECLARE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK_, __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK \ - GKO_BATCH_INSTANTIATE( \ - GKO_INSTANTIATE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK_) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_GET_NUM_THREADS_PER_BLOCK) template int get_max_dynamic_shared_memory(std::shared_ptr exec); -#define GKO_DECLARE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY_( \ +#define GKO_DECLARE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY( \ _vtype, mat_t, log_t, pre_t, stop_t) \ int get_max_dynamic_shared_memory< \ stop_t>, pre_t>, \ log_t>, mat_t>, \ cuda_type<_vtype>>(std::shared_ptr exec) -#define GKO_INSTANTIATE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY_(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS( \ - GKO_DECLARE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY_, \ - __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY \ GKO_BATCH_INSTANTIATE( \ - GKO_INSTANTIATE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY_) + GKO_DECLARE_BATCH_BICGSTAB_GET_MAX_DYNAMIC_SHARED_MEMORY) } // namespace batch_bicgstab diff --git a/cuda/solver/batch_cg_kernels.cu b/cuda/solver/batch_cg_kernels.cu index e45e1baf03b..e1aec94852b 100644 --- a/cuda/solver/batch_cg_kernels.cu +++ b/cuda/solver/batch_cg_kernels.cu @@ -9,6 +9,7 @@ #include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" #include "common/cuda_hip/matrix/batch_struct.hpp" #include "common/cuda_hip/solver/batch_cg_kernels.hpp" +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" #include "cuda/solver/batch_cg_launch.cuh" @@ -119,20 +120,21 @@ private: }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precon, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precond); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER); } // namespace batch_cg diff --git a/cuda/solver/batch_cg_launch.cuh b/cuda/solver/batch_cg_launch.cuh index e803e15fe80..94d948cf202 100644 --- a/cuda/solver/batch_cg_launch.cuh +++ b/cuda/solver/batch_cg_launch.cuh @@ -31,41 +31,31 @@ template exec, const int num_rows); -#define GKO_DECLARE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK_(_vtype, mat_t, log_t, \ - pre_t, stop_t) \ - int get_num_threads_per_block< \ - stop_t>, pre_t>, \ - log_t>>, \ - mat_t>, cuda_type<_vtype>>( \ +#define GKO_DECLARE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK(_vtype, mat_t, log_t, \ + pre_t, stop_t) \ + int get_num_threads_per_block< \ + stop_t>, pre_t>, \ + log_t>>, \ + mat_t>, cuda_type<_vtype>>( \ std::shared_ptr exec, const int num_rows) -#define GKO_INSTANTIATE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK_(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS( \ - GKO_DECLARE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK_, __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK \ - GKO_BATCH_INSTANTIATE(GKO_INSTANTIATE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK_) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_GET_NUM_THREADS_PER_BLOCK) template int get_max_dynamic_shared_memory(std::shared_ptr exec); -#define GKO_DECLARE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY_( \ +#define GKO_DECLARE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY( \ _vtype, mat_t, log_t, pre_t, stop_t) \ int get_max_dynamic_shared_memory< \ stop_t>, pre_t>, \ log_t>, mat_t>, \ cuda_type<_vtype>>(std::shared_ptr exec) - -#define GKO_INSTANTIATE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY_(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS( \ - GKO_DECLARE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY_, __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY \ - GKO_BATCH_INSTANTIATE( \ - GKO_INSTANTIATE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY_) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_GET_MAX_DYNAMIC_SHARED_MEMORY) } // namespace batch_cg diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index c02ca02e1d8..e86eec5f21b 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -8,6 +8,7 @@ #include +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" @@ -168,10 +169,10 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precond, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) @@ -181,8 +182,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, logdata); } - -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER); } // namespace batch_bicgstab diff --git a/dpcpp/solver/batch_bicgstab_launch.hpp b/dpcpp/solver/batch_bicgstab_launch.hpp index 06ba8531b42..a9c78b9df45 100644 --- a/dpcpp/solver/batch_bicgstab_launch.hpp +++ b/dpcpp/solver/batch_bicgstab_launch.hpp @@ -53,34 +53,30 @@ void launch_apply_kernel( _vtype* const __restrict__ workspace_data, const int& block_size, \ const int& shared_size) -#define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS( \ - GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_0 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 0) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 0) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_1 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 1) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 1) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_2 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 2) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 2) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_3 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 3) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 3) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_4 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 4) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 4) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_5 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 5) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 5) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_6 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 6) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 6) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_7 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 7) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 7) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_8 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 8) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 8) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_9 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 9) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 9) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_10 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 32, 10) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 32, 10) #define GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH_10_16 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_BICGSTAB_LAUNCH, 16, 10) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_BICGSTAB_LAUNCH, 16, 10) } // namespace batch_bicgstab diff --git a/dpcpp/solver/batch_cg_kernels.dp.cpp b/dpcpp/solver/batch_cg_kernels.dp.cpp index d94019125b1..5ded4a53978 100644 --- a/dpcpp/solver/batch_cg_kernels.dp.cpp +++ b/dpcpp/solver/batch_cg_kernels.dp.cpp @@ -8,6 +8,7 @@ #include +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" @@ -145,10 +146,10 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precond, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) @@ -158,8 +159,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, logdata); } - -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER); } // namespace batch_cg diff --git a/dpcpp/solver/batch_cg_launch.hpp b/dpcpp/solver/batch_cg_launch.hpp index 3fe1e704963..c5f8e0d5dba 100644 --- a/dpcpp/solver/batch_cg_launch.hpp +++ b/dpcpp/solver/batch_cg_launch.hpp @@ -6,6 +6,7 @@ #include +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_cg_kernels.hpp" @@ -50,26 +51,22 @@ void launch_apply_kernel(std::shared_ptr exec, _vtype* const __restrict__ workspace_data, const int& block_size, \ const int& shared_size) -#define GKO_INSTANTIATE_BATCH_CG_LAUNCH(...) \ - GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS(GKO_DECLARE_BATCH_CG_LAUNCH, \ - __VA_ARGS__) - #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_0 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 0) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 0) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_1 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 1) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 1) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_2 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 2) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 2) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_3 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 3) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 3) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_4 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 4) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 4) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_5 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 5) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 5) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_6 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 32, 6) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 32, 6) #define GKO_INSTANTIATE_BATCH_CG_LAUNCH_6_16 \ - GKO_BATCH_INSTANTIATE_VARGS(GKO_INSTANTIATE_BATCH_CG_LAUNCH, 16, 6) + GKO_BATCH_INSTANTIATE(GKO_DECLARE_BATCH_CG_LAUNCH, 16, 6) } // namespace batch_cg diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 3e019fd3ad1..66d6130cfd0 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -10,6 +10,7 @@ #include "common/cuda_hip/matrix/batch_struct.hpp" #include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp" #include "common/cuda_hip/solver/batch_bicgstab_launch.hpp" +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" @@ -162,20 +163,21 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precon, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precond); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER); } // namespace batch_bicgstab diff --git a/hip/solver/batch_cg_kernels.hip.cpp b/hip/solver/batch_cg_kernels.hip.cpp index 7f6f7ffe1db..f36974aae06 100644 --- a/hip/solver/batch_cg_kernels.hip.cpp +++ b/hip/solver/batch_cg_kernels.hip.cpp @@ -10,6 +10,7 @@ #include "common/cuda_hip/matrix/batch_struct.hpp" #include "common/cuda_hip/solver/batch_cg_kernels.hpp" #include "common/cuda_hip/solver/batch_cg_launch.hpp" +#include "core/base/batch_instantiation.hpp" #include "core/base/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" @@ -144,20 +145,21 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precon, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precond); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER); } // namespace batch_cg diff --git a/include/ginkgo/core/base/types.hpp b/include/ginkgo/core/base/types.hpp index 4b06b494707..72dd8a93584 100644 --- a/include/ginkgo/core/base/types.hpp +++ b/include/ginkgo/core/base/types.hpp @@ -443,6 +443,8 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, // Helper macro to make Windows builds work +// In MSVC, __VA_ARGS__ behave like one argument by default. +// with this, we can expand the __VA_ARGS__ properly #define GKO_INDIRECT(...) __VA_ARGS__ diff --git a/omp/solver/batch_bicgstab_kernels.cpp b/omp/solver/batch_bicgstab_kernels.cpp index 5e069806f60..f8a4dbb8172 100644 --- a/omp/solver/batch_bicgstab_kernels.cpp +++ b/omp/solver/batch_bicgstab_kernels.cpp @@ -8,6 +8,7 @@ #include +#include "core/base/batch_instantiation.hpp" #include "core/solver/batch_dispatch.hpp" #include "reference/base/batch_multi_vector_kernels.hpp" #include "reference/matrix/batch_csr_kernels.hpp" @@ -77,10 +78,10 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precond, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) @@ -90,7 +91,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER); } // namespace batch_bicgstab diff --git a/omp/solver/batch_cg_kernels.cpp b/omp/solver/batch_cg_kernels.cpp index 0664c0244b6..26a7046a176 100644 --- a/omp/solver/batch_cg_kernels.cpp +++ b/omp/solver/batch_cg_kernels.cpp @@ -8,6 +8,7 @@ #include +#include "core/base/batch_instantiation.hpp" #include "core/solver/batch_dispatch.hpp" #include "reference/base/batch_multi_vector_kernels.hpp" #include "reference/matrix/batch_csr_kernels.hpp" @@ -83,10 +84,10 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precond, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, batch::log::detail::log_data>& logdata) @@ -96,7 +97,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER); } // namespace batch_cg diff --git a/reference/solver/batch_bicgstab_kernels.cpp b/reference/solver/batch_bicgstab_kernels.cpp index 5bc75c5ebdb..3f105f27c48 100644 --- a/reference/solver/batch_bicgstab_kernels.cpp +++ b/reference/solver/batch_bicgstab_kernels.cpp @@ -4,6 +4,7 @@ #include "core/solver/batch_bicgstab_kernels.hpp" +#include "core/base/batch_instantiation.hpp" #include "core/solver/batch_dispatch.hpp" #include "reference/base/batch_multi_vector_kernels.hpp" #include "reference/matrix/batch_csr_kernels.hpp" @@ -37,10 +38,10 @@ class kernel_caller { : exec_{std::move(exec)}, settings_{settings} {} - template void call_kernel( - const LogType& logger, const BatchMatrixType& mat, PrecType prec, + const LogType& logger, const BatchMatrixEntry& mat, PrecEntry prec, const gko::batch::multi_vector::uniform_batch& b, const gko::batch::multi_vector::uniform_batch& x) const { @@ -55,13 +56,13 @@ class kernel_caller { const size_type local_size_bytes = gko::kernels::batch_bicgstab::local_memory_requirement( num_rows, num_rhs) + - PrecType::dynamic_work_size(num_rows, - mat.get_single_item_num_nnz()); + PrecEntry::dynamic_work_size(num_rows, + mat.get_single_item_num_nnz()); array local_space(exec_, local_size_bytes); for (size_type batch_id = 0; batch_id < num_batch_items; batch_id++) { batch_single_kernels::batch_entry_bicgstab_impl< - StopType, PrecType, LogType, BatchMatrixType, ValueType>( + StopType, PrecEntry, LogType, BatchMatrixEntry, ValueType>( settings_, logger, prec, mat, b, x, batch_id, local_space.get_data()); } @@ -73,20 +74,21 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precon, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, - batch::log::detail::log_data>& log_data) + batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); - dispatcher.apply(b, x, log_data); + kernel_caller(exec, settings), settings, mat, precond); + dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL_WRAPPER); } // namespace batch_bicgstab diff --git a/reference/solver/batch_cg_kernels.cpp b/reference/solver/batch_cg_kernels.cpp index ba54329c31a..3acc49fc524 100644 --- a/reference/solver/batch_cg_kernels.cpp +++ b/reference/solver/batch_cg_kernels.cpp @@ -4,6 +4,7 @@ #include "core/solver/batch_cg_kernels.hpp" +#include "core/base/batch_instantiation.hpp" #include "core/solver/batch_dispatch.hpp" #include "reference/base/batch_multi_vector_kernels.hpp" #include "reference/matrix/batch_csr_kernels.hpp" @@ -11,7 +12,6 @@ #include "reference/matrix/batch_ell_kernels.hpp" #include "reference/solver/batch_cg_kernels.hpp" - namespace gko { namespace kernels { namespace reference { @@ -37,10 +37,10 @@ class kernel_caller { : exec_{std::move(exec)}, settings_{settings} {} - template void call_kernel( - const LogType& logger, const BatchMatrixType& mat, PrecType prec, + const LogType& logger, const BatchMatrixEntry& mat, PrecEntry prec, const gko::batch::multi_vector::uniform_batch& b, const gko::batch::multi_vector::uniform_batch& x) const { @@ -55,13 +55,13 @@ class kernel_caller { const size_type local_size_bytes = gko::kernels::batch_cg::local_memory_requirement( num_rows, num_rhs) + - PrecType::dynamic_work_size(num_rows, - mat.get_single_item_num_nnz()); + PrecEntry::dynamic_work_size(num_rows, + mat.get_single_item_num_nnz()); array local_space(exec_, local_size_bytes); for (size_type batch_id = 0; batch_id < num_batch_items; batch_id++) { batch_single_kernels::batch_entry_cg_impl< - StopType, PrecType, LogType, BatchMatrixType, ValueType>( + StopType, PrecEntry, LogType, BatchMatrixEntry, ValueType>( settings_, logger, prec, mat, b, x, batch_id, local_space.get_data()); } @@ -73,20 +73,21 @@ class kernel_caller { }; -template +template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* mat, const batch::BatchLinOp* precon, + const BatchMatrixType* mat, const PrecType* precond, const batch::MultiVector* b, batch::MultiVector* x, - batch::log::detail::log_data>& log_data) + batch::log::detail::log_data>& logdata) { auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); - dispatcher.apply(b, x, log_data); + kernel_caller(exec, settings), settings, mat, precond); + dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL_WRAPPER); } // namespace batch_cg diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index ddb6d09e12a..c7b36ba875c 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -14,9 +14,12 @@ #include #include #include +#include +#include #include #include "core/base/batch_utilities.hpp" +#include "core/base/dispatch_helper.hpp" #include "core/matrix/batch_dense_kernels.hpp" #include "core/test/utils.hpp" #include "core/test/utils/batch_helpers.hpp" @@ -49,9 +52,12 @@ class BatchBicgstab : public ::testing::Test { const gko::batch::BatchLinOp* prec, const Mtx* mtx, const MVec* b, MVec* x, LogData& log_data) { - gko::kernels::reference::batch_bicgstab::apply< - typename Mtx::value_type>(executor, opts, mtx, prec, b, x, - log_data); + gko::run, + gko::batch::preconditioner::Jacobi>( + prec, [&](auto preconditioner) { + gko::kernels::reference::batch_bicgstab::apply( + executor, opts, mtx, preconditioner, b, x, log_data); + }); }; } diff --git a/reference/test/solver/batch_cg_kernels.cpp b/reference/test/solver/batch_cg_kernels.cpp index 4ccabfb8849..86efa158fb5 100644 --- a/reference/test/solver/batch_cg_kernels.cpp +++ b/reference/test/solver/batch_cg_kernels.cpp @@ -14,9 +14,12 @@ #include #include #include +#include +#include #include #include "core/base/batch_utilities.hpp" +#include "core/base/dispatch_helper.hpp" #include "core/matrix/batch_dense_kernels.hpp" #include "core/test/utils.hpp" #include "core/test/utils/batch_helpers.hpp" @@ -49,8 +52,12 @@ class BatchCg : public ::testing::Test { const gko::batch::BatchLinOp* prec, const Mtx* mtx, const MVec* b, MVec* x, LogData& log_data) { - gko::kernels::reference::batch_cg::apply( - executor, opts, mtx, prec, b, x, log_data); + gko::run, + gko::batch::preconditioner::Jacobi>( + prec, [&](auto preconditioner) { + gko::kernels::reference::batch_cg::apply( + executor, opts, mtx, preconditioner, b, x, log_data); + }); }; } diff --git a/test/preconditioner/batch_jacobi_kernels.cpp b/test/preconditioner/batch_jacobi_kernels.cpp index 62e309361c9..fe013cee9aa 100644 --- a/test/preconditioner/batch_jacobi_kernels.cpp +++ b/test/preconditioner/batch_jacobi_kernels.cpp @@ -14,10 +14,12 @@ #include #include #include +#include #include #include #include +#include "core/base/dispatch_helper.hpp" #include "core/solver/batch_bicgstab_kernels.hpp" #include "core/test/utils.hpp" #include "core/test/utils/assertions.hpp" @@ -113,9 +115,13 @@ class BatchJacobi : public CommonTestFixture { const gko::batch::BatchLinOp* prec, const Mtx* mtx, const MVec* b, MVec* x, LogData& log_data) { - gko::kernels::GKO_DEVICE_NAMESPACE::batch_bicgstab::apply< - typename Mtx::value_type>(executor, settings, mtx, prec, b, x, - log_data); + gko::run, + gko::batch::preconditioner::Jacobi>( + prec, [&](auto preconditioner) { + gko::kernels::GKO_DEVICE_NAMESPACE::batch_bicgstab::apply( + executor, settings, mtx, preconditioner, b, x, + log_data); + }); }; solver_settings = Settings{max_iters, tol, gko::batch::stop::tolerance_type::relative}; diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index 1a852eacfe9..c5eb3996926 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -14,9 +14,12 @@ #include #include #include +#include +#include #include #include "core/base/batch_utilities.hpp" +#include "core/base/dispatch_helper.hpp" #include "core/matrix/batch_dense_kernels.hpp" #include "core/test/utils.hpp" #include "core/test/utils/batch_helpers.hpp" @@ -48,9 +51,13 @@ class BatchBicgstab : public CommonTestFixture { const gko::batch::BatchLinOp* prec, const Mtx* mtx, const MVec* b, MVec* x, LogData& log_data) { - gko::kernels::GKO_DEVICE_NAMESPACE::batch_bicgstab::apply< - typename Mtx::value_type>(executor, settings, mtx, prec, b, x, - log_data); + gko::run, + gko::batch::preconditioner::Jacobi>( + prec, [&](auto preconditioner) { + gko::kernels::GKO_DEVICE_NAMESPACE::batch_bicgstab::apply( + executor, settings, mtx, preconditioner, b, x, + log_data); + }); }; solver_settings = Settings{max_iters, tol, gko::batch::stop::tolerance_type::relative}; diff --git a/test/solver/batch_cg_kernels.cpp b/test/solver/batch_cg_kernels.cpp index 4c6de9004c9..582f26ec497 100644 --- a/test/solver/batch_cg_kernels.cpp +++ b/test/solver/batch_cg_kernels.cpp @@ -13,9 +13,12 @@ #include #include #include +#include +#include #include #include "core/base/batch_utilities.hpp" +#include "core/base/dispatch_helper.hpp" #include "core/matrix/batch_dense_kernels.hpp" #include "core/test/utils.hpp" #include "core/test/utils/batch_helpers.hpp" @@ -46,9 +49,13 @@ class BatchCg : public CommonTestFixture { const gko::batch::BatchLinOp* prec, const Mtx* mtx, const MVec* b, MVec* x, LogData& log_data) { - gko::kernels::GKO_DEVICE_NAMESPACE::batch_cg::apply< - typename Mtx::value_type>(executor, settings, mtx, prec, b, x, - log_data); + gko::run, + gko::batch::preconditioner::Jacobi>( + prec, [&](auto preconditioner) { + gko::kernels::GKO_DEVICE_NAMESPACE::batch_cg::apply( + executor, settings, mtx, preconditioner, b, x, + log_data); + }); }; solver_settings = Settings{max_iters, tol, gko::batch::stop::tolerance_type::relative};