Skip to content

Commit

Permalink
fixups, make sure taping is disabled after each test
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Jan 17, 2025
1 parent 2a6c83a commit 5a56d8a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tests/firedrake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,13 @@ def pytest_collection_modifyitems(session, config, items):
@pytest.fixture(scope="module", autouse=True)
def check_empty_tape(request):
"""Check that the tape is empty at the end of each module"""
from pyadjoint.tape import get_working_tape
from pyadjoint.tape import annotate_tape, get_working_tape

def fin():
# make sure taping is switched off
assert not annotate_tape()

# make sure the tape is empty
tape = get_working_tape()
if tape is not None:
assert len(tape.get_blocks()) == 0
Expand Down
2 changes: 2 additions & 0 deletions tests/firedrake/demos/test_demos_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def test_serial_demo(demo, env, monkeypatch, tmpdir):
_exec_file(py_file)

if "adjoint" in demo.requirements:
pyadjoint.pause_annotation()
pyadjoint.get_working_tape().clear_tape()


Expand All @@ -164,4 +165,5 @@ def test_parallel_demo(demo, env, monkeypatch, tmpdir):
_exec_file(py_file)

if "adjoint" in demo.requirements:
pyadjoint.pause_annotation()
pyadjoint.get_working_tape().clear_tape()
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_restricted_function_space_extrusion_stokes(ncells):
sol_res = Function(W_res)
solve(a_res == L_res, sol_res, bcs=[bc_res])
# Compare.
assert assemble(inner(sol_res - sol, sol_res - sol) * dx)**0.5 < 1.e-15
assert assemble(inner(sol_res - sol, sol_res - sol) * dx)**0.5 < 1.e-14
# -- Actually, the ordering is the same.
assert np.allclose(sol_res.subfunctions[0].dat.data_ro_with_halos, sol.subfunctions[0].dat.data_ro_with_halos)
assert np.allclose(sol_res.subfunctions[1].dat.data_ro_with_halos, sol.subfunctions[1].dat.data_ro_with_halos)

0 comments on commit 5a56d8a

Please sign in to comment.