diff --git a/python/test/fixtures.py b/python/test/fixtures.py index cca2f33ce..6682cb1e5 100644 --- a/python/test/fixtures.py +++ b/python/test/fixtures.py @@ -4,12 +4,8 @@ from functools import lru_cache as cache from pathlib import Path import subprocess -import atexit import inspect -_mpi_enabled = A.__config__["mpi"] -_mpi4py_enabled = A.__config__["mpi4py"] - # The API of `functools`'s caches went through a bunch of breaking changes from # 3.6 to 3.9. Patch them up in a local `cache` function. try: @@ -72,13 +68,22 @@ def repo_path(): return Path(__file__).parent.parent.parent -def _finalize_mpi(): - if _mpi4py_enabled: - from mpi4py import MPI - - MPI.Finalize() - else: - A.mpi_finalize() +def get_mpi_comm_world(): + """ + Obtain MPI_COMM_WORLD as --- in order --- + 1. MPI4PY.MPI.COMM_WORLD + 2. Arbor MPI + 3. None + """ + if A.config()["mpi"]: + if A.config()["mpi4py"]: + from mpi4py import MPI + return MPI.COMM_WORLD + else: + if not A.mpi_is_initialized(): + A.mpi_init() + return A.mpi_comm() + return None @_fixture @@ -86,19 +91,7 @@ def context(): """ Fixture that produces an MPI sensitive `A.context` """ - if _mpi_enabled: - if _mpi4py_enabled: - from mpi4py import MPI - - if not MPI.Is_initialized(): - MPI.Initialize() - atexit.register(_finalize_mpi) - return A.context(A.proc_allocation(), mpi=MPI.COMM_WORLD) - elif not A.mpi_is_initialized(): - A.mpi_init() - atexit.register(_finalize_mpi) - return A.context(A.proc_allocation(), mpi=A.mpi_comm()) - return A.context(A.proc_allocation()) + return A.context(mpi=get_mpi_comm_world()) class _BuildCatError(Exception): diff --git a/python/test/unit/test_diffusion.py b/python/test/unit/test_diffusion.py index 7e5405390..afbeffffc 100644 --- a/python/test/unit/test_diffusion.py +++ b/python/test/unit/test_diffusion.py @@ -77,11 +77,7 @@ def __init__(self, args): self.runtime = 5.00 * U.ms # runtime of the whole simulation in ms self.dt = 0.01 * U.ms # duration of one timestep in ms self.dev = 0.01 # accepted relative deviation for `assertAlmostEqual` - mpi = None - if A.config()["mpi"]: - from mpi4py import MPI - - mpi = MPI.COMM_WORLD + mpi = fixtures.get_mpi_comm_world() gpu_id = None if A.config()["gpu"]: if mpi: