Skip to content

Commit

Permalink
Merge pull request #278 from thetisproject/fix-ilu-params
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 authored Nov 27, 2021
2 parents 343c7eb + ba32653 commit d534fec
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 9 deletions.
82 changes: 82 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from subprocess import check_call


def pytest_runtest_teardown(item, nextitem):
Expand All @@ -12,3 +13,84 @@ def pytest_runtest_teardown(item, nextitem):
Kernel._cache.clear()
TSFCKernel._cache.clear()
JITModule._cache.clear()


def parallel(item):
"""
Run a test in parallel.
NOTE: Copied from firedrake/src/firedrake/tests/conftest.py
:arg item: The test item to run.
"""
from mpi4py import MPI
if MPI.COMM_WORLD.size > 1:
raise RuntimeError("Parallel test can't be run within parallel environment")
marker = item.get_closest_marker("parallel")
if marker is None:
raise RuntimeError("Parallel test doesn't have parallel marker")
nprocs = marker.kwargs.get("nprocs", 3)
if nprocs < 2:
raise RuntimeError("Need at least two processes to run parallel test")

# Only spew tracebacks on rank 0.
# Run xfailing tests to ensure that errors are reported to calling process
call = [
"mpiexec", "-n", "1", "python", "-m", "pytest", "--runxfail", "-s", "-q",
"%s::%s" % (item.fspath, item.name)
]
call.extend([
":", "-n", "%d" % (nprocs - 1), "python", "-m", "pytest", "--runxfail", "--tb=no", "-q",
"%s::%s" % (item.fspath, item.name)
])
check_call(call)


def pytest_configure(config):
"""
Register an additional marker.
NOTE: Copied from firedrake/src/firedrake/tests/conftest.py
"""
config.addinivalue_line(
"markers",
"parallel(nprocs): mark test to run in parallel on nprocs processors")


def pytest_runtest_setup(item):
"""
Special setup for parallel tests.
NOTE: Copied from firedrake/src/firedrake/tests/conftest.py
"""
if item.get_closest_marker("parallel"):
from mpi4py import MPI
if MPI.COMM_WORLD.size > 1:
# Turn on source hash checking
from firedrake import parameters
from functools import partial

def _reset(check):
parameters["pyop2_options"]["check_src_hashes"] = check

# Reset to current value when test is cleaned up
item.addfinalizer(partial(_reset,
parameters["pyop2_options"]["check_src_hashes"]))

parameters["pyop2_options"]["check_src_hashes"] = True
else:
# Blow away function arg in "master" process, to ensure
# this test isn't run on only one process.
item.obj = lambda *args, **kwargs: True


def pytest_runtest_call(item):
"""
Special call for parallel tests.
NOTE: Copied from firedrake/src/firedrake/tests/conftest.py
"""
from mpi4py import MPI
if item.get_closest_marker("parallel") and MPI.COMM_WORLD.size == 1:
# Spawn parallel processes to run test
parallel(item)
5 changes: 5 additions & 0 deletions test/solver3d/test_baroclinic_mms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from thetis import *
from scipy import stats
import pytest


field_metadata['uv_full'] = dict(field_metadata['uv_3d'])

Expand Down Expand Up @@ -341,6 +343,8 @@ def run(setup, refinement, polynomial_degree, do_export=True, **options):

def run_convergence(setup, ref_list, saveplot=False, **options):
"""Runs test for a list of refinements and computes error convergence rate"""
if saveplot and COMM_WORLD.size > 1:
raise Exception('Cannot use matplotlib in parallel')
polynomial_degree = options.get('polynomial_degree', 1)
space_str = options.get('element_family')
l2_err = []
Expand Down Expand Up @@ -417,6 +421,7 @@ def expected_rate(v):
plt.savefig(imgfile, dpi=200, bbox_inches='tight')


