Skip to content

Commit

Permalink
Disk checkpointing (#3812)
Browse files Browse the repository at this point in the history
* Disk checkpointing managing start, pause and continue checkpointing.

* AdjointDiskCheckpoint

---------

Co-authored-by: David A. Ham <[email protected]>
  • Loading branch information
Ig-dolci and dham authored Nov 21, 2024
1 parent 51bc8d6 commit 76d8daa
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
3 changes: 2 additions & 1 deletion firedrake/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
pause_annotation, continue_annotation, \
stop_annotating, annotate_tape # noqa F401
from pyadjoint.reduced_functional import ReducedFunctional # noqa F401
from pyadjoint.checkpointing import disk_checkpointing_callback # noqa F401
from firedrake.adjoint_utils.checkpointing import \
enable_disk_checkpointing, pause_disk_checkpointing, \
continue_disk_checkpointing, stop_disk_checkpointing, \
checkpointable_mesh # noqa F401
checkpointable_mesh # noqa F401
from firedrake.adjoint_utils import get_solve_blocks # noqa F401

from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401
Expand Down
15 changes: 13 additions & 2 deletions firedrake/adjoint_utils/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""A module providing support for disk checkpointing of the adjoint tape."""
from pyadjoint import get_working_tape, OverloadedType
from pyadjoint import get_working_tape, OverloadedType, disk_checkpointing_callback
from pyadjoint.tape import TapePackageData
from pyop2.mpi import COMM_WORLD
import tempfile
Expand All @@ -10,6 +10,8 @@
from numbers import Number
_enable_disk_checkpoint = False
_checkpoint_init_data = False
disk_checkpointing_callback["firedrake"] = "Please call enable_disk_checkpointing() "\
"before checkpointing on the disk."

__all__ = ["enable_disk_checkpointing", "disk_checkpointing",
"pause_disk_checkpointing", "continue_disk_checkpointing",
Expand Down Expand Up @@ -204,6 +206,12 @@ def restore_from_checkpoint(self, state):
self.init_checkpoint_file = state["init"]
self.current_checkpoint_file = state["current"]

def continue_checkpointing(self):
continue_disk_checkpointing()

def pause_checkpointing(self):
pause_disk_checkpointing()


def checkpointable_mesh(mesh):
"""Write a mesh to disk and read it back.
Expand Down Expand Up @@ -251,7 +259,7 @@ def restore(self):
pass


class CheckpointFunction(CheckpointBase):
class CheckpointFunction(CheckpointBase, OverloadedType):
"""Metadata for a Function checkpointed to disk.
An object of this class replaces the :class:`~firedrake.Function` stored as
Expand Down Expand Up @@ -304,6 +312,9 @@ def restore(self):
return type(function)(function.function_space(),
function.dat, name=self.name(), count=self.count)

def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint.restore()


def maybe_disk_checkpoint(function):
"""Checkpoint a Function to disk if disk checkpointing is active."""
Expand Down

0 comments on commit 76d8daa

Please sign in to comment.