diff --git a/test/conftest.py b/test/conftest.py index 46229d581..0b817e341 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,4 @@ +from subprocess import check_call def pytest_runtest_teardown(item, nextitem): @@ -12,3 +13,84 @@ def pytest_runtest_teardown(item, nextitem): Kernel._cache.clear() TSFCKernel._cache.clear() JITModule._cache.clear() + + +def parallel(item): + """ + Run a test in parallel. + + NOTE: Copied from firedrake/src/firedrake/tests/conftest.py + + :arg item: The test item to run. + """ + from mpi4py import MPI + if MPI.COMM_WORLD.size > 1: + raise RuntimeError("Parallel test can't be run within parallel environment") + marker = item.get_closest_marker("parallel") + if marker is None: + raise RuntimeError("Parallel test doesn't have parallel marker") + nprocs = marker.kwargs.get("nprocs", 3) + if nprocs < 2: + raise RuntimeError("Need at least two processes to run parallel test") + + # Only spew tracebacks on rank 0. + # Run xfailing tests to ensure that errors are reported to calling process + call = [ + "mpiexec", "-n", "1", "python", "-m", "pytest", "--runxfail", "-s", "-q", + "%s::%s" % (item.fspath, item.name) + ] + call.extend([ + ":", "-n", "%d" % (nprocs - 1), "python", "-m", "pytest", "--runxfail", "--tb=no", "-q", + "%s::%s" % (item.fspath, item.name) + ]) + check_call(call) + + +def pytest_configure(config): + """ + Register an additional marker. + + NOTE: Copied from firedrake/src/firedrake/tests/conftest.py + """ + config.addinivalue_line( + "markers", + "parallel(nprocs): mark test to run in parallel on nprocs processors") + + +def pytest_runtest_setup(item): + """ + Special setup for parallel tests. + + NOTE: Copied from firedrake/src/firedrake/tests/conftest.py + """ + if item.get_closest_marker("parallel"): + from mpi4py import MPI + if MPI.COMM_WORLD.size > 1: + # Turn on source hash checking + from firedrake import parameters + from functools import partial + + def _reset(check): + parameters["pyop2_options"]["check_src_hashes"] = check + + # Reset to current value when test is cleaned up + item.addfinalizer(partial(_reset, + parameters["pyop2_options"]["check_src_hashes"])) + + parameters["pyop2_options"]["check_src_hashes"] = True + else: + # Blow away function arg in "master" process, to ensure + # this test isn't run on only one process. + item.obj = lambda *args, **kwargs: True + + +def pytest_runtest_call(item): + """ + Special call for parallel tests. + + NOTE: Copied from firedrake/src/firedrake/tests/conftest.py + """ + from mpi4py import MPI + if item.get_closest_marker("parallel") and MPI.COMM_WORLD.size == 1: + # Spawn parallel processes to run test + parallel(item) diff --git a/test/solver3d/test_baroclinic_mms.py b/test/solver3d/test_baroclinic_mms.py index e78e76382..96332a87f 100644 --- a/test/solver3d/test_baroclinic_mms.py +++ b/test/solver3d/test_baroclinic_mms.py @@ -3,6 +3,8 @@ """ from thetis import * from scipy import stats +import pytest + field_metadata['uv_full'] = dict(field_metadata['uv_3d']) @@ -341,6 +343,8 @@ def run(setup, refinement, polynomial_degree, do_export=True, **options): def run_convergence(setup, ref_list, saveplot=False, **options): """Runs test for a list of refinements and computes error convergence rate""" + if saveplot and COMM_WORLD.size > 1: + raise Exception('Cannot use matplotlib in parallel') polynomial_degree = options.get('polynomial_degree', 1) space_str = options.get('element_family') l2_err = [] @@ -417,6 +421,7 @@ def expected_rate(v): plt.savefig(imgfile, dpi=200, bbox_inches='tight') +@pytest.mark.parallel(nprocs=2) def test_baroclinic_mms(): run_convergence(setup5, [1, 2, 4], polynomial_degree=1, element_family='dg-dg', diff --git a/test/solver3d/test_barotropic_mes.py b/test/solver3d/test_barotropic_mes.py index dc277c980..dec648dca 100644 --- a/test/solver3d/test_barotropic_mes.py +++ b/test/solver3d/test_barotropic_mes.py @@ -5,6 +5,7 @@ """ from thetis import * from scipy import stats +import pytest def run(refinement=1, ncycles=2, **kwargs): @@ -98,6 +99,8 @@ def run(refinement=1, ncycles=2, **kwargs): def run_convergence(ref_list, saveplot=False, **options): """Runs test for a list of refinements and computes error convergence rate""" + if saveplot and COMM_WORLD.size > 1: + raise Exception('Cannot use matplotlib in parallel') polynomial_degree = options.get('polynomial_degree', 1) space_str = options.get('element_family') l2_err = [] @@ -161,6 +164,7 @@ def check_convergence(x_log, y_log, expected_slope, field_str, saveplot, ax): plt.savefig(imgfile, dpi=200, bbox_inches='tight') +@pytest.mark.parallel(nprocs=2) def test_standing_wave(): run_convergence([1, 2, 4, 8], polynomial_degree=1, element_family='dg-dg', @@ -170,4 +174,4 @@ def test_standing_wave(): if __name__ == '__main__': run_convergence([1, 2, 4, 6, 8, 10], polynomial_degree=1, element_family='dg-dg', - saveplot=True, no_exports=True) + saveplot=False, no_exports=True) diff --git a/test/swe2d/test_steady_state_channel.py b/test/swe2d/test_steady_state_channel.py index 3f1e5c812..6a1144f03 100644 --- a/test/swe2d/test_steady_state_channel.py +++ b/test/swe2d/test_steady_state_channel.py @@ -1,7 +1,9 @@ from thetis import * import math +import pytest +@pytest.mark.parallel(nprocs=2) def test_steady_state_channel(do_export=False): lx = 5e3 diff --git a/thetis/solver.py b/thetis/solver.py index d254f6bdc..ea0218f66 100644 --- a/thetis/solver.py +++ b/thetis/solver.py @@ -217,8 +217,9 @@ def compute_dt_2d(self, u_scale): l = inner(test, csize / u) * dx sp = { "snes_type": "ksponly", - "ksp_type": "gmres", - "pc_type": "ilu", + "ksp_type": "cg", + "pc_type": "bjacobi", + "sub_pc_type": "ilu", } solve(a == l, solution, solver_parameters=sp) dt = float(solution.dat.data.min()) diff --git a/thetis/utility.py b/thetis/utility.py index f129ed758..59b3741b3 100755 --- a/thetis/utility.py +++ b/thetis/utility.py @@ -543,8 +543,9 @@ def get_horizontal_elem_size_2d(sol2d): l = inner(test, sqrt(CellVolume(mesh))) * dx sp = { "snes_type": "ksponly", - "ksp_type": "gmres", - "pc_type": "ilu", + "ksp_type": "cg", + "pc_type": "bjacobi", + "sub_pc_type": "ilu", } solve(a == l, sol2d, solver_parameters=sp) @@ -919,8 +920,9 @@ def __init__(self, vorticity_2d, solver_obj, **kwargs): # Setup vorticity solver prob = LinearVariationalProblem(a, L, vorticity_2d) kwargs.setdefault('solver_parameters', { - "ksp_type": "gmres", - "pc_type": "ilu", + "ksp_type": "cg", + "pc_type": "bjacobi", + "sub_pc_type": "ilu", }) self.solver = LinearVariationalSolver(prob, **kwargs) @@ -1024,8 +1026,9 @@ def __init__(self, q_2d, uv_2d, w_2d, elev_2d, depth, dt, bnd_functions=None, so l_w = dot(self.w_2d + self.dt/rho_0*(self.q_2d/h_star), test_w)*dx prob_w = LinearVariationalProblem(a_w, l_w, self.w_2d) sp = { - "ksp_type": "gmres", - "pc_type": "ilu", + "ksp_type": "cg", + "pc_type": "bjacobi", + "sub_pc_type": "ilu", } self.solver_w = LinearVariationalSolver(prob_w, solver_parameters=sp)