@pytest.mark.parallel(nprocs=2)
def test_baroclinic_mms():
run_convergence(setup5, [1, 2, 4],
polynomial_degree=1, element_family='dg-dg',
Expand Down
6 changes: 5 additions & 1 deletion test/solver3d/test_barotropic_mes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from thetis import *
from scipy import stats
import pytest


def run(refinement=1, ncycles=2, **kwargs):
Expand Down Expand Up @@ -98,6 +99,8 @@ def run(refinement=1, ncycles=2, **kwargs):

def run_convergence(ref_list, saveplot=False, **options):
"""Runs test for a list of refinements and computes error convergence rate"""
if saveplot and COMM_WORLD.size > 1:
raise Exception('Cannot use matplotlib in parallel')
polynomial_degree = options.get('polynomial_degree', 1)
space_str = options.get('element_family')
l2_err = []
Expand Down Expand Up @@ -161,6 +164,7 @@ def check_convergence(x_log, y_log, expected_slope, field_str, saveplot, ax):
plt.savefig(imgfile, dpi=200, bbox_inches='tight')


@pytest.mark.parallel(nprocs=2)
def test_standing_wave():
run_convergence([1, 2, 4, 8],
polynomial_degree=1, element_family='dg-dg',
Expand All @@ -170,4 +174,4 @@ def test_standing_wave():
if __name__ == '__main__':
run_convergence([1, 2, 4, 6, 8, 10],
polynomial_degree=1, element_family='dg-dg',
saveplot=True, no_exports=True)
saveplot=False, no_exports=True)
2 changes: 2 additions & 0 deletions test/swe2d/test_steady_state_channel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from thetis import *
import math
import pytest


@pytest.mark.parallel(nprocs=2)
def test_steady_state_channel(do_export=False):

lx = 5e3
Expand Down
5 changes: 3 additions & 2 deletions thetis/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ def compute_dt_2d(self, u_scale):
l = inner(test, csize / u) * dx
sp = {
"snes_type": "ksponly",
"ksp_type": "gmres",
"pc_type": "ilu",
"ksp_type": "cg",
"pc_type": "bjacobi",
"sub_pc_type": "ilu",
}
solve(a == l, solution, solver_parameters=sp)
dt = float(solution.dat.data.min())
Expand Down
15 changes: 9 additions & 6 deletions thetis/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,9 @@ def get_horizontal_elem_size_2d(sol2d):
l = inner(test, sqrt(CellVolume(mesh))) * dx
sp = {
"snes_type": "ksponly",
"ksp_type": "gmres",
"pc_type": "ilu",
"ksp_type": "cg",
"pc_type": "bjacobi",
"sub_pc_type": "ilu",
}
solve(a == l, sol2d, solver_parameters=sp)

Expand Down Expand Up @@ -919,8 +920,9 @@ def __init__(self, vorticity_2d, solver_obj, **kwargs):
# Setup vorticity solver
prob = LinearVariationalProblem(a, L, vorticity_2d)
kwargs.setdefault('solver_parameters', {
"ksp_type": "gmres",
"pc_type": "ilu",
"ksp_type": "cg",
"pc_type": "bjacobi",
"sub_pc_type": "ilu",
})
self.solver = LinearVariationalSolver(prob, **kwargs)

Expand Down Expand Up @@ -1024,8 +1026,9 @@ def __init__(self, q_2d, uv_2d, w_2d, elev_2d, depth, dt, bnd_functions=None, so
l_w = dot(self.w_2d + self.dt/rho_0*(self.q_2d/h_star), test_w)*dx
prob_w = LinearVariationalProblem(a_w, l_w, self.w_2d)
sp = {
"ksp_type": "gmres",
"pc_type": "ilu",
"ksp_type": "cg",
"pc_type": "bjacobi",
"sub_pc_type": "ilu",
}
self.solver_w = LinearVariationalSolver(prob_w, solver_parameters=sp)

Expand Down

0 comments on commit d534fec

Please sign in to comment.