Skip to content

Commit

Permalink
_set_timestep and get_timesteps for in CheckpointFile (#3310)
Browse files Browse the repository at this point in the history
* adding set_timestep and get_timestep for checkpointing
Co-authored-by: Connor Ward <[email protected]>
Co-authored-by: ksagiyam <[email protected]>
  • Loading branch information
sghelichkhani authored Jan 31, 2024
1 parent 7d5ff0d commit d8bf62c
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 5 deletions.
127 changes: 122 additions & 5 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
PREFIX_EMBEDDED = "_".join([PREFIX, "embedded"])
r"""The prefix attached to the DG function resulting from projecting the original function to the embedding DG space."""

PREFIX_TIMESTEPPING = "_".join([PREFIX, "timestepping"])
r"""The prefix attached to the attributes associated with timestepping."""

PREFIX_TIMESTEPPING_HISTORY = "_".join([PREFIX_TIMESTEPPING, "history"])
r"""The prefix attached to the attributes associated with timestepping history."""

# This is the distribution_parameters and reorder that one must use when
# distribution and permutation are loaded.
Expand Down Expand Up @@ -707,6 +712,82 @@ def _save_mesh_topology(self, tmesh):
perm_is.setName(None)
self.viewer.popGroup()

@PETSc.Log.EventDecorator("GetTimesteps")
def get_timestepping_history(self, mesh, name):
"""
Retrieve the timestepping history and indices for a specified function within a mesh.
This method is primarily used in checkpointing scenarios during timestepping simulations.
It returns the indices associated with each function stored in the timestepping mode,
along with any additional timestepping-related information (like time or timestep values) if available.
If the specified function has not been stored in timestepping mode, it returns an empty dictionary.
Parameters
----------
mesh : firedrake.mesh.MeshGeometry
The mesh containing the function to be queried.
name : str
The name of the function whose timestepping history is to be retrieved.
Returns
-------
dict
- Returns an empty dictionary if the function `name` has not been stored in timestepping mode.
- If the function `name` is stored in timestepping mode, returns a dictionary with the following contents:
- 'indices': A list of all stored indices for the function.
- Additional key-value pairs representing timestepping information, if available.
Raises
------
RuntimeError
If the function `name` is not found within the given `mesh` in the current file.
See Also
--------
CheckpointFile.save_function : Describes how timestepping information should be provided.
Notes
-----
The function internally checks whether the specified function is mixed or exists in the file.
It then retrieves the appropriate data paths and extracts the timestepping information
as specified in the checkpoint file.
"""

# check if the function is mixed, or even exists in file
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
V_name = self._get_mixed_function_name_mixed_function_space_name_map(mesh.name)[name]
base_path = self._path_to_mixed_function(mesh.name, V_name, name)
path = os.path.join(base_path, str(0)) # path to subfunction 0
fsub_name = self.get_attr(path, PREFIX + "_function")
return self.get_timestepping_history(mesh, fsub_name)
elif name in self._get_function_name_function_space_name_map(self._get_mesh_name_topology_name_map()[mesh.name], mesh.name):
tmesh_name = self._get_mesh_name_topology_name_map()[mesh.name]
V_name = self._get_function_name_function_space_name_map(tmesh_name, mesh.name)[name]
V = self._load_function_space(mesh, V_name)
tV = V.topological
path = self._path_to_function(tmesh_name, mesh.name, V_name, name)
if PREFIX_EMBEDDED in self.h5pyfile[path]:
path = self._path_to_function_embedded(tmesh_name, mesh.name, V_name, name)
_name = self.get_attr(path, PREFIX_EMBEDDED + "_function")
return self.get_timestepping_history(mesh, _name)
else:
tf_name = self.get_attr(path, PREFIX + "_vec")
dm_name = self._get_dm_name_for_checkpointing(tV.mesh(), tV.ufl_element())
timestepping_info = {}
tpath = self._path_to_vec_timestepping(tV.mesh().name, dm_name, tf_name)
path = self._path_to_function_timestepping(tmesh_name, mesh.name, V_name, name)
if tpath in self.h5pyfile:
assert path in self.h5pyfile
timestepping_info["index"] = self.get_attr(tpath, PREFIX_TIMESTEPPING_HISTORY + "_index")
for key in self.h5pyfile[path].attrs.keys():
if key.startswith(PREFIX_TIMESTEPPING_HISTORY):
key_ = key.replace(PREFIX_TIMESTEPPING_HISTORY + "_", "", 1)
timestepping_info[key_] = self.get_attr(path, key)
return timestepping_info
else:
raise RuntimeError(
f"""Function ({name}) not found in {self.filename}""")

@PETSc.Log.EventDecorator("SaveFunctionSpace")
def _save_function_space(self, V):
mesh = V.mesh()
Expand Down Expand Up @@ -770,7 +851,7 @@ def _save_function_space_topology(self, tV):
topology_dm.setName(base_tmesh_name)

@PETSc.Log.EventDecorator("SaveFunction")
def save_function(self, f, idx=None, name=None):
def save_function(self, f, idx=None, name=None, timestepping_info={}):
r"""Save a :class:`~.Function`.
:arg f: the :class:`~.Function` to save.
Expand All @@ -780,12 +861,15 @@ def save_function(self, f, idx=None, name=None):
this method must always be called with the idx parameter
set or never be called with the idx parameter set.
:kwarg name: optional alternative name to save the function under.
:kwarg timestepping_info: optional (requires idx) additional information
such as time, timestepping that can be stored along a function for
each index.
"""
V = f.function_space()
mesh = V.mesh()
if name:
g = Function(V, val=f.dat, name=name)
return self.save_function(g, idx=idx)
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
# -- Save function space --
self._save_function_space(V)
# -- Save function --
Expand All @@ -797,7 +881,7 @@ def save_function(self, f, idx=None, name=None):
path = os.path.join(base_path, str(i))
self.require_group(path)
self.set_attr(path, PREFIX + "_function", fsub.name())
self.save_function(fsub, idx=idx)
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info)
self._update_mixed_function_name_mixed_function_space_name_map(mesh.name, {f.name(): V_name})
else:
tf = f.topological
Expand All @@ -815,14 +899,34 @@ def save_function(self, f, idx=None, name=None):
_name = "_".join([PREFIX_EMBEDDED, f.name()])
_f = Function(_V, name=_name)
self._project_function_for_checkpointing(_f, f, method)
self.save_function(_f, idx=idx)
self.save_function(_f, idx=idx, timestepping_info=timestepping_info)
self.set_attr(path, PREFIX_EMBEDDED + "_function", _name)
else:
# -- Save function topology --
path = self._path_to_function(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
self.set_attr(path, PREFIX + "_vec", tf.name())
self._save_function_topology(tf, idx=idx)
# store timstepping_info only if in timestepping mode
if idx is not None:
path = self._path_to_function_timestepping(tmesh.name, mesh.name, V_name, f.name())
new = path not in self.h5pyfile
self.require_group(path)
# We make sure the provided timestepping_info is consistent all along timestepping.
if not new:
existing_keys = {key.replace(PREFIX_TIMESTEPPING_HISTORY + "_", "", 1)
for key in self.h5pyfile[path].attrs.keys()
if key.startswith(PREFIX_TIMESTEPPING_HISTORY)}
if timestepping_info.keys() != existing_keys:
raise RuntimeError(
r"Provided keys in timestepping_info must remain consistent")
# store items in timestepping_info accordingly
for ts_info_key, ts_info_value in timestepping_info.items():
if not isinstance(ts_info_value, float):
raise NotImplementedError(f"timestepping_info must have float values: got {type(ts_info_value)}")
old_items = [] if new else self.get_attr(path, PREFIX_TIMESTEPPING_HISTORY + f"_{ts_info_key}")
items = np.concatenate((old_items, [ts_info_value]))
self.set_attr(path, PREFIX_TIMESTEPPING_HISTORY + f"_{ts_info_key}", items)

@PETSc.Log.EventDecorator("SaveFunctionTopology")
def _save_function_topology(self, tf, idx=None):
Expand All @@ -844,7 +948,8 @@ def _save_function_topology(self, tf, idx=None):
else:
topology_dm = tmesh.topology_dm
dm = self._get_dm_for_checkpointing(tV)
path = self._path_to_vec(tmesh.name, dm.name, tf.name())
dm_name = dm.name
path = self._path_to_vec(tmesh.name, dm_name, tf.name())
if path in self.h5pyfile:
try:
timestepping = self.get_attr(os.path.join(path, tf.name()), "timestepping")
Expand All @@ -863,6 +968,12 @@ def _save_function_topology(self, tf, idx=None):
topology_dm.setName(base_tmesh_name)
if idx is not None:
self.viewer.popTimestepping()
path = self._path_to_vec_timestepping(tmesh.name, dm_name, tf.name())
new = path not in self.h5pyfile
self.require_group(path)
old_indices = [] if new else self.get_attr(path, PREFIX_TIMESTEPPING_HISTORY + "_index")
indices = np.concatenate((old_indices, [idx]))
self.set_attr(path, PREFIX_TIMESTEPPING_HISTORY + "_index", indices)

@PETSc.Log.EventDecorator("LoadMesh")
def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameters=None, topology=None):
Expand Down Expand Up @@ -1333,6 +1444,9 @@ def _path_to_vecs(self, tmesh_name, dm_name):
def _path_to_vec(self, tmesh_name, dm_name, tf_name):
return os.path.join(self._path_to_vecs(tmesh_name, dm_name), tf_name)

def _path_to_vec_timestepping(self, tmesh_name, dm_name, tf_name):
return os.path.join(self._path_to_vec(tmesh_name, dm_name, tf_name), PREFIX_TIMESTEPPING)

def _path_to_meshes(self, tmesh_name):
return os.path.join(self._path_to_topology(tmesh_name), PREFIX + "_meshes")

Expand All @@ -1354,6 +1468,9 @@ def _path_to_functions(self, tmesh_name, mesh_name, V_name):
def _path_to_function(self, tmesh_name, mesh_name, V_name, function_name):
return os.path.join(self._path_to_functions(tmesh_name, mesh_name, V_name), function_name)

def _path_to_function_timestepping(self, tmesh_name, mesh_name, V_name, function_name):
return os.path.join(self._path_to_function(tmesh_name, mesh_name, V_name, function_name), PREFIX_TIMESTEPPING)

def _path_to_function_embedded(self, tmesh_name, mesh_name, V_name, function_name):
return os.path.join(self._path_to_function(tmesh_name, mesh_name, V_name, function_name), PREFIX_EMBEDDED)

Expand Down
41 changes: 41 additions & 0 deletions tests/output/test_io_timestepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ufl
import finat.ufl
import os
import numpy as np

cwd = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -74,3 +75,43 @@ def test_io_timestepping(element, tmpdir):
g = Function(V)
_project(g, _get_expr(V, i), method)
assert assemble(inner(g - f, g - f) * dx) < 1.e-16


def test_io_timestepping_setting_time(tmpdir):
filename = os.path.join(
str(tmpdir), "test_io_timestepping_setting_time_dump.h5")
filename = COMM_WORLD.bcast(filename, root=0)
mesh = UnitSquareMesh(5, 5)
RT2_space = VectorFunctionSpace(mesh, "RT", 2)
cg1_space = FunctionSpace(mesh, "CG", 1)
mixed_space = MixedFunctionSpace([RT2_space, cg1_space])
z = Function(mixed_space, name="z")
u, v = z.subfunctions
u.rename("u")
v.rename("v")

indices = range(0, 10, 2)
ts = np.linspace(0, 1.0, len(indices))*2*np.pi
timesteps = np.linspace(0, 1.0, len(indices))

with CheckpointFile(filename, mode="w") as f:
f.save_mesh(mesh)
for idx, t, timestep in zip(indices, ts, timesteps):
u.assign(t)
v.interpolate((cos(Constant(t)/pi)))
f.save_function(z, idx=idx, timestepping_info={"time": t, "timestep": timestep})

with CheckpointFile(filename, mode="r") as f:
mesh = f.load_mesh(name="firedrake_default")
timestepping_history = f.get_timestepping_history(mesh, name="u")
timestepping_history_z = f.get_timestepping_history(mesh, name="z")
loaded_v = f.load_function(mesh, "v", idx=timestepping_history.get("index")[-2])

for timesteppng_hist in [timestepping_history, timestepping_history_z]:
assert (indices == timestepping_history.get("index")).all()
assert (ts == timestepping_history.get("time")).all()
assert (timesteps == timestepping_history.get("timestep")).all()

# checking if the function is exactly what we think
v_answer = Function(loaded_v.function_space()).interpolate(cos(Constant(timestepping_history.get("time")[-2])/pi))
assert assemble((loaded_v - v_answer)**2 * dx) < 1.0e-16

0 comments on commit d8bf62c

Please sign in to comment.