Skip to content

Commit

Permalink
Clean-up MPI handling in fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenhater committed Jan 17, 2025
1 parent 1d0dfd1 commit 1b322fb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 29 deletions.
41 changes: 17 additions & 24 deletions python/test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
from functools import lru_cache as cache
from pathlib import Path
import subprocess
import atexit
import inspect

_mpi_enabled = A.__config__["mpi"]
_mpi4py_enabled = A.__config__["mpi4py"]

# The API of `functools`'s caches went through a bunch of breaking changes from
# 3.6 to 3.9. Patch them up in a local `cache` function.
try:
Expand Down Expand Up @@ -72,33 +68,30 @@ def repo_path():
return Path(__file__).parent.parent.parent


def _finalize_mpi():
if _mpi4py_enabled:
from mpi4py import MPI

MPI.Finalize()
else:
A.mpi_finalize()
def get_mpi_comm_world():
"""
Obtain MPI_COMM_WORLD as --- in order ---
1. MPI4PY.MPI.COMM_WORLD
2. Arbor MPI
3. None
"""
if A.config()["mpi"]:
if A.config()["mpi4py"]:
from mpi4py import MPI
return MPI.COMM_WORLD
else:
if not A.mpi_is_initialized():
A.mpi_init()
return A.mpi_comm()
return None


@_fixture
def context():
"""
Fixture that produces an MPI sensitive `A.context`
"""
if _mpi_enabled:
if _mpi4py_enabled:
from mpi4py import MPI

if not MPI.Is_initialized():
MPI.Initialize()
atexit.register(_finalize_mpi)
return A.context(A.proc_allocation(), mpi=MPI.COMM_WORLD)
elif not A.mpi_is_initialized():
A.mpi_init()
atexit.register(_finalize_mpi)
return A.context(A.proc_allocation(), mpi=A.mpi_comm())
return A.context(A.proc_allocation())
return A.context(mpi=get_mpi_comm_world())


class _BuildCatError(Exception):
Expand Down
6 changes: 1 addition & 5 deletions python/test/unit/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ def __init__(self, args):
self.runtime = 5.00 * U.ms # runtime of the whole simulation in ms
self.dt = 0.01 * U.ms # duration of one timestep in ms
self.dev = 0.01 # accepted relative deviation for `assertAlmostEqual`
mpi = None
if A.config()["mpi"]:
from mpi4py import MPI

mpi = MPI.COMM_WORLD
mpi = fixtures.get_mpi_comm_world()
gpu_id = None
if A.config()["gpu"]:
if mpi:
Expand Down

0 comments on commit 1b322fb

Please sign in to comment.