Skip to content

Commit

Permalink
batch with half
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 7, 2024
1 parent c443777 commit b1f71bf
Show file tree
Hide file tree
Showing 49 changed files with 526 additions and 295 deletions.
13 changes: 7 additions & 6 deletions common/cuda_hip/base/batch_multi_vector_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL);


Expand All @@ -81,7 +81,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL);


Expand All @@ -101,7 +101,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
x_ub, y_ub, res_ub, [] __device__(auto val) { return val; });
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL);


Expand All @@ -121,7 +121,7 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); });
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL);


Expand All @@ -139,7 +139,7 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
x_ub, res_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL);


Expand All @@ -156,7 +156,8 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
x_ub, result_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
Expand Down
8 changes: 4 additions & 4 deletions common/cuda_hip/matrix/batch_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}


GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL);


Expand All @@ -72,7 +72,7 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, b_ub, beta_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL);


Expand All @@ -91,7 +91,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_SCALE_KERNEL);


Expand All @@ -110,7 +110,7 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, beta_ub, mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);


Expand Down
12 changes: 7 additions & 5 deletions common/cuda_hip/matrix/batch_dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
mat_ub, b_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL);


Expand All @@ -71,7 +71,7 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, b_ub, beta_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL);


Expand All @@ -90,7 +90,8 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);


template <typename ValueType>
Expand All @@ -108,7 +109,8 @@ void scale_add(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, in_out_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);


template <typename ValueType>
Expand All @@ -126,7 +128,7 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, beta_ub, mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);


Expand Down
8 changes: 4 additions & 4 deletions common/cuda_hip/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}


GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);


Expand All @@ -72,7 +72,7 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, b_ub, beta_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);


Expand All @@ -91,7 +91,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);


Expand All @@ -110,7 +110,7 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, beta_ub, mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);


Expand Down
27 changes: 24 additions & 3 deletions core/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ void MultiVector<ValueType>::compute_norm2(

template <typename ValueType>
void MultiVector<ValueType>::convert_to(
MultiVector<next_precision<ValueType>>* result) const
MultiVector<next_precision_with_half<ValueType>>* result) const
{
result->values_ = this->values_;
result->set_size(this->get_size());
Expand All @@ -290,14 +290,35 @@ void MultiVector<ValueType>::convert_to(

template <typename ValueType>
void MultiVector<ValueType>::move_to(
MultiVector<next_precision<ValueType>>* result)
MultiVector<next_precision_with_half<ValueType>>* result)
{
this->convert_to(result);
}


#if GINKGO_ENABLE_HALF
template <typename ValueType>
void MultiVector<ValueType>::convert_to(
MultiVector<next_precision_with_half<next_precision_with_half<ValueType>>>*
result) const
{
result->values_ = this->values_;
result->set_size(this->get_size());
}


template <typename ValueType>
void MultiVector<ValueType>::move_to(
MultiVector<next_precision_with_half<next_precision_with_half<ValueType>>>*
result)
{
this->convert_to(result);
}
#endif


#define GKO_DECLARE_BATCH_MULTI_VECTOR(_type) class MultiVector<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR);


} // namespace batch
Expand Down
57 changes: 34 additions & 23 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,15 @@ GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL);
namespace batch_multi_vector {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
Expand All @@ -355,10 +358,13 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);
namespace batch_csr {


GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CSR_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);


} // namespace batch_csr
Expand All @@ -367,11 +373,12 @@ GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);
namespace batch_dense {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);


} // namespace batch_dense
Expand All @@ -380,10 +387,13 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);
namespace batch_ell {


GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);


} // namespace batch_ell
Expand Down Expand Up @@ -506,7 +516,7 @@ GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(
namespace batch_bicgstab {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);


} // namespace batch_bicgstab
Expand All @@ -515,7 +525,7 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL);
namespace batch_cg {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CG_APPLY_KERNEL);


} // namespace batch_cg
Expand Down Expand Up @@ -916,9 +926,10 @@ namespace batch_jacobi {
GKO_STUB_INDEX_TYPE(
GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_CUMULATIVE_BLOCK_STORAGE);
GKO_STUB_INDEX_TYPE(GKO_DECLARE_BATCH_BLOCK_JACOBI_FIND_ROW_BLOCK_MAP);
GKO_STUB_VALUE_AND_INT32_TYPE(
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL);


} // namespace batch_jacobi
Expand Down
4 changes: 2 additions & 2 deletions core/log/batch_logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ log_data<ValueType>::log_data(std::shared_ptr<const Executor> exec,

#define GKO_DECLARE_LOG_DATA(_type) class log_data<_type>

GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(GKO_DECLARE_LOG_DATA);
GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF(GKO_DECLARE_LOG_DATA);

#undef GKO_DECLARE_LOG_DATA

Expand All @@ -92,7 +92,7 @@ void BatchConvergence<ValueType>::on_batch_solver_completed(


#define GKO_DECLARE_BATCH_CONVERGENCE(_type) class BatchConvergence<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CONVERGENCE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CONVERGENCE);


} // namespace log
Expand Down
29 changes: 26 additions & 3 deletions core/matrix/batch_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ void Csr<ValueType, IndexType>::add_scaled_identity(

template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::convert_to(
Csr<next_precision<ValueType>, IndexType>* result) const
Csr<next_precision_with_half<ValueType>, IndexType>* result) const
{
result->values_ = this->values_;
result->col_idxs_ = this->col_idxs_;
Expand All @@ -257,14 +257,37 @@ void Csr<ValueType, IndexType>::convert_to(

template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::move_to(
Csr<next_precision<ValueType>, IndexType>* result)
Csr<next_precision_with_half<ValueType>, IndexType>* result)
{
this->convert_to(result);
}


#if GINKGO_ENABLE_HALF
template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::convert_to(
Csr<next_precision_with_half<next_precision_with_half<ValueType>>,
IndexType>* result) const
{
result->values_ = this->values_;
result->col_idxs_ = this->col_idxs_;
result->row_ptrs_ = this->row_ptrs_;
result->set_size(this->get_size());
}


template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::move_to(
Csr<next_precision_with_half<next_precision_with_half<ValueType>>,
IndexType>* result)
{
this->convert_to(result);
}
#endif


#define GKO_DECLARE_BATCH_CSR_MATRIX(ValueType) class Csr<ValueType, int32>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CSR_MATRIX);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CSR_MATRIX);


} // namespace matrix
Expand Down
Loading

0 comments on commit b1f71bf

Please sign in to comment.