Skip to content

Commit

Permalink
io: increment version to 3.0.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Feb 2, 2024
1 parent 09a0e8f commit 0aa8541
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
41 changes: 32 additions & 9 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,9 @@ class CheckpointFile(object):
One can also use different number of processes for saving and for loading.
"""

latest_version = '3.0.0'

def __init__(self, filename, mode, comm=COMM_WORLD):
self.viewer = ViewerHDF5()
self.filename = filename
Expand All @@ -530,7 +533,14 @@ def __init__(self, filename, mode, comm=COMM_WORLD):
assert self.commkey != MPI.COMM_NULL.py2f()
self._function_spaces = {}
self._function_load_utils = {}
self.opts = OptionsManager({"dm_plex_view_hdf5_storage_version": "2.1.0"}, "")
if mode in [PETSc.Viewer.FileMode.APPEND, PETSc.Viewer.FileMode.A, 'a']:
if self.has_attr("/", "dmplex_storage_version"):
version = self.get_attr("/", "dmplex_storage_version").decode()
else:
raise RuntimeError(f"Only files generated with CheckpointFile are supported: got an invalid file ({filename})")
else:
version = CheckpointFile.latest_version
self.opts = OptionsManager({"dm_plex_view_hdf5_storage_version": version}, "")
r"""DMPlex HDF5 version options."""

def __enter__(self):
Expand Down Expand Up @@ -657,12 +667,30 @@ def _save_mesh_topology(self, tmesh):
perm_is = tmesh._dm_renumbering
permutation_name = tmesh._permutation_name
if tmesh_name in self.require_group(self._path_to_topologies()):
version_str = self.opts.parameters['dm_plex_view_hdf5_storage_version']
version_major, version_minor, version_patch = tuple(int(ver) for ver in version_str.split('.'))
# Check if the global number of DMPlex points and
# the global sum of DMPlex cone sizes are consistent.
cell_dim = topology_dm.getDimension()
if version_major < 3:
path = os.path.join(self._path_to_topology(tmesh_name), "topology", "cells")
else:
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
cell_dim1 = self.get_attr(path, "cell_dim")
if cell_dim1 != cell_dim:
raise ValueError(f"Mesh ({tmesh_name}) already exists in {self.filename}, but the topological dimension is inconsistent: {cell_dim1} ({self.filename}) != {cell_dim} ({tmesh_name})")
order_array_size, ornt_array_size = dmcommon.compute_point_cone_global_sizes(topology_dm)
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
order_array_size1 = self.h5pyfile[path]["order"].size
ornt_array_size1 = self.h5pyfile[path]["orientation"].size
if version_major < 3:
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
order_array_size1 = self.h5pyfile[path]["order"].size
ornt_array_size1 = self.h5pyfile[path]["orientation"].size
else:
order_array_size1 = 0
ornt_array_size1 = 0
for d in range(cell_dim + 1):
path = os.path.join(self._path_to_topology(tmesh_name), "topology", "strata", str(d))
order_array_size1 += self.h5pyfile[path]["cone_sizes"].size
ornt_array_size1 += self.h5pyfile[path]["cones"].size
if order_array_size1 != order_array_size:
raise ValueError(f"Mesh ({tmesh_name}) already exists in {self.filename}, but the global number of DMPlex points is inconsistent: {order_array_size1} ({self.filename}) != {order_array_size} ({tmesh_name})")
if ornt_array_size1 != ornt_array_size:
Expand Down Expand Up @@ -1143,11 +1171,6 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
plex.setName(tmesh_name)
# Check format
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"),
os.path.join(path, "cones"),
os.path.join(path, "order"),
os.path.join(path, "orientation")]):
raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}")
format = ViewerHDF5.Format.HDF5_PETSC
self.viewer.pushFormat(format=format)
plex.distributionSetName(distribution_name)
Expand Down
32 changes: 32 additions & 0 deletions tests/output/test_io_backward_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from firedrake import *
from firedrake.mesh import make_mesh_from_coordinates
from firedrake.utils import IntType
import shutil


test_version = "2024_01_27"
Expand Down Expand Up @@ -276,3 +277,34 @@ def test_io_backward_compat_timestepping_load(version):
f_ = Function(V_)
_initialise_function(f_, _get_expr_timestepping(V_, i))
assert assemble(inner(f - f_, f - f_) * dx) < 1.e-16


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('version', ["2024_01_27"])
def test_io_backward_compat_timestepping_append(version, tmpdir):
filename = join(filedir, "_".join([basename, version, "timestepping" + ".h5"]))
copyname = join(str(tmpdir), "test_io_backward_compat_timestepping_append_dump.h5")
copyname = COMM_WORLD.bcast(copyname, root=0)
shutil.copyfile(filename, copyname)
with CheckpointFile(copyname, "r") as afile:
version = afile.opts.parameters['dm_plex_view_hdf5_storage_version']
assert version == CheckpointFile.latest_version
mesh = afile.load_mesh(mesh_name)
f = afile.load_function(mesh, func_name, idx=0)
V = f.function_space()
with CheckpointFile(copyname, 'a') as afile:
version = afile.opts.parameters['dm_plex_view_hdf5_storage_version']
assert version == '2.1.0'
for i in range(5, 10):
_initialise_function(f, _get_expr_timestepping(V, i))
afile.save_function(f, idx=i)
with CheckpointFile(copyname, "r") as afile:
version = afile.opts.parameters['dm_plex_view_hdf5_storage_version']
assert version == CheckpointFile.latest_version
for i in range(0, 10):
f = afile.load_function(mesh, func_name, idx=i)
V_ = f.function_space()
f_ = Function(V_)
_initialise_function(f_, _get_expr_timestepping(V_, i))
assert assemble(inner(f - f_, f - f_) * dx) < 1.e-16

0 comments on commit 0aa8541

Please sign in to comment.