From 76d8daa0132e5067bae17d05985401b64f443acd Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:58:32 +0000 Subject: [PATCH] Disk checkpointing (#3812) * Disk checkpointing managing start, pause and continue checkpointing. * AdjointDiskCheckpoint --------- Co-authored-by: David A. Ham --- firedrake/adjoint/__init__.py | 3 ++- firedrake/adjoint_utils/checkpointing.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index 8373a4285a..c48b990420 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -21,10 +21,11 @@ pause_annotation, continue_annotation, \ stop_annotating, annotate_tape # noqa F401 from pyadjoint.reduced_functional import ReducedFunctional # noqa F401 +from pyadjoint.checkpointing import disk_checkpointing_callback # noqa F401 from firedrake.adjoint_utils.checkpointing import \ enable_disk_checkpointing, pause_disk_checkpointing, \ continue_disk_checkpointing, stop_disk_checkpointing, \ - checkpointable_mesh # noqa F401 + checkpointable_mesh # noqa F401 from firedrake.adjoint_utils import get_solve_blocks # noqa F401 from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401 diff --git a/firedrake/adjoint_utils/checkpointing.py b/firedrake/adjoint_utils/checkpointing.py index 583d7155eb..0ab337e17b 100644 --- a/firedrake/adjoint_utils/checkpointing.py +++ b/firedrake/adjoint_utils/checkpointing.py @@ -1,5 +1,5 @@ """A module providing support for disk checkpointing of the adjoint tape.""" -from pyadjoint import get_working_tape, OverloadedType +from pyadjoint import get_working_tape, OverloadedType, disk_checkpointing_callback from pyadjoint.tape import TapePackageData from pyop2.mpi import COMM_WORLD import tempfile @@ -10,6 +10,8 @@ from numbers import Number _enable_disk_checkpoint = False _checkpoint_init_data = False +disk_checkpointing_callback["firedrake"] = "Please call enable_disk_checkpointing() "\ + "before checkpointing on the disk." __all__ = ["enable_disk_checkpointing", "disk_checkpointing", "pause_disk_checkpointing", "continue_disk_checkpointing", @@ -204,6 +206,12 @@ def restore_from_checkpoint(self, state): self.init_checkpoint_file = state["init"] self.current_checkpoint_file = state["current"] + def continue_checkpointing(self): + continue_disk_checkpointing() + + def pause_checkpointing(self): + pause_disk_checkpointing() + def checkpointable_mesh(mesh): """Write a mesh to disk and read it back. @@ -251,7 +259,7 @@ def restore(self): pass -class CheckpointFunction(CheckpointBase): +class CheckpointFunction(CheckpointBase, OverloadedType): """Metadata for a Function checkpointed to disk. An object of this class replaces the :class:`~firedrake.Function` stored as @@ -304,6 +312,9 @@ def restore(self): return type(function)(function.function_space(), function.dat, name=self.name(), count=self.count) + def _ad_restore_at_checkpoint(self, checkpoint): + return checkpoint.restore() + def maybe_disk_checkpoint(function): """Checkpoint a Function to disk if disk checkpointing is active."""