diff --git a/benchmark/sparse_blas/operations.cpp b/benchmark/sparse_blas/operations.cpp index 30f3b5a80fe..d967e2ccab9 100644 --- a/benchmark/sparse_blas/operations.cpp +++ b/benchmark/sparse_blas/operations.cpp @@ -11,6 +11,7 @@ #include "core/base/array_access.hpp" #include "core/factorization/elimination_forest.hpp" +#include "core/factorization/factorization_kernels.hpp" #include "core/factorization/symbolic.hpp" #include "core/matrix/csr_kernels.hpp" #include "core/matrix/csr_lookup.hpp" @@ -20,6 +21,7 @@ namespace { +GKO_REGISTER_OPERATION(symbolic_validate, factorization::symbolic_validate); GKO_REGISTER_OPERATION(build_lookup_offsets, csr::build_lookup_offsets); GKO_REGISTER_OPERATION(build_lookup, csr::build_lookup); GKO_REGISTER_OPERATION(benchmark_lookup, csr::benchmark_lookup); @@ -541,47 +543,11 @@ class LookupOperation : public BenchmarkOperation { bool validate_symbolic_factorization(const Mtx* input, const Mtx* factors) { - const auto host_exec = input->get_executor()->get_master(); - const auto host_input = gko::make_temporary_clone(host_exec, input); - const auto host_factors = gko::make_temporary_clone(host_exec, factors); - const auto num_rows = input->get_size()[0]; - const auto in_row_ptrs = host_input->get_const_row_ptrs(); - const auto in_cols = host_input->get_const_col_idxs(); - const auto factor_row_ptrs = host_factors->get_const_row_ptrs(); - const auto factor_cols = host_factors->get_const_col_idxs(); - std::unordered_set columns; - for (itype row = 0; row < num_rows; row++) { - const auto in_begin = in_cols + in_row_ptrs[row]; - const auto in_end = in_cols + in_row_ptrs[row + 1]; - const auto factor_begin = factor_cols + factor_row_ptrs[row]; - const auto factor_end = factor_cols + factor_row_ptrs[row + 1]; - columns.clear(); - // the factor needs to contain the original matrix - // plus the diagonal if that was missing - columns.insert(in_begin, in_end); - columns.insert(row); - for (auto col_it = factor_begin; col_it < factor_end; ++col_it) { - const auto col = *col_it; - if (col >= row) { - break; - } - const auto dep_begin = factor_cols + factor_row_ptrs[col]; - const auto dep_end = factor_cols + factor_row_ptrs[col + 1]; - // insert the upper triangular part of the row - const auto dep_diag = std::find(dep_begin, dep_end, col); - columns.insert(dep_diag, dep_end); - } - // the factor should contain exactly these columns, no more - if (factor_end - factor_begin != columns.size()) { - return false; - } - for (auto col_it = factor_begin; col_it < factor_end; ++col_it) { - if (columns.find(*col_it) == columns.end()) { - return false; - } - } - } - return true; + const auto exec = factors->get_executor(); + bool valid = false; + exec->run(make_symbolic_validate( + input, factors, gko::matrix::csr::build_lookup(factors), valid)); + return valid; }