Skip to content

Commit

Permalink
batch test with half
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 7, 2024
1 parent b1f71bf commit 9de27d8
Show file tree
Hide file tree
Showing 14 changed files with 56 additions and 29 deletions.
3 changes: 2 additions & 1 deletion core/test/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class MultiVector : public ::testing::Test {
std::unique_ptr<gko::matrix::Dense<value_type>> dense_mtx;
};

TYPED_TEST_SUITE(MultiVector, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(MultiVector, gko::test::ValueTypesWithHalf,
TypenameNameGenerator);


TYPED_TEST(MultiVector, CanBeEmpty)
Expand Down
2 changes: 1 addition & 1 deletion core/test/matrix/batch_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Csr : public ::testing::Test {
std::unique_ptr<CsrMtx> sp_csr_mtx;
};

TYPED_TEST_SUITE(Csr, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Csr, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(Csr, KnowsItsSizeAndValues)
Expand Down
2 changes: 1 addition & 1 deletion core/test/matrix/batch_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Dense : public ::testing::Test {
std::unique_ptr<gko::matrix::Dense<value_type>> dense_mtx;
};

TYPED_TEST_SUITE(Dense, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Dense, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(Dense, KnowsItsSizeAndValues)
Expand Down
2 changes: 1 addition & 1 deletion core/test/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Ell : public ::testing::Test {
std::unique_ptr<EllMtx> sp_ell_mtx;
};

TYPED_TEST_SUITE(Ell, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Ell, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(Ell, KnowsItsSizeAndValues)
Expand Down
3 changes: 2 additions & 1 deletion core/test/matrix/batch_identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class Identity : public ::testing::Test {
std::unique_ptr<gko::batch::MultiVector<value_type>> mvec;
};

TYPED_TEST_SUITE(Identity, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Identity, gko::test::ValueTypesWithHalf,
TypenameNameGenerator);


TYPED_TEST(Identity, KnowsItsSizeAndValues)
Expand Down
3 changes: 2 additions & 1 deletion core/test/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class BatchBicgstab : public ::testing::Test {
std::unique_ptr<gko::batch::BatchLinOp> solver;
};

TYPED_TEST_SUITE(BatchBicgstab, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(BatchBicgstab, gko::test::ValueTypesWithHalf,
TypenameNameGenerator);


TYPED_TEST(BatchBicgstab, FactoryKnowsItsExecutor)
Expand Down
2 changes: 1 addition & 1 deletion core/test/solver/batch_cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class BatchCg : public ::testing::Test {
std::unique_ptr<gko::batch::BatchLinOp> solver;
};

TYPED_TEST_SUITE(BatchCg, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(BatchCg, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(BatchCg, FactoryKnowsItsExecutor)
Expand Down
2 changes: 1 addition & 1 deletion core/test/utils/batch_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ std::unique_ptr<MatrixType> generate_diag_dominant_batch_matrix(
static_cast<size_type>(num_cols)},
{}};
auto engine = std::default_random_engine(42);
auto rand_diag_dist = std::normal_distribution<real_type>(20.0, 1.0);
auto rand_diag_dist = std::normal_distribution<>(20.0, 1.0);
for (int row = 0; row < num_rows; ++row) {
std::uniform_int_distribution<index_type> rand_nnz_dist{1, row + 1};
const auto k = rand_nnz_dist(engine);
Expand Down
11 changes: 6 additions & 5 deletions reference/test/base/batch_multi_vector_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class MultiVector : public ::testing::Test {
std::default_random_engine rand_engine;
};

TYPED_TEST_SUITE(MultiVector, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(MultiVector, gko::test::ValueTypesWithHalf,
TypenameNameGenerator);


TYPED_TEST(MultiVector, ScalesData)
Expand Down Expand Up @@ -342,7 +343,7 @@ TYPED_TEST(MultiVector, ConvertsToPrecision)
{
using MultiVector = typename TestFixture::Mtx;
using T = typename TestFixture::value_type;
using OtherT = typename gko::next_precision<T>;
using OtherT = typename gko::next_precision_with_half<T>;
using OtherMultiVector = typename gko::batch::MultiVector<OtherT>;
auto tmp = OtherMultiVector::create(this->exec);
auto res = MultiVector::create(this->exec);
Expand All @@ -366,7 +367,7 @@ TYPED_TEST(MultiVector, MovesToPrecision)
{
using MultiVector = typename TestFixture::Mtx;
using T = typename TestFixture::value_type;
using OtherT = typename gko::next_precision<T>;
using OtherT = typename gko::next_precision_with_half<T>;
using OtherMultiVector = typename gko::batch::MultiVector<OtherT>;
auto tmp = OtherMultiVector::create(this->exec);
auto res = MultiVector::create(this->exec);
Expand All @@ -390,7 +391,7 @@ TYPED_TEST(MultiVector, ConvertsEmptyToPrecision)
{
using MultiVector = typename TestFixture::Mtx;
using T = typename TestFixture::value_type;
using OtherT = typename gko::next_precision<T>;
using OtherT = typename gko::next_precision_with_half<T>;
using OtherMultiVector = typename gko::batch::MultiVector<OtherT>;
auto empty = OtherMultiVector::create(this->exec);
auto res = MultiVector::create(this->exec);
Expand All @@ -405,7 +406,7 @@ TYPED_TEST(MultiVector, MovesEmptyToPrecision)
{
using MultiVector = typename TestFixture::Mtx;
using T = typename TestFixture::value_type;
using OtherT = typename gko::next_precision<T>;
using OtherT = typename gko::next_precision_with_half<T>;
using OtherMultiVector = typename gko::batch::MultiVector<OtherT>;
auto empty = OtherMultiVector::create(this->exec);
auto res = MultiVector::create(this->exec);
Expand Down
2 changes: 1 addition & 1 deletion reference/test/matrix/batch_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Csr : public ::testing::Test {
std::ranlux48 rand_engine;
};

TYPED_TEST_SUITE(Csr, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Csr, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(Csr, AppliesToBatchMultiVector)
Expand Down
2 changes: 1 addition & 1 deletion reference/test/matrix/batch_dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Dense : public ::testing::Test {
};


TYPED_TEST_SUITE(Dense, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Dense, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(Dense, AppliesToBatchMultiVector)
Expand Down
2 changes: 1 addition & 1 deletion reference/test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Ell : public ::testing::Test {
};


TYPED_TEST_SUITE(Ell, gko::test::ValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(Ell, gko::test::ValueTypesWithHalf, TypenameNameGenerator);


TYPED_TEST(Ell, AppliesToBatchMultiVector)
Expand Down
27 changes: 18 additions & 9 deletions reference/test/solver/batch_bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class BatchBicgstab : public ::testing::Test {
solve_lambda;
};

TYPED_TEST_SUITE(BatchBicgstab, gko::test::RealValueTypes,
TYPED_TEST_SUITE(BatchBicgstab, gko::test::RealValueTypesWithHalf,
TypenameNameGenerator);


Expand Down Expand Up @@ -105,8 +105,13 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsResidual)
ASSERT_LE(
res_log_array[i] / this->linear_system.host_rhs_norm->at(i, 0, 0),
this->solver_settings.residual_tol);
ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i],
10 * this->eps);
if (!std::is_same<real_type, gko::half>::value) {
// There is no guarantee of this condition. We disable this check in
// half.
ASSERT_NEAR(res_log_array[i],
res.host_res_norm->get_const_values()[i],
10 * this->eps);
}
}
}

Expand All @@ -125,7 +130,7 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsIterations)

auto iter_array = res.log_data->iter_counts.get_const_data();
for (size_t i = 0; i < this->num_batch_items; i++) {
ASSERT_EQ(iter_array[i], ref_iters);
ASSERT_LE(iter_array[i], ref_iters);
}
}

Expand All @@ -136,7 +141,7 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem)
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::Mtx;
const real_type tol = 1e-5;
const real_type tol = 1e-4;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
Expand All @@ -161,7 +166,7 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem)
for (size_t i = 0; i < num_batch_items; i++) {
ASSERT_LE(res.host_res_norm->get_const_values()[i] /
linear_system.host_rhs_norm->get_const_values()[i],
tol);
tol * 10);
}
}

Expand All @@ -173,7 +178,7 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters)
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::Mtx;
using Logger = gko::batch::log::BatchConvergence<value_type>;
const real_type tol = 1e-5;
const real_type tol = 1e-4;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
Expand Down Expand Up @@ -216,7 +221,7 @@ TYPED_TEST(BatchBicgstab, CanSolveEllSystem)
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::EllMtx;
const real_type tol = 1e-5;
const real_type tol = 1e-4;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
Expand Down Expand Up @@ -252,7 +257,7 @@ TYPED_TEST(BatchBicgstab, CanSolveCsrSystem)
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::CsrMtx;
const real_type tol = 1e-5;
const real_type tol = 1e-4;
const int max_iters = 1000;
auto solver_factory =
Solver::build()
Expand Down Expand Up @@ -288,6 +293,10 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem)
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::Mtx;
// Need to design a better random system. With different random value
// distribution, the solver can not solve the hpd matrix even with single
// precision
SKIP_IF_HALF(value_type);
const real_type tol = 1e-5;
const int max_iters = 1000;
auto solver_factory =
Expand Down
22 changes: 18 additions & 4 deletions reference/test/solver/batch_cg_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class BatchCg : public ::testing::Test {
solve_lambda;
};

TYPED_TEST_SUITE(BatchCg, gko::test::RealValueTypes, TypenameNameGenerator);
TYPED_TEST_SUITE(BatchCg, gko::test::RealValueTypesWithHalf,
TypenameNameGenerator);


TYPED_TEST(BatchCg, SolvesStencilSystem)
Expand All @@ -80,7 +81,7 @@ TYPED_TEST(BatchCg, SolvesStencilSystem)
for (size_t i = 0; i < this->num_batch_items; i++) {
ASSERT_LE(res.host_res_norm->get_const_values()[i] /
this->linear_system.host_rhs_norm->get_const_values()[i],
this->solver_settings.residual_tol);
5 * this->solver_settings.residual_tol);
}
GKO_ASSERT_BATCH_MTX_NEAR(res.x, this->linear_system.exact_sol,
this->eps * 10);
Expand All @@ -101,8 +102,13 @@ TYPED_TEST(BatchCg, StencilSystemLoggerLogsResidual)
ASSERT_LE(
res_log_array[i] / this->linear_system.host_rhs_norm->at(i, 0, 0),
this->solver_settings.residual_tol);
ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i],
10 * this->eps);
if (!std::is_same<real_type, gko::half>::value) {
// There is no guarantee of this condition. We disable this check in
// half.
ASSERT_NEAR(res_log_array[i],
res.host_res_norm->get_const_values()[i],
10 * this->eps);
}
}
}

Expand Down Expand Up @@ -133,6 +139,10 @@ TYPED_TEST(BatchCg, ApplyLogsResAndIters)
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::Mtx;
using Logger = gko::batch::log::BatchConvergence<value_type>;
// Need to design a better random system. With different random value
// distribution, the solver can not solve the hpd matrix even with single
// precision
SKIP_IF_HALF(value_type);
const real_type tol = 1e-6;
const int max_iters = 1000;
auto solver_factory =
Expand Down Expand Up @@ -174,6 +184,10 @@ TYPED_TEST(BatchCg, CanSolveHpdSystem)
using real_type = gko::remove_complex<value_type>;
using Solver = typename TestFixture::solver_type;
using Mtx = typename TestFixture::Mtx;
// Need to design a better random system. With different random value
// distribution, the solver can not solve the hpd matrix even with single
// precision
SKIP_IF_HALF(value_type);
const real_type tol = 1e-6;
const int max_iters = 1000;
auto solver_factory =
Expand Down

0 comments on commit 9de27d8

Please sign in to comment